summaryrefslogtreecommitdiffhomepage
diff options
context:
space:
mode:
authorAndrei Vagin <avagin@google.com>2019-10-02 13:00:07 -0700
committerGitHub <noreply@github.com>2019-10-02 13:00:07 -0700
commit9a875306dbabcf335a2abccc08119a1b67d0e51a (patch)
tree0f72c12e951a5eee7156df7a5d63351bc89befa6
parent38bc0b6b6addd25ceec4f66ef1af41c1e61e2985 (diff)
parent03ce4dd86c9acd6b6148f68d5d2cf025d8c254bb (diff)
Merge branch 'master' into pr_syscall_linux
-rw-r--r--.bazelrc2
-rw-r--r--CONTRIBUTING.md32
-rw-r--r--Makefile2
-rw-r--r--README.md2
-rw-r--r--WORKSPACE84
-rw-r--r--cloudbuild/go.Dockerfile2
-rw-r--r--cloudbuild/go.yaml22
-rw-r--r--kokoro/build.cfg23
-rw-r--r--kokoro/build_tests.cfg1
-rw-r--r--kokoro/common.cfg2
-rw-r--r--kokoro/continuous.cfg13
-rw-r--r--kokoro/do_tests.cfg9
-rw-r--r--kokoro/docker_tests.cfg10
-rw-r--r--kokoro/go.cfg20
-rw-r--r--kokoro/go_tests.cfg1
-rw-r--r--kokoro/hostnet_tests.cfg10
-rw-r--r--kokoro/kvm_tests.cfg10
-rw-r--r--kokoro/make_tests.cfg9
-rw-r--r--kokoro/overlay_tests.cfg10
-rw-r--r--kokoro/presubmit.cfg13
-rw-r--r--kokoro/release-nightly.cfg11
-rw-r--r--kokoro/release.cfg1
-rw-r--r--kokoro/root_tests.cfg10
l---------kokoro/run_build.sh1
l---------kokoro/run_tests.sh1
-rw-r--r--kokoro/simple_tests.cfg9
-rw-r--r--kokoro/syscall_tests.cfg9
-rwxr-xr-xkokoro/ubuntu1604/10_core.sh30
-rwxr-xr-xkokoro/ubuntu1604/20_bazel.sh28
-rwxr-xr-xkokoro/ubuntu1604/25_docker.sh35
-rwxr-xr-xkokoro/ubuntu1604/30_containerd.sh76
-rwxr-xr-xkokoro/ubuntu1604/40_kokoro.sh54
-rwxr-xr-xkokoro/ubuntu1604/build.sh20
l---------kokoro/ubuntu1804/10_core.sh1
l---------kokoro/ubuntu1804/20_bazel.sh1
l---------kokoro/ubuntu1804/25_docker.sh1
l---------kokoro/ubuntu1804/30_containerd.sh1
l---------kokoro/ubuntu1804/40_kokoro.sh1
-rwxr-xr-xkokoro/ubuntu1804/build.sh20
-rw-r--r--pkg/abi/linux/BUILD5
-rw-r--r--pkg/abi/linux/signalfd.go45
-rw-r--r--pkg/amutex/BUILD3
-rw-r--r--pkg/atomicbitops/BUILD3
-rw-r--r--pkg/binary/BUILD3
-rw-r--r--pkg/bits/BUILD3
-rw-r--r--pkg/bpf/BUILD4
-rw-r--r--pkg/compressio/BUILD3
-rw-r--r--pkg/cpuid/BUILD4
-rw-r--r--pkg/cpuid/cpuid.go203
-rw-r--r--pkg/cpuid/cpuid_test.go24
-rw-r--r--pkg/eventchannel/BUILD3
-rw-r--r--pkg/fd/BUILD6
-rw-r--r--pkg/fd/fd.go8
-rw-r--r--pkg/fdchannel/BUILD3
-rw-r--r--pkg/flipcall/BUILD4
-rw-r--r--pkg/flipcall/ctrl_futex.go32
-rw-r--r--pkg/flipcall/flipcall.go72
-rw-r--r--pkg/flipcall/flipcall_example_test.go4
-rw-r--r--pkg/flipcall/flipcall_test.go204
-rw-r--r--pkg/flipcall/flipcall_unsafe.go28
-rw-r--r--pkg/flipcall/futex_linux.go23
-rw-r--r--pkg/fspath/BUILD3
-rw-r--r--pkg/gate/BUILD3
-rw-r--r--pkg/ilist/BUILD3
-rw-r--r--pkg/linewriter/BUILD3
-rw-r--r--pkg/log/BUILD3
-rw-r--r--pkg/metric/BUILD10
-rw-r--r--pkg/p9/BUILD6
-rw-r--r--pkg/p9/client.go273
-rw-r--r--pkg/p9/client_test.go50
-rw-r--r--pkg/p9/handlers.go53
-rw-r--r--pkg/p9/messages.go103
-rw-r--r--pkg/p9/p9.go2
-rw-r--r--pkg/p9/p9test/BUILD6
-rw-r--r--pkg/p9/p9test/client_test.go95
-rw-r--r--pkg/p9/p9test/p9test.go4
-rw-r--r--pkg/p9/server.go178
-rw-r--r--pkg/p9/transport.go5
-rw-r--r--pkg/p9/transport_flipcall.go263
-rw-r--r--pkg/p9/transport_test.go4
-rw-r--r--pkg/p9/version.go9
-rw-r--r--pkg/procid/BUILD3
-rw-r--r--pkg/refs/BUILD4
-rw-r--r--pkg/refs/refcounter.go6
-rw-r--r--pkg/seccomp/BUILD4
-rw-r--r--pkg/seccomp/seccomp_unsafe.go2
-rw-r--r--pkg/secio/BUILD3
-rw-r--r--pkg/segment/test/BUILD3
-rw-r--r--pkg/sentry/BUILD2
-rw-r--r--pkg/sentry/arch/BUILD7
-rw-r--r--pkg/sentry/control/BUILD3
-rw-r--r--pkg/sentry/device/BUILD4
-rw-r--r--pkg/sentry/fs/BUILD4
-rw-r--r--pkg/sentry/fs/dirent.go4
-rw-r--r--pkg/sentry/fs/dirent_refs_test.go2
-rw-r--r--pkg/sentry/fs/fdpipe/BUILD4
-rw-r--r--pkg/sentry/fs/file.go23
-rw-r--r--pkg/sentry/fs/file_operations.go9
-rw-r--r--pkg/sentry/fs/file_overlay.go9
-rw-r--r--pkg/sentry/fs/fsutil/BUILD4
-rw-r--r--pkg/sentry/fs/fsutil/file.go6
-rw-r--r--pkg/sentry/fs/fsutil/host_mappable.go2
-rw-r--r--pkg/sentry/fs/fsutil/inode_cached.go74
-rw-r--r--pkg/sentry/fs/fsutil/inode_cached_test.go14
-rw-r--r--pkg/sentry/fs/gofer/BUILD4
-rw-r--r--pkg/sentry/fs/gofer/fs.go22
-rw-r--r--pkg/sentry/fs/gofer/inode.go15
-rw-r--r--pkg/sentry/fs/gofer/session.go27
-rw-r--r--pkg/sentry/fs/host/BUILD4
-rw-r--r--pkg/sentry/fs/host/inode.go10
-rw-r--r--pkg/sentry/fs/host/tty.go15
-rw-r--r--pkg/sentry/fs/inotify.go5
-rw-r--r--pkg/sentry/fs/lock/BUILD4
-rw-r--r--pkg/sentry/fs/mounts.go3
-rw-r--r--pkg/sentry/fs/proc/BUILD5
-rw-r--r--pkg/sentry/fs/proc/net.go201
-rw-r--r--pkg/sentry/fs/proc/seqfile/BUILD4
-rw-r--r--pkg/sentry/fs/ramfs/BUILD4
-rw-r--r--pkg/sentry/fs/splice.go162
-rw-r--r--pkg/sentry/fs/timerfd/timerfd.go3
-rw-r--r--pkg/sentry/fs/tmpfs/BUILD4
-rw-r--r--pkg/sentry/fs/tty/BUILD4
-rw-r--r--pkg/sentry/fsimpl/ext/BUILD6
-rw-r--r--pkg/sentry/fsimpl/ext/benchmark/BUILD2
-rw-r--r--pkg/sentry/fsimpl/ext/directory.go8
-rw-r--r--pkg/sentry/fsimpl/ext/disklayout/BUILD4
-rw-r--r--pkg/sentry/fsimpl/ext/ext_test.go4
-rw-r--r--pkg/sentry/fsimpl/memfs/BUILD3
-rw-r--r--pkg/sentry/fsimpl/memfs/directory.go24
-rw-r--r--pkg/sentry/fsimpl/proc/BUILD3
-rw-r--r--pkg/sentry/hostcpu/BUILD4
-rw-r--r--pkg/sentry/hostcpu/getcpu_arm64.s28
-rw-r--r--pkg/sentry/kernel/BUILD10
-rw-r--r--pkg/sentry/kernel/epoll/BUILD4
-rw-r--r--pkg/sentry/kernel/eventfd/BUILD4
-rw-r--r--pkg/sentry/kernel/futex/BUILD4
-rw-r--r--pkg/sentry/kernel/kernel.go129
-rw-r--r--pkg/sentry/kernel/memevent/BUILD7
-rw-r--r--pkg/sentry/kernel/pipe/BUILD4
-rw-r--r--pkg/sentry/kernel/pipe/buffer.go25
-rw-r--r--pkg/sentry/kernel/pipe/pipe.go82
-rw-r--r--pkg/sentry/kernel/pipe/reader_writer.go76
-rw-r--r--pkg/sentry/kernel/posixtimer.go8
-rw-r--r--pkg/sentry/kernel/sched/BUILD3
-rw-r--r--pkg/sentry/kernel/semaphore/BUILD4
-rw-r--r--pkg/sentry/kernel/sessions.go8
-rw-r--r--pkg/sentry/kernel/signalfd/BUILD22
-rw-r--r--pkg/sentry/kernel/signalfd/signalfd.go137
-rw-r--r--pkg/sentry/kernel/task.go8
-rw-r--r--pkg/sentry/kernel/task_identity.go4
-rw-r--r--pkg/sentry/kernel/task_sched.go33
-rw-r--r--pkg/sentry/kernel/task_signals.go18
-rw-r--r--pkg/sentry/kernel/thread_group.go3
-rw-r--r--pkg/sentry/kernel/time/time.go27
-rw-r--r--pkg/sentry/limits/BUILD4
-rw-r--r--pkg/sentry/loader/elf.go20
-rw-r--r--pkg/sentry/loader/loader.go3
-rw-r--r--pkg/sentry/memmap/BUILD4
-rw-r--r--pkg/sentry/mm/BUILD4
-rw-r--r--pkg/sentry/pgalloc/BUILD4
-rw-r--r--pkg/sentry/platform/interrupt/BUILD3
-rw-r--r--pkg/sentry/platform/kvm/BUILD4
-rw-r--r--pkg/sentry/platform/ptrace/ptrace_unsafe.go16
-rw-r--r--pkg/sentry/platform/ptrace/subprocess.go20
-rw-r--r--pkg/sentry/platform/ptrace/subprocess_linux.go4
-rw-r--r--pkg/sentry/platform/ring0/pagetables/BUILD3
-rw-r--r--pkg/sentry/platform/safecopy/BUILD3
-rw-r--r--pkg/sentry/safemem/BUILD3
-rw-r--r--pkg/sentry/sighandling/sighandling_unsafe.go2
-rw-r--r--pkg/sentry/socket/epsocket/epsocket.go227
-rw-r--r--pkg/sentry/socket/epsocket/provider.go2
-rw-r--r--pkg/sentry/socket/netlink/port/BUILD4
-rw-r--r--pkg/sentry/socket/rpcinet/BUILD9
-rw-r--r--pkg/sentry/socket/unix/transport/unix.go82
-rw-r--r--pkg/sentry/strace/BUILD7
-rw-r--r--pkg/sentry/strace/linux64.go1
-rw-r--r--pkg/sentry/syscalls/linux/BUILD1
-rw-r--r--pkg/sentry/syscalls/linux/linux64.go2
-rw-r--r--pkg/sentry/syscalls/linux/sys_file.go5
-rw-r--r--pkg/sentry/syscalls/linux/sys_read.go33
-rw-r--r--pkg/sentry/syscalls/linux/sys_signal.go77
-rw-r--r--pkg/sentry/syscalls/linux/sys_socket.go2
-rw-r--r--pkg/sentry/syscalls/linux/sys_splice.go115
-rw-r--r--pkg/sentry/syscalls/linux/sys_time.go39
-rw-r--r--pkg/sentry/syscalls/linux/sys_utsname.go6
-rw-r--r--pkg/sentry/time/BUILD3
-rw-r--r--pkg/sentry/unimpl/BUILD7
-rw-r--r--pkg/sentry/usage/memory.go5
-rw-r--r--pkg/sentry/usermem/BUILD4
-rw-r--r--pkg/sentry/usermem/usermem.go8
-rw-r--r--pkg/sentry/vfs/BUILD3
-rw-r--r--pkg/sentry/vfs/file_description.go7
-rw-r--r--pkg/sleep/BUILD3
-rw-r--r--pkg/state/BUILD3
-rw-r--r--pkg/state/statefile/BUILD3
-rw-r--r--pkg/syserror/BUILD3
-rw-r--r--pkg/tcpip/BUILD4
-rw-r--r--pkg/tcpip/adapters/gonet/BUILD3
-rw-r--r--pkg/tcpip/adapters/gonet/gonet_test.go5
-rw-r--r--pkg/tcpip/buffer/BUILD4
-rw-r--r--pkg/tcpip/checker/checker.go100
-rw-r--r--pkg/tcpip/hash/jenkins/BUILD3
-rw-r--r--pkg/tcpip/header/BUILD4
-rw-r--r--pkg/tcpip/header/icmpv4.go71
-rw-r--r--pkg/tcpip/header/icmpv6.go88
-rw-r--r--pkg/tcpip/header/ipv4.go33
-rw-r--r--pkg/tcpip/header/ipv6.go55
-rw-r--r--pkg/tcpip/header/udp.go5
-rw-r--r--pkg/tcpip/link/channel/channel.go9
-rw-r--r--pkg/tcpip/link/fdbased/BUILD7
-rw-r--r--pkg/tcpip/link/fdbased/endpoint.go50
-rw-r--r--pkg/tcpip/link/fdbased/endpoint_test.go3
-rw-r--r--pkg/tcpip/link/fdbased/mmap.go179
-rw-r--r--pkg/tcpip/link/fdbased/mmap_amd64.go194
-rw-r--r--pkg/tcpip/link/fdbased/mmap_stub.go (renamed from test/runtimes/runtimes.go)15
-rw-r--r--pkg/tcpip/link/fdbased/mmap_unsafe.go (renamed from pkg/tcpip/link/fdbased/mmap_amd64_unsafe.go)2
-rw-r--r--pkg/tcpip/link/loopback/loopback.go7
-rw-r--r--pkg/tcpip/link/muxed/BUILD3
-rw-r--r--pkg/tcpip/link/muxed/injectable.go12
-rw-r--r--pkg/tcpip/link/muxed/injectable_test.go4
-rw-r--r--pkg/tcpip/link/rawfile/BUILD5
-rw-r--r--pkg/tcpip/link/rawfile/blockingpoll_arm64.s42
-rw-r--r--pkg/tcpip/link/rawfile/blockingpoll_noyield_unsafe.go (renamed from pkg/tcpip/link/rawfile/blockingpoll_unsafe.go)4
-rw-r--r--pkg/tcpip/link/rawfile/blockingpoll_yield_unsafe.go (renamed from pkg/tcpip/link/rawfile/blockingpoll_amd64_unsafe.go)8
-rw-r--r--pkg/tcpip/link/sharedmem/BUILD3
-rw-r--r--pkg/tcpip/link/sharedmem/pipe/BUILD3
-rw-r--r--pkg/tcpip/link/sharedmem/queue/BUILD3
-rw-r--r--pkg/tcpip/link/sharedmem/sharedmem.go11
-rw-r--r--pkg/tcpip/link/sharedmem/sharedmem_test.go4
-rw-r--r--pkg/tcpip/link/sniffer/sniffer.go21
-rw-r--r--pkg/tcpip/link/waitable/BUILD3
-rw-r--r--pkg/tcpip/link/waitable/waitable.go10
-rw-r--r--pkg/tcpip/link/waitable/waitable_test.go9
-rw-r--r--pkg/tcpip/network/BUILD2
-rw-r--r--pkg/tcpip/network/arp/BUILD3
-rw-r--r--pkg/tcpip/network/arp/arp.go20
-rw-r--r--pkg/tcpip/network/arp/arp_test.go15
-rw-r--r--pkg/tcpip/network/fragmentation/BUILD4
-rw-r--r--pkg/tcpip/network/ip_test.go25
-rw-r--r--pkg/tcpip/network/ipv4/BUILD3
-rw-r--r--pkg/tcpip/network/ipv4/icmp.go4
-rw-r--r--pkg/tcpip/network/ipv4/ipv4.go44
-rw-r--r--pkg/tcpip/network/ipv4/ipv4_test.go31
-rw-r--r--pkg/tcpip/network/ipv6/BUILD10
-rw-r--r--pkg/tcpip/network/ipv6/icmp.go61
-rw-r--r--pkg/tcpip/network/ipv6/icmp_test.go47
-rw-r--r--pkg/tcpip/network/ipv6/ipv6.go24
-rw-r--r--pkg/tcpip/network/ipv6/ipv6_test.go266
-rw-r--r--pkg/tcpip/network/ipv6/ndp_test.go181
-rw-r--r--pkg/tcpip/ports/BUILD3
-rw-r--r--pkg/tcpip/ports/ports.go161
-rw-r--r--pkg/tcpip/ports/ports_test.go169
-rw-r--r--pkg/tcpip/sample/tun_tcp_connect/main.go9
-rw-r--r--pkg/tcpip/sample/tun_tcp_echo/main.go9
-rw-r--r--pkg/tcpip/stack/BUILD33
-rw-r--r--pkg/tcpip/stack/icmp_rate_limit.go41
-rw-r--r--pkg/tcpip/stack/linkaddrcache.go253
-rw-r--r--pkg/tcpip/stack/linkaddrcache_test.go79
-rw-r--r--pkg/tcpip/stack/nic.go441
-rw-r--r--pkg/tcpip/stack/registration.go75
-rw-r--r--pkg/tcpip/stack/route.go15
-rw-r--r--pkg/tcpip/stack/stack.go261
-rw-r--r--pkg/tcpip/stack/stack_test.go921
-rw-r--r--pkg/tcpip/stack/transport_demuxer.go227
-rw-r--r--pkg/tcpip/stack/transport_demuxer_test.go352
-rw-r--r--pkg/tcpip/stack/transport_test.go75
-rw-r--r--pkg/tcpip/tcpip.go106
-rw-r--r--pkg/tcpip/transport/icmp/endpoint.go58
-rw-r--r--pkg/tcpip/transport/icmp/protocol.go38
-rw-r--r--pkg/tcpip/transport/raw/endpoint.go37
-rw-r--r--pkg/tcpip/transport/raw/protocol.go9
-rw-r--r--pkg/tcpip/transport/tcp/BUILD5
-rw-r--r--pkg/tcpip/transport/tcp/accept.go32
-rw-r--r--pkg/tcpip/transport/tcp/connect.go19
-rw-r--r--pkg/tcpip/transport/tcp/endpoint.go285
-rw-r--r--pkg/tcpip/transport/tcp/protocol.go24
-rw-r--r--pkg/tcpip/transport/tcp/snd.go9
-rw-r--r--pkg/tcpip/transport/tcp/tcp_noracedetector_test.go10
-rw-r--r--pkg/tcpip/transport/tcp/tcp_test.go392
-rw-r--r--pkg/tcpip/transport/tcp/testing/context/context.go52
-rw-r--r--pkg/tcpip/transport/tcpconntrack/BUILD3
-rw-r--r--pkg/tcpip/transport/udp/BUILD5
-rw-r--r--pkg/tcpip/transport/udp/endpoint.go151
-rw-r--r--pkg/tcpip/transport/udp/endpoint_state.go4
-rw-r--r--pkg/tcpip/transport/udp/forwarder.go4
-rw-r--r--pkg/tcpip/transport/udp/protocol.go113
-rw-r--r--pkg/tcpip/transport/udp/udp_test.go290
-rw-r--r--pkg/tmutex/BUILD3
-rw-r--r--pkg/unet/BUILD3
-rw-r--r--pkg/urpc/BUILD3
-rw-r--r--pkg/waiter/BUILD4
-rw-r--r--runsc/BUILD29
-rw-r--r--runsc/boot/BUILD2
-rw-r--r--runsc/boot/config.go63
-rw-r--r--runsc/boot/filter/config.go14
-rw-r--r--runsc/boot/fs.go159
-rw-r--r--runsc/boot/loader.go121
-rw-r--r--runsc/boot/loader_test.go17
-rw-r--r--runsc/boot/network.go18
-rw-r--r--runsc/boot/user.go28
-rw-r--r--runsc/boot/user_test.go15
-rw-r--r--runsc/cgroup/BUILD4
-rw-r--r--runsc/cmd/BUILD3
-rw-r--r--runsc/cmd/capability_test.go4
-rw-r--r--runsc/cmd/exec.go53
-rw-r--r--runsc/cmd/exec_test.go4
-rw-r--r--runsc/cmd/gofer.go5
-rw-r--r--runsc/cmd/install.go210
-rw-r--r--runsc/container/BUILD3
-rw-r--r--runsc/container/console_test.go2
-rw-r--r--runsc/container/container.go113
-rw-r--r--runsc/container/container_test.go74
-rw-r--r--runsc/container/multi_container_test.go71
-rw-r--r--runsc/container/shared_volume_test.go2
-rw-r--r--runsc/container/test_app/BUILD2
-rw-r--r--runsc/container/test_app/fds.go4
-rw-r--r--runsc/container/test_app/test_app.go67
-rw-r--r--runsc/criutil/BUILD12
-rw-r--r--runsc/criutil/criutil.go (renamed from runsc/test/testutil/crictl.go)13
-rwxr-xr-xrunsc/debian/postinst.sh6
-rw-r--r--runsc/dockerutil/BUILD15
-rw-r--r--runsc/dockerutil/dockerutil.go (renamed from runsc/test/testutil/docker.go)115
-rw-r--r--runsc/fsgofer/filter/BUILD1
-rw-r--r--runsc/fsgofer/filter/config.go49
-rw-r--r--runsc/fsgofer/filter/filter.go13
-rw-r--r--runsc/fsgofer/fsgofer.go70
-rw-r--r--runsc/fsgofer/fsgofer_test.go12
-rw-r--r--runsc/main.go25
-rw-r--r--runsc/sandbox/sandbox.go10
-rw-r--r--runsc/specutils/BUILD1
-rw-r--r--runsc/specutils/namespace.go14
-rw-r--r--runsc/specutils/specutils.go65
-rw-r--r--runsc/test/BUILD0
-rw-r--r--runsc/test/README.md24
-rw-r--r--runsc/test/build_defs.bzl19
-rwxr-xr-xrunsc/test/install.sh93
-rw-r--r--runsc/test/integration/exec_test.go161
-rw-r--r--runsc/testutil/BUILD (renamed from runsc/test/testutil/BUILD)11
-rw-r--r--runsc/testutil/testutil.go (renamed from runsc/test/testutil/testutil.go)80
-rw-r--r--runsc/tools/dockercfg/BUILD10
-rw-r--r--runsc/tools/dockercfg/dockercfg.go193
-rw-r--r--runsc/version.go2
-rwxr-xr-xrunsc/version_test.sh36
-rwxr-xr-xscripts/build.sh79
-rwxr-xr-xscripts/common.sh80
-rwxr-xr-xscripts/common_bazel.sh99
-rwxr-xr-xscripts/dev.sh73
-rwxr-xr-xscripts/do_tests.sh27
-rwxr-xr-xscripts/docker_tests.sh20
-rwxr-xr-xscripts/go.sh43
-rwxr-xr-xscripts/hostnet_tests.sh21
-rwxr-xr-xscripts/kvm_tests.sh28
-rwxr-xr-xscripts/make_tests.sh25
-rwxr-xr-xscripts/overlay_tests.sh21
-rwxr-xr-xscripts/release.sh38
-rwxr-xr-xscripts/root_tests.sh31
-rwxr-xr-xscripts/simple_tests.sh20
-rwxr-xr-xscripts/syscall_tests.sh20
-rw-r--r--test/README.md18
-rw-r--r--test/e2e/BUILD (renamed from runsc/test/integration/BUILD)13
-rw-r--r--test/e2e/exec_test.go256
-rw-r--r--test/e2e/integration.go (renamed from runsc/test/integration/integration.go)0
-rw-r--r--test/e2e/integration_test.go (renamed from runsc/test/integration/integration_test.go)46
-rw-r--r--test/e2e/regression_test.go (renamed from runsc/test/integration/regression_test.go)6
-rw-r--r--test/image/BUILD (renamed from runsc/test/image/BUILD)13
-rw-r--r--test/image/image.go (renamed from runsc/test/image/image.go)0
-rw-r--r--test/image/image_test.go (renamed from runsc/test/image/image_test.go)65
-rw-r--r--test/image/latin10k.txt (renamed from runsc/test/image/latin10k.txt)0
-rw-r--r--test/image/mysql.sql (renamed from runsc/test/image/mysql.sql)0
-rw-r--r--test/image/ruby.rb (renamed from runsc/test/image/ruby.rb)0
-rw-r--r--test/image/ruby.sh (renamed from runsc/test/image/ruby.sh)0
-rw-r--r--test/root/BUILD (renamed from runsc/test/root/BUILD)17
-rw-r--r--test/root/cgroup_test.go (renamed from runsc/test/root/cgroup_test.go)29
-rw-r--r--test/root/chroot_test.go (renamed from runsc/test/root/chroot_test.go)27
-rw-r--r--test/root/crictl_test.go (renamed from runsc/test/root/crictl_test.go)44
-rw-r--r--test/root/main_test.go49
-rw-r--r--test/root/oom_score_adj_test.go376
-rw-r--r--test/root/root.go21
-rw-r--r--test/root/testdata/BUILD (renamed from runsc/test/root/testdata/BUILD)2
-rw-r--r--test/root/testdata/busybox.go (renamed from runsc/test/root/testdata/busybox.go)0
-rw-r--r--test/root/testdata/containerd_config.go (renamed from runsc/test/root/testdata/containerd_config.go)0
-rw-r--r--test/root/testdata/httpd.go (renamed from runsc/test/root/testdata/httpd.go)0
-rw-r--r--test/root/testdata/httpd_mount_paths.go (renamed from runsc/test/root/testdata/httpd_mount_paths.go)0
-rw-r--r--test/root/testdata/sandbox.go (renamed from runsc/test/root/testdata/sandbox.go)0
-rw-r--r--test/runtimes/BUILD49
-rw-r--r--test/runtimes/README.md5
-rw-r--r--test/runtimes/blacklist_nodejs12.4.0.csv47
-rw-r--r--test/runtimes/build_defs.bzl42
-rw-r--r--test/runtimes/common/BUILD20
-rw-r--r--test/runtimes/common/common.go114
-rw-r--r--test/runtimes/go/BUILD9
-rw-r--r--test/runtimes/go/Dockerfile34
-rw-r--r--test/runtimes/images/Dockerfile_go1.1210
-rw-r--r--test/runtimes/images/Dockerfile_java1130
-rw-r--r--test/runtimes/images/Dockerfile_nodejs12.4.028
-rw-r--r--test/runtimes/images/Dockerfile_php7.3.627
-rw-r--r--test/runtimes/images/Dockerfile_python3.7.330
-rw-r--r--test/runtimes/images/proctor/BUILD26
-rw-r--r--test/runtimes/images/proctor/go.go (renamed from test/runtimes/go/proctor-go.go)67
-rw-r--r--test/runtimes/images/proctor/java.go (renamed from test/runtimes/java/proctor-java.go)49
-rw-r--r--test/runtimes/images/proctor/nodejs.go46
-rw-r--r--test/runtimes/images/proctor/php.go (renamed from test/runtimes/php/proctor-php.go)36
-rw-r--r--test/runtimes/images/proctor/proctor.go154
-rw-r--r--test/runtimes/images/proctor/proctor_test.go (renamed from test/runtimes/common/common_test.go)13
-rw-r--r--test/runtimes/images/proctor/python.go (renamed from test/runtimes/python/proctor-python.go)34
-rw-r--r--test/runtimes/java/BUILD9
-rw-r--r--test/runtimes/java/Dockerfile35
-rw-r--r--test/runtimes/nodejs/BUILD9
-rw-r--r--test/runtimes/nodejs/Dockerfile30
-rw-r--r--test/runtimes/nodejs/proctor-nodejs.go60
-rw-r--r--test/runtimes/php/BUILD9
-rw-r--r--test/runtimes/php/Dockerfile30
-rw-r--r--test/runtimes/python/BUILD9
-rw-r--r--test/runtimes/python/Dockerfile32
-rw-r--r--test/runtimes/runner.go199
-rwxr-xr-xtest/runtimes/runner.sh35
-rw-r--r--test/runtimes/runtimes_test.go93
-rw-r--r--test/syscalls/BUILD14
-rw-r--r--test/syscalls/build_defs.bzl4
-rw-r--r--test/syscalls/linux/BUILD171
-rw-r--r--test/syscalls/linux/affinity.cc1
-rw-r--r--test/syscalls/linux/aio.cc155
-rw-r--r--test/syscalls/linux/base_poll_test.h2
-rw-r--r--test/syscalls/linux/chown.cc30
-rw-r--r--test/syscalls/linux/clock_nanosleep.cc86
-rw-r--r--test/syscalls/linux/exec_binary.cc70
-rw-r--r--test/syscalls/linux/fcntl.cc47
-rw-r--r--test/syscalls/linux/futex.cc16
-rw-r--r--test/syscalls/linux/ip_socket_test_util.cc10
-rw-r--r--test/syscalls/linux/ip_socket_test_util.h25
-rw-r--r--test/syscalls/linux/itimer.cc4
-rw-r--r--test/syscalls/linux/kill.cc13
-rw-r--r--test/syscalls/linux/link.cc6
-rw-r--r--test/syscalls/linux/mlock.cc38
-rw-r--r--test/syscalls/linux/mremap.cc11
-rw-r--r--test/syscalls/linux/open.cc1
-rw-r--r--test/syscalls/linux/packet_socket.cc29
-rw-r--r--test/syscalls/linux/packet_socket_raw.cc21
-rw-r--r--test/syscalls/linux/pipe.cc14
-rw-r--r--test/syscalls/linux/prctl.cc9
-rw-r--r--test/syscalls/linux/prctl_setuid.cc24
-rw-r--r--test/syscalls/linux/proc.cc13
-rw-r--r--test/syscalls/linux/proc_net.cc5
-rw-r--r--test/syscalls/linux/proc_net_tcp.cc65
-rw-r--r--test/syscalls/linux/proc_net_udp.cc309
-rw-r--r--test/syscalls/linux/proc_net_unix.cc54
-rw-r--r--test/syscalls/linux/ptrace.cc11
-rw-r--r--test/syscalls/linux/pty.cc87
-rw-r--r--test/syscalls/linux/pty_root.cc2
-rw-r--r--test/syscalls/linux/pwritev2.cc16
-rw-r--r--test/syscalls/linux/raw_socket_hdrincl.cc43
-rw-r--r--test/syscalls/linux/raw_socket_icmp.cc13
-rw-r--r--test/syscalls/linux/raw_socket_ipv4.cc13
-rw-r--r--test/syscalls/linux/readahead.cc91
-rw-r--r--test/syscalls/linux/semaphore.cc2
-rw-r--r--test/syscalls/linux/sendfile.cc91
-rw-r--r--test/syscalls/linux/signalfd.cc333
-rw-r--r--test/syscalls/linux/sigstop.cc7
-rw-r--r--test/syscalls/linux/socket.cc19
-rw-r--r--test/syscalls/linux/socket_bind_to_device.cc314
-rw-r--r--test/syscalls/linux/socket_bind_to_device_distribution.cc381
-rw-r--r--test/syscalls/linux/socket_bind_to_device_sequence.cc316
-rw-r--r--test/syscalls/linux/socket_bind_to_device_util.cc75
-rw-r--r--test/syscalls/linux/socket_bind_to_device_util.h67
-rw-r--r--test/syscalls/linux/socket_ip_tcp_generic.cc2
-rw-r--r--test/syscalls/linux/socket_test_util.cc10
-rw-r--r--test/syscalls/linux/socket_test_util.h5
-rw-r--r--test/syscalls/linux/socket_test_util_impl.cc28
-rw-r--r--test/syscalls/linux/splice.cc194
-rw-r--r--test/syscalls/linux/sticky.cc31
-rw-r--r--test/syscalls/linux/tcp_socket.cc21
-rw-r--r--test/syscalls/linux/timers.cc7
-rw-r--r--test/syscalls/linux/uidgid.cc48
-rw-r--r--test/syscalls/linux/uname.cc14
-rw-r--r--test/syscalls/linux/unlink.cc2
-rw-r--r--test/syscalls/linux/vfork.cc7
-rw-r--r--test/syscalls/syscall_test_runner.go34
-rw-r--r--test/util/BUILD23
-rw-r--r--test/util/fs_util.cc9
-rw-r--r--test/util/fs_util.h3
-rw-r--r--test/util/memory_util.h12
-rw-r--r--test/util/proc_util.cc2
-rw-r--r--test/util/save_util.cc12
-rw-r--r--test/util/save_util.h5
-rw-r--r--test/util/save_util_linux.cc33
-rw-r--r--test/util/save_util_other.cc (renamed from runsc/test/testutil/testutil_race.go)14
-rw-r--r--test/util/test_util.cc4
-rw-r--r--test/util/test_util.h1
-rw-r--r--test/util/thread_util.h18
-rw-r--r--test/util/uid_util.cc44
-rw-r--r--test/util/uid_util.h29
-rw-r--r--third_party/gvsync/downgradable_rwmutex_unsafe.go3
-rwxr-xr-xtools/go_branch.sh6
-rw-r--r--tools/go_marshal/BUILD14
-rw-r--r--tools/go_marshal/README.md164
-rw-r--r--tools/go_marshal/analysis/BUILD13
-rw-r--r--tools/go_marshal/analysis/analysis_unsafe.go175
-rw-r--r--tools/go_marshal/defs.bzl152
-rw-r--r--tools/go_marshal/gomarshal/BUILD17
-rw-r--r--tools/go_marshal/gomarshal/generator.go382
-rw-r--r--tools/go_marshal/gomarshal/generator_interfaces.go507
-rw-r--r--tools/go_marshal/gomarshal/generator_tests.go154
-rw-r--r--tools/go_marshal/gomarshal/util.go387
-rw-r--r--tools/go_marshal/main.go73
-rw-r--r--tools/go_marshal/marshal/BUILD14
-rw-r--r--tools/go_marshal/marshal/marshal.go60
-rw-r--r--tools/go_marshal/test/BUILD31
-rw-r--r--tools/go_marshal/test/benchmark_test.go178
-rw-r--r--tools/go_marshal/test/external/BUILD11
-rw-r--r--tools/go_marshal/test/external/external.go (renamed from runsc/test/root/root.go)13
-rw-r--r--tools/go_marshal/test/test.go105
-rw-r--r--tools/go_stateify/defs.bzl65
-rwxr-xr-xtools/image_build.sh98
-rwxr-xr-xtools/make_repository.sh79
-rwxr-xr-xtools/run_build.sh49
-rwxr-xr-xtools/run_tests.sh302
-rwxr-xr-xtools/workspace_status.sh3
517 files changed, 17708 insertions, 5266 deletions
diff --git a/.bazelrc b/.bazelrc
index eda884473..379fc8328 100644
--- a/.bazelrc
+++ b/.bazelrc
@@ -13,7 +13,7 @@
# limitations under the License.
# Display the current git revision in the info block.
-build --workspace_status_command tools/workspace_status.sh
+build --stamp --workspace_status_command tools/workspace_status.sh
# Enable remote execution so actions are performed on the remote systems.
build:remote --remote_executor=grpcs://remotebuildexecution.googleapis.com
diff --git a/CONTRIBUTING.md b/CONTRIBUTING.md
index 638942a42..5d46168bc 100644
--- a/CONTRIBUTING.md
+++ b/CONTRIBUTING.md
@@ -83,6 +83,8 @@ Rules:
### Code reviews
+Before sending code reviews, run `bazel test ...` to ensure tests are passing.
+
Code changes are accepted via [pull request][github].
When approved, the change will be submitted by a team member and automatically
@@ -100,6 +102,36 @@ form `b/1234`. These correspond to bugs in our internal bug tracker. Eventually
these bugs will be moved to the GitHub Issues, but until then they can simply be
ignored.
+### Build and test with Docker
+
+`scripts/dev.sh` is a convenient script that builds and installs `runsc` as a
+new Docker runtime for you. The scripts tries to extract the runtime name from
+your local environment and will print it at the end. You can also customize it.
+The script creates one regular runtime and another with debug flags enabled.
+Here are a few examples:
+
+```bash
+# Default case (inside branch my-branch)
+$ scripts/dev.sh
+...
+Runtimes my-branch and my-branch-d (debug enabled) setup.
+Use --runtime=my-branch with your Docker command.
+ docker run --rm --runtime=my-branch --rm hello-world
+
+If you rebuild, use scripts/dev.sh --refresh.
+Logs are in: /tmp/my-branch/logs
+
+# --refresh just updates the runtime binary and doesn't restart docker.
+$ git/my_branch> scripts/dev.sh --refresh
+
+# Using a custom runtime name
+$ git/my_branch> scripts/dev.sh my-runtime
+...
+Runtimes my-runtime and my-runtime-d (debug enabled) setup.
+Use --runtime=my-runtime with your Docker command.
+ docker run --rm --runtime=my-runtime --rm hello-world
+```
+
### The small print
Contributions made by corporations are covered by a different agreement than the
diff --git a/Makefile b/Makefile
index 561618478..1735c07df 100644
--- a/Makefile
+++ b/Makefile
@@ -22,7 +22,7 @@ bazel-server-start: docker-build
--privileged \
gvisor-bazel \
sh -c "while :; do sleep 100; done" && \
- docker exec --user 0:0 -i gvisor-bazel sh -c "groupadd --gid $(GID) gvisor && useradd --uid $(UID) --gid $(GID) -d $(HOME) gvisor"
+ docker exec --user 0:0 -i gvisor-bazel sh -c "groupadd --gid $(GID) --non-unique gvisor && useradd --uid $(UID) --gid $(GID) -d $(HOME) gvisor"
bazel-server:
docker exec gvisor-bazel true || \
diff --git a/README.md b/README.md
index d102845ac..7ab76d305 100644
--- a/README.md
+++ b/README.md
@@ -48,7 +48,7 @@ Make sure the following dependencies are installed:
* Linux 4.14.77+ ([older linux][old-linux])
* [git][git]
-* [Bazel][bazel] 0.23.0+
+* [Bazel][bazel] 0.28.0+
* [Python][python]
* [Docker version 17.09.0 or greater][docker]
* Gold linker (e.g. `binutils-gold` package on Ubuntu)
diff --git a/WORKSPACE b/WORKSPACE
index e5c5dfa2b..082e26ee9 100644
--- a/WORKSPACE
+++ b/WORKSPACE
@@ -3,19 +3,19 @@ load("@bazel_tools//tools/build_defs/repo:http.bzl", "http_archive")
http_archive(
name = "io_bazel_rules_go",
- sha256 = "313f2c7a23fecc33023563f082f381a32b9b7254f727a7dd2d6380ccc6dfe09b",
+ sha256 = "513c12397db1bc9aa46dd62f02dd94b49a9b5d17444d49b5a04c5a89f3053c1c",
urls = [
- "https://storage.googleapis.com/bazel-mirror/github.com/bazelbuild/rules_go/releases/download/0.19.3/rules_go-0.19.3.tar.gz",
- "https://github.com/bazelbuild/rules_go/releases/download/0.19.3/rules_go-0.19.3.tar.gz",
+ "https://storage.googleapis.com/bazel-mirror/github.com/bazelbuild/rules_go/releases/download/v0.19.5/rules_go-v0.19.5.tar.gz",
+ "https://github.com/bazelbuild/rules_go/releases/download/v0.19.5/rules_go-v0.19.5.tar.gz",
],
)
http_archive(
name = "bazel_gazelle",
- sha256 = "be9296bfd64882e3c08e3283c58fcb461fa6dd3c171764fcc4cf322f60615a9b",
+ sha256 = "7fc87f4170011201b1690326e8c16c5d802836e3a0d617d8f75c3af2b23180c4",
urls = [
- "https://storage.googleapis.com/bazel-mirror/github.com/bazelbuild/bazel-gazelle/releases/download/0.18.1/bazel-gazelle-0.18.1.tar.gz",
- "https://github.com/bazelbuild/bazel-gazelle/releases/download/0.18.1/bazel-gazelle-0.18.1.tar.gz",
+ "https://storage.googleapis.com/bazel-mirror/github.com/bazelbuild/bazel-gazelle/releases/download/0.18.2/bazel-gazelle-0.18.2.tar.gz",
+ "https://github.com/bazelbuild/bazel-gazelle/releases/download/0.18.2/bazel-gazelle-0.18.2.tar.gz",
],
)
@@ -24,7 +24,7 @@ load("@io_bazel_rules_go//go:deps.bzl", "go_rules_dependencies", "go_register_to
go_rules_dependencies()
go_register_toolchains(
- go_version = "1.12.9",
+ go_version = "1.13.1",
nogo = "@//:nogo",
)
@@ -32,14 +32,26 @@ load("@bazel_gazelle//:deps.bzl", "gazelle_dependencies", "go_repository")
gazelle_dependencies()
-# Load protobuf dependencies.
-load("@bazel_tools//tools/build_defs/repo:git.bzl", "git_repository")
+# Load C++ rules.
+http_archive(
+ name = "rules_cc",
+ sha256 = "67412176974bfce3f4cf8bdaff39784a72ed709fc58def599d1f68710b58d68b",
+ strip_prefix = "rules_cc-b7fe9697c0c76ab2fd431a891dbb9a6a32ed7c3e",
+ urls = [
+ "https://mirror.bazel.build/github.com/bazelbuild/rules_cc/archive/b7fe9697c0c76ab2fd431a891dbb9a6a32ed7c3e.zip",
+ "https://github.com/bazelbuild/rules_cc/archive/b7fe9697c0c76ab2fd431a891dbb9a6a32ed7c3e.zip",
+ ],
+)
-git_repository(
+# Load protobuf dependencies.
+http_archive(
name = "com_google_protobuf",
- commit = "09745575a923640154bcf307fba8aedff47f240a",
- remote = "https://github.com/protocolbuffers/protobuf",
- shallow_since = "1558721209 -0700",
+ sha256 = "532d2575d8c0992065bb19ec5fba13aa3683499726f6055c11b474f91a00bb0c",
+ strip_prefix = "protobuf-7f520092d9050d96fb4b707ad11a51701af4ce49",
+ urls = [
+ "https://mirror.bazel.build/github.com/protocolbuffers/protobuf/archive/7f520092d9050d96fb4b707ad11a51701af4ce49.zip",
+ "https://github.com/protocolbuffers/protobuf/archive/7f520092d9050d96fb4b707ad11a51701af4ce49.zip",
+ ],
)
load("@com_google_protobuf//:protobuf_deps.bzl", "protobuf_deps")
@@ -50,11 +62,11 @@ protobuf_deps()
# See releases at https://releases.bazel.build/bazel-toolchains.html
http_archive(
name = "bazel_toolchains",
- sha256 = "e71eadcfcbdb47b4b740eb48b32ca4226e36aabc425d035a18dd40c2dda808c1",
- strip_prefix = "bazel-toolchains-0.28.4",
+ sha256 = "a019fbd579ce5aed0239de865b2d8281dbb809efd537bf42e0d366783e8dec65",
+ strip_prefix = "bazel-toolchains-0.29.2",
urls = [
- "https://mirror.bazel.build/github.com/bazelbuild/bazel-toolchains/archive/0.28.4.tar.gz",
- "https://github.com/bazelbuild/bazel-toolchains/archive/0.28.4.tar.gz",
+ "https://mirror.bazel.build/github.com/bazelbuild/bazel-toolchains/archive/0.29.2.tar.gz",
+ "https://github.com/bazelbuild/bazel-toolchains/archive/0.29.2.tar.gz",
],
)
@@ -63,6 +75,16 @@ load("@bazel_toolchains//rules:rbe_repo.bzl", "rbe_autoconfig")
rbe_autoconfig(name = "rbe_default")
+http_archive(
+ name = "rules_pkg",
+ sha256 = "5bdc04987af79bd27bc5b00fe30f59a858f77ffa0bd2d8143d5b31ad8b1bd71c",
+ url = "https://github.com/bazelbuild/rules_pkg/releases/download/0.2.0/rules_pkg-0.2.0.tar.gz",
+)
+
+load("@rules_pkg//:deps.bzl", "rules_pkg_dependencies")
+
+rules_pkg_dependencies()
+
# External repositories, in sorted order.
go_repository(
name = "com_github_cenkalti_backoff",
@@ -184,7 +206,7 @@ go_repository(
go_repository(
name = "org_golang_x_time",
- commit = "9d24e82272b4f38b78bc8cff74fa936d31ccd8ef",
+ commit = "c4c64cad1fd0a1a8dab2523e04e61d35308e131e",
importpath = "golang.org/x/time",
)
@@ -210,31 +232,21 @@ go_repository(
# System Call test dependencies.
http_archive(
- name = "com_github_gflags_gflags",
- sha256 = "34af2f15cf7367513b352bdcd2493ab14ce43692d2dcd9dfc499492966c64dcf",
- strip_prefix = "gflags-2.2.2",
- urls = [
- "https://mirror.bazel.build/github.com/gflags/gflags/archive/v2.2.2.tar.gz",
- "https://github.com/gflags/gflags/archive/v2.2.2.tar.gz",
- ],
-)
-
-http_archive(
name = "com_google_absl",
- sha256 = "01ba1185a0e6e048e4890f39e383515195bc335f0627cdddc0c325ee68be4434",
- strip_prefix = "abseil-cpp-cd86d0d20ab167c33b23d3875db68d1d4bad3a3b",
+ sha256 = "56775f1283a59e6274c28d99981a9717ff4e0b1161e9129fdb2fcf22531d8d93",
+ strip_prefix = "abseil-cpp-a0d1e098c2f99694fa399b175a7ccf920762030e",
urls = [
- "https://mirror.bazel.build/github.com/abseil/abseil-cpp/archive/cd86d0d20ab167c33b23d3875db68d1d4bad3a3b.tar.gz",
- "https://github.com/abseil/abseil-cpp/archive/cd86d0d20ab167c33b23d3875db68d1d4bad3a3b.tar.gz",
+ "https://mirror.bazel.build/github.com/abseil/abseil-cpp/archive/a0d1e098c2f99694fa399b175a7ccf920762030e.tar.gz",
+ "https://github.com/abseil/abseil-cpp/archive/a0d1e098c2f99694fa399b175a7ccf920762030e.tar.gz",
],
)
http_archive(
name = "com_google_googletest",
- sha256 = "db657310d3c5ca2d3f674e3a4b79718d1d39da70604568ee0568ba8e39065ef4",
- strip_prefix = "googletest-31200def0dec8a624c861f919e86e4444e6e6ee7",
+ sha256 = "0a10bea96d8670e5eef948d79d824162b1577bb7889539e49ec786bfc3e48912",
+ strip_prefix = "googletest-565f1b848215b77c3732bca345fe76a0431d8b34",
urls = [
- "https://mirror.bazel.build/github.com/google/googletest/archive/31200def0dec8a624c861f919e86e4444e6e6ee7.tar.gz",
- "https://github.com/google/googletest/archive/31200def0dec8a624c861f919e86e4444e6e6ee7.tar.gz",
+ "https://mirror.bazel.build/github.com/google/googletest/archive/565f1b848215b77c3732bca345fe76a0431d8b34.tar.gz",
+ "https://github.com/google/googletest/archive/565f1b848215b77c3732bca345fe76a0431d8b34.tar.gz",
],
)
diff --git a/cloudbuild/go.Dockerfile b/cloudbuild/go.Dockerfile
deleted file mode 100644
index 226442fd2..000000000
--- a/cloudbuild/go.Dockerfile
+++ /dev/null
@@ -1,2 +0,0 @@
-FROM ubuntu
-RUN apt-get -q update && apt-get install -qqy git rsync
diff --git a/cloudbuild/go.yaml b/cloudbuild/go.yaml
deleted file mode 100644
index a38ef71fc..000000000
--- a/cloudbuild/go.yaml
+++ /dev/null
@@ -1,22 +0,0 @@
-steps:
-- name: 'gcr.io/cloud-builders/git'
- args: ['fetch', '--all', '--unshallow']
-- name: 'gcr.io/cloud-builders/bazel'
- args: ['build', ':gopath']
-- name: 'gcr.io/cloud-builders/docker'
- args: ['build', '-t', 'gcr.io/$PROJECT_ID/go-branch', '-f', 'cloudbuild/go.Dockerfile', '.']
-- name: 'gcr.io/$PROJECT_ID/go-branch'
- args: ['tools/go_branch.sh']
-- name: 'gcr.io/cloud-builders/git'
- args: ['checkout', 'go']
-- name: 'gcr.io/cloud-builders/git'
- args: ['clean', '-f']
-- name: 'golang'
- args: ['go', 'build', './...']
-- name: 'gcr.io/cloud-builders/git'
- entrypoint: 'bash'
- args:
- - '-c'
- - 'if [[ "$BRANCH_NAME" == "master" ]]; then git push "${_ORIGIN}" go:go; fi'
-substitutions:
- _ORIGIN: origin
diff --git a/kokoro/build.cfg b/kokoro/build.cfg
new file mode 100644
index 000000000..6c1d262d4
--- /dev/null
+++ b/kokoro/build.cfg
@@ -0,0 +1,23 @@
+build_file: "repo/scripts/build.sh"
+
+before_action {
+ fetch_keystore {
+ keystore_resource {
+ keystore_config_id: 73898
+ keyname: "kokoro-repo-key"
+ }
+ }
+}
+
+env_vars {
+ key: "KOKORO_REPO_KEY"
+ value: "73898_kokoro-repo-key"
+}
+
+action {
+ define_artifacts {
+ regex: "**/runsc"
+ regex: "**/runsc.*"
+ regex: "**/dists/**"
+ }
+}
diff --git a/kokoro/build_tests.cfg b/kokoro/build_tests.cfg
new file mode 100644
index 000000000..c64b7e679
--- /dev/null
+++ b/kokoro/build_tests.cfg
@@ -0,0 +1 @@
+build_file: "repo/scripts/build.sh"
diff --git a/kokoro/common.cfg b/kokoro/common.cfg
index cad873fe1..669a2e458 100644
--- a/kokoro/common.cfg
+++ b/kokoro/common.cfg
@@ -10,7 +10,7 @@ before_action {
# Configure bazel to access RBE.
bazel_setting {
- # Our GCP project name
+ # Our GCP project name.
project_id: "gvisor-rbe"
# Use RBE for execution as well as caching.
diff --git a/kokoro/continuous.cfg b/kokoro/continuous.cfg
deleted file mode 100644
index 8da47736a..000000000
--- a/kokoro/continuous.cfg
+++ /dev/null
@@ -1,13 +0,0 @@
-# Location of bash script that runs the test. The first directory in the path
-# is the directory where Kokoro will check out the repo. The rest is the path
-# is the path to the test script.
-build_file: "repo/kokoro/run_tests.sh"
-
-action {
- define_artifacts {
- regex: "**/sponge_log.xml"
- regex: "**/sponge_log.log"
- regex: "**/outputs.zip"
- regex: "**/runsc-logs.tar.gz"
- }
-}
diff --git a/kokoro/do_tests.cfg b/kokoro/do_tests.cfg
new file mode 100644
index 000000000..b45ec0b42
--- /dev/null
+++ b/kokoro/do_tests.cfg
@@ -0,0 +1,9 @@
+build_file: "repo/scripts/do_tests.sh"
+
+action {
+ define_artifacts {
+ regex: "**/sponge_log.xml"
+ regex: "**/sponge_log.log"
+ regex: "**/outputs.zip"
+ }
+}
diff --git a/kokoro/docker_tests.cfg b/kokoro/docker_tests.cfg
new file mode 100644
index 000000000..0a0ef87ed
--- /dev/null
+++ b/kokoro/docker_tests.cfg
@@ -0,0 +1,10 @@
+build_file: "repo/scripts/docker_tests.sh"
+
+action {
+ define_artifacts {
+ regex: "**/sponge_log.xml"
+ regex: "**/sponge_log.log"
+ regex: "**/outputs.zip"
+ regex: "**/runsc_logs_*.tar.gz"
+ }
+}
diff --git a/kokoro/go.cfg b/kokoro/go.cfg
new file mode 100644
index 000000000..b9c1fcb12
--- /dev/null
+++ b/kokoro/go.cfg
@@ -0,0 +1,20 @@
+build_file: "repo/scripts/go.sh"
+
+before_action {
+ fetch_keystore {
+ keystore_resource {
+ keystore_config_id: 73898
+ keyname: "kokoro-github-access-token"
+ }
+ }
+}
+
+env_vars {
+ key: "KOKORO_GITHUB_ACCESS_TOKEN"
+ value: "73898_kokoro-github-access-token"
+}
+
+env_vars {
+ key: "KOKORO_GO_PUSH"
+ value: "true"
+}
diff --git a/kokoro/go_tests.cfg b/kokoro/go_tests.cfg
new file mode 100644
index 000000000..5eb51041a
--- /dev/null
+++ b/kokoro/go_tests.cfg
@@ -0,0 +1 @@
+build_file: "repo/scripts/go.sh"
diff --git a/kokoro/hostnet_tests.cfg b/kokoro/hostnet_tests.cfg
new file mode 100644
index 000000000..520dc55a3
--- /dev/null
+++ b/kokoro/hostnet_tests.cfg
@@ -0,0 +1,10 @@
+build_file: "repo/scripts/hostnet_tests.sh"
+
+action {
+ define_artifacts {
+ regex: "**/sponge_log.xml"
+ regex: "**/sponge_log.log"
+ regex: "**/outputs.zip"
+ regex: "**/runsc_logs_*.tar.gz"
+ }
+}
diff --git a/kokoro/kvm_tests.cfg b/kokoro/kvm_tests.cfg
new file mode 100644
index 000000000..1feb60c8a
--- /dev/null
+++ b/kokoro/kvm_tests.cfg
@@ -0,0 +1,10 @@
+build_file: "repo/scripts/kvm_tests.sh"
+
+action {
+ define_artifacts {
+ regex: "**/sponge_log.xml"
+ regex: "**/sponge_log.log"
+ regex: "**/outputs.zip"
+ regex: "**/runsc_logs_*.tar.gz"
+ }
+}
diff --git a/kokoro/make_tests.cfg b/kokoro/make_tests.cfg
new file mode 100644
index 000000000..d973130ff
--- /dev/null
+++ b/kokoro/make_tests.cfg
@@ -0,0 +1,9 @@
+build_file: "repo/scripts/make_tests.sh"
+
+action {
+ define_artifacts {
+ regex: "**/sponge_log.xml"
+ regex: "**/sponge_log.log"
+ regex: "**/outputs.zip"
+ }
+}
diff --git a/kokoro/overlay_tests.cfg b/kokoro/overlay_tests.cfg
new file mode 100644
index 000000000..6a2ddbd03
--- /dev/null
+++ b/kokoro/overlay_tests.cfg
@@ -0,0 +1,10 @@
+build_file: "repo/scripts/overlay_tests.sh"
+
+action {
+ define_artifacts {
+ regex: "**/sponge_log.xml"
+ regex: "**/sponge_log.log"
+ regex: "**/outputs.zip"
+ regex: "**/runsc_logs_*.tar.gz"
+ }
+}
diff --git a/kokoro/presubmit.cfg b/kokoro/presubmit.cfg
deleted file mode 100644
index 8da47736a..000000000
--- a/kokoro/presubmit.cfg
+++ /dev/null
@@ -1,13 +0,0 @@
-# Location of bash script that runs the test. The first directory in the path
-# is the directory where Kokoro will check out the repo. The rest is the path
-# is the path to the test script.
-build_file: "repo/kokoro/run_tests.sh"
-
-action {
- define_artifacts {
- regex: "**/sponge_log.xml"
- regex: "**/sponge_log.log"
- regex: "**/outputs.zip"
- regex: "**/runsc-logs.tar.gz"
- }
-}
diff --git a/kokoro/release-nightly.cfg b/kokoro/release-nightly.cfg
deleted file mode 100644
index e5087b1cd..000000000
--- a/kokoro/release-nightly.cfg
+++ /dev/null
@@ -1,11 +0,0 @@
-# Location of bash script that builds a release.
-build_file: "repo/kokoro/run_build.sh"
-
-action {
- # Upload runsc binary and its checksum. It may be in multiple paths, so we
- # must use the wildcard.
- define_artifacts {
- regex: "**/runsc"
- regex: "**/runsc.sha512"
- }
-}
diff --git a/kokoro/release.cfg b/kokoro/release.cfg
new file mode 100644
index 000000000..b9d35bc51
--- /dev/null
+++ b/kokoro/release.cfg
@@ -0,0 +1 @@
+build_file: "repo/scripts/release.sh"
diff --git a/kokoro/root_tests.cfg b/kokoro/root_tests.cfg
new file mode 100644
index 000000000..28351695c
--- /dev/null
+++ b/kokoro/root_tests.cfg
@@ -0,0 +1,10 @@
+build_file: "repo/scripts/root_tests.sh"
+
+action {
+ define_artifacts {
+ regex: "**/sponge_log.xml"
+ regex: "**/sponge_log.log"
+ regex: "**/outputs.zip"
+ regex: "**/runsc_logs_*.tar.gz"
+ }
+}
diff --git a/kokoro/run_build.sh b/kokoro/run_build.sh
deleted file mode 120000
index 9deafe9bb..000000000
--- a/kokoro/run_build.sh
+++ /dev/null
@@ -1 +0,0 @@
-../tools/run_build.sh \ No newline at end of file
diff --git a/kokoro/run_tests.sh b/kokoro/run_tests.sh
deleted file mode 120000
index 931cd2622..000000000
--- a/kokoro/run_tests.sh
+++ /dev/null
@@ -1 +0,0 @@
-../tools/run_tests.sh \ No newline at end of file
diff --git a/kokoro/simple_tests.cfg b/kokoro/simple_tests.cfg
new file mode 100644
index 000000000..32e0a9431
--- /dev/null
+++ b/kokoro/simple_tests.cfg
@@ -0,0 +1,9 @@
+build_file: "repo/scripts/simple_tests.sh"
+
+action {
+ define_artifacts {
+ regex: "**/sponge_log.xml"
+ regex: "**/sponge_log.log"
+ regex: "**/outputs.zip"
+ }
+}
diff --git a/kokoro/syscall_tests.cfg b/kokoro/syscall_tests.cfg
new file mode 100644
index 000000000..ee6e4a3a4
--- /dev/null
+++ b/kokoro/syscall_tests.cfg
@@ -0,0 +1,9 @@
+build_file: "repo/scripts/syscall_tests.sh"
+
+action {
+ define_artifacts {
+ regex: "**/sponge_log.xml"
+ regex: "**/sponge_log.log"
+ regex: "**/outputs.zip"
+ }
+}
diff --git a/kokoro/ubuntu1604/10_core.sh b/kokoro/ubuntu1604/10_core.sh
new file mode 100755
index 000000000..e87a6eee8
--- /dev/null
+++ b/kokoro/ubuntu1604/10_core.sh
@@ -0,0 +1,30 @@
+#!/bin/bash
+
+# Copyright 2019 The gVisor Authors.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+set -xeo pipefail
+
+# Install all essential build tools.
+apt-get update && apt-get -y install make git-core build-essential linux-headers-$(uname -r) pkg-config
+
+# Install a recent go toolchain.
+if ! [[ -d /usr/local/go ]]; then
+ wget https://dl.google.com/go/go1.12.linux-amd64.tar.gz
+ tar -xvf go1.12.linux-amd64.tar.gz
+ mv go /usr/local
+fi
+
+# Link the Go binary from /usr/bin; replacing anything there.
+(cd /usr/bin && rm -f go && sudo ln -fs /usr/local/go/bin/go go)
diff --git a/kokoro/ubuntu1604/20_bazel.sh b/kokoro/ubuntu1604/20_bazel.sh
new file mode 100755
index 000000000..b9a894024
--- /dev/null
+++ b/kokoro/ubuntu1604/20_bazel.sh
@@ -0,0 +1,28 @@
+#!/bin/bash
+
+# Copyright 2019 The gVisor Authors.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+set -xeo pipefail
+
+declare -r BAZEL_VERSION=0.29.1
+
+# Install bazel dependencies.
+apt-get update && apt-get install -y openjdk-8-jdk-headless unzip
+
+# Use the release installer.
+curl -L -o bazel-${BAZEL_VERSION}-installer-linux-x86_64.sh https://github.com/bazelbuild/bazel/releases/download/${BAZEL_VERSION}/bazel-${BAZEL_VERSION}-installer-linux-x86_64.sh
+chmod a+x bazel-${BAZEL_VERSION}-installer-linux-x86_64.sh
+./bazel-${BAZEL_VERSION}-installer-linux-x86_64.sh
+rm -f bazel-${BAZEL_VERSION}-installer-linux-x86_64.sh
diff --git a/kokoro/ubuntu1604/25_docker.sh b/kokoro/ubuntu1604/25_docker.sh
new file mode 100755
index 000000000..1d3defcd3
--- /dev/null
+++ b/kokoro/ubuntu1604/25_docker.sh
@@ -0,0 +1,35 @@
+#!/bin/bash
+
+# Copyright 2019 The gVisor Authors.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+# Add dependencies.
+apt-get update && apt-get -y install \
+ apt-transport-https \
+ ca-certificates \
+ curl \
+ gnupg-agent \
+ software-properties-common
+
+# Install the key.
+curl -fsSL https://download.docker.com/linux/ubuntu/gpg | apt-key add -
+
+# Add the repository.
+add-apt-repository \
+ "deb [arch=amd64] https://download.docker.com/linux/ubuntu \
+ $(lsb_release -cs) \
+ stable"
+
+# Install docker.
+apt-get update && apt-get install -y docker-ce docker-ce-cli containerd.io
diff --git a/kokoro/ubuntu1604/30_containerd.sh b/kokoro/ubuntu1604/30_containerd.sh
new file mode 100755
index 000000000..a7472bd1c
--- /dev/null
+++ b/kokoro/ubuntu1604/30_containerd.sh
@@ -0,0 +1,76 @@
+#!/bin/bash
+
+# Copyright 2019 The gVisor Authors.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+set -xeo pipefail
+
+# Helper for Go packages below.
+install_helper() {
+ PACKAGE="${1}"
+ TAG="${2}"
+ GOPATH="${3}"
+
+ # Clone the repository.
+ mkdir -p "${GOPATH}"/src/$(dirname "${PACKAGE}") && \
+ git clone https://"${PACKAGE}" "${GOPATH}"/src/"${PACKAGE}"
+
+ # Checkout and build the repository.
+ (cd "${GOPATH}"/src/"${PACKAGE}" && \
+ git checkout "${TAG}" && \
+ GOPATH="${GOPATH}" make && \
+ GOPATH="${GOPATH}" make install)
+}
+
+# Install dependencies for the crictl tests.
+apt-get install -y btrfs-tools libseccomp-dev
+
+# Install containerd & cri-tools.
+GOPATH=$(mktemp -d --tmpdir gopathXXXXX)
+install_helper github.com/containerd/containerd v1.2.2 "${GOPATH}"
+install_helper github.com/kubernetes-sigs/cri-tools v1.11.0 "${GOPATH}"
+
+# Install gvisor-containerd-shim.
+declare -r base="https://storage.googleapis.com/cri-containerd-staging/gvisor-containerd-shim"
+declare -r latest=$(mktemp --tmpdir gvisor-containerd-shim-latest.XXXXXX)
+declare -r shim_path=$(mktemp --tmpdir gvisor-containerd-shim.XXXXXX)
+wget --no-verbose "${base}"/latest -O ${latest}
+wget --no-verbose "${base}"/gvisor-containerd-shim-$(cat ${latest}) -O ${shim_path}
+chmod +x ${shim_path}
+mv ${shim_path} /usr/local/bin
+
+# Configure containerd-shim.
+declare -r shim_config_path=/etc/containerd
+declare -r shim_config_tmp_path=$(mktemp --tmpdir gvisor-containerd-shim.XXXXXX.toml)
+mkdir -p ${shim_config_path}
+cat > ${shim_config_tmp_path} <<-EOF
+ runc_shim = "/usr/local/bin/containerd-shim"
+
+[runsc_config]
+ debug = "true"
+ debug-log = "/tmp/runsc-logs/"
+ strace = "true"
+ file-access = "shared"
+EOF
+mv ${shim_config_tmp_path} ${shim_config_path}
+
+# Configure CNI.
+(cd "${GOPATH}" && GOPATH="${GOPATH}" \
+ src/github.com/containerd/containerd/script/setup/install-cni)
+
+# Cleanup the above.
+rm -rf "${GOPATH}"
+rm -rf "${latest}"
+rm -rf "${shim_path}"
+rm -rf "${shim_config_tmp_path}"
diff --git a/kokoro/ubuntu1604/40_kokoro.sh b/kokoro/ubuntu1604/40_kokoro.sh
new file mode 100755
index 000000000..64772d74d
--- /dev/null
+++ b/kokoro/ubuntu1604/40_kokoro.sh
@@ -0,0 +1,54 @@
+#!/bin/bash
+
+# Copyright 2019 The gVisor Authors.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+set -xeo pipefail
+
+# Declare kokoro's required public keys.
+declare -r ssh_public_keys=(
+ "ssh-rsa AAAAB3NzaC1yc2EAAAADAQABAAABAQDg7L/ZaEauETWrPklUTky3kvxqQfe2Ax/2CsSqhNIGNMnK/8d79CHlmY9+dE1FFQ/RzKNCaltgy7XcN/fCYiCZr5jm2ZtnLuGNOTzupMNhaYiPL419qmL+5rZXt4/dWTrsHbFRACxT8j51PcRMO5wgbL0Bg2XXimbx8kDFaurL2gqduQYqlu4lxWCaJqOL71WogcimeL63Nq/yeH5PJPWpqE4P9VUQSwAzBWFK/hLeds/AiP3MgVS65qHBnhq0JsHy8JQsqjZbG7Iidt/Ll0+gqzEbi62gDIcczG4KC0iOVzDDP/1BxDtt1lKeA23ll769Fcm3rJyoBMYxjvdw1TDx sabujp@trigger.mtv.corp.google.com"
+ "ecdsa-sha2-nistp256 AAAAE2VjZHNhLXNoYTItbmlzdHAyNTYAAAAIbmlzdHAyNTYAAABBBNgGK/hCdjmulHfRE3hp4rZs38NCR8yAh0eDsztxqGcuXnuSnL7jOlRrbcQpremJ84omD4eKrIpwJUs+YokMdv4= sabujp@trigger.svl.corp.google.com"
+)
+
+# Install dependencies.
+apt-get update && apt-get install -y rsync coreutils python-psutil qemu-kvm
+
+# We need a kbuilder user.
+if useradd -c "kbuilder user" -m -s /bin/bash kbuilder; then
+ # User was added successfully; we add the relevant SSH keys here.
+ mkdir -p ~kbuilder/.ssh
+ (IFS=$'\n'; echo "${ssh_public_keys[*]}") > ~kbuilder/.ssh/authorized_keys
+ chmod 0600 ~kbuilder/.ssh/authorized_keys
+ chown -R kbuilder ~kbuilder/.ssh
+fi
+
+# Give passwordless sudo access.
+cat > /etc/sudoers.d/kokoro <<EOF
+kbuilder ALL=(ALL) NOPASSWD:ALL
+EOF
+
+# Ensure we can run Docker without sudo.
+usermod -aG docker kbuilder
+
+# Ensure that we can access kvm.
+usermod -aG kvm kbuilder
+
+# Ensure that /tmpfs exists and is writable by kokoro.
+#
+# Note that kokoro will typically attach a second disk (sdb) to the instance
+# that is used for the /tmpfs volume. In the future we could setup an init
+# script that formats and mounts this here; however, we don't expect our build
+# artifacts to be that large.
+mkdir -p /tmpfs && chmod 0777 /tmpfs && touch /tmpfs/READY
diff --git a/kokoro/ubuntu1604/build.sh b/kokoro/ubuntu1604/build.sh
new file mode 100755
index 000000000..d664a3a76
--- /dev/null
+++ b/kokoro/ubuntu1604/build.sh
@@ -0,0 +1,20 @@
+#!/bin/bash
+
+# Copyright 2019 The gVisor Authors.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+set -xeo pipefail
+
+# Run the image_build.sh script with appropriate parameters.
+IMAGE_PROJECT=ubuntu-os-cloud IMAGE_FAMILY=ubuntu-1604-lts $(dirname $0)/../../tools/image_build.sh $(dirname $0)/??_*.sh
diff --git a/kokoro/ubuntu1804/10_core.sh b/kokoro/ubuntu1804/10_core.sh
new file mode 120000
index 000000000..6facceeee
--- /dev/null
+++ b/kokoro/ubuntu1804/10_core.sh
@@ -0,0 +1 @@
+../ubuntu1604/10_core.sh \ No newline at end of file
diff --git a/kokoro/ubuntu1804/20_bazel.sh b/kokoro/ubuntu1804/20_bazel.sh
new file mode 120000
index 000000000..39194c0f5
--- /dev/null
+++ b/kokoro/ubuntu1804/20_bazel.sh
@@ -0,0 +1 @@
+../ubuntu1604/20_bazel.sh \ No newline at end of file
diff --git a/kokoro/ubuntu1804/25_docker.sh b/kokoro/ubuntu1804/25_docker.sh
new file mode 120000
index 000000000..63269bd83
--- /dev/null
+++ b/kokoro/ubuntu1804/25_docker.sh
@@ -0,0 +1 @@
+../ubuntu1604/25_docker.sh \ No newline at end of file
diff --git a/kokoro/ubuntu1804/30_containerd.sh b/kokoro/ubuntu1804/30_containerd.sh
new file mode 120000
index 000000000..6ac2377ed
--- /dev/null
+++ b/kokoro/ubuntu1804/30_containerd.sh
@@ -0,0 +1 @@
+../ubuntu1604/30_containerd.sh \ No newline at end of file
diff --git a/kokoro/ubuntu1804/40_kokoro.sh b/kokoro/ubuntu1804/40_kokoro.sh
new file mode 120000
index 000000000..e861fb5e1
--- /dev/null
+++ b/kokoro/ubuntu1804/40_kokoro.sh
@@ -0,0 +1 @@
+../ubuntu1604/40_kokoro.sh \ No newline at end of file
diff --git a/kokoro/ubuntu1804/build.sh b/kokoro/ubuntu1804/build.sh
new file mode 100755
index 000000000..2b5c9a6f2
--- /dev/null
+++ b/kokoro/ubuntu1804/build.sh
@@ -0,0 +1,20 @@
+#!/bin/bash
+
+# Copyright 2019 The gVisor Authors.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+set -xeo pipefail
+
+# Run the image_build.sh script with appropriate parameters.
+IMAGE_PROJECT=ubuntu-os-cloud IMAGE_FAMILY=ubuntu-1804-lts $(dirname $0)/../../tools/image_build.sh $(dirname $0)/??_*.sh
diff --git a/pkg/abi/linux/BUILD b/pkg/abi/linux/BUILD
index ba233b93f..f45934466 100644
--- a/pkg/abi/linux/BUILD
+++ b/pkg/abi/linux/BUILD
@@ -2,9 +2,11 @@
# Linux kernel. It should be used instead of syscall or golang.org/x/sys/unix
# when the host OS may not be Linux.
+load("@io_bazel_rules_go//go:def.bzl", "go_test")
+
package(licenses = ["notice"])
-load("//tools/go_stateify:defs.bzl", "go_library", "go_test")
+load("//tools/go_stateify:defs.bzl", "go_library")
go_library(
name = "linux",
@@ -44,6 +46,7 @@ go_library(
"sem.go",
"shm.go",
"signal.go",
+ "signalfd.go",
"socket.go",
"splice.go",
"tcp.go",
diff --git a/pkg/abi/linux/signalfd.go b/pkg/abi/linux/signalfd.go
new file mode 100644
index 000000000..85fad9956
--- /dev/null
+++ b/pkg/abi/linux/signalfd.go
@@ -0,0 +1,45 @@
+// Copyright 2019 The gVisor Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+package linux
+
+const (
+ // SFD_NONBLOCK is a signalfd(2) flag.
+ SFD_NONBLOCK = 00004000
+
+ // SFD_CLOEXEC is a signalfd(2) flag.
+ SFD_CLOEXEC = 02000000
+)
+
+// SignalfdSiginfo is the siginfo encoding for signalfds.
+type SignalfdSiginfo struct {
+ Signo uint32
+ Errno int32
+ Code int32
+ PID uint32
+ UID uint32
+ FD int32
+ TID uint32
+ Band uint32
+ Overrun uint32
+ TrapNo uint32
+ Status int32
+ Int int32
+ Ptr uint64
+ UTime uint64
+ STime uint64
+ Addr uint64
+ AddrLSB uint16
+ _ [48]uint8
+}
diff --git a/pkg/amutex/BUILD b/pkg/amutex/BUILD
index 39d253b98..6bc486b62 100644
--- a/pkg/amutex/BUILD
+++ b/pkg/amutex/BUILD
@@ -1,4 +1,5 @@
-load("//tools/go_stateify:defs.bzl", "go_library", "go_test")
+load("//tools/go_stateify:defs.bzl", "go_library")
+load("@io_bazel_rules_go//go:def.bzl", "go_test")
package(licenses = ["notice"])
diff --git a/pkg/atomicbitops/BUILD b/pkg/atomicbitops/BUILD
index 47ab65346..5f59866fa 100644
--- a/pkg/atomicbitops/BUILD
+++ b/pkg/atomicbitops/BUILD
@@ -1,4 +1,5 @@
-load("//tools/go_stateify:defs.bzl", "go_library", "go_test")
+load("//tools/go_stateify:defs.bzl", "go_library")
+load("@io_bazel_rules_go//go:def.bzl", "go_test")
package(licenses = ["notice"])
diff --git a/pkg/binary/BUILD b/pkg/binary/BUILD
index 09d6c2c1f..543fb54bf 100644
--- a/pkg/binary/BUILD
+++ b/pkg/binary/BUILD
@@ -1,4 +1,5 @@
-load("//tools/go_stateify:defs.bzl", "go_library", "go_test")
+load("//tools/go_stateify:defs.bzl", "go_library")
+load("@io_bazel_rules_go//go:def.bzl", "go_test")
package(licenses = ["notice"])
diff --git a/pkg/bits/BUILD b/pkg/bits/BUILD
index 0c2dde4f8..51967b811 100644
--- a/pkg/bits/BUILD
+++ b/pkg/bits/BUILD
@@ -1,4 +1,5 @@
-load("//tools/go_stateify:defs.bzl", "go_library", "go_test")
+load("//tools/go_stateify:defs.bzl", "go_library")
+load("@io_bazel_rules_go//go:def.bzl", "go_test")
package(licenses = ["notice"])
diff --git a/pkg/bpf/BUILD b/pkg/bpf/BUILD
index b692aa3b1..8d31e068c 100644
--- a/pkg/bpf/BUILD
+++ b/pkg/bpf/BUILD
@@ -1,6 +1,8 @@
+load("@io_bazel_rules_go//go:def.bzl", "go_test")
+
package(licenses = ["notice"])
-load("//tools/go_stateify:defs.bzl", "go_library", "go_test")
+load("//tools/go_stateify:defs.bzl", "go_library")
go_library(
name = "bpf",
diff --git a/pkg/compressio/BUILD b/pkg/compressio/BUILD
index cdec96df1..a0b21d4bd 100644
--- a/pkg/compressio/BUILD
+++ b/pkg/compressio/BUILD
@@ -1,4 +1,5 @@
-load("//tools/go_stateify:defs.bzl", "go_library", "go_test")
+load("//tools/go_stateify:defs.bzl", "go_library")
+load("@io_bazel_rules_go//go:def.bzl", "go_test")
package(licenses = ["notice"])
diff --git a/pkg/cpuid/BUILD b/pkg/cpuid/BUILD
index 830e19e07..32422f9e2 100644
--- a/pkg/cpuid/BUILD
+++ b/pkg/cpuid/BUILD
@@ -1,6 +1,8 @@
+load("@io_bazel_rules_go//go:def.bzl", "go_test")
+
package(licenses = ["notice"])
-load("//tools/go_stateify:defs.bzl", "go_library", "go_test")
+load("//tools/go_stateify:defs.bzl", "go_library")
go_library(
name = "cpuid",
diff --git a/pkg/cpuid/cpuid.go b/pkg/cpuid/cpuid.go
index 3fabaf445..5d61dc2ff 100644
--- a/pkg/cpuid/cpuid.go
+++ b/pkg/cpuid/cpuid.go
@@ -418,6 +418,73 @@ var x86FeatureParseOnlyStrings = map[Feature]string{
X86FeaturePREFETCHWT1: "prefetchwt1",
}
+// intelCacheDescriptors describe the caches and TLBs on the system. They are
+// returned in the registers for eax=2. Intel only.
+type intelCacheDescriptor uint8
+
+// Valid cache/TLB descriptors. All descriptors can be found in Intel SDM Vol.
+// 2, Ch. 3.2, "CPUID", Table 3-12 "Encoding of CPUID Leaf 2 Descriptors".
+const (
+ intelNullDescriptor intelCacheDescriptor = 0
+ intelNoTLBDescriptor intelCacheDescriptor = 0xfe
+ intelNoCacheDescriptor intelCacheDescriptor = 0xff
+
+ // Most descriptors omitted for brevity as they are currently unused.
+)
+
+// CacheType describes the type of a cache, as returned in eax[4:0] for eax=4.
+type CacheType uint8
+
+const (
+ // cacheNull indicates that there are no more entries.
+ cacheNull CacheType = iota
+
+ // CacheData is a data cache.
+ CacheData
+
+ // CacheInstruction is an instruction cache.
+ CacheInstruction
+
+ // CacheUnified is a unified instruction and data cache.
+ CacheUnified
+)
+
+// Cache describes the parameters of a single cache on the system.
+//
+// +stateify savable
+type Cache struct {
+ // Level is the hierarchical level of this cache (L1, L2, etc).
+ Level uint32
+
+ // Type is the type of cache.
+ Type CacheType
+
+ // FullyAssociative indicates that entries may be placed in any block.
+ FullyAssociative bool
+
+ // Partitions is the number of physical partitions in the cache.
+ Partitions uint32
+
+ // Ways is the number of ways of associativity in the cache.
+ Ways uint32
+
+ // Sets is the number of sets in the cache.
+ Sets uint32
+
+ // InvalidateHierarchical indicates that WBINVD/INVD from threads
+ // sharing this cache acts upon lower level caches for threads sharing
+ // this cache.
+ InvalidateHierarchical bool
+
+ // Inclusive indicates that this cache is inclusive of lower cache
+ // levels.
+ Inclusive bool
+
+ // DirectMapped indicates that this cache is directly mapped from
+ // address, rather than using a hash function.
+ DirectMapped bool
+}
+
// Just a way to wrap cpuid function numbers.
type cpuidFunction uint32
@@ -494,7 +561,7 @@ func (f Feature) flagString(cpuinfoOnly bool) string {
return ""
}
-// FeatureSet is a set of Features for a cpu.
+// FeatureSet is a set of Features for a CPU.
//
// +stateify savable
type FeatureSet struct {
@@ -521,6 +588,15 @@ type FeatureSet struct {
// SteppingID is part of the processor signature.
SteppingID uint8
+
+ // Caches describes the caches on the CPU.
+ Caches []Cache
+
+ // CacheLine is the size of a cache line in bytes.
+ //
+ // All caches use the same line size. This is not enforced in the CPUID
+ // encoding, but is true on all known x86 processors.
+ CacheLine uint32
}
// FlagsString prints out supported CPU flags. If cpuinfoOnly is true, it is
@@ -557,22 +633,27 @@ func (fs FeatureSet) CPUInfo(cpu uint) string {
fmt.Fprintln(&b, "wp\t\t: yes")
fmt.Fprintf(&b, "flags\t\t: %s\n", fs.FlagsString(true))
fmt.Fprintf(&b, "bogomips\t: %.02f\n", cpuFreqMHz) // It's bogus anyway.
- fmt.Fprintf(&b, "clflush size\t: %d\n", 64)
- fmt.Fprintf(&b, "cache_alignment\t: %d\n", 64)
+ fmt.Fprintf(&b, "clflush size\t: %d\n", fs.CacheLine)
+ fmt.Fprintf(&b, "cache_alignment\t: %d\n", fs.CacheLine)
fmt.Fprintf(&b, "address sizes\t: %d bits physical, %d bits virtual\n", 46, 48)
fmt.Fprintln(&b, "power management:") // This is always here, but can be blank.
fmt.Fprintln(&b, "") // The /proc/cpuinfo file ends with an extra newline.
return b.String()
}
+const (
+ amdVendorID = "AuthenticAMD"
+ intelVendorID = "GenuineIntel"
+)
+
// AMD returns true if fs describes an AMD CPU.
func (fs *FeatureSet) AMD() bool {
- return fs.VendorID == "AuthenticAMD"
+ return fs.VendorID == amdVendorID
}
// Intel returns true if fs describes an Intel CPU.
func (fs *FeatureSet) Intel() bool {
- return fs.VendorID == "GenuineIntel"
+ return fs.VendorID == intelVendorID
}
// ErrIncompatible is returned by FeatureSet.HostCompatible if fs is not a
@@ -589,9 +670,18 @@ func (e ErrIncompatible) Error() string {
// CheckHostCompatible returns nil if fs is a subset of the host feature set.
func (fs *FeatureSet) CheckHostCompatible() error {
hfs := HostFeatureSet()
+
if diff := fs.Subtract(hfs); diff != nil {
return ErrIncompatible{fmt.Sprintf("CPU feature set %v incompatible with host feature set %v (missing: %v)", fs.FlagsString(false), hfs.FlagsString(false), diff)}
}
+
+ // The size of a cache line must match, as it is critical to correctly
+ // utilizing CLFLUSH. Other cache properties are allowed to change, as
+ // they are not important to correctness.
+ if fs.CacheLine != hfs.CacheLine {
+ return ErrIncompatible{fmt.Sprintf("CPU cache line size %d incompatible with host cache line size %d", fs.CacheLine, hfs.CacheLine)}
+ }
+
return nil
}
@@ -732,14 +822,6 @@ func (fs *FeatureSet) HasFeature(feature Feature) bool {
return fs.Set[feature]
}
-// IsSubset returns true if the FeatureSet is a subset of the FeatureSet passed in.
-// This is useful if you want to see if a FeatureSet is compatible with another
-// FeatureSet, since you can only run with a given FeatureSet if it's a subset of
-// the host's.
-func (fs *FeatureSet) IsSubset(other *FeatureSet) bool {
- return fs.Subtract(other) == nil
-}
-
// Subtract returns the features present in fs that are not present in other.
// If all features in fs are present in other, Subtract returns nil.
func (fs *FeatureSet) Subtract(other *FeatureSet) (diff map[Feature]bool) {
@@ -755,17 +837,6 @@ func (fs *FeatureSet) Subtract(other *FeatureSet) (diff map[Feature]bool) {
return
}
-// TakeFeatureIntersection will set the features in `fs` to the intersection of
-// the features in `fs` and `other` (effectively clearing any feature bits on
-// `fs` that are not also set in `other`).
-func (fs *FeatureSet) TakeFeatureIntersection(other *FeatureSet) {
- for f := range fs.Set {
- if !other.Set[f] {
- delete(fs.Set, f)
- }
- }
-}
-
// EmulateID emulates a cpuid instruction based on the feature set.
func (fs *FeatureSet) EmulateID(origAx, origCx uint32) (ax, bx, cx, dx uint32) {
switch cpuidFunction(origAx) {
@@ -773,9 +844,8 @@ func (fs *FeatureSet) EmulateID(origAx, origCx uint32) (ax, bx, cx, dx uint32) {
ax = uint32(xSaveInfo) // 0xd (xSaveInfo) is the highest function we support.
bx, dx, cx = fs.vendorIDRegs()
case featureInfo:
- // clflush line size (ebx bits[15:8]) hardcoded as 8. This
- // means cache lines of size 64 bytes.
- bx = 8 << 8
+ // CLFLUSH line size is encoded in quadwords. Other fields in bx unsupported.
+ bx = (fs.CacheLine / 8) << 8
cx = fs.blockMask(block(0))
dx = fs.blockMask(block(1))
ax = fs.signature()
@@ -789,10 +859,46 @@ func (fs *FeatureSet) EmulateID(origAx, origCx uint32) (ax, bx, cx, dx uint32) {
// will always return 01H. Software should ignore this value
// and not interpret it as an informational descriptor." - SDM
//
- // We do not support exposing cache information, but we do set
- // this fixed field because some language runtimes (dlang) get
- // confused by ax = 0 and will loop infinitely.
- ax = 1
+ // We only support reporting cache parameters via
+ // intelDeterministicCacheParams; report as much here.
+ //
+ // We do not support exposing TLB information at all.
+ ax = 1 | (uint32(intelNoCacheDescriptor) << 8)
+ case intelDeterministicCacheParams:
+ if !fs.Intel() {
+ // Reserved on non-Intel.
+ return 0, 0, 0, 0
+ }
+
+ // cx is the index of the cache to describe.
+ if int(origCx) >= len(fs.Caches) {
+ return uint32(cacheNull), 0, 0, 0
+ }
+ c := fs.Caches[origCx]
+
+ ax = uint32(c.Type)
+ ax |= c.Level << 5
+ ax |= 1 << 8 // Always claim the cache is "self-initializing".
+ if c.FullyAssociative {
+ ax |= 1 << 9
+ }
+ // Processor topology not supported.
+
+ bx = fs.CacheLine - 1
+ bx |= (c.Partitions - 1) << 12
+ bx |= (c.Ways - 1) << 22
+
+ cx = c.Sets - 1
+
+ if !c.InvalidateHierarchical {
+ dx |= 1
+ }
+ if c.Inclusive {
+ dx |= 1 << 1
+ }
+ if !c.DirectMapped {
+ dx |= 1 << 2
+ }
case xSaveInfo:
if !fs.UseXsave() {
return 0, 0, 0, 0
@@ -845,10 +951,41 @@ func HostFeatureSet() *FeatureSet {
vendorID := vendorIDFromRegs(bx, cx, dx)
// eax=1 gets basic features in ecx:edx.
- ax, _, cx, dx := HostID(1, 0)
+ ax, bx, cx, dx := HostID(1, 0)
featureBlock0 := cx
featureBlock1 := dx
ef, em, pt, f, m, sid := signatureSplit(ax)
+ cacheLine := 8 * (bx >> 8) & 0xff
+
+ // eax=4, ecx=i gets details about cache index i. Only supported on Intel.
+ var caches []Cache
+ if vendorID == intelVendorID {
+ // ecx selects the cache index until a null type is returned.
+ for i := uint32(0); ; i++ {
+ ax, bx, cx, dx := HostID(4, i)
+ t := CacheType(ax & 0xf)
+ if t == cacheNull {
+ break
+ }
+
+ lineSize := (bx & 0xfff) + 1
+ if lineSize != cacheLine {
+ panic(fmt.Sprintf("Mismatched cache line size: %d vs %d", lineSize, cacheLine))
+ }
+
+ caches = append(caches, Cache{
+ Type: t,
+ Level: (ax >> 5) & 0x7,
+ FullyAssociative: ((ax >> 9) & 1) == 1,
+ Partitions: ((bx >> 12) & 0x3ff) + 1,
+ Ways: ((bx >> 22) & 0x3ff) + 1,
+ Sets: cx + 1,
+ InvalidateHierarchical: (dx & 1) == 0,
+ Inclusive: ((dx >> 1) & 1) == 1,
+ DirectMapped: ((dx >> 2) & 1) == 0,
+ })
+ }
+ }
// eax=7, ecx=0 gets extended features in ecx:ebx.
_, bx, cx, _ = HostID(7, 0)
@@ -883,6 +1020,8 @@ func HostFeatureSet() *FeatureSet {
Family: f,
Model: m,
SteppingID: sid,
+ CacheLine: cacheLine,
+ Caches: caches,
}
}
diff --git a/pkg/cpuid/cpuid_test.go b/pkg/cpuid/cpuid_test.go
index 6ae14d2da..a707ebb55 100644
--- a/pkg/cpuid/cpuid_test.go
+++ b/pkg/cpuid/cpuid_test.go
@@ -57,24 +57,13 @@ var justFPUandPAE = &FeatureSet{
X86FeaturePAE: true,
}}
-func TestIsSubset(t *testing.T) {
- if !justFPU.IsSubset(justFPUandPAE) {
- t.Errorf("Got %v is not subset of %v, want IsSubset being true", justFPU, justFPUandPAE)
+func TestSubtract(t *testing.T) {
+ if diff := justFPU.Subtract(justFPUandPAE); diff != nil {
+ t.Errorf("Got %v is not subset of %v, want diff (%v) to be nil", justFPU, justFPUandPAE, diff)
}
- if justFPUandPAE.IsSubset(justFPU) {
- t.Errorf("Got %v is a subset of %v, want IsSubset being false", justFPU, justFPUandPAE)
- }
-}
-
-func TestTakeFeatureIntersection(t *testing.T) {
- testFeatures := HostFeatureSet()
- testFeatures.TakeFeatureIntersection(justFPU)
- if !testFeatures.IsSubset(justFPU) {
- t.Errorf("Got more features than expected after intersecting host features with justFPU: %v, want %v", testFeatures.Set, justFPU.Set)
- }
- if !testFeatures.HasFeature(X86FeatureFPU) {
- t.Errorf("Got no features in testFeatures after intersecting, want %v", X86FeatureFPU)
+ if justFPUandPAE.Subtract(justFPU) == nil {
+ t.Errorf("Got %v is a subset of %v, want diff to be nil", justFPU, justFPUandPAE)
}
}
@@ -83,7 +72,7 @@ func TestTakeFeatureIntersection(t *testing.T) {
// if HostFeatureSet gives back junk bits.
func TestHostFeatureSet(t *testing.T) {
hostFeatures := HostFeatureSet()
- if !justFPUandPAE.IsSubset(hostFeatures) {
+ if justFPUandPAE.Subtract(hostFeatures) != nil {
t.Errorf("Got invalid feature set %v from HostFeatureSet()", hostFeatures)
}
}
@@ -175,6 +164,7 @@ func TestEmulateIDBasicFeatures(t *testing.T) {
testFeatures := newEmptyFeatureSet()
testFeatures.Add(X86FeatureCLFSH)
testFeatures.Add(X86FeatureAVX)
+ testFeatures.CacheLine = 64
ax, bx, cx, dx := testFeatures.EmulateID(1, 0)
ECXAVXBit := uint32(1 << uint(X86FeatureAVX))
diff --git a/pkg/eventchannel/BUILD b/pkg/eventchannel/BUILD
index 9961baaa9..71f2abc83 100644
--- a/pkg/eventchannel/BUILD
+++ b/pkg/eventchannel/BUILD
@@ -1,5 +1,6 @@
-load("//tools/go_stateify:defs.bzl", "go_library", "go_test")
+load("//tools/go_stateify:defs.bzl", "go_library")
load("@io_bazel_rules_go//proto:def.bzl", "go_proto_library")
+load("@io_bazel_rules_go//go:def.bzl", "go_test")
package(licenses = ["notice"])
diff --git a/pkg/fd/BUILD b/pkg/fd/BUILD
index 785c685a0..c7f549428 100644
--- a/pkg/fd/BUILD
+++ b/pkg/fd/BUILD
@@ -1,4 +1,5 @@
-load("//tools/go_stateify:defs.bzl", "go_library", "go_test")
+load("//tools/go_stateify:defs.bzl", "go_library")
+load("@io_bazel_rules_go//go:def.bzl", "go_test")
package(licenses = ["notice"])
@@ -7,6 +8,9 @@ go_library(
srcs = ["fd.go"],
importpath = "gvisor.dev/gvisor/pkg/fd",
visibility = ["//visibility:public"],
+ deps = [
+ "//pkg/unet",
+ ],
)
go_test(
diff --git a/pkg/fd/fd.go b/pkg/fd/fd.go
index 83bcfe220..7691b477b 100644
--- a/pkg/fd/fd.go
+++ b/pkg/fd/fd.go
@@ -22,6 +22,8 @@ import (
"runtime"
"sync/atomic"
"syscall"
+
+ "gvisor.dev/gvisor/pkg/unet"
)
// ReadWriter implements io.ReadWriter, io.ReaderAt, and io.WriterAt for fd. It
@@ -185,6 +187,12 @@ func OpenAt(dir *FD, path string, flags int, mode uint32) (*FD, error) {
return New(f), nil
}
+// DialUnix connects to a Unix Domain Socket and return the file descriptor.
+func DialUnix(path string) (*FD, error) {
+ socket, err := unet.Connect(path, false)
+ return New(socket.FD()), err
+}
+
// Close closes the file descriptor contained in the FD.
//
// Close is safe to call multiple times, but will return an error after the
diff --git a/pkg/fdchannel/BUILD b/pkg/fdchannel/BUILD
index e54e7371c..56495cbd9 100644
--- a/pkg/fdchannel/BUILD
+++ b/pkg/fdchannel/BUILD
@@ -1,4 +1,5 @@
-load("//tools/go_stateify:defs.bzl", "go_library", "go_test")
+load("//tools/go_stateify:defs.bzl", "go_library")
+load("@io_bazel_rules_go//go:def.bzl", "go_test")
package(licenses = ["notice"])
diff --git a/pkg/flipcall/BUILD b/pkg/flipcall/BUILD
index bd1d614b6..5643d5f26 100644
--- a/pkg/flipcall/BUILD
+++ b/pkg/flipcall/BUILD
@@ -1,4 +1,5 @@
-load("//tools/go_stateify:defs.bzl", "go_library", "go_test")
+load("//tools/go_stateify:defs.bzl", "go_library")
+load("@io_bazel_rules_go//go:def.bzl", "go_test")
package(licenses = ["notice"])
@@ -18,6 +19,7 @@ go_library(
"//pkg/abi/linux",
"//pkg/log",
"//pkg/memutil",
+ "//third_party/gvsync",
],
)
diff --git a/pkg/flipcall/ctrl_futex.go b/pkg/flipcall/ctrl_futex.go
index 865b6f640..8390915a2 100644
--- a/pkg/flipcall/ctrl_futex.go
+++ b/pkg/flipcall/ctrl_futex.go
@@ -82,6 +82,7 @@ func (ep *Endpoint) ctrlWaitFirst() error {
*ep.dataLen() = w.Len()
// Return control to the client.
+ raceBecomeInactive()
if err := ep.futexSwitchToPeer(); err != nil {
return err
}
@@ -121,7 +122,16 @@ func (ep *Endpoint) enterFutexWait() error {
}
func (ep *Endpoint) exitFutexWait() {
- atomic.AddInt32(&ep.ctrl.state, -epsBlocked)
+ switch eps := atomic.AddInt32(&ep.ctrl.state, -epsBlocked); eps {
+ case 0:
+ return
+ case epsShutdown:
+ // ep.ctrlShutdown() was called while we were blocked, so we are
+ // repsonsible for indicating connection shutdown.
+ ep.shutdownConn()
+ default:
+ panic(fmt.Sprintf("invalid flipcall.Endpoint.ctrl.state after flipcall.Endpoint.exitFutexWait(): %v", eps+epsBlocked))
+ }
}
func (ep *Endpoint) ctrlShutdown() {
@@ -142,5 +152,25 @@ func (ep *Endpoint) ctrlShutdown() {
break
}
}
+ } else {
+ // There is no blocked thread, so we are responsible for indicating
+ // connection shutdown.
+ ep.shutdownConn()
+ }
+}
+
+func (ep *Endpoint) shutdownConn() {
+ switch cs := atomic.SwapUint32(ep.connState(), csShutdown); cs {
+ case ep.activeState:
+ if err := ep.futexWakeConnState(1); err != nil {
+ log.Warningf("failed to FUTEX_WAKE peer Endpoint for shutdown: %v", err)
+ }
+ case ep.inactiveState:
+ // The peer is currently active and will detect shutdown when it tries
+ // to update the connection state.
+ case csShutdown:
+ // The peer also called Endpoint.Shutdown().
+ default:
+ log.Warningf("unexpected connection state before Endpoint.shutdownConn(): %v", cs)
}
}
diff --git a/pkg/flipcall/flipcall.go b/pkg/flipcall/flipcall.go
index 5c9212c33..386cee42c 100644
--- a/pkg/flipcall/flipcall.go
+++ b/pkg/flipcall/flipcall.go
@@ -42,11 +42,6 @@ type Endpoint struct {
// dataCap is immutable.
dataCap uint32
- // shutdown is non-zero if Endpoint.Shutdown() has been called, or if the
- // Endpoint has acknowledged shutdown initiated by the peer. shutdown is
- // accessed using atomic memory operations.
- shutdown uint32
-
// activeState is csClientActive if this is a client Endpoint and
// csServerActive if this is a server Endpoint.
activeState uint32
@@ -55,9 +50,27 @@ type Endpoint struct {
// csClientActive if this is a server Endpoint.
inactiveState uint32
+ // shutdown is non-zero if Endpoint.Shutdown() has been called, or if the
+ // Endpoint has acknowledged shutdown initiated by the peer. shutdown is
+ // accessed using atomic memory operations.
+ shutdown uint32
+
ctrl endpointControlImpl
}
+// EndpointSide indicates which side of a connection an Endpoint belongs to.
+type EndpointSide int
+
+const (
+ // ClientSide indicates that an Endpoint is a client (initially-active;
+ // first method call should be Connect).
+ ClientSide EndpointSide = iota
+
+ // ServerSide indicates that an Endpoint is a server (initially-inactive;
+ // first method call should be RecvFirst.)
+ ServerSide
+)
+
// Init must be called on zero-value Endpoints before first use. If it
// succeeds, ep.Destroy() must be called once the Endpoint is no longer in use.
//
@@ -65,7 +78,17 @@ type Endpoint struct {
// Endpoint. FD may differ between Endpoints if they are in different
// processes, but must represent the same file. The packet window must
// initially be filled with zero bytes.
-func (ep *Endpoint) Init(pwd PacketWindowDescriptor, opts ...EndpointOption) error {
+func (ep *Endpoint) Init(side EndpointSide, pwd PacketWindowDescriptor, opts ...EndpointOption) error {
+ switch side {
+ case ClientSide:
+ ep.activeState = csClientActive
+ ep.inactiveState = csServerActive
+ case ServerSide:
+ ep.activeState = csServerActive
+ ep.inactiveState = csClientActive
+ default:
+ return fmt.Errorf("invalid EndpointSide: %v", side)
+ }
if pwd.Length < pageSize {
return fmt.Errorf("packet window size (%d) less than minimum (%d)", pwd.Length, pageSize)
}
@@ -78,9 +101,6 @@ func (ep *Endpoint) Init(pwd PacketWindowDescriptor, opts ...EndpointOption) err
}
ep.packet = m
ep.dataCap = uint32(pwd.Length) - uint32(PacketHeaderBytes)
- // These will be overwritten by ep.Connect() for client Endpoints.
- ep.activeState = csServerActive
- ep.inactiveState = csClientActive
if err := ep.ctrlInit(opts...); err != nil {
ep.unmapPacket()
return err
@@ -90,9 +110,9 @@ func (ep *Endpoint) Init(pwd PacketWindowDescriptor, opts ...EndpointOption) err
// NewEndpoint is a convenience function that returns an initialized Endpoint
// allocated on the heap.
-func NewEndpoint(pwd PacketWindowDescriptor, opts ...EndpointOption) (*Endpoint, error) {
+func NewEndpoint(side EndpointSide, pwd PacketWindowDescriptor, opts ...EndpointOption) (*Endpoint, error) {
var ep Endpoint
- if err := ep.Init(pwd, opts...); err != nil {
+ if err := ep.Init(side, pwd, opts...); err != nil {
return nil, err
}
return &ep, nil
@@ -115,9 +135,9 @@ func (ep *Endpoint) unmapPacket() {
}
// Shutdown causes concurrent and future calls to ep.Connect(), ep.SendRecv(),
-// ep.RecvFirst(), and ep.SendLast() to unblock and return errors. It does not
-// wait for concurrent calls to return. The effect of Shutdown on the peer
-// Endpoint is unspecified. Successive calls to Shutdown have no effect.
+// ep.RecvFirst(), and ep.SendLast(), as well as the same calls in the peer
+// Endpoint, to unblock and return errors. It does not wait for concurrent
+// calls to return. Successive calls to Shutdown have no effect.
//
// Shutdown is the only Endpoint method that may be called concurrently with
// other methods on the same Endpoint.
@@ -152,28 +172,31 @@ const (
// The client is, by definition, initially active, so this must be 0.
csClientActive = 0
csServerActive = 1
+ csShutdown = 2
)
-// Connect designates ep as a client Endpoint and blocks until the peer
-// Endpoint has called Endpoint.RecvFirst().
+// Connect blocks until the peer Endpoint has called Endpoint.RecvFirst().
//
-// Preconditions: ep.Connect(), ep.RecvFirst(), ep.SendRecv(), and
-// ep.SendLast() have never been called.
+// Preconditions: ep is a client Endpoint. ep.Connect(), ep.RecvFirst(),
+// ep.SendRecv(), and ep.SendLast() have never been called.
func (ep *Endpoint) Connect() error {
- ep.activeState = csClientActive
- ep.inactiveState = csServerActive
- return ep.ctrlConnect()
+ err := ep.ctrlConnect()
+ if err == nil {
+ raceBecomeActive()
+ }
+ return err
}
// RecvFirst blocks until the peer Endpoint calls Endpoint.SendRecv(), then
// returns the datagram length specified by that call.
//
-// Preconditions: ep.SendRecv(), ep.RecvFirst(), and ep.SendLast() have never
-// been called.
+// Preconditions: ep is a server Endpoint. ep.SendRecv(), ep.RecvFirst(), and
+// ep.SendLast() have never been called.
func (ep *Endpoint) RecvFirst() (uint32, error) {
if err := ep.ctrlWaitFirst(); err != nil {
return 0, err
}
+ raceBecomeActive()
recvDataLen := atomic.LoadUint32(ep.dataLen())
if recvDataLen > ep.dataCap {
return 0, fmt.Errorf("received packet with invalid datagram length %d (maximum %d)", recvDataLen, ep.dataCap)
@@ -200,9 +223,11 @@ func (ep *Endpoint) SendRecv(dataLen uint32) (uint32, error) {
// after ep.ctrlRoundTrip(), so if the peer is mutating it concurrently then
// they can only shoot themselves in the foot.
*ep.dataLen() = dataLen
+ raceBecomeInactive()
if err := ep.ctrlRoundTrip(); err != nil {
return 0, err
}
+ raceBecomeActive()
recvDataLen := atomic.LoadUint32(ep.dataLen())
if recvDataLen > ep.dataCap {
return 0, fmt.Errorf("received packet with invalid datagram length %d (maximum %d)", recvDataLen, ep.dataCap)
@@ -222,6 +247,7 @@ func (ep *Endpoint) SendLast(dataLen uint32) error {
panic(fmt.Sprintf("attempting to send packet with datagram length %d (maximum %d)", dataLen, ep.dataCap))
}
*ep.dataLen() = dataLen
+ raceBecomeInactive()
if err := ep.ctrlWakeLast(); err != nil {
return err
}
diff --git a/pkg/flipcall/flipcall_example_test.go b/pkg/flipcall/flipcall_example_test.go
index edb6a8bef..8d88b845d 100644
--- a/pkg/flipcall/flipcall_example_test.go
+++ b/pkg/flipcall/flipcall_example_test.go
@@ -38,12 +38,12 @@ func Example() {
panic(err)
}
var clientEP Endpoint
- if err := clientEP.Init(pwd); err != nil {
+ if err := clientEP.Init(ClientSide, pwd); err != nil {
panic(err)
}
defer clientEP.Destroy()
var serverEP Endpoint
- if err := serverEP.Init(pwd); err != nil {
+ if err := serverEP.Init(ServerSide, pwd); err != nil {
panic(err)
}
defer serverEP.Destroy()
diff --git a/pkg/flipcall/flipcall_test.go b/pkg/flipcall/flipcall_test.go
index da9d736ab..168a487ec 100644
--- a/pkg/flipcall/flipcall_test.go
+++ b/pkg/flipcall/flipcall_test.go
@@ -39,11 +39,11 @@ func newTestConnectionWithOptions(tb testing.TB, clientOpts, serverOpts []Endpoi
c.pwa.Destroy()
tb.Fatalf("PacketWindowAllocator.Allocate() failed: %v", err)
}
- if err := c.clientEP.Init(pwd, clientOpts...); err != nil {
+ if err := c.clientEP.Init(ClientSide, pwd, clientOpts...); err != nil {
c.pwa.Destroy()
tb.Fatalf("failed to create client Endpoint: %v", err)
}
- if err := c.serverEP.Init(pwd, serverOpts...); err != nil {
+ if err := c.serverEP.Init(ServerSide, pwd, serverOpts...); err != nil {
c.pwa.Destroy()
c.clientEP.Destroy()
tb.Fatalf("failed to create server Endpoint: %v", err)
@@ -62,17 +62,30 @@ func (c *testConnection) destroy() {
}
func testSendRecv(t *testing.T, c *testConnection) {
+ // This shared variable is used to confirm that synchronization between
+ // flipcall endpoints is visible to the Go race detector.
+ state := 0
var serverRun sync.WaitGroup
serverRun.Add(1)
go func() {
defer serverRun.Done()
t.Logf("server Endpoint waiting for packet 1")
if _, err := c.serverEP.RecvFirst(); err != nil {
- t.Fatalf("server Endpoint.RecvFirst() failed: %v", err)
+ t.Errorf("server Endpoint.RecvFirst() failed: %v", err)
+ return
+ }
+ state++
+ if state != 2 {
+ t.Errorf("shared state counter: got %d, wanted 2", state)
}
t.Logf("server Endpoint got packet 1, sending packet 2 and waiting for packet 3")
if _, err := c.serverEP.SendRecv(0); err != nil {
- t.Fatalf("server Endpoint.SendRecv() failed: %v", err)
+ t.Errorf("server Endpoint.SendRecv() failed: %v", err)
+ return
+ }
+ state++
+ if state != 4 {
+ t.Errorf("shared state counter: got %d, wanted 4", state)
}
t.Logf("server Endpoint got packet 3")
}()
@@ -87,10 +100,18 @@ func testSendRecv(t *testing.T, c *testConnection) {
if err := c.clientEP.Connect(); err != nil {
t.Fatalf("client Endpoint.Connect() failed: %v", err)
}
+ state++
+ if state != 1 {
+ t.Errorf("shared state counter: got %d, wanted 1", state)
+ }
t.Logf("client Endpoint sending packet 1 and waiting for packet 2")
if _, err := c.clientEP.SendRecv(0); err != nil {
t.Fatalf("client Endpoint.SendRecv() failed: %v", err)
}
+ state++
+ if state != 3 {
+ t.Errorf("shared state counter: got %d, wanted 3", state)
+ }
t.Logf("client Endpoint got packet 2, sending packet 3")
if err := c.clientEP.SendLast(0); err != nil {
t.Fatalf("client Endpoint.SendLast() failed: %v", err)
@@ -105,7 +126,30 @@ func TestSendRecv(t *testing.T) {
testSendRecv(t, c)
}
-func testShutdownConnect(t *testing.T, c *testConnection) {
+func testShutdownBeforeConnect(t *testing.T, c *testConnection, remoteShutdown bool) {
+ if remoteShutdown {
+ c.serverEP.Shutdown()
+ } else {
+ c.clientEP.Shutdown()
+ }
+ if err := c.clientEP.Connect(); err == nil {
+ t.Errorf("client Endpoint.Connect() succeeded unexpectedly")
+ }
+}
+
+func TestShutdownBeforeConnectLocal(t *testing.T) {
+ c := newTestConnection(t)
+ defer c.destroy()
+ testShutdownBeforeConnect(t, c, false)
+}
+
+func TestShutdownBeforeConnectRemote(t *testing.T) {
+ c := newTestConnection(t)
+ defer c.destroy()
+ testShutdownBeforeConnect(t, c, true)
+}
+
+func testShutdownDuringConnect(t *testing.T, c *testConnection, remoteShutdown bool) {
var clientRun sync.WaitGroup
clientRun.Add(1)
go func() {
@@ -115,44 +159,86 @@ func testShutdownConnect(t *testing.T, c *testConnection) {
}
}()
time.Sleep(time.Second) // to allow c.clientEP.Connect() to block
- c.clientEP.Shutdown()
+ if remoteShutdown {
+ c.serverEP.Shutdown()
+ } else {
+ c.clientEP.Shutdown()
+ }
clientRun.Wait()
}
-func TestShutdownConnect(t *testing.T) {
+func TestShutdownDuringConnectLocal(t *testing.T) {
+ c := newTestConnection(t)
+ defer c.destroy()
+ testShutdownDuringConnect(t, c, false)
+}
+
+func TestShutdownDuringConnectRemote(t *testing.T) {
+ c := newTestConnection(t)
+ defer c.destroy()
+ testShutdownDuringConnect(t, c, true)
+}
+
+func testShutdownBeforeRecvFirst(t *testing.T, c *testConnection, remoteShutdown bool) {
+ if remoteShutdown {
+ c.clientEP.Shutdown()
+ } else {
+ c.serverEP.Shutdown()
+ }
+ if _, err := c.serverEP.RecvFirst(); err == nil {
+ t.Errorf("server Endpoint.RecvFirst() succeeded unexpectedly")
+ }
+}
+
+func TestShutdownBeforeRecvFirstLocal(t *testing.T) {
+ c := newTestConnection(t)
+ defer c.destroy()
+ testShutdownBeforeRecvFirst(t, c, false)
+}
+
+func TestShutdownBeforeRecvFirstRemote(t *testing.T) {
c := newTestConnection(t)
defer c.destroy()
- testShutdownConnect(t, c)
+ testShutdownBeforeRecvFirst(t, c, true)
}
-func testShutdownRecvFirstBeforeConnect(t *testing.T, c *testConnection) {
+func testShutdownDuringRecvFirstBeforeConnect(t *testing.T, c *testConnection, remoteShutdown bool) {
var serverRun sync.WaitGroup
serverRun.Add(1)
go func() {
defer serverRun.Done()
- _, err := c.serverEP.RecvFirst()
- if err == nil {
+ if _, err := c.serverEP.RecvFirst(); err == nil {
t.Errorf("server Endpoint.RecvFirst() succeeded unexpectedly")
}
}()
time.Sleep(time.Second) // to allow c.serverEP.RecvFirst() to block
- c.serverEP.Shutdown()
+ if remoteShutdown {
+ c.clientEP.Shutdown()
+ } else {
+ c.serverEP.Shutdown()
+ }
serverRun.Wait()
}
-func TestShutdownRecvFirstBeforeConnect(t *testing.T) {
+func TestShutdownDuringRecvFirstBeforeConnectLocal(t *testing.T) {
+ c := newTestConnection(t)
+ defer c.destroy()
+ testShutdownDuringRecvFirstBeforeConnect(t, c, false)
+}
+
+func TestShutdownDuringRecvFirstBeforeConnectRemote(t *testing.T) {
c := newTestConnection(t)
defer c.destroy()
- testShutdownRecvFirstBeforeConnect(t, c)
+ testShutdownDuringRecvFirstBeforeConnect(t, c, true)
}
-func testShutdownRecvFirstAfterConnect(t *testing.T, c *testConnection) {
+func testShutdownDuringRecvFirstAfterConnect(t *testing.T, c *testConnection, remoteShutdown bool) {
var serverRun sync.WaitGroup
serverRun.Add(1)
go func() {
defer serverRun.Done()
if _, err := c.serverEP.RecvFirst(); err == nil {
- t.Fatalf("server Endpoint.RecvFirst() succeeded unexpectedly")
+ t.Errorf("server Endpoint.RecvFirst() succeeded unexpectedly")
}
}()
defer func() {
@@ -164,23 +250,75 @@ func testShutdownRecvFirstAfterConnect(t *testing.T, c *testConnection) {
if err := c.clientEP.Connect(); err != nil {
t.Fatalf("client Endpoint.Connect() failed: %v", err)
}
- c.serverEP.Shutdown()
+ if remoteShutdown {
+ c.clientEP.Shutdown()
+ } else {
+ c.serverEP.Shutdown()
+ }
serverRun.Wait()
}
-func TestShutdownRecvFirstAfterConnect(t *testing.T) {
+func TestShutdownDuringRecvFirstAfterConnectLocal(t *testing.T) {
+ c := newTestConnection(t)
+ defer c.destroy()
+ testShutdownDuringRecvFirstAfterConnect(t, c, false)
+}
+
+func TestShutdownDuringRecvFirstAfterConnectRemote(t *testing.T) {
+ c := newTestConnection(t)
+ defer c.destroy()
+ testShutdownDuringRecvFirstAfterConnect(t, c, true)
+}
+
+func testShutdownDuringClientSendRecv(t *testing.T, c *testConnection, remoteShutdown bool) {
+ var serverRun sync.WaitGroup
+ serverRun.Add(1)
+ go func() {
+ defer serverRun.Done()
+ if _, err := c.serverEP.RecvFirst(); err != nil {
+ t.Errorf("server Endpoint.RecvFirst() failed: %v", err)
+ }
+ // At this point, the client must be blocked in c.clientEP.SendRecv().
+ if remoteShutdown {
+ c.serverEP.Shutdown()
+ } else {
+ c.clientEP.Shutdown()
+ }
+ }()
+ defer func() {
+ // Ensure that the server goroutine is cleaned up before
+ // c.serverEP.Destroy(), even if the test fails.
+ c.serverEP.Shutdown()
+ serverRun.Wait()
+ }()
+ if err := c.clientEP.Connect(); err != nil {
+ t.Fatalf("client Endpoint.Connect() failed: %v", err)
+ }
+ if _, err := c.clientEP.SendRecv(0); err == nil {
+ t.Errorf("client Endpoint.SendRecv() succeeded unexpectedly")
+ }
+}
+
+func TestShutdownDuringClientSendRecvLocal(t *testing.T) {
+ c := newTestConnection(t)
+ defer c.destroy()
+ testShutdownDuringClientSendRecv(t, c, false)
+}
+
+func TestShutdownDuringClientSendRecvRemote(t *testing.T) {
c := newTestConnection(t)
defer c.destroy()
- testShutdownRecvFirstAfterConnect(t, c)
+ testShutdownDuringClientSendRecv(t, c, true)
}
-func testShutdownSendRecv(t *testing.T, c *testConnection) {
+func testShutdownDuringServerSendRecv(t *testing.T, c *testConnection, remoteShutdown bool) {
var serverRun sync.WaitGroup
serverRun.Add(1)
go func() {
defer serverRun.Done()
if _, err := c.serverEP.RecvFirst(); err != nil {
- t.Fatalf("server Endpoint.RecvFirst() failed: %v", err)
+ t.Errorf("server Endpoint.RecvFirst() failed: %v", err)
+ return
}
if _, err := c.serverEP.SendRecv(0); err == nil {
t.Errorf("server Endpoint.SendRecv() succeeded unexpectedly")
@@ -199,14 +337,24 @@ func testShutdownSendRecv(t *testing.T, c *testConnection) {
t.Fatalf("client Endpoint.SendRecv() failed: %v", err)
}
time.Sleep(time.Second) // to allow serverEP.SendRecv() to block
- c.serverEP.Shutdown()
+ if remoteShutdown {
+ c.clientEP.Shutdown()
+ } else {
+ c.serverEP.Shutdown()
+ }
serverRun.Wait()
}
-func TestShutdownSendRecv(t *testing.T) {
+func TestShutdownDuringServerSendRecvLocal(t *testing.T) {
c := newTestConnection(t)
defer c.destroy()
- testShutdownSendRecv(t, c)
+ testShutdownDuringServerSendRecv(t, c, false)
+}
+
+func TestShutdownDuringServerSendRecvRemote(t *testing.T) {
+ c := newTestConnection(t)
+ defer c.destroy()
+ testShutdownDuringServerSendRecv(t, c, true)
}
func benchmarkSendRecv(b *testing.B, c *testConnection) {
@@ -218,15 +366,17 @@ func benchmarkSendRecv(b *testing.B, c *testConnection) {
return
}
if _, err := c.serverEP.RecvFirst(); err != nil {
- b.Fatalf("server Endpoint.RecvFirst() failed: %v", err)
+ b.Errorf("server Endpoint.RecvFirst() failed: %v", err)
+ return
}
for i := 1; i < b.N; i++ {
if _, err := c.serverEP.SendRecv(0); err != nil {
- b.Fatalf("server Endpoint.SendRecv() failed: %v", err)
+ b.Errorf("server Endpoint.SendRecv() failed: %v", err)
+ return
}
}
if err := c.serverEP.SendLast(0); err != nil {
- b.Fatalf("server Endpoint.SendLast() failed: %v", err)
+ b.Errorf("server Endpoint.SendLast() failed: %v", err)
}
}()
defer func() {
diff --git a/pkg/flipcall/flipcall_unsafe.go b/pkg/flipcall/flipcall_unsafe.go
index 7c8977893..a37952637 100644
--- a/pkg/flipcall/flipcall_unsafe.go
+++ b/pkg/flipcall/flipcall_unsafe.go
@@ -17,17 +17,19 @@ package flipcall
import (
"reflect"
"unsafe"
+
+ "gvisor.dev/gvisor/third_party/gvsync"
)
-// Packets consist of an 8-byte header followed by an arbitrarily-sized
+// Packets consist of a 16-byte header followed by an arbitrarily-sized
// datagram. The header consists of:
//
// - A 4-byte native-endian connection state.
//
// - A 4-byte native-endian datagram length in bytes.
+//
+// - 8 reserved bytes.
const (
- sizeofUint32 = unsafe.Sizeof(uint32(0))
-
// PacketHeaderBytes is the size of a flipcall packet header in bytes. The
// maximum datagram size supported by a flipcall connection is equal to the
// length of the packet window minus PacketHeaderBytes.
@@ -35,7 +37,7 @@ const (
// PacketHeaderBytes is exported to support its use in constant
// expressions. Non-constant expressions may prefer to use
// PacketWindowLengthForDataCap().
- PacketHeaderBytes = 2 * sizeofUint32
+ PacketHeaderBytes = 16
)
func (ep *Endpoint) connState() *uint32 {
@@ -43,7 +45,7 @@ func (ep *Endpoint) connState() *uint32 {
}
func (ep *Endpoint) dataLen() *uint32 {
- return (*uint32)((unsafe.Pointer)(ep.packet + sizeofUint32))
+ return (*uint32)((unsafe.Pointer)(ep.packet + 4))
}
// Data returns the datagram part of ep's packet window as a byte slice.
@@ -67,3 +69,19 @@ func (ep *Endpoint) Data() []byte {
bsReflect.Cap = int(ep.dataCap)
return bs
}
+
+// ioSync is a dummy variable used to indicate synchronization to the Go race
+// detector. Compare syscall.ioSync.
+var ioSync int64
+
+func raceBecomeActive() {
+ if gvsync.RaceEnabled {
+ gvsync.RaceAcquire((unsafe.Pointer)(&ioSync))
+ }
+}
+
+func raceBecomeInactive() {
+ if gvsync.RaceEnabled {
+ gvsync.RaceReleaseMerge((unsafe.Pointer)(&ioSync))
+ }
+}
diff --git a/pkg/flipcall/futex_linux.go b/pkg/flipcall/futex_linux.go
index e7dd812b3..b127a2bbb 100644
--- a/pkg/flipcall/futex_linux.go
+++ b/pkg/flipcall/futex_linux.go
@@ -59,7 +59,12 @@ func (ep *Endpoint) futexConnect(req *ctrlHandshakeRequest) (ctrlHandshakeRespon
func (ep *Endpoint) futexSwitchToPeer() error {
// Update connection state to indicate that the peer should be active.
if !atomic.CompareAndSwapUint32(ep.connState(), ep.activeState, ep.inactiveState) {
- return fmt.Errorf("unexpected connection state before FUTEX_WAKE: %v", atomic.LoadUint32(ep.connState()))
+ switch cs := atomic.LoadUint32(ep.connState()); cs {
+ case csShutdown:
+ return shutdownError{}
+ default:
+ return fmt.Errorf("unexpected connection state before FUTEX_WAKE: %v", cs)
+ }
}
// Wake the peer's Endpoint.futexSwitchFromPeer().
@@ -75,16 +80,18 @@ func (ep *Endpoint) futexSwitchFromPeer() error {
case ep.activeState:
return nil
case ep.inactiveState:
- // Continue to FUTEX_WAIT.
+ if ep.isShutdownLocally() {
+ return shutdownError{}
+ }
+ if err := ep.futexWaitConnState(ep.inactiveState); err != nil {
+ return fmt.Errorf("failed to FUTEX_WAIT for peer Endpoint: %v", err)
+ }
+ continue
+ case csShutdown:
+ return shutdownError{}
default:
return fmt.Errorf("unexpected connection state before FUTEX_WAIT: %v", cs)
}
- if ep.isShutdownLocally() {
- return shutdownError{}
- }
- if err := ep.futexWaitConnState(ep.inactiveState); err != nil {
- return fmt.Errorf("failed to FUTEX_WAIT for peer Endpoint: %v", err)
- }
}
}
diff --git a/pkg/fspath/BUILD b/pkg/fspath/BUILD
index 11716af81..0c5f50397 100644
--- a/pkg/fspath/BUILD
+++ b/pkg/fspath/BUILD
@@ -1,4 +1,5 @@
-load("//tools/go_stateify:defs.bzl", "go_library", "go_test")
+load("//tools/go_stateify:defs.bzl", "go_library")
+load("@io_bazel_rules_go//go:def.bzl", "go_test")
package(
default_visibility = ["//visibility:public"],
diff --git a/pkg/gate/BUILD b/pkg/gate/BUILD
index e6a8dbd02..4b9321711 100644
--- a/pkg/gate/BUILD
+++ b/pkg/gate/BUILD
@@ -1,4 +1,5 @@
-load("//tools/go_stateify:defs.bzl", "go_library", "go_test")
+load("//tools/go_stateify:defs.bzl", "go_library")
+load("@io_bazel_rules_go//go:def.bzl", "go_test")
package(licenses = ["notice"])
diff --git a/pkg/ilist/BUILD b/pkg/ilist/BUILD
index 8f3defa25..34d2673ef 100644
--- a/pkg/ilist/BUILD
+++ b/pkg/ilist/BUILD
@@ -1,5 +1,6 @@
+load("@io_bazel_rules_go//go:def.bzl", "go_test")
load("//tools/go_generics:defs.bzl", "go_template", "go_template_instance")
-load("//tools/go_stateify:defs.bzl", "go_library", "go_test")
+load("//tools/go_stateify:defs.bzl", "go_library")
package(licenses = ["notice"])
diff --git a/pkg/linewriter/BUILD b/pkg/linewriter/BUILD
index c8e923a74..a5d980d14 100644
--- a/pkg/linewriter/BUILD
+++ b/pkg/linewriter/BUILD
@@ -1,4 +1,5 @@
-load("//tools/go_stateify:defs.bzl", "go_library", "go_test")
+load("//tools/go_stateify:defs.bzl", "go_library")
+load("@io_bazel_rules_go//go:def.bzl", "go_test")
package(licenses = ["notice"])
diff --git a/pkg/log/BUILD b/pkg/log/BUILD
index 12615240c..fc5f5779b 100644
--- a/pkg/log/BUILD
+++ b/pkg/log/BUILD
@@ -1,4 +1,5 @@
-load("//tools/go_stateify:defs.bzl", "go_library", "go_test")
+load("//tools/go_stateify:defs.bzl", "go_library")
+load("@io_bazel_rules_go//go:def.bzl", "go_test")
package(licenses = ["notice"])
diff --git a/pkg/metric/BUILD b/pkg/metric/BUILD
index 3b8a691f4..dd6ca6d39 100644
--- a/pkg/metric/BUILD
+++ b/pkg/metric/BUILD
@@ -1,5 +1,7 @@
-load("//tools/go_stateify:defs.bzl", "go_library", "go_test")
+load("//tools/go_stateify:defs.bzl", "go_library")
load("@io_bazel_rules_go//proto:def.bzl", "go_proto_library")
+load("@io_bazel_rules_go//go:def.bzl", "go_test")
+load("@rules_cc//cc:defs.bzl", "cc_proto_library")
package(licenses = ["notice"])
@@ -21,6 +23,12 @@ proto_library(
visibility = ["//:sandbox"],
)
+cc_proto_library(
+ name = "metric_cc_proto",
+ visibility = ["//:sandbox"],
+ deps = [":metric_proto"],
+)
+
go_proto_library(
name = "metric_go_proto",
importpath = "gvisor.dev/gvisor/pkg/metric/metric_go_proto",
diff --git a/pkg/p9/BUILD b/pkg/p9/BUILD
index c6737bf97..f32244c69 100644
--- a/pkg/p9/BUILD
+++ b/pkg/p9/BUILD
@@ -1,4 +1,5 @@
-load("//tools/go_stateify:defs.bzl", "go_library", "go_test")
+load("//tools/go_stateify:defs.bzl", "go_library")
+load("@io_bazel_rules_go//go:def.bzl", "go_test")
package(
default_visibility = ["//visibility:public"],
@@ -19,11 +20,14 @@ go_library(
"pool.go",
"server.go",
"transport.go",
+ "transport_flipcall.go",
"version.go",
],
importpath = "gvisor.dev/gvisor/pkg/p9",
deps = [
"//pkg/fd",
+ "//pkg/fdchannel",
+ "//pkg/flipcall",
"//pkg/log",
"//pkg/unet",
"@org_golang_x_sys//unix:go_default_library",
diff --git a/pkg/p9/client.go b/pkg/p9/client.go
index 7dc20aeef..2412aa5e1 100644
--- a/pkg/p9/client.go
+++ b/pkg/p9/client.go
@@ -20,6 +20,8 @@ import (
"sync"
"syscall"
+ "golang.org/x/sys/unix"
+ "gvisor.dev/gvisor/pkg/flipcall"
"gvisor.dev/gvisor/pkg/log"
"gvisor.dev/gvisor/pkg/unet"
)
@@ -77,6 +79,47 @@ type Client struct {
// fidPool is the collection of available fids.
fidPool pool
+ // messageSize is the maximum total size of a message.
+ messageSize uint32
+
+ // payloadSize is the maximum payload size of a read or write.
+ //
+ // For large reads and writes this means that the read or write is
+ // broken up into buffer-size/payloadSize requests.
+ payloadSize uint32
+
+ // version is the agreed upon version X of 9P2000.L.Google.X.
+ // version 0 implies 9P2000.L.
+ version uint32
+
+ // closedWg is marked as done when the Client.watch() goroutine, which is
+ // responsible for closing channels and the socket fd, returns.
+ closedWg sync.WaitGroup
+
+ // sendRecv is the transport function.
+ //
+ // This is determined dynamically based on whether or not the server
+ // supports flipcall channels (preferred as it is faster and more
+ // efficient, and does not require tags).
+ sendRecv func(message, message) error
+
+ // -- below corresponds to sendRecvChannel --
+
+ // channelsMu protects channels.
+ channelsMu sync.Mutex
+
+ // channelsWg counts the number of channels for which channel.active ==
+ // true.
+ channelsWg sync.WaitGroup
+
+ // channels is the set of all initialized channels.
+ channels []*channel
+
+ // availableChannels is a FIFO of inactive channels.
+ availableChannels []*channel
+
+ // -- below corresponds to sendRecvLegacy --
+
// pending is the set of pending messages.
pending map[Tag]*response
pendingMu sync.Mutex
@@ -89,25 +132,12 @@ type Client struct {
// Whoever writes to this channel is permitted to call recv. When
// finished calling recv, this channel should be emptied.
recvr chan bool
-
- // messageSize is the maximum total size of a message.
- messageSize uint32
-
- // payloadSize is the maximum payload size of a read or write
- // request. For large reads and writes this means that the
- // read or write is broken up into buffer-size/payloadSize
- // requests.
- payloadSize uint32
-
- // version is the agreed upon version X of 9P2000.L.Google.X.
- // version 0 implies 9P2000.L.
- version uint32
}
// NewClient creates a new client. It performs a Tversion exchange with
// the server to assert that messageSize is ok to use.
//
-// You should not use the same socket for multiple clients.
+// If NewClient succeeds, ownership of socket is transferred to the new Client.
func NewClient(socket *unet.Socket, messageSize uint32, version string) (*Client, error) {
// Need at least one byte of payload.
if messageSize <= msgRegistry.largestFixedSize {
@@ -138,8 +168,15 @@ func NewClient(socket *unet.Socket, messageSize uint32, version string) (*Client
return nil, ErrBadVersionString
}
for {
+ // Always exchange the version using the legacy version of the
+ // protocol. If the protocol supports flipcall, then we switch
+ // our sendRecv function to use that functionality. Otherwise,
+ // we stick to sendRecvLegacy.
rversion := Rversion{}
- err := c.sendRecv(&Tversion{Version: versionString(requested), MSize: messageSize}, &rversion)
+ err := c.sendRecvLegacy(&Tversion{
+ Version: versionString(requested),
+ MSize: messageSize,
+ }, &rversion)
// The server told us to try again with a lower version.
if err == syscall.EAGAIN {
@@ -165,9 +202,155 @@ func NewClient(socket *unet.Socket, messageSize uint32, version string) (*Client
c.version = version
break
}
+
+ // Can we switch to use the more advanced channels and create
+ // independent channels for communication? Prefer it if possible.
+ if versionSupportsFlipcall(c.version) {
+ // Attempt to initialize IPC-based communication.
+ for i := 0; i < channelsPerClient; i++ {
+ if err := c.openChannel(i); err != nil {
+ log.Warningf("error opening flipcall channel: %v", err)
+ break // Stop.
+ }
+ }
+ if len(c.channels) >= 1 {
+ // At least one channel created.
+ c.sendRecv = c.sendRecvChannel
+ } else {
+ // Channel setup failed; fallback.
+ c.sendRecv = c.sendRecvLegacy
+ }
+ } else {
+ // No channels available: use the legacy mechanism.
+ c.sendRecv = c.sendRecvLegacy
+ }
+
+ // Ensure that the socket and channels are closed when the socket is shut
+ // down.
+ c.closedWg.Add(1)
+ go c.watch(socket) // S/R-SAFE: not relevant.
+
return c, nil
}
+// watch watches the given socket and releases resources on hangup events.
+//
+// This is intended to be called as a goroutine.
+func (c *Client) watch(socket *unet.Socket) {
+ defer c.closedWg.Done()
+
+ events := []unix.PollFd{
+ unix.PollFd{
+ Fd: int32(socket.FD()),
+ Events: unix.POLLHUP | unix.POLLRDHUP,
+ },
+ }
+
+ // Wait for a shutdown event.
+ for {
+ n, err := unix.Ppoll(events, nil, nil)
+ if err == syscall.EINTR || err == syscall.EAGAIN {
+ continue
+ }
+ if err != nil {
+ log.Warningf("p9.Client.watch(): %v", err)
+ break
+ }
+ if n != 1 {
+ log.Warningf("p9.Client.watch(): got %d events, wanted 1", n)
+ }
+ break
+ }
+
+ // Set availableChannels to nil so that future calls to c.sendRecvChannel()
+ // don't attempt to activate a channel, and concurrent calls to
+ // c.sendRecvChannel() don't mark released channels as available.
+ c.channelsMu.Lock()
+ c.availableChannels = nil
+
+ // Shut down all active channels.
+ for _, ch := range c.channels {
+ if ch.active {
+ log.Debugf("shutting down active channel@%p...", ch)
+ ch.Shutdown()
+ }
+ }
+ c.channelsMu.Unlock()
+
+ // Wait for active channels to become inactive.
+ c.channelsWg.Wait()
+
+ // Close all channels.
+ c.channelsMu.Lock()
+ for _, ch := range c.channels {
+ ch.Close()
+ }
+ c.channelsMu.Unlock()
+
+ // Close the main socket.
+ c.socket.Close()
+}
+
+// openChannel attempts to open a client channel.
+//
+// Note that this function returns naked errors which should not be propagated
+// directly to a caller. It is expected that the errors will be logged and a
+// fallback path will be used instead.
+func (c *Client) openChannel(id int) error {
+ var (
+ rchannel0 Rchannel
+ rchannel1 Rchannel
+ res = new(channel)
+ )
+
+ // Open the data channel.
+ if err := c.sendRecvLegacy(&Tchannel{
+ ID: uint32(id),
+ Control: 0,
+ }, &rchannel0); err != nil {
+ return fmt.Errorf("error handling Tchannel message: %v", err)
+ }
+ if rchannel0.FilePayload() == nil {
+ return fmt.Errorf("missing file descriptor on primary channel")
+ }
+
+ // We don't need to hold this.
+ defer rchannel0.FilePayload().Close()
+
+ // Open the channel for file descriptors.
+ if err := c.sendRecvLegacy(&Tchannel{
+ ID: uint32(id),
+ Control: 1,
+ }, &rchannel1); err != nil {
+ return err
+ }
+ if rchannel1.FilePayload() == nil {
+ return fmt.Errorf("missing file descriptor on file descriptor channel")
+ }
+
+ // Construct the endpoints.
+ res.desc = flipcall.PacketWindowDescriptor{
+ FD: rchannel0.FilePayload().FD(),
+ Offset: int64(rchannel0.Offset),
+ Length: int(rchannel0.Length),
+ }
+ if err := res.data.Init(flipcall.ClientSide, res.desc); err != nil {
+ rchannel1.FilePayload().Close()
+ return err
+ }
+
+ // The fds channel owns the control payload, and it will be closed when
+ // the channel object is closed.
+ res.fds.Init(rchannel1.FilePayload().Release())
+
+ // Save the channel.
+ c.channelsMu.Lock()
+ defer c.channelsMu.Unlock()
+ c.channels = append(c.channels, res)
+ c.availableChannels = append(c.availableChannels, res)
+ return nil
+}
+
// handleOne handles a single incoming message.
//
// This should only be called with the token from recvr. Note that the received
@@ -247,10 +430,10 @@ func (c *Client) waitAndRecv(done chan error) error {
}
}
-// sendRecv performs a roundtrip message exchange.
+// sendRecvLegacy performs a roundtrip message exchange.
//
// This is called by internal functions.
-func (c *Client) sendRecv(t message, r message) error {
+func (c *Client) sendRecvLegacy(t message, r message) error {
tag, ok := c.tagPool.Get()
if !ok {
return ErrOutOfTags
@@ -296,12 +479,62 @@ func (c *Client) sendRecv(t message, r message) error {
return nil
}
+// sendRecvChannel uses channels to send a message.
+func (c *Client) sendRecvChannel(t message, r message) error {
+ // Acquire an available channel.
+ c.channelsMu.Lock()
+ if len(c.availableChannels) == 0 {
+ c.channelsMu.Unlock()
+ return c.sendRecvLegacy(t, r)
+ }
+ idx := len(c.availableChannels) - 1
+ ch := c.availableChannels[idx]
+ c.availableChannels = c.availableChannels[:idx]
+ ch.active = true
+ c.channelsWg.Add(1)
+ c.channelsMu.Unlock()
+
+ // Ensure that it's connected.
+ if !ch.connected {
+ ch.connected = true
+ if err := ch.data.Connect(); err != nil {
+ // The channel is unusable, so don't return it to
+ // c.availableChannels. However, we still have to mark it as
+ // inactive so c.watch() doesn't wait for it.
+ c.channelsMu.Lock()
+ ch.active = false
+ c.channelsMu.Unlock()
+ c.channelsWg.Done()
+ return err
+ }
+ }
+
+ // Send the message.
+ err := ch.sendRecv(c, t, r)
+
+ // Release the channel.
+ c.channelsMu.Lock()
+ ch.active = false
+ // If c.availableChannels is nil, c.watch() has fired and we should not
+ // mark this channel as available.
+ if c.availableChannels != nil {
+ c.availableChannels = append(c.availableChannels, ch)
+ }
+ c.channelsMu.Unlock()
+ c.channelsWg.Done()
+
+ return err
+}
+
// Version returns the negotiated 9P2000.L.Google version number.
func (c *Client) Version() uint32 {
return c.version
}
-// Close closes the underlying socket.
-func (c *Client) Close() error {
- return c.socket.Close()
+// Close closes the underlying socket and channels.
+func (c *Client) Close() {
+ // unet.Socket.Shutdown() has no effect if unet.Socket.Close() has already
+ // been called (by c.watch()).
+ c.socket.Shutdown()
+ c.closedWg.Wait()
}
diff --git a/pkg/p9/client_test.go b/pkg/p9/client_test.go
index 87b2dd61e..29a0afadf 100644
--- a/pkg/p9/client_test.go
+++ b/pkg/p9/client_test.go
@@ -35,23 +35,23 @@ func TestVersion(t *testing.T) {
go s.Handle(serverSocket)
// NewClient does a Tversion exchange, so this is our test for success.
- c, err := NewClient(clientSocket, 1024*1024 /* 1M message size */, HighestVersionString())
+ c, err := NewClient(clientSocket, DefaultMessageSize, HighestVersionString())
if err != nil {
t.Fatalf("got %v, expected nil", err)
}
// Check a bogus version string.
- if err := c.sendRecv(&Tversion{Version: "notokay", MSize: 1024 * 1024}, &Rversion{}); err != syscall.EINVAL {
+ if err := c.sendRecv(&Tversion{Version: "notokay", MSize: DefaultMessageSize}, &Rversion{}); err != syscall.EINVAL {
t.Errorf("got %v expected %v", err, syscall.EINVAL)
}
// Check a bogus version number.
- if err := c.sendRecv(&Tversion{Version: "9P1000.L", MSize: 1024 * 1024}, &Rversion{}); err != syscall.EINVAL {
+ if err := c.sendRecv(&Tversion{Version: "9P1000.L", MSize: DefaultMessageSize}, &Rversion{}); err != syscall.EINVAL {
t.Errorf("got %v expected %v", err, syscall.EINVAL)
}
// Check a too high version number.
- if err := c.sendRecv(&Tversion{Version: versionString(highestSupportedVersion + 1), MSize: 1024 * 1024}, &Rversion{}); err != syscall.EAGAIN {
+ if err := c.sendRecv(&Tversion{Version: versionString(highestSupportedVersion + 1), MSize: DefaultMessageSize}, &Rversion{}); err != syscall.EAGAIN {
t.Errorf("got %v expected %v", err, syscall.EAGAIN)
}
@@ -60,3 +60,45 @@ func TestVersion(t *testing.T) {
t.Errorf("got %v expected %v", err, syscall.EINVAL)
}
}
+
+func benchmarkSendRecv(b *testing.B, fn func(c *Client) func(message, message) error) {
+ // See above.
+ serverSocket, clientSocket, err := unet.SocketPair(false)
+ if err != nil {
+ b.Fatalf("socketpair got err %v expected nil", err)
+ }
+ defer clientSocket.Close()
+
+ // See above.
+ s := NewServer(nil)
+ go s.Handle(serverSocket)
+
+ // See above.
+ c, err := NewClient(clientSocket, DefaultMessageSize, HighestVersionString())
+ if err != nil {
+ b.Fatalf("got %v, expected nil", err)
+ }
+
+ // Initialize messages.
+ sendRecv := fn(c)
+ tversion := &Tversion{
+ Version: versionString(highestSupportedVersion),
+ MSize: DefaultMessageSize,
+ }
+ rversion := new(Rversion)
+
+ // Run in a loop.
+ for i := 0; i < b.N; i++ {
+ if err := sendRecv(tversion, rversion); err != nil {
+ b.Fatalf("got unexpected err: %v", err)
+ }
+ }
+}
+
+func BenchmarkSendRecvLegacy(b *testing.B) {
+ benchmarkSendRecv(b, func(c *Client) func(message, message) error { return c.sendRecvLegacy })
+}
+
+func BenchmarkSendRecvChannel(b *testing.B) {
+ benchmarkSendRecv(b, func(c *Client) func(message, message) error { return c.sendRecvChannel })
+}
diff --git a/pkg/p9/handlers.go b/pkg/p9/handlers.go
index 999b4f684..ba9a55d6d 100644
--- a/pkg/p9/handlers.go
+++ b/pkg/p9/handlers.go
@@ -305,7 +305,9 @@ func (t *Tlopen) handle(cs *connState) message {
ref.opened = true
ref.openFlags = t.Flags
- return &Rlopen{QID: qid, IoUnit: ioUnit, File: osFile}
+ rlopen := &Rlopen{QID: qid, IoUnit: ioUnit}
+ rlopen.SetFilePayload(osFile)
+ return rlopen
}
func (t *Tlcreate) do(cs *connState, uid UID) (*Rlcreate, error) {
@@ -364,7 +366,9 @@ func (t *Tlcreate) do(cs *connState, uid UID) (*Rlcreate, error) {
// Replace the FID reference.
cs.InsertFID(t.FID, newRef)
- return &Rlcreate{Rlopen: Rlopen{QID: qid, IoUnit: ioUnit, File: osFile}}, nil
+ rlcreate := &Rlcreate{Rlopen: Rlopen{QID: qid, IoUnit: ioUnit}}
+ rlcreate.SetFilePayload(osFile)
+ return rlcreate, nil
}
// handle implements handler.handle.
@@ -1287,5 +1291,48 @@ func (t *Tlconnect) handle(cs *connState) message {
return newErr(err)
}
- return &Rlconnect{File: osFile}
+ rlconnect := &Rlconnect{}
+ rlconnect.SetFilePayload(osFile)
+ return rlconnect
+}
+
+// handle implements handler.handle.
+func (t *Tchannel) handle(cs *connState) message {
+ // Ensure that channels are enabled.
+ if err := cs.initializeChannels(); err != nil {
+ return newErr(err)
+ }
+
+ // Lookup the given channel.
+ ch := cs.lookupChannel(t.ID)
+ if ch == nil {
+ return newErr(syscall.ENOSYS)
+ }
+
+ // Return the payload. Note that we need to duplicate the file
+ // descriptor for the channel allocator, because sending is a
+ // destructive operation between sendRecvLegacy (and now the newer
+ // channel send operations). Same goes for the client FD.
+ rchannel := &Rchannel{
+ Offset: uint64(ch.desc.Offset),
+ Length: uint64(ch.desc.Length),
+ }
+ switch t.Control {
+ case 0:
+ // Open the main data channel.
+ mfd, err := syscall.Dup(int(cs.channelAlloc.FD()))
+ if err != nil {
+ return newErr(err)
+ }
+ rchannel.SetFilePayload(fd.New(mfd))
+ case 1:
+ cfd, err := syscall.Dup(ch.client.FD())
+ if err != nil {
+ return newErr(err)
+ }
+ rchannel.SetFilePayload(fd.New(cfd))
+ default:
+ return newErr(syscall.EINVAL)
+ }
+ return rchannel
}
diff --git a/pkg/p9/messages.go b/pkg/p9/messages.go
index fd9eb1c5d..ffdd7e8c6 100644
--- a/pkg/p9/messages.go
+++ b/pkg/p9/messages.go
@@ -64,6 +64,21 @@ type filer interface {
SetFilePayload(*fd.FD)
}
+// filePayload embeds a File object.
+type filePayload struct {
+ File *fd.FD
+}
+
+// FilePayload returns the file payload.
+func (f *filePayload) FilePayload() *fd.FD {
+ return f.File
+}
+
+// SetFilePayload sets the received file.
+func (f *filePayload) SetFilePayload(file *fd.FD) {
+ f.File = file
+}
+
// Tversion is a version request.
type Tversion struct {
// MSize is the message size to use.
@@ -524,10 +539,7 @@ type Rlopen struct {
// IoUnit is the recommended I/O unit.
IoUnit uint32
- // File may be attached via the socket.
- //
- // This is an extension specific to this package.
- File *fd.FD
+ filePayload
}
// Decode implements encoder.Decode.
@@ -547,16 +559,6 @@ func (*Rlopen) Type() MsgType {
return MsgRlopen
}
-// FilePayload returns the file payload.
-func (r *Rlopen) FilePayload() *fd.FD {
- return r.File
-}
-
-// SetFilePayload sets the received file.
-func (r *Rlopen) SetFilePayload(file *fd.FD) {
- r.File = file
-}
-
// String implements fmt.Stringer.
func (r *Rlopen) String() string {
return fmt.Sprintf("Rlopen{QID: %s, IoUnit: %d, File: %v}", r.QID, r.IoUnit, r.File)
@@ -2171,8 +2173,7 @@ func (t *Tlconnect) String() string {
// Rlconnect is a connect response.
type Rlconnect struct {
- // File is a host socket.
- File *fd.FD
+ filePayload
}
// Decode implements encoder.Decode.
@@ -2186,19 +2187,71 @@ func (*Rlconnect) Type() MsgType {
return MsgRlconnect
}
-// FilePayload returns the file payload.
-func (r *Rlconnect) FilePayload() *fd.FD {
- return r.File
+// String implements fmt.Stringer.
+func (r *Rlconnect) String() string {
+ return fmt.Sprintf("Rlconnect{File: %v}", r.File)
}
-// SetFilePayload sets the received file.
-func (r *Rlconnect) SetFilePayload(file *fd.FD) {
- r.File = file
+// Tchannel creates a new channel.
+type Tchannel struct {
+ // ID is the channel ID.
+ ID uint32
+
+ // Control is 0 if the Rchannel response should provide the flipcall
+ // component of the channel, and 1 if the Rchannel response should
+ // provide the fdchannel component of the channel.
+ Control uint32
+}
+
+// Decode implements encoder.Decode.
+func (t *Tchannel) Decode(b *buffer) {
+ t.ID = b.Read32()
+ t.Control = b.Read32()
+}
+
+// Encode implements encoder.Encode.
+func (t *Tchannel) Encode(b *buffer) {
+ b.Write32(t.ID)
+ b.Write32(t.Control)
+}
+
+// Type implements message.Type.
+func (*Tchannel) Type() MsgType {
+ return MsgTchannel
}
// String implements fmt.Stringer.
-func (r *Rlconnect) String() string {
- return fmt.Sprintf("Rlconnect{File: %v}", r.File)
+func (t *Tchannel) String() string {
+ return fmt.Sprintf("Tchannel{ID: %d, Control: %d}", t.ID, t.Control)
+}
+
+// Rchannel is the channel response.
+type Rchannel struct {
+ Offset uint64
+ Length uint64
+ filePayload
+}
+
+// Decode implements encoder.Decode.
+func (r *Rchannel) Decode(b *buffer) {
+ r.Offset = b.Read64()
+ r.Length = b.Read64()
+}
+
+// Encode implements encoder.Encode.
+func (r *Rchannel) Encode(b *buffer) {
+ b.Write64(r.Offset)
+ b.Write64(r.Length)
+}
+
+// Type implements message.Type.
+func (*Rchannel) Type() MsgType {
+ return MsgRchannel
+}
+
+// String implements fmt.Stringer.
+func (r *Rchannel) String() string {
+ return fmt.Sprintf("Rchannel{Offset: %d, Length: %d}", r.Offset, r.Length)
}
const maxCacheSize = 3
@@ -2356,4 +2409,6 @@ func init() {
msgRegistry.register(MsgRlconnect, func() message { return &Rlconnect{} })
msgRegistry.register(MsgTallocate, func() message { return &Tallocate{} })
msgRegistry.register(MsgRallocate, func() message { return &Rallocate{} })
+ msgRegistry.register(MsgTchannel, func() message { return &Tchannel{} })
+ msgRegistry.register(MsgRchannel, func() message { return &Rchannel{} })
}
diff --git a/pkg/p9/p9.go b/pkg/p9/p9.go
index e12831dbd..25530adca 100644
--- a/pkg/p9/p9.go
+++ b/pkg/p9/p9.go
@@ -378,6 +378,8 @@ const (
MsgRlconnect = 137
MsgTallocate = 138
MsgRallocate = 139
+ MsgTchannel = 250
+ MsgRchannel = 251
)
// QIDType represents the file type for QIDs.
diff --git a/pkg/p9/p9test/BUILD b/pkg/p9/p9test/BUILD
index 6e939a49a..28707c0ca 100644
--- a/pkg/p9/p9test/BUILD
+++ b/pkg/p9/p9test/BUILD
@@ -1,5 +1,5 @@
-load("//tools/go_stateify:defs.bzl", "go_library", "go_test")
-load("@io_bazel_rules_go//go:def.bzl", "go_binary")
+load("//tools/go_stateify:defs.bzl", "go_library")
+load("@io_bazel_rules_go//go:def.bzl", "go_binary", "go_test")
package(licenses = ["notice"])
@@ -77,7 +77,7 @@ go_library(
go_test(
name = "client_test",
- size = "small",
+ size = "medium",
srcs = ["client_test.go"],
embed = [":p9test"],
deps = [
diff --git a/pkg/p9/p9test/client_test.go b/pkg/p9/p9test/client_test.go
index fe649c2e8..8bbdb2488 100644
--- a/pkg/p9/p9test/client_test.go
+++ b/pkg/p9/p9test/client_test.go
@@ -2127,3 +2127,98 @@ func TestConcurrency(t *testing.T) {
}
}
}
+
+func TestReadWriteConcurrent(t *testing.T) {
+ h, c := NewHarness(t)
+ defer h.Finish()
+
+ _, root := newRoot(h, c)
+ defer root.Close()
+
+ const (
+ instances = 10
+ iterations = 10000
+ dataSize = 1024
+ )
+ var (
+ dataSets [instances][dataSize]byte
+ backends [instances]*Mock
+ files [instances]p9.File
+ )
+
+ // Walk to the file normally.
+ for i := 0; i < instances; i++ {
+ _, backends[i], files[i] = walkHelper(h, "file", root)
+ defer files[i].Close()
+ }
+
+ // Open the files.
+ for i := 0; i < instances; i++ {
+ backends[i].EXPECT().Open(p9.ReadWrite)
+ if _, _, _, err := files[i].Open(p9.ReadWrite); err != nil {
+ t.Fatalf("open got %v, wanted nil", err)
+ }
+ }
+
+ // Initialize random data for each instance.
+ for i := 0; i < instances; i++ {
+ if _, err := rand.Read(dataSets[i][:]); err != nil {
+ t.Fatalf("error initializing dataSet#%d, got %v", i, err)
+ }
+ }
+
+ // Define our random read/write mechanism.
+ randRead := func(h *Harness, backend *Mock, f p9.File, data, test []byte) {
+ // Prepare the backend.
+ backend.EXPECT().ReadAt(gomock.Any(), uint64(0)).Do(func(p []byte, offset uint64) {
+ if n := copy(p, data); n != len(data) {
+ // Note that we have to assert the result here, as the Return statement
+ // below cannot be dynamic: it will be bound before this call is made.
+ h.t.Errorf("wanted length %d, got %d", len(data), n)
+ }
+ }).Return(len(data), nil)
+
+ // Execute the read.
+ if n, err := f.ReadAt(test, 0); n != len(test) || err != nil {
+ t.Errorf("failed read: wanted (%d, nil), got (%d, %v)", len(test), n, err)
+ return // No sense doing check below.
+ }
+ if !bytes.Equal(test, data) {
+ t.Errorf("data integrity failed during read") // Not as expected.
+ }
+ }
+ randWrite := func(h *Harness, backend *Mock, f p9.File, data []byte) {
+ // Prepare the backend.
+ backend.EXPECT().WriteAt(gomock.Any(), uint64(0)).Do(func(p []byte, offset uint64) {
+ if !bytes.Equal(p, data) {
+ h.t.Errorf("data integrity failed during write") // Not as expected.
+ }
+ }).Return(len(data), nil)
+
+ // Execute the write.
+ if n, err := f.WriteAt(data, 0); n != len(data) || err != nil {
+ t.Errorf("failed read: wanted (%d, nil), got (%d, %v)", len(data), n, err)
+ }
+ }
+ randReadWrite := func(n int, h *Harness, backend *Mock, f p9.File, data []byte) {
+ test := make([]byte, len(data))
+ for i := 0; i < n; i++ {
+ if rand.Intn(2) == 0 {
+ randRead(h, backend, f, data, test)
+ } else {
+ randWrite(h, backend, f, data)
+ }
+ }
+ }
+
+ // Start reading and writing.
+ var wg sync.WaitGroup
+ for i := 0; i < instances; i++ {
+ wg.Add(1)
+ go func(i int) {
+ defer wg.Done()
+ randReadWrite(iterations, h, backends[i], files[i], dataSets[i][:])
+ }(i)
+ }
+ wg.Wait()
+}
diff --git a/pkg/p9/p9test/p9test.go b/pkg/p9/p9test/p9test.go
index 95846e5f7..4d3271b37 100644
--- a/pkg/p9/p9test/p9test.go
+++ b/pkg/p9/p9test/p9test.go
@@ -279,7 +279,7 @@ func (h *Harness) NewSocket() Generator {
// Finish completes all checks and shuts down the server.
func (h *Harness) Finish() {
- h.clientSocket.Close()
+ h.clientSocket.Shutdown()
h.wg.Wait()
h.mockCtrl.Finish()
}
@@ -315,7 +315,7 @@ func NewHarness(t *testing.T) (*Harness, *p9.Client) {
}()
// Create the client.
- client, err := p9.NewClient(clientSocket, 1024, p9.HighestVersionString())
+ client, err := p9.NewClient(clientSocket, p9.DefaultMessageSize, p9.HighestVersionString())
if err != nil {
serverSocket.Close()
clientSocket.Close()
diff --git a/pkg/p9/server.go b/pkg/p9/server.go
index b294efbb0..69c886a5d 100644
--- a/pkg/p9/server.go
+++ b/pkg/p9/server.go
@@ -21,6 +21,9 @@ import (
"sync/atomic"
"syscall"
+ "gvisor.dev/gvisor/pkg/fd"
+ "gvisor.dev/gvisor/pkg/fdchannel"
+ "gvisor.dev/gvisor/pkg/flipcall"
"gvisor.dev/gvisor/pkg/log"
"gvisor.dev/gvisor/pkg/unet"
)
@@ -45,7 +48,6 @@ type Server struct {
}
// NewServer returns a new server.
-//
func NewServer(attacher Attacher) *Server {
return &Server{
attacher: attacher,
@@ -85,6 +87,8 @@ type connState struct {
// version 0 implies 9P2000.L.
version uint32
+ // -- below relates to the legacy handler --
+
// recvOkay indicates that a receive may start.
recvOkay chan bool
@@ -93,6 +97,20 @@ type connState struct {
// sendDone is signalled when a send is finished.
sendDone chan error
+
+ // -- below relates to the flipcall handler --
+
+ // channelMu protects below.
+ channelMu sync.Mutex
+
+ // channelWg represents active workers.
+ channelWg sync.WaitGroup
+
+ // channelAlloc allocates channel memory.
+ channelAlloc *flipcall.PacketWindowAllocator
+
+ // channels are the set of initialized channels.
+ channels []*channel
}
// fidRef wraps a node and tracks references.
@@ -386,6 +404,99 @@ func (cs *connState) WaitTag(t Tag) {
<-ch
}
+// initializeChannels initializes all channels.
+//
+// This is a no-op if channels are already initialized.
+func (cs *connState) initializeChannels() (err error) {
+ cs.channelMu.Lock()
+ defer cs.channelMu.Unlock()
+
+ // Initialize our channel allocator.
+ if cs.channelAlloc == nil {
+ alloc, err := flipcall.NewPacketWindowAllocator()
+ if err != nil {
+ return err
+ }
+ cs.channelAlloc = alloc
+ }
+
+ // Create all the channels.
+ for len(cs.channels) < channelsPerClient {
+ res := &channel{
+ done: make(chan struct{}),
+ }
+
+ res.desc, err = cs.channelAlloc.Allocate(channelSize)
+ if err != nil {
+ return err
+ }
+ if err := res.data.Init(flipcall.ServerSide, res.desc); err != nil {
+ return err
+ }
+
+ socks, err := fdchannel.NewConnectedSockets()
+ if err != nil {
+ res.data.Destroy() // Cleanup.
+ return err
+ }
+ res.fds.Init(socks[0])
+ res.client = fd.New(socks[1])
+
+ cs.channels = append(cs.channels, res)
+
+ // Start servicing the channel.
+ //
+ // When we call stop, we will close all the channels and these
+ // routines should finish. We need the wait group to ensure
+ // that active handlers are actually finished before cleanup.
+ cs.channelWg.Add(1)
+ go func() { // S/R-SAFE: Server side.
+ defer cs.channelWg.Done()
+ res.service(cs)
+ }()
+ }
+
+ return nil
+}
+
+// lookupChannel looks up the channel with given id.
+//
+// The function returns nil if no such channel is available.
+func (cs *connState) lookupChannel(id uint32) *channel {
+ cs.channelMu.Lock()
+ defer cs.channelMu.Unlock()
+ if id >= uint32(len(cs.channels)) {
+ return nil
+ }
+ return cs.channels[id]
+}
+
+// handle handles a single message.
+func (cs *connState) handle(m message) (r message) {
+ defer func() {
+ if r == nil {
+ // Don't allow a panic to propagate.
+ recover()
+
+ // Include a useful log message.
+ log.Warningf("panic in handler: %s", debug.Stack())
+
+ // Wrap in an EFAULT error; we don't really have a
+ // better way to describe this kind of error. It will
+ // usually manifest as a result of the test framework.
+ r = newErr(syscall.EFAULT)
+ }
+ }()
+ if handler, ok := m.(handler); ok {
+ // Call the message handler.
+ r = handler.handle(cs)
+ } else {
+ // Produce an ENOSYS error.
+ r = newErr(syscall.ENOSYS)
+ }
+ return
+}
+
// handleRequest handles a single request.
//
// The recvDone channel is signaled when recv is done (with a error if
@@ -428,41 +539,20 @@ func (cs *connState) handleRequest() {
}
// Handle the message.
- var r message // r is the response.
- defer func() {
- if r == nil {
- // Don't allow a panic to propagate.
- recover()
+ r := cs.handle(m)
- // Include a useful log message.
- log.Warningf("panic in handler: %s", debug.Stack())
+ // Clear the tag before sending. That's because as soon as this hits
+ // the wire, the client can legally send the same tag.
+ cs.ClearTag(tag)
- // Wrap in an EFAULT error; we don't really have a
- // better way to describe this kind of error. It will
- // usually manifest as a result of the test framework.
- r = newErr(syscall.EFAULT)
- }
+ // Send back the result.
+ cs.sendMu.Lock()
+ err = send(cs.conn, tag, r)
+ cs.sendMu.Unlock()
+ cs.sendDone <- err
- // Clear the tag before sending. That's because as soon as this
- // hits the wire, the client can legally send another message
- // with the same tag.
- cs.ClearTag(tag)
-
- // Send back the result.
- cs.sendMu.Lock()
- err = send(cs.conn, tag, r)
- cs.sendMu.Unlock()
- cs.sendDone <- err
- }()
- if handler, ok := m.(handler); ok {
- // Call the message handler.
- r = handler.handle(cs)
- } else {
- // Produce an ENOSYS error.
- r = newErr(syscall.ENOSYS)
- }
+ // Return the message to the cache.
msgRegistry.put(m)
- m = nil // 'm' should not be touched after this point.
}
func (cs *connState) handleRequests() {
@@ -477,7 +567,27 @@ func (cs *connState) stop() {
close(cs.recvDone)
close(cs.sendDone)
- for _, fidRef := range cs.fids {
+ // Free the channels.
+ cs.channelMu.Lock()
+ for _, ch := range cs.channels {
+ ch.Shutdown()
+ }
+ cs.channelWg.Wait()
+ for _, ch := range cs.channels {
+ ch.Close()
+ }
+ cs.channels = nil // Clear.
+ cs.channelMu.Unlock()
+
+ // Free the channel memory.
+ if cs.channelAlloc != nil {
+ cs.channelAlloc.Destroy()
+ }
+
+ // Close all remaining fids.
+ for fid, fidRef := range cs.fids {
+ delete(cs.fids, fid)
+
// Drop final reference in the FID table. Note this should
// always close the file, since we've ensured that there are no
// handlers running via the wait for Pending => 0 below.
@@ -510,7 +620,7 @@ func (cs *connState) service() error {
for i := 0; i < pending; i++ {
<-cs.sendDone
}
- return err
+ return nil
}
// This handler is now pending.
diff --git a/pkg/p9/transport.go b/pkg/p9/transport.go
index 5648df589..6e8b4bbcd 100644
--- a/pkg/p9/transport.go
+++ b/pkg/p9/transport.go
@@ -54,7 +54,10 @@ const (
headerLength uint32 = 7
// maximumLength is the largest possible message.
- maximumLength uint32 = 4 * 1024 * 1024
+ maximumLength uint32 = 1 << 20
+
+ // DefaultMessageSize is a sensible default.
+ DefaultMessageSize uint32 = 64 << 10
// initialBufferLength is the initial data buffer we allocate.
initialBufferLength uint32 = 64
diff --git a/pkg/p9/transport_flipcall.go b/pkg/p9/transport_flipcall.go
new file mode 100644
index 000000000..7cdf4ecc3
--- /dev/null
+++ b/pkg/p9/transport_flipcall.go
@@ -0,0 +1,263 @@
+// Copyright 2019 The gVisor Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+package p9
+
+import (
+ "runtime"
+ "syscall"
+
+ "gvisor.dev/gvisor/pkg/fd"
+ "gvisor.dev/gvisor/pkg/fdchannel"
+ "gvisor.dev/gvisor/pkg/flipcall"
+ "gvisor.dev/gvisor/pkg/log"
+)
+
+// channelsPerClient is the number of channels to create per client.
+//
+// While the client and server will generally agree on this number, in reality
+// it's completely up to the server. We simply define a minimum of 2, and a
+// maximum of 4, and select the number of available processes as a tie-breaker.
+// Note that we don't want the number of channels to be too large, because each
+// will account for channelSize memory used, which can be large.
+var channelsPerClient = func() int {
+ n := runtime.NumCPU()
+ if n < 2 {
+ return 2
+ }
+ if n > 4 {
+ return 4
+ }
+ return n
+}()
+
+// channelSize is the channel size to create.
+//
+// We simply ensure that this is larger than the largest possible message size,
+// plus the flipcall packet header, plus the two bytes we write below.
+const channelSize = int(2 + flipcall.PacketHeaderBytes + 2 + maximumLength)
+
+// channel is a fast IPC channel.
+//
+// The same object is used by both the server and client implementations. In
+// general, the client will use only the send and recv methods.
+type channel struct {
+ desc flipcall.PacketWindowDescriptor
+ data flipcall.Endpoint
+ fds fdchannel.Endpoint
+ buf buffer
+
+ // -- client only --
+ connected bool
+ active bool
+
+ // -- server only --
+ client *fd.FD
+ done chan struct{}
+}
+
+// reset resets the channel buffer.
+func (ch *channel) reset(sz uint32) {
+ ch.buf.data = ch.data.Data()[:sz]
+}
+
+// service services the channel.
+func (ch *channel) service(cs *connState) error {
+ rsz, err := ch.data.RecvFirst()
+ if err != nil {
+ return err
+ }
+ for rsz > 0 {
+ m, err := ch.recv(nil, rsz)
+ if err != nil {
+ return err
+ }
+ r := cs.handle(m)
+ msgRegistry.put(m)
+ rsz, err = ch.send(r)
+ if err != nil {
+ return err
+ }
+ }
+ return nil // Done.
+}
+
+// Shutdown shuts down the channel.
+//
+// This must be called before Close.
+func (ch *channel) Shutdown() {
+ ch.data.Shutdown()
+}
+
+// Close closes the channel.
+//
+// This must only be called once, and cannot return an error. Note that
+// synchronization for this method is provided at a high-level, depending on
+// whether it is the client or server. This cannot be called while there are
+// active callers in either service or sendRecv.
+//
+// Precondition: the channel should be shutdown.
+func (ch *channel) Close() error {
+ // Close all backing transports.
+ ch.fds.Destroy()
+ ch.data.Destroy()
+ if ch.client != nil {
+ ch.client.Close()
+ }
+ return nil
+}
+
+// send sends the given message.
+//
+// The return value is the size of the received response. Not that in the
+// server case, this is the size of the next request.
+func (ch *channel) send(m message) (uint32, error) {
+ if log.IsLogging(log.Debug) {
+ log.Debugf("send [channel @%p] %s", ch, m.String())
+ }
+
+ // Send any file payload.
+ sentFD := false
+ if filer, ok := m.(filer); ok {
+ if f := filer.FilePayload(); f != nil {
+ if err := ch.fds.SendFD(f.FD()); err != nil {
+ return 0, syscall.EIO // Map everything to EIO.
+ }
+ f.Close() // Per sendRecvLegacy.
+ sentFD = true // To mark below.
+ }
+ }
+
+ // Encode the message.
+ //
+ // Note that IPC itself encodes the length of messages, so we don't
+ // need to encode a standard 9P header. We write only the message type.
+ ch.reset(0)
+
+ ch.buf.WriteMsgType(m.Type())
+ if sentFD {
+ ch.buf.Write8(1) // Incoming FD.
+ } else {
+ ch.buf.Write8(0) // No incoming FD.
+ }
+ m.Encode(&ch.buf)
+ ssz := uint32(len(ch.buf.data)) // Updated below.
+
+ // Is there a payload?
+ if payloader, ok := m.(payloader); ok {
+ p := payloader.Payload()
+ copy(ch.data.Data()[ssz:], p)
+ ssz += uint32(len(p))
+ }
+
+ // Perform the one-shot communication.
+ n, err := ch.data.SendRecv(ssz)
+ if err != nil {
+ if n > 0 {
+ return n, nil
+ }
+ return 0, syscall.EIO // See above.
+ }
+
+ return n, nil
+}
+
+// recv decodes a message that exists on the channel.
+//
+// If the passed r is non-nil, then the type must match or an error will be
+// generated. If the passed r is nil, then a new message will be created and
+// returned.
+func (ch *channel) recv(r message, rsz uint32) (message, error) {
+ // Decode the response from the inline buffer.
+ ch.reset(rsz)
+ t := ch.buf.ReadMsgType()
+ hasFD := ch.buf.Read8() != 0
+ if t == MsgRlerror {
+ // Change the message type. We check for this special case
+ // after decoding below, and transform into an error.
+ r = &Rlerror{}
+ } else if r == nil {
+ nr, err := msgRegistry.get(0, t)
+ if err != nil {
+ return nil, err
+ }
+ r = nr // New message.
+ } else if t != r.Type() {
+ // Not an error and not the expected response; propagate.
+ return nil, &ErrBadResponse{Got: t, Want: r.Type()}
+ }
+
+ // Is there a payload? Copy from the latter portion.
+ if payloader, ok := r.(payloader); ok {
+ fs := payloader.FixedSize()
+ p := payloader.Payload()
+ payloadData := ch.buf.data[fs:]
+ if len(p) < len(payloadData) {
+ p = make([]byte, len(payloadData))
+ copy(p, payloadData)
+ payloader.SetPayload(p)
+ } else if n := copy(p, payloadData); n < len(p) {
+ payloader.SetPayload(p[:n])
+ }
+ ch.buf.data = ch.buf.data[:fs]
+ }
+
+ r.Decode(&ch.buf)
+ if ch.buf.isOverrun() {
+ // Nothing valid was available.
+ log.Debugf("recv [got %d bytes, needed more]", rsz)
+ return nil, ErrNoValidMessage
+ }
+
+ // Read any FD result.
+ if hasFD {
+ if rfd, err := ch.fds.RecvFDNonblock(); err == nil {
+ f := fd.New(rfd)
+ if filer, ok := r.(filer); ok {
+ // Set the payload.
+ filer.SetFilePayload(f)
+ } else {
+ // Don't want the FD.
+ f.Close()
+ }
+ } else {
+ // The header bit was set but nothing came in.
+ log.Warningf("expected FD, got err: %v", err)
+ }
+ }
+
+ // Log a message.
+ if log.IsLogging(log.Debug) {
+ log.Debugf("recv [channel @%p] %s", ch, r.String())
+ }
+
+ // Convert errors appropriately; see above.
+ if rlerr, ok := r.(*Rlerror); ok {
+ return nil, syscall.Errno(rlerr.Error)
+ }
+
+ return r, nil
+}
+
+// sendRecv sends the given message over the channel.
+//
+// This is used by the client.
+func (ch *channel) sendRecv(c *Client, m, r message) error {
+ rsz, err := ch.send(m)
+ if err != nil {
+ return err
+ }
+ _, err = ch.recv(r, rsz)
+ return err
+}
diff --git a/pkg/p9/transport_test.go b/pkg/p9/transport_test.go
index cdb3bc841..2f50ff3ea 100644
--- a/pkg/p9/transport_test.go
+++ b/pkg/p9/transport_test.go
@@ -124,7 +124,9 @@ func TestSendRecvWithFile(t *testing.T) {
t.Fatalf("unable to create file: %v", err)
}
- if err := send(client, Tag(1), &Rlopen{File: f}); err != nil {
+ rlopen := &Rlopen{}
+ rlopen.SetFilePayload(f)
+ if err := send(client, Tag(1), rlopen); err != nil {
t.Fatalf("send got err %v expected nil", err)
}
diff --git a/pkg/p9/version.go b/pkg/p9/version.go
index c2a2885ae..f1ffdd23a 100644
--- a/pkg/p9/version.go
+++ b/pkg/p9/version.go
@@ -26,7 +26,7 @@ const (
//
// Clients are expected to start requesting this version number and
// to continuously decrement it until a Tversion request succeeds.
- highestSupportedVersion uint32 = 7
+ highestSupportedVersion uint32 = 8
// lowestSupportedVersion is the lowest supported version X in a
// version string of the format 9P2000.L.Google.X.
@@ -148,3 +148,10 @@ func VersionSupportsMultiUser(v uint32) bool {
func versionSupportsTallocate(v uint32) bool {
return v >= 7
}
+
+// versionSupportsFlipcall returns true if version v supports IPC channels from
+// the flipcall package. Note that these must be negotiated, but this version
+// string indicates that such a facility exists.
+func versionSupportsFlipcall(v uint32) bool {
+ return v >= 8
+}
diff --git a/pkg/procid/BUILD b/pkg/procid/BUILD
index 697e7a2f4..078f084b2 100644
--- a/pkg/procid/BUILD
+++ b/pkg/procid/BUILD
@@ -1,4 +1,5 @@
-load("//tools/go_stateify:defs.bzl", "go_library", "go_test")
+load("//tools/go_stateify:defs.bzl", "go_library")
+load("@io_bazel_rules_go//go:def.bzl", "go_test")
package(licenses = ["notice"])
diff --git a/pkg/refs/BUILD b/pkg/refs/BUILD
index 9c08452fc..827385139 100644
--- a/pkg/refs/BUILD
+++ b/pkg/refs/BUILD
@@ -1,7 +1,9 @@
+load("@io_bazel_rules_go//go:def.bzl", "go_test")
+
package(licenses = ["notice"])
load("//tools/go_generics:defs.bzl", "go_template_instance")
-load("//tools/go_stateify:defs.bzl", "go_library", "go_test")
+load("//tools/go_stateify:defs.bzl", "go_library")
go_template_instance(
name = "weak_ref_list",
diff --git a/pkg/refs/refcounter.go b/pkg/refs/refcounter.go
index 828e9b5c1..ad69e0757 100644
--- a/pkg/refs/refcounter.go
+++ b/pkg/refs/refcounter.go
@@ -215,8 +215,8 @@ type AtomicRefCount struct {
type LeakMode uint32
const (
- // uninitializedLeakChecking indicates that the leak checker has not yet been initialized.
- uninitializedLeakChecking LeakMode = iota
+ // UninitializedLeakChecking indicates that the leak checker has not yet been initialized.
+ UninitializedLeakChecking LeakMode = iota
// NoLeakChecking indicates that no effort should be made to check for
// leaks.
@@ -318,7 +318,7 @@ func (r *AtomicRefCount) finalize() {
switch LeakMode(atomic.LoadUint32(&leakMode)) {
case NoLeakChecking:
return
- case uninitializedLeakChecking:
+ case UninitializedLeakChecking:
note = "(Leak checker uninitialized): "
}
if n := r.ReadRefs(); n != 0 {
diff --git a/pkg/seccomp/BUILD b/pkg/seccomp/BUILD
index d1024e49d..af94e944d 100644
--- a/pkg/seccomp/BUILD
+++ b/pkg/seccomp/BUILD
@@ -1,5 +1,5 @@
-load("//tools/go_stateify:defs.bzl", "go_library", "go_test")
-load("@io_bazel_rules_go//go:def.bzl", "go_binary", "go_embed_data")
+load("//tools/go_stateify:defs.bzl", "go_library")
+load("@io_bazel_rules_go//go:def.bzl", "go_binary", "go_embed_data", "go_test")
package(licenses = ["notice"])
diff --git a/pkg/seccomp/seccomp_unsafe.go b/pkg/seccomp/seccomp_unsafe.go
index 0a3d92854..be328db12 100644
--- a/pkg/seccomp/seccomp_unsafe.go
+++ b/pkg/seccomp/seccomp_unsafe.go
@@ -35,7 +35,7 @@ type sockFprog struct {
//go:nosplit
func SetFilter(instrs []linux.BPFInstruction) syscall.Errno {
// PR_SET_NO_NEW_PRIVS is required in order to enable seccomp. See seccomp(2) for details.
- if _, _, errno := syscall.RawSyscall(syscall.SYS_PRCTL, linux.PR_SET_NO_NEW_PRIVS, 1, 0); errno != 0 {
+ if _, _, errno := syscall.RawSyscall6(syscall.SYS_PRCTL, linux.PR_SET_NO_NEW_PRIVS, 1, 0, 0, 0, 0); errno != 0 {
return errno
}
diff --git a/pkg/secio/BUILD b/pkg/secio/BUILD
index f38fb39f3..22abdc69f 100644
--- a/pkg/secio/BUILD
+++ b/pkg/secio/BUILD
@@ -1,4 +1,5 @@
-load("//tools/go_stateify:defs.bzl", "go_library", "go_test")
+load("//tools/go_stateify:defs.bzl", "go_library")
+load("@io_bazel_rules_go//go:def.bzl", "go_test")
package(licenses = ["notice"])
diff --git a/pkg/segment/test/BUILD b/pkg/segment/test/BUILD
index 694486296..12d7c77d2 100644
--- a/pkg/segment/test/BUILD
+++ b/pkg/segment/test/BUILD
@@ -1,4 +1,5 @@
-load("//tools/go_stateify:defs.bzl", "go_library", "go_test")
+load("//tools/go_stateify:defs.bzl", "go_library")
+load("@io_bazel_rules_go//go:def.bzl", "go_test")
package(
default_visibility = ["//visibility:private"],
diff --git a/pkg/sentry/BUILD b/pkg/sentry/BUILD
index 53989301f..2d6379c86 100644
--- a/pkg/sentry/BUILD
+++ b/pkg/sentry/BUILD
@@ -8,5 +8,7 @@ package_group(
packages = [
"//pkg/sentry/...",
"//runsc/...",
+ # Code generated by go_marshal relies on go_marshal libraries.
+ "//tools/go_marshal/...",
],
)
diff --git a/pkg/sentry/arch/BUILD b/pkg/sentry/arch/BUILD
index 7aace2d7b..c71cff9f3 100644
--- a/pkg/sentry/arch/BUILD
+++ b/pkg/sentry/arch/BUILD
@@ -1,4 +1,5 @@
load("@io_bazel_rules_go//proto:def.bzl", "go_proto_library")
+load("@rules_cc//cc:defs.bzl", "cc_proto_library")
package(licenses = ["notice"])
@@ -42,6 +43,12 @@ proto_library(
visibility = ["//visibility:public"],
)
+cc_proto_library(
+ name = "registers_cc_proto",
+ visibility = ["//visibility:public"],
+ deps = [":registers_proto"],
+)
+
go_proto_library(
name = "registers_go_proto",
importpath = "gvisor.dev/gvisor/pkg/sentry/arch/registers_go_proto",
diff --git a/pkg/sentry/control/BUILD b/pkg/sentry/control/BUILD
index bf802d1b6..5522cecd0 100644
--- a/pkg/sentry/control/BUILD
+++ b/pkg/sentry/control/BUILD
@@ -1,4 +1,5 @@
-load("//tools/go_stateify:defs.bzl", "go_library", "go_test")
+load("//tools/go_stateify:defs.bzl", "go_library")
+load("@io_bazel_rules_go//go:def.bzl", "go_test")
package(licenses = ["notice"])
diff --git a/pkg/sentry/device/BUILD b/pkg/sentry/device/BUILD
index 7e8918722..0c86197f7 100644
--- a/pkg/sentry/device/BUILD
+++ b/pkg/sentry/device/BUILD
@@ -1,6 +1,8 @@
+load("@io_bazel_rules_go//go:def.bzl", "go_test")
+
package(licenses = ["notice"])
-load("//tools/go_stateify:defs.bzl", "go_library", "go_test")
+load("//tools/go_stateify:defs.bzl", "go_library")
go_library(
name = "device",
diff --git a/pkg/sentry/fs/BUILD b/pkg/sentry/fs/BUILD
index d7259b47b..3119a61b6 100644
--- a/pkg/sentry/fs/BUILD
+++ b/pkg/sentry/fs/BUILD
@@ -1,7 +1,9 @@
+load("@io_bazel_rules_go//go:def.bzl", "go_test")
+
package(licenses = ["notice"])
load("//tools/go_generics:defs.bzl", "go_template_instance")
-load("//tools/go_stateify:defs.bzl", "go_library", "go_test")
+load("//tools/go_stateify:defs.bzl", "go_library")
go_library(
name = "fs",
diff --git a/pkg/sentry/fs/dirent.go b/pkg/sentry/fs/dirent.go
index fbca06761..3cb73bd78 100644
--- a/pkg/sentry/fs/dirent.go
+++ b/pkg/sentry/fs/dirent.go
@@ -1126,7 +1126,7 @@ func (d *Dirent) unmount(ctx context.Context, replacement *Dirent) error {
// Remove removes the given file or symlink. The root dirent is used to
// resolve name, and must not be nil.
-func (d *Dirent) Remove(ctx context.Context, root *Dirent, name string) error {
+func (d *Dirent) Remove(ctx context.Context, root *Dirent, name string, dirPath bool) error {
// Check the root.
if root == nil {
panic("Dirent.Remove: root must not be nil")
@@ -1151,6 +1151,8 @@ func (d *Dirent) Remove(ctx context.Context, root *Dirent, name string) error {
// Remove cannot remove directories.
if IsDir(child.Inode.StableAttr) {
return syscall.EISDIR
+ } else if dirPath {
+ return syscall.ENOTDIR
}
// Remove cannot remove a mount point.
diff --git a/pkg/sentry/fs/dirent_refs_test.go b/pkg/sentry/fs/dirent_refs_test.go
index 884e3ff06..47bc72a88 100644
--- a/pkg/sentry/fs/dirent_refs_test.go
+++ b/pkg/sentry/fs/dirent_refs_test.go
@@ -343,7 +343,7 @@ func TestRemoveExtraRefs(t *testing.T) {
}
d := f.Dirent
- if err := test.root.Remove(contexttest.Context(t), test.root, name); err != nil {
+ if err := test.root.Remove(contexttest.Context(t), test.root, name, false /* dirPath */); err != nil {
t.Fatalf("root.Remove(root, %q) failed: %v", name, err)
}
diff --git a/pkg/sentry/fs/fdpipe/BUILD b/pkg/sentry/fs/fdpipe/BUILD
index bf00b9c09..b9bd9ed17 100644
--- a/pkg/sentry/fs/fdpipe/BUILD
+++ b/pkg/sentry/fs/fdpipe/BUILD
@@ -1,6 +1,8 @@
+load("@io_bazel_rules_go//go:def.bzl", "go_test")
+
package(licenses = ["notice"])
-load("//tools/go_stateify:defs.bzl", "go_library", "go_test")
+load("//tools/go_stateify:defs.bzl", "go_library")
go_library(
name = "fdpipe",
diff --git a/pkg/sentry/fs/file.go b/pkg/sentry/fs/file.go
index bb8117f89..c0a6e884b 100644
--- a/pkg/sentry/fs/file.go
+++ b/pkg/sentry/fs/file.go
@@ -515,6 +515,11 @@ type lockedReader struct {
// File is the file to read from.
File *File
+
+ // Offset is the offset to start at.
+ //
+ // This applies only to Read, not ReadAt.
+ Offset int64
}
// Read implements io.Reader.Read.
@@ -522,7 +527,8 @@ func (r *lockedReader) Read(buf []byte) (int, error) {
if r.Ctx.Interrupted() {
return 0, syserror.ErrInterrupted
}
- n, err := r.File.FileOperations.Read(r.Ctx, r.File, usermem.BytesIOSequence(buf), r.File.offset)
+ n, err := r.File.FileOperations.Read(r.Ctx, r.File, usermem.BytesIOSequence(buf), r.Offset)
+ r.Offset += n
return int(n), err
}
@@ -544,11 +550,21 @@ type lockedWriter struct {
// File is the file to write to.
File *File
+
+ // Offset is the offset to start at.
+ //
+ // This applies only to Write, not WriteAt.
+ Offset int64
}
// Write implements io.Writer.Write.
func (w *lockedWriter) Write(buf []byte) (int, error) {
- return w.WriteAt(buf, w.File.offset)
+ if w.Ctx.Interrupted() {
+ return 0, syserror.ErrInterrupted
+ }
+ n, err := w.WriteAt(buf, w.Offset)
+ w.Offset += int64(n)
+ return int(n), err
}
// WriteAt implements io.Writer.WriteAt.
@@ -562,6 +578,9 @@ func (w *lockedWriter) WriteAt(buf []byte, offset int64) (int, error) {
// io.Copy, since our own Write interface does not have this same
// contract. Enforce that here.
for written < len(buf) {
+ if w.Ctx.Interrupted() {
+ return written, syserror.ErrInterrupted
+ }
var n int64
n, err = w.File.FileOperations.Write(w.Ctx, w.File, usermem.BytesIOSequence(buf[written:]), offset+int64(written))
if n > 0 {
diff --git a/pkg/sentry/fs/file_operations.go b/pkg/sentry/fs/file_operations.go
index d86f5bf45..b88303f17 100644
--- a/pkg/sentry/fs/file_operations.go
+++ b/pkg/sentry/fs/file_operations.go
@@ -15,6 +15,8 @@
package fs
import (
+ "io"
+
"gvisor.dev/gvisor/pkg/sentry/arch"
"gvisor.dev/gvisor/pkg/sentry/context"
"gvisor.dev/gvisor/pkg/sentry/memmap"
@@ -105,8 +107,11 @@ type FileOperations interface {
// on the destination, following by a buffered copy with standard Read
// and Write operations.
//
+ // If dup is set, the data should be duplicated into the destination
+ // and retained.
+ //
// The same preconditions as Read apply.
- WriteTo(ctx context.Context, file *File, dst *File, opts SpliceOpts) (int64, error)
+ WriteTo(ctx context.Context, file *File, dst io.Writer, count int64, dup bool) (int64, error)
// Write writes src to file at offset and returns the number of bytes
// written which must be greater than or equal to 0. Like Read, file
@@ -126,7 +131,7 @@ type FileOperations interface {
// source. See WriteTo for details regarding how this is called.
//
// The same preconditions as Write apply; FileFlags.Write must be set.
- ReadFrom(ctx context.Context, file *File, src *File, opts SpliceOpts) (int64, error)
+ ReadFrom(ctx context.Context, file *File, src io.Reader, count int64) (int64, error)
// Fsync writes buffered modifications of file and/or flushes in-flight
// operations to backing storage based on syncType. The range to sync is
diff --git a/pkg/sentry/fs/file_overlay.go b/pkg/sentry/fs/file_overlay.go
index 9820f0b13..225e40186 100644
--- a/pkg/sentry/fs/file_overlay.go
+++ b/pkg/sentry/fs/file_overlay.go
@@ -15,6 +15,7 @@
package fs
import (
+ "io"
"sync"
"gvisor.dev/gvisor/pkg/refs"
@@ -268,9 +269,9 @@ func (f *overlayFileOperations) Read(ctx context.Context, file *File, dst userme
}
// WriteTo implements FileOperations.WriteTo.
-func (f *overlayFileOperations) WriteTo(ctx context.Context, file *File, dst *File, opts SpliceOpts) (n int64, err error) {
+func (f *overlayFileOperations) WriteTo(ctx context.Context, file *File, dst io.Writer, count int64, dup bool) (n int64, err error) {
err = f.onTop(ctx, file, func(file *File, ops FileOperations) error {
- n, err = ops.WriteTo(ctx, file, dst, opts)
+ n, err = ops.WriteTo(ctx, file, dst, count, dup)
return err // Will overwrite itself.
})
return
@@ -285,9 +286,9 @@ func (f *overlayFileOperations) Write(ctx context.Context, file *File, src userm
}
// ReadFrom implements FileOperations.ReadFrom.
-func (f *overlayFileOperations) ReadFrom(ctx context.Context, file *File, src *File, opts SpliceOpts) (n int64, err error) {
+func (f *overlayFileOperations) ReadFrom(ctx context.Context, file *File, src io.Reader, count int64) (n int64, err error) {
// See above; f.upper must be non-nil.
- return f.upper.FileOperations.ReadFrom(ctx, f.upper, src, opts)
+ return f.upper.FileOperations.ReadFrom(ctx, f.upper, src, count)
}
// Fsync implements FileOperations.Fsync.
diff --git a/pkg/sentry/fs/fsutil/BUILD b/pkg/sentry/fs/fsutil/BUILD
index 6499f87ac..b4ac83dc4 100644
--- a/pkg/sentry/fs/fsutil/BUILD
+++ b/pkg/sentry/fs/fsutil/BUILD
@@ -1,7 +1,9 @@
+load("@io_bazel_rules_go//go:def.bzl", "go_test")
+
package(licenses = ["notice"])
load("//tools/go_generics:defs.bzl", "go_template_instance")
-load("//tools/go_stateify:defs.bzl", "go_library", "go_test")
+load("//tools/go_stateify:defs.bzl", "go_library")
go_template_instance(
name = "dirty_set_impl",
diff --git a/pkg/sentry/fs/fsutil/file.go b/pkg/sentry/fs/fsutil/file.go
index 626b9126a..fc5b3b1a1 100644
--- a/pkg/sentry/fs/fsutil/file.go
+++ b/pkg/sentry/fs/fsutil/file.go
@@ -15,6 +15,8 @@
package fsutil
import (
+ "io"
+
"gvisor.dev/gvisor/pkg/sentry/arch"
"gvisor.dev/gvisor/pkg/sentry/context"
"gvisor.dev/gvisor/pkg/sentry/fs"
@@ -228,12 +230,12 @@ func (FileNoIoctl) Ioctl(context.Context, *fs.File, usermem.IO, arch.SyscallArgu
type FileNoSplice struct{}
// WriteTo implements fs.FileOperations.WriteTo.
-func (FileNoSplice) WriteTo(context.Context, *fs.File, *fs.File, fs.SpliceOpts) (int64, error) {
+func (FileNoSplice) WriteTo(context.Context, *fs.File, io.Writer, int64, bool) (int64, error) {
return 0, syserror.ENOSYS
}
// ReadFrom implements fs.FileOperations.ReadFrom.
-func (FileNoSplice) ReadFrom(context.Context, *fs.File, *fs.File, fs.SpliceOpts) (int64, error) {
+func (FileNoSplice) ReadFrom(context.Context, *fs.File, io.Reader, int64) (int64, error) {
return 0, syserror.ENOSYS
}
diff --git a/pkg/sentry/fs/fsutil/host_mappable.go b/pkg/sentry/fs/fsutil/host_mappable.go
index d2495cb83..693625ddc 100644
--- a/pkg/sentry/fs/fsutil/host_mappable.go
+++ b/pkg/sentry/fs/fsutil/host_mappable.go
@@ -144,7 +144,7 @@ func (h *HostMappable) Truncate(ctx context.Context, newSize int64) error {
mask := fs.AttrMask{Size: true}
attr := fs.UnstableAttr{Size: newSize}
- if err := h.backingFile.SetMaskedAttributes(ctx, mask, attr); err != nil {
+ if err := h.backingFile.SetMaskedAttributes(ctx, mask, attr, false); err != nil {
return err
}
diff --git a/pkg/sentry/fs/fsutil/inode_cached.go b/pkg/sentry/fs/fsutil/inode_cached.go
index e70bc28fb..dd80757dc 100644
--- a/pkg/sentry/fs/fsutil/inode_cached.go
+++ b/pkg/sentry/fs/fsutil/inode_cached.go
@@ -66,10 +66,8 @@ type CachingInodeOperations struct {
// mfp is used to allocate memory that caches backingFile's contents.
mfp pgalloc.MemoryFileProvider
- // forcePageCache indicates the sentry page cache should be used regardless
- // of whether the platform supports host mapped I/O or not. This must not be
- // modified after inode creation.
- forcePageCache bool
+ // opts contains options. opts is immutable.
+ opts CachingInodeOperationsOptions
attrMu sync.Mutex `state:"nosave"`
@@ -116,6 +114,20 @@ type CachingInodeOperations struct {
refs frameRefSet
}
+// CachingInodeOperationsOptions configures a CachingInodeOperations.
+//
+// +stateify savable
+type CachingInodeOperationsOptions struct {
+ // If ForcePageCache is true, use the sentry page cache even if a host file
+ // descriptor is available.
+ ForcePageCache bool
+
+ // If LimitHostFDTranslation is true, apply maxFillRange() constraints to
+ // host file descriptor mappings returned by
+ // CachingInodeOperations.Translate().
+ LimitHostFDTranslation bool
+}
+
// CachedFileObject is a file that may require caching.
type CachedFileObject interface {
// ReadToBlocksAt reads up to dsts.NumBytes() bytes from the file to dsts,
@@ -128,12 +140,16 @@ type CachedFileObject interface {
// WriteFromBlocksAt may return a partial write without an error.
WriteFromBlocksAt(ctx context.Context, srcs safemem.BlockSeq, offset uint64) (uint64, error)
- // SetMaskedAttributes sets the attributes in attr that are true in mask
- // on the backing file.
+ // SetMaskedAttributes sets the attributes in attr that are true in
+ // mask on the backing file. If the mask contains only ATime or MTime
+ // and the CachedFileObject has an FD to the file, then this operation
+ // is a noop unless forceSetTimestamps is true. This avoids an extra
+ // RPC to the gofer in the open-read/write-close case, when the
+ // timestamps on the file will be updated by the host kernel for us.
//
// SetMaskedAttributes may be called at any point, regardless of whether
// the file was opened.
- SetMaskedAttributes(ctx context.Context, mask fs.AttrMask, attr fs.UnstableAttr) error
+ SetMaskedAttributes(ctx context.Context, mask fs.AttrMask, attr fs.UnstableAttr, forceSetTimestamps bool) error
// Allocate allows the caller to reserve disk space for the inode.
// It's equivalent to fallocate(2) with 'mode=0'.
@@ -159,7 +175,7 @@ type CachedFileObject interface {
// NewCachingInodeOperations returns a new CachingInodeOperations backed by
// a CachedFileObject and its initial unstable attributes.
-func NewCachingInodeOperations(ctx context.Context, backingFile CachedFileObject, uattr fs.UnstableAttr, forcePageCache bool) *CachingInodeOperations {
+func NewCachingInodeOperations(ctx context.Context, backingFile CachedFileObject, uattr fs.UnstableAttr, opts CachingInodeOperationsOptions) *CachingInodeOperations {
mfp := pgalloc.MemoryFileProviderFromContext(ctx)
if mfp == nil {
panic(fmt.Sprintf("context.Context %T lacks non-nil value for key %T", ctx, pgalloc.CtxMemoryFileProvider))
@@ -167,7 +183,7 @@ func NewCachingInodeOperations(ctx context.Context, backingFile CachedFileObject
return &CachingInodeOperations{
backingFile: backingFile,
mfp: mfp,
- forcePageCache: forcePageCache,
+ opts: opts,
attr: uattr,
hostFileMapper: NewHostFileMapper(),
}
@@ -212,7 +228,7 @@ func (c *CachingInodeOperations) SetPermissions(ctx context.Context, inode *fs.I
now := ktime.NowFromContext(ctx)
masked := fs.AttrMask{Perms: true}
- if err := c.backingFile.SetMaskedAttributes(ctx, masked, fs.UnstableAttr{Perms: perms}); err != nil {
+ if err := c.backingFile.SetMaskedAttributes(ctx, masked, fs.UnstableAttr{Perms: perms}, false); err != nil {
return false
}
c.attr.Perms = perms
@@ -234,7 +250,7 @@ func (c *CachingInodeOperations) SetOwner(ctx context.Context, inode *fs.Inode,
UID: owner.UID.Ok(),
GID: owner.GID.Ok(),
}
- if err := c.backingFile.SetMaskedAttributes(ctx, masked, fs.UnstableAttr{Owner: owner}); err != nil {
+ if err := c.backingFile.SetMaskedAttributes(ctx, masked, fs.UnstableAttr{Owner: owner}, false); err != nil {
return err
}
if owner.UID.Ok() {
@@ -270,7 +286,9 @@ func (c *CachingInodeOperations) SetTimestamps(ctx context.Context, inode *fs.In
AccessTime: !ts.ATimeOmit,
ModificationTime: !ts.MTimeOmit,
}
- if err := c.backingFile.SetMaskedAttributes(ctx, masked, fs.UnstableAttr{AccessTime: ts.ATime, ModificationTime: ts.MTime}); err != nil {
+ // Call SetMaskedAttributes with forceSetTimestamps = true to make sure
+ // the timestamp is updated.
+ if err := c.backingFile.SetMaskedAttributes(ctx, masked, fs.UnstableAttr{AccessTime: ts.ATime, ModificationTime: ts.MTime}, true); err != nil {
return err
}
if !ts.ATimeOmit {
@@ -293,7 +311,7 @@ func (c *CachingInodeOperations) Truncate(ctx context.Context, inode *fs.Inode,
now := ktime.NowFromContext(ctx)
masked := fs.AttrMask{Size: true}
attr := fs.UnstableAttr{Size: size}
- if err := c.backingFile.SetMaskedAttributes(ctx, masked, attr); err != nil {
+ if err := c.backingFile.SetMaskedAttributes(ctx, masked, attr, false); err != nil {
c.dataMu.Unlock()
return err
}
@@ -382,7 +400,7 @@ func (c *CachingInodeOperations) WriteOut(ctx context.Context, inode *fs.Inode)
c.dirtyAttr.Size = false
// Write out cached attributes.
- if err := c.backingFile.SetMaskedAttributes(ctx, c.dirtyAttr, c.attr); err != nil {
+ if err := c.backingFile.SetMaskedAttributes(ctx, c.dirtyAttr, c.attr, false); err != nil {
c.attrMu.Unlock()
return err
}
@@ -763,7 +781,7 @@ func (rw *inodeReadWriter) WriteFromBlocks(srcs safemem.BlockSeq) (uint64, error
// and memory mappings, and false if c.cache may contain data cached from
// c.backingFile.
func (c *CachingInodeOperations) useHostPageCache() bool {
- return !c.forcePageCache && c.backingFile.FD() >= 0
+ return !c.opts.ForcePageCache && c.backingFile.FD() >= 0
}
// AddMapping implements memmap.Mappable.AddMapping.
@@ -784,11 +802,6 @@ func (c *CachingInodeOperations) AddMapping(ctx context.Context, ms memmap.Mappi
mf.MarkUnevictable(c, pgalloc.EvictableRange{r.Start, r.End})
}
}
- if c.useHostPageCache() && !usage.IncrementalMappedAccounting {
- for _, r := range mapped {
- usage.MemoryAccounting.Inc(r.Length(), usage.Mapped)
- }
- }
c.mapsMu.Unlock()
return nil
}
@@ -802,11 +815,6 @@ func (c *CachingInodeOperations) RemoveMapping(ctx context.Context, ms memmap.Ma
c.hostFileMapper.DecRefOn(r)
}
if c.useHostPageCache() {
- if !usage.IncrementalMappedAccounting {
- for _, r := range unmapped {
- usage.MemoryAccounting.Dec(r.Length(), usage.Mapped)
- }
- }
c.mapsMu.Unlock()
return
}
@@ -835,11 +843,15 @@ func (c *CachingInodeOperations) CopyMapping(ctx context.Context, ms memmap.Mapp
func (c *CachingInodeOperations) Translate(ctx context.Context, required, optional memmap.MappableRange, at usermem.AccessType) ([]memmap.Translation, error) {
// Hot path. Avoid defer.
if c.useHostPageCache() {
+ mr := optional
+ if c.opts.LimitHostFDTranslation {
+ mr = maxFillRange(required, optional)
+ }
return []memmap.Translation{
{
- Source: optional,
+ Source: mr,
File: c,
- Offset: optional.Start,
+ Offset: mr.Start,
Perms: usermem.AnyAccess,
},
}, nil
@@ -985,9 +997,7 @@ func (c *CachingInodeOperations) IncRef(fr platform.FileRange) {
seg, gap = seg.NextNonEmpty()
case gap.Ok() && gap.Start() < fr.End:
newRange := gap.Range().Intersect(fr)
- if usage.IncrementalMappedAccounting {
- usage.MemoryAccounting.Inc(newRange.Length(), usage.Mapped)
- }
+ usage.MemoryAccounting.Inc(newRange.Length(), usage.Mapped)
seg, gap = c.refs.InsertWithoutMerging(gap, newRange, 1).NextNonEmpty()
default:
c.refs.MergeAdjacent(fr)
@@ -1008,9 +1018,7 @@ func (c *CachingInodeOperations) DecRef(fr platform.FileRange) {
for seg.Ok() && seg.Start() < fr.End {
seg = c.refs.Isolate(seg, fr)
if old := seg.Value(); old == 1 {
- if usage.IncrementalMappedAccounting {
- usage.MemoryAccounting.Dec(seg.Range().Length(), usage.Mapped)
- }
+ usage.MemoryAccounting.Dec(seg.Range().Length(), usage.Mapped)
seg = c.refs.Remove(seg).NextSegment()
} else {
seg.SetValue(old - 1)
diff --git a/pkg/sentry/fs/fsutil/inode_cached_test.go b/pkg/sentry/fs/fsutil/inode_cached_test.go
index dc19255ed..129f314c8 100644
--- a/pkg/sentry/fs/fsutil/inode_cached_test.go
+++ b/pkg/sentry/fs/fsutil/inode_cached_test.go
@@ -39,7 +39,7 @@ func (noopBackingFile) WriteFromBlocksAt(ctx context.Context, srcs safemem.Block
return srcs.NumBytes(), nil
}
-func (noopBackingFile) SetMaskedAttributes(context.Context, fs.AttrMask, fs.UnstableAttr) error {
+func (noopBackingFile) SetMaskedAttributes(context.Context, fs.AttrMask, fs.UnstableAttr, bool) error {
return nil
}
@@ -61,7 +61,7 @@ func TestSetPermissions(t *testing.T) {
uattr := fs.WithCurrentTime(ctx, fs.UnstableAttr{
Perms: fs.FilePermsFromMode(0444),
})
- iops := NewCachingInodeOperations(ctx, noopBackingFile{}, uattr, false /*forcePageCache*/)
+ iops := NewCachingInodeOperations(ctx, noopBackingFile{}, uattr, CachingInodeOperationsOptions{})
defer iops.Release()
perms := fs.FilePermsFromMode(0777)
@@ -150,7 +150,7 @@ func TestSetTimestamps(t *testing.T) {
ModificationTime: epoch,
StatusChangeTime: epoch,
}
- iops := NewCachingInodeOperations(ctx, noopBackingFile{}, uattr, false /*forcePageCache*/)
+ iops := NewCachingInodeOperations(ctx, noopBackingFile{}, uattr, CachingInodeOperationsOptions{})
defer iops.Release()
if err := iops.SetTimestamps(ctx, nil, test.ts); err != nil {
@@ -188,7 +188,7 @@ func TestTruncate(t *testing.T) {
uattr := fs.UnstableAttr{
Size: 0,
}
- iops := NewCachingInodeOperations(ctx, noopBackingFile{}, uattr, false /*forcePageCache*/)
+ iops := NewCachingInodeOperations(ctx, noopBackingFile{}, uattr, CachingInodeOperationsOptions{})
defer iops.Release()
if err := iops.Truncate(ctx, nil, uattr.Size); err != nil {
@@ -230,7 +230,7 @@ func (f *sliceBackingFile) WriteFromBlocksAt(ctx context.Context, srcs safemem.B
return w.WriteFromBlocks(srcs)
}
-func (*sliceBackingFile) SetMaskedAttributes(context.Context, fs.AttrMask, fs.UnstableAttr) error {
+func (*sliceBackingFile) SetMaskedAttributes(context.Context, fs.AttrMask, fs.UnstableAttr, bool) error {
return nil
}
@@ -280,7 +280,7 @@ func TestRead(t *testing.T) {
uattr := fs.UnstableAttr{
Size: int64(len(buf)),
}
- iops := NewCachingInodeOperations(ctx, newSliceBackingFile(buf), uattr, false /*forcePageCache*/)
+ iops := NewCachingInodeOperations(ctx, newSliceBackingFile(buf), uattr, CachingInodeOperationsOptions{})
defer iops.Release()
// Expect the cache to be initially empty.
@@ -336,7 +336,7 @@ func TestWrite(t *testing.T) {
uattr := fs.UnstableAttr{
Size: int64(len(buf)),
}
- iops := NewCachingInodeOperations(ctx, newSliceBackingFile(buf), uattr, false /*forcePageCache*/)
+ iops := NewCachingInodeOperations(ctx, newSliceBackingFile(buf), uattr, CachingInodeOperationsOptions{})
defer iops.Release()
// Expect the cache to be initially empty.
diff --git a/pkg/sentry/fs/gofer/BUILD b/pkg/sentry/fs/gofer/BUILD
index 6b993928c..2b71ca0e1 100644
--- a/pkg/sentry/fs/gofer/BUILD
+++ b/pkg/sentry/fs/gofer/BUILD
@@ -1,6 +1,8 @@
+load("@io_bazel_rules_go//go:def.bzl", "go_test")
+
package(licenses = ["notice"])
-load("//tools/go_stateify:defs.bzl", "go_library", "go_test")
+load("//tools/go_stateify:defs.bzl", "go_library")
go_library(
name = "gofer",
diff --git a/pkg/sentry/fs/gofer/fs.go b/pkg/sentry/fs/gofer/fs.go
index 69999dc28..8f8ab5d29 100644
--- a/pkg/sentry/fs/gofer/fs.go
+++ b/pkg/sentry/fs/gofer/fs.go
@@ -54,6 +54,10 @@ const (
// sandbox using files backed by the gofer. If set to false, unix sockets
// cannot be bound to gofer files without an overlay on top.
privateUnixSocketKey = "privateunixsocket"
+
+ // If present, sets CachingInodeOperationsOptions.LimitHostFDTranslation to
+ // true.
+ limitHostFDTranslationKey = "limit_host_fd_translation"
)
// defaultAname is the default attach name.
@@ -134,12 +138,13 @@ func (f *filesystem) Mount(ctx context.Context, device string, flags fs.MountSou
// opts are parsed 9p mount options.
type opts struct {
- fd int
- aname string
- policy cachePolicy
- msize uint32
- version string
- privateunixsocket bool
+ fd int
+ aname string
+ policy cachePolicy
+ msize uint32
+ version string
+ privateunixsocket bool
+ limitHostFDTranslation bool
}
// options parses mount(2) data into structured options.
@@ -237,6 +242,11 @@ func options(data string) (opts, error) {
delete(options, privateUnixSocketKey)
}
+ if _, ok := options[limitHostFDTranslationKey]; ok {
+ o.limitHostFDTranslation = true
+ delete(options, limitHostFDTranslationKey)
+ }
+
// Fail to attach if the caller wanted us to do something that we
// don't support.
if len(options) > 0 {
diff --git a/pkg/sentry/fs/gofer/inode.go b/pkg/sentry/fs/gofer/inode.go
index 95b064aea..d918d6620 100644
--- a/pkg/sentry/fs/gofer/inode.go
+++ b/pkg/sentry/fs/gofer/inode.go
@@ -215,8 +215,8 @@ func (i *inodeFileState) WriteFromBlocksAt(ctx context.Context, srcs safemem.Blo
}
// SetMaskedAttributes implements fsutil.CachedFileObject.SetMaskedAttributes.
-func (i *inodeFileState) SetMaskedAttributes(ctx context.Context, mask fs.AttrMask, attr fs.UnstableAttr) error {
- if i.skipSetAttr(mask) {
+func (i *inodeFileState) SetMaskedAttributes(ctx context.Context, mask fs.AttrMask, attr fs.UnstableAttr, forceSetTimestamps bool) error {
+ if i.skipSetAttr(mask, forceSetTimestamps) {
return nil
}
as, ans := attr.AccessTime.Unix()
@@ -251,13 +251,14 @@ func (i *inodeFileState) SetMaskedAttributes(ctx context.Context, mask fs.AttrMa
// when:
// - Mask is empty
// - Mask contains only attributes that cannot be set in the gofer
-// - Mask contains only atime and/or mtime, and host FD exists
+// - forceSetTimestamps is false and mask contains only atime and/or mtime
+// and host FD exists
//
// Updates to atime and mtime can be skipped because cached value will be
// "close enough" to host value, given that operation went directly to host FD.
// Skipping atime updates is particularly important to reduce the number of
// operations sent to the Gofer for readonly files.
-func (i *inodeFileState) skipSetAttr(mask fs.AttrMask) bool {
+func (i *inodeFileState) skipSetAttr(mask fs.AttrMask, forceSetTimestamps bool) bool {
// First remove attributes that cannot be updated.
cpy := mask
cpy.Type = false
@@ -277,6 +278,12 @@ func (i *inodeFileState) skipSetAttr(mask fs.AttrMask) bool {
return false
}
+ // If forceSetTimestamps was passed, then we cannot skip.
+ if forceSetTimestamps {
+ return false
+ }
+
+ // Skip if we have a host FD.
i.handlesMu.RLock()
defer i.handlesMu.RUnlock()
return (i.readHandles != nil && i.readHandles.Host != nil) ||
diff --git a/pkg/sentry/fs/gofer/session.go b/pkg/sentry/fs/gofer/session.go
index 69d08a627..50da865c1 100644
--- a/pkg/sentry/fs/gofer/session.go
+++ b/pkg/sentry/fs/gofer/session.go
@@ -117,6 +117,11 @@ type session struct {
// Flags provided to the mount.
superBlockFlags fs.MountSourceFlags `state:"wait"`
+ // limitHostFDTranslation is the value used for
+ // CachingInodeOperationsOptions.LimitHostFDTranslation for all
+ // CachingInodeOperations created by the session.
+ limitHostFDTranslation bool
+
// connID is a unique identifier for the session connection.
connID string `state:"wait"`
@@ -218,8 +223,11 @@ func newInodeOperations(ctx context.Context, s *session, file contextFile, qid p
uattr := unstable(ctx, valid, attr, s.mounter, s.client)
return sattr, &inodeOperations{
- fileState: fileState,
- cachingInodeOps: fsutil.NewCachingInodeOperations(ctx, fileState, uattr, s.superBlockFlags.ForcePageCache),
+ fileState: fileState,
+ cachingInodeOps: fsutil.NewCachingInodeOperations(ctx, fileState, uattr, fsutil.CachingInodeOperationsOptions{
+ ForcePageCache: s.superBlockFlags.ForcePageCache,
+ LimitHostFDTranslation: s.limitHostFDTranslation,
+ }),
}
}
@@ -242,13 +250,14 @@ func Root(ctx context.Context, dev string, filesystem fs.Filesystem, superBlockF
// Construct the session.
s := session{
- connID: dev,
- msize: o.msize,
- version: o.version,
- cachePolicy: o.policy,
- aname: o.aname,
- superBlockFlags: superBlockFlags,
- mounter: mounter,
+ connID: dev,
+ msize: o.msize,
+ version: o.version,
+ cachePolicy: o.policy,
+ aname: o.aname,
+ superBlockFlags: superBlockFlags,
+ limitHostFDTranslation: o.limitHostFDTranslation,
+ mounter: mounter,
}
s.EnableLeakCheck("gofer.session")
diff --git a/pkg/sentry/fs/host/BUILD b/pkg/sentry/fs/host/BUILD
index b1080fb1a..3e532332e 100644
--- a/pkg/sentry/fs/host/BUILD
+++ b/pkg/sentry/fs/host/BUILD
@@ -1,6 +1,8 @@
+load("@io_bazel_rules_go//go:def.bzl", "go_test")
+
package(licenses = ["notice"])
-load("//tools/go_stateify:defs.bzl", "go_library", "go_test")
+load("//tools/go_stateify:defs.bzl", "go_library")
go_library(
name = "host",
diff --git a/pkg/sentry/fs/host/inode.go b/pkg/sentry/fs/host/inode.go
index 679d8321a..a6e4a09e3 100644
--- a/pkg/sentry/fs/host/inode.go
+++ b/pkg/sentry/fs/host/inode.go
@@ -114,7 +114,7 @@ func (i *inodeFileState) WriteFromBlocksAt(ctx context.Context, srcs safemem.Blo
}
// SetMaskedAttributes implements fsutil.CachedFileObject.SetMaskedAttributes.
-func (i *inodeFileState) SetMaskedAttributes(ctx context.Context, mask fs.AttrMask, attr fs.UnstableAttr) error {
+func (i *inodeFileState) SetMaskedAttributes(ctx context.Context, mask fs.AttrMask, attr fs.UnstableAttr, _ bool) error {
if mask.Empty() {
return nil
}
@@ -163,7 +163,7 @@ func (i *inodeFileState) unstableAttr(ctx context.Context) (fs.UnstableAttr, err
return unstableAttr(i.mops, &s), nil
}
-// SetMaskedAttributes implements fsutil.CachedFileObject.SetMaskedAttributes.
+// Allocate implements fsutil.CachedFileObject.Allocate.
func (i *inodeFileState) Allocate(_ context.Context, offset, length int64) error {
return syscall.Fallocate(i.FD(), 0, offset, length)
}
@@ -200,8 +200,10 @@ func newInode(ctx context.Context, msrc *fs.MountSource, fd int, saveable bool,
// Build the fs.InodeOperations.
uattr := unstableAttr(msrc.MountSourceOperations.(*superOperations), &s)
iops := &inodeOperations{
- fileState: fileState,
- cachingInodeOps: fsutil.NewCachingInodeOperations(ctx, fileState, uattr, msrc.Flags.ForcePageCache),
+ fileState: fileState,
+ cachingInodeOps: fsutil.NewCachingInodeOperations(ctx, fileState, uattr, fsutil.CachingInodeOperationsOptions{
+ ForcePageCache: msrc.Flags.ForcePageCache,
+ }),
}
// Return the fs.Inode.
diff --git a/pkg/sentry/fs/host/tty.go b/pkg/sentry/fs/host/tty.go
index 2526412a4..90331e3b2 100644
--- a/pkg/sentry/fs/host/tty.go
+++ b/pkg/sentry/fs/host/tty.go
@@ -43,12 +43,15 @@ type TTYFileOperations struct {
// fgProcessGroup is the foreground process group that is currently
// connected to this TTY.
fgProcessGroup *kernel.ProcessGroup
+
+ termios linux.KernelTermios
}
// newTTYFile returns a new fs.File that wraps a TTY FD.
func newTTYFile(ctx context.Context, dirent *fs.Dirent, flags fs.FileFlags, iops *inodeOperations) *fs.File {
return fs.NewFile(ctx, dirent, flags, &TTYFileOperations{
fileOperations: fileOperations{iops: iops},
+ termios: linux.DefaultSlaveTermios,
})
}
@@ -97,9 +100,12 @@ func (t *TTYFileOperations) Write(ctx context.Context, file *fs.File, src userme
t.mu.Lock()
defer t.mu.Unlock()
- // Are we allowed to do the write?
- if err := t.checkChange(ctx, linux.SIGTTOU); err != nil {
- return 0, err
+ // Check whether TOSTOP is enabled. This corresponds to the check in
+ // drivers/tty/n_tty.c:n_tty_write().
+ if t.termios.LEnabled(linux.TOSTOP) {
+ if err := t.checkChange(ctx, linux.SIGTTOU); err != nil {
+ return 0, err
+ }
}
return t.fileOperations.Write(ctx, file, src, offset)
}
@@ -144,6 +150,9 @@ func (t *TTYFileOperations) Ioctl(ctx context.Context, _ *fs.File, io usermem.IO
return 0, err
}
err := ioctlSetTermios(fd, ioctl, &termios)
+ if err == nil {
+ t.termios.FromTermios(termios)
+ }
return 0, err
case linux.TIOCGPGRP:
diff --git a/pkg/sentry/fs/inotify.go b/pkg/sentry/fs/inotify.go
index c7f4e2d13..ba3e0233d 100644
--- a/pkg/sentry/fs/inotify.go
+++ b/pkg/sentry/fs/inotify.go
@@ -15,6 +15,7 @@
package fs
import (
+ "io"
"sync"
"sync/atomic"
@@ -172,7 +173,7 @@ func (i *Inotify) Read(ctx context.Context, _ *File, dst usermem.IOSequence, _ i
}
// WriteTo implements FileOperations.WriteTo.
-func (*Inotify) WriteTo(context.Context, *File, *File, SpliceOpts) (int64, error) {
+func (*Inotify) WriteTo(context.Context, *File, io.Writer, int64, bool) (int64, error) {
return 0, syserror.ENOSYS
}
@@ -182,7 +183,7 @@ func (*Inotify) Fsync(context.Context, *File, int64, int64, SyncType) error {
}
// ReadFrom implements FileOperations.ReadFrom.
-func (*Inotify) ReadFrom(context.Context, *File, *File, SpliceOpts) (int64, error) {
+func (*Inotify) ReadFrom(context.Context, *File, io.Reader, int64) (int64, error) {
return 0, syserror.ENOSYS
}
diff --git a/pkg/sentry/fs/lock/BUILD b/pkg/sentry/fs/lock/BUILD
index 08d7c0c57..5a7a5b8cd 100644
--- a/pkg/sentry/fs/lock/BUILD
+++ b/pkg/sentry/fs/lock/BUILD
@@ -1,7 +1,9 @@
+load("@io_bazel_rules_go//go:def.bzl", "go_test")
+
package(licenses = ["notice"])
load("//tools/go_generics:defs.bzl", "go_template_instance")
-load("//tools/go_stateify:defs.bzl", "go_library", "go_test")
+load("//tools/go_stateify:defs.bzl", "go_library")
go_template_instance(
name = "lock_range",
diff --git a/pkg/sentry/fs/mounts.go b/pkg/sentry/fs/mounts.go
index 9b713e785..ac0398bd9 100644
--- a/pkg/sentry/fs/mounts.go
+++ b/pkg/sentry/fs/mounts.go
@@ -171,8 +171,6 @@ type MountNamespace struct {
// NewMountNamespace returns a new MountNamespace, with the provided node at the
// root, and the given cache size. A root must always be provided.
func NewMountNamespace(ctx context.Context, root *Inode) (*MountNamespace, error) {
- creds := auth.CredentialsFromContext(ctx)
-
// Set the root dirent and id on the root mount. The reference returned from
// NewDirent will be donated to the MountNamespace constructed below.
d := NewDirent(ctx, root, "/")
@@ -181,6 +179,7 @@ func NewMountNamespace(ctx context.Context, root *Inode) (*MountNamespace, error
d: newRootMount(1, d),
}
+ creds := auth.CredentialsFromContext(ctx)
mns := MountNamespace{
userns: creds.UserNamespace,
root: d,
diff --git a/pkg/sentry/fs/proc/BUILD b/pkg/sentry/fs/proc/BUILD
index 70ed854a8..1c93e8886 100644
--- a/pkg/sentry/fs/proc/BUILD
+++ b/pkg/sentry/fs/proc/BUILD
@@ -1,6 +1,8 @@
+load("@io_bazel_rules_go//go:def.bzl", "go_test")
+
package(licenses = ["notice"])
-load("//tools/go_stateify:defs.bzl", "go_library", "go_test")
+load("//tools/go_stateify:defs.bzl", "go_library")
go_library(
name = "proc",
@@ -31,7 +33,6 @@ go_library(
visibility = ["//pkg/sentry:internal"],
deps = [
"//pkg/abi/linux",
- "//pkg/binary",
"//pkg/log",
"//pkg/sentry/context",
"//pkg/sentry/fs",
diff --git a/pkg/sentry/fs/proc/net.go b/pkg/sentry/fs/proc/net.go
index 9adb23608..5e28982c5 100644
--- a/pkg/sentry/fs/proc/net.go
+++ b/pkg/sentry/fs/proc/net.go
@@ -17,10 +17,10 @@ package proc
import (
"bytes"
"fmt"
+ "io"
"time"
"gvisor.dev/gvisor/pkg/abi/linux"
- "gvisor.dev/gvisor/pkg/binary"
"gvisor.dev/gvisor/pkg/log"
"gvisor.dev/gvisor/pkg/sentry/context"
"gvisor.dev/gvisor/pkg/sentry/fs"
@@ -28,9 +28,11 @@ import (
"gvisor.dev/gvisor/pkg/sentry/fs/ramfs"
"gvisor.dev/gvisor/pkg/sentry/inet"
"gvisor.dev/gvisor/pkg/sentry/kernel"
+ "gvisor.dev/gvisor/pkg/sentry/kernel/auth"
"gvisor.dev/gvisor/pkg/sentry/socket"
"gvisor.dev/gvisor/pkg/sentry/socket/unix"
"gvisor.dev/gvisor/pkg/sentry/socket/unix/transport"
+ "gvisor.dev/gvisor/pkg/sentry/usermem"
)
// newNet creates a new proc net entry.
@@ -57,9 +59,8 @@ func (p *proc) newNetDir(ctx context.Context, k *kernel.Kernel, msrc *fs.MountSo
"ptype": newStaticProcInode(ctx, msrc, []byte("Type Device Function")),
"route": newStaticProcInode(ctx, msrc, []byte("Iface Destination Gateway Flags RefCnt Use Metric Mask MTU Window IRTT")),
"tcp": seqfile.NewSeqFileInode(ctx, &netTCP{k: k}, msrc),
- "udp": newStaticProcInode(ctx, msrc, []byte(" sl local_address rem_address st tx_queue rx_queue tr tm->when retrnsmt uid timeout inode ref pointer drops")),
-
- "unix": seqfile.NewSeqFileInode(ctx, &netUnix{k: k}, msrc),
+ "udp": seqfile.NewSeqFileInode(ctx, &netUDP{k: k}, msrc),
+ "unix": seqfile.NewSeqFileInode(ctx, &netUnix{k: k}, msrc),
}
if s.SupportsIPv6() {
@@ -216,7 +217,7 @@ func (n *netUnix) ReadSeqFileData(ctx context.Context, h seqfile.SeqHandle) ([]s
for _, se := range n.k.ListSockets() {
s := se.Sock.Get()
if s == nil {
- log.Debugf("Couldn't resolve weakref %v in socket table, racing with destruction?", se.Sock)
+ log.Debugf("Couldn't resolve weakref with ID %v in socket table, racing with destruction?", se.ID)
continue
}
sfile := s.(*fs.File)
@@ -297,6 +298,42 @@ func (n *netUnix) ReadSeqFileData(ctx context.Context, h seqfile.SeqHandle) ([]s
return data, 0
}
+func networkToHost16(n uint16) uint16 {
+ // n is in network byte order, so is big-endian. The most-significant byte
+ // should be stored in the lower address.
+ //
+ // We manually inline binary.BigEndian.Uint16() because Go does not support
+ // non-primitive consts, so binary.BigEndian is a (mutable) var, so calls to
+ // binary.BigEndian.Uint16() require a read of binary.BigEndian and an
+ // interface method call, defeating inlining.
+ buf := [2]byte{byte(n >> 8 & 0xff), byte(n & 0xff)}
+ return usermem.ByteOrder.Uint16(buf[:])
+}
+
+func writeInetAddr(w io.Writer, a linux.SockAddrInet) {
+ // linux.SockAddrInet.Port is stored in the network byte order and is
+ // printed like a number in host byte order. Note that all numbers in host
+ // byte order are printed with the most-significant byte first when
+ // formatted with %X. See get_tcp4_sock() and udp4_format_sock() in Linux.
+ port := networkToHost16(a.Port)
+
+ // linux.SockAddrInet.Addr is stored as a byte slice in big-endian order
+ // (i.e. most-significant byte in index 0). Linux represents this as a
+ // __be32 which is a typedef for an unsigned int, and is printed with
+ // %X. This means that for a little-endian machine, Linux prints the
+ // least-significant byte of the address first. To emulate this, we first
+ // invert the byte order for the address using usermem.ByteOrder.Uint32,
+ // which makes it have the equivalent encoding to a __be32 on a little
+ // endian machine. Note that this operation is a no-op on a big endian
+ // machine. Then similar to Linux, we format it with %X, which will print
+ // the most-significant byte of the __be32 address first, which is now
+ // actually the least-significant byte of the original address in
+ // linux.SockAddrInet.Addr on little endian machines, due to the conversion.
+ addr := usermem.ByteOrder.Uint32(a.Addr[:])
+
+ fmt.Fprintf(w, "%08X:%04X ", addr, port)
+}
+
// netTCP implements seqfile.SeqSource for /proc/net/tcp.
//
// +stateify savable
@@ -311,6 +348,9 @@ func (*netTCP) NeedsUpdate(generation int64) bool {
// ReadSeqFileData implements seqfile.SeqSource.ReadSeqFileData.
func (n *netTCP) ReadSeqFileData(ctx context.Context, h seqfile.SeqHandle) ([]seqfile.SeqData, int64) {
+ // t may be nil here if our caller is not part of a task goroutine. This can
+ // happen for example if we're here for "sentryctl cat". When t is nil,
+ // degrade gracefully and retrieve what we can.
t := kernel.TaskFromContext(ctx)
if h != nil {
@@ -321,7 +361,7 @@ func (n *netTCP) ReadSeqFileData(ctx context.Context, h seqfile.SeqHandle) ([]se
for _, se := range n.k.ListSockets() {
s := se.Sock.Get()
if s == nil {
- log.Debugf("Couldn't resolve weakref %+v in socket table, racing with destruction?", se.Sock)
+ log.Debugf("Couldn't resolve weakref with ID %v in socket table, racing with destruction?", se.ID)
continue
}
sfile := s.(*fs.File)
@@ -343,27 +383,23 @@ func (n *netTCP) ReadSeqFileData(ctx context.Context, h seqfile.SeqHandle) ([]se
// Field: sl; entry number.
fmt.Fprintf(&buf, "%4d: ", se.ID)
- portBuf := make([]byte, 2)
-
// Field: local_adddress.
var localAddr linux.SockAddrInet
- if local, _, err := sops.GetSockName(t); err == nil {
- localAddr = *local.(*linux.SockAddrInet)
+ if t != nil {
+ if local, _, err := sops.GetSockName(t); err == nil {
+ localAddr = *local.(*linux.SockAddrInet)
+ }
}
- binary.LittleEndian.PutUint16(portBuf, localAddr.Port)
- fmt.Fprintf(&buf, "%08X:%04X ",
- binary.LittleEndian.Uint32(localAddr.Addr[:]),
- portBuf)
+ writeInetAddr(&buf, localAddr)
// Field: rem_address.
var remoteAddr linux.SockAddrInet
- if remote, _, err := sops.GetPeerName(t); err == nil {
- remoteAddr = *remote.(*linux.SockAddrInet)
+ if t != nil {
+ if remote, _, err := sops.GetPeerName(t); err == nil {
+ remoteAddr = *remote.(*linux.SockAddrInet)
+ }
}
- binary.LittleEndian.PutUint16(portBuf, remoteAddr.Port)
- fmt.Fprintf(&buf, "%08X:%04X ",
- binary.LittleEndian.Uint32(remoteAddr.Addr[:]),
- portBuf)
+ writeInetAddr(&buf, remoteAddr)
// Field: state; socket state.
fmt.Fprintf(&buf, "%02X ", sops.State())
@@ -386,7 +422,8 @@ func (n *netTCP) ReadSeqFileData(ctx context.Context, h seqfile.SeqHandle) ([]se
log.Warningf("Failed to retrieve unstable attr for socket file: %v", err)
fmt.Fprintf(&buf, "%5d ", 0)
} else {
- fmt.Fprintf(&buf, "%5d ", uint32(uattr.Owner.UID.In(t.UserNamespace()).OrOverflow()))
+ creds := auth.CredentialsFromContext(ctx)
+ fmt.Fprintf(&buf, "%5d ", uint32(uattr.Owner.UID.In(creds.UserNamespace).OrOverflow()))
}
// Field: timeout; number of unanswered 0-window probes.
@@ -438,3 +475,125 @@ func (n *netTCP) ReadSeqFileData(ctx context.Context, h seqfile.SeqHandle) ([]se
}
return data, 0
}
+
+// netUDP implements seqfile.SeqSource for /proc/net/udp.
+//
+// +stateify savable
+type netUDP struct {
+ k *kernel.Kernel
+}
+
+// NeedsUpdate implements seqfile.SeqSource.NeedsUpdate.
+func (*netUDP) NeedsUpdate(generation int64) bool {
+ return true
+}
+
+// ReadSeqFileData implements seqfile.SeqSource.ReadSeqFileData.
+func (n *netUDP) ReadSeqFileData(ctx context.Context, h seqfile.SeqHandle) ([]seqfile.SeqData, int64) {
+ // t may be nil here if our caller is not part of a task goroutine. This can
+ // happen for example if we're here for "sentryctl cat". When t is nil,
+ // degrade gracefully and retrieve what we can.
+ t := kernel.TaskFromContext(ctx)
+
+ if h != nil {
+ return nil, 0
+ }
+
+ var buf bytes.Buffer
+ for _, se := range n.k.ListSockets() {
+ s := se.Sock.Get()
+ if s == nil {
+ log.Debugf("Couldn't resolve weakref with ID %v in socket table, racing with destruction?", se.ID)
+ continue
+ }
+ sfile := s.(*fs.File)
+ sops, ok := sfile.FileOperations.(socket.Socket)
+ if !ok {
+ panic(fmt.Sprintf("Found non-socket file in socket table: %+v", sfile))
+ }
+ if family, stype, _ := sops.Type(); family != linux.AF_INET || stype != linux.SOCK_DGRAM {
+ s.DecRef()
+ // Not udp4 socket.
+ continue
+ }
+
+ // For Linux's implementation, see net/ipv4/udp.c:udp4_format_sock().
+
+ // Field: sl; entry number.
+ fmt.Fprintf(&buf, "%5d: ", se.ID)
+
+ // Field: local_adddress.
+ var localAddr linux.SockAddrInet
+ if t != nil {
+ if local, _, err := sops.GetSockName(t); err == nil {
+ localAddr = *local.(*linux.SockAddrInet)
+ }
+ }
+ writeInetAddr(&buf, localAddr)
+
+ // Field: rem_address.
+ var remoteAddr linux.SockAddrInet
+ if t != nil {
+ if remote, _, err := sops.GetPeerName(t); err == nil {
+ remoteAddr = *remote.(*linux.SockAddrInet)
+ }
+ }
+ writeInetAddr(&buf, remoteAddr)
+
+ // Field: state; socket state.
+ fmt.Fprintf(&buf, "%02X ", sops.State())
+
+ // Field: tx_queue, rx_queue; number of packets in the transmit and
+ // receive queue. Unimplemented.
+ fmt.Fprintf(&buf, "%08X:%08X ", 0, 0)
+
+ // Field: tr, tm->when. Always 0 for UDP.
+ fmt.Fprintf(&buf, "%02X:%08X ", 0, 0)
+
+ // Field: retrnsmt. Always 0 for UDP.
+ fmt.Fprintf(&buf, "%08X ", 0)
+
+ // Field: uid.
+ uattr, err := sfile.Dirent.Inode.UnstableAttr(ctx)
+ if err != nil {
+ log.Warningf("Failed to retrieve unstable attr for socket file: %v", err)
+ fmt.Fprintf(&buf, "%5d ", 0)
+ } else {
+ creds := auth.CredentialsFromContext(ctx)
+ fmt.Fprintf(&buf, "%5d ", uint32(uattr.Owner.UID.In(creds.UserNamespace).OrOverflow()))
+ }
+
+ // Field: timeout. Always 0 for UDP.
+ fmt.Fprintf(&buf, "%8d ", 0)
+
+ // Field: inode.
+ fmt.Fprintf(&buf, "%8d ", sfile.InodeID())
+
+ // Field: ref; reference count on the socket inode. Don't count the ref
+ // we obtain while deferencing the weakref to this socket.
+ fmt.Fprintf(&buf, "%d ", sfile.ReadRefs()-1)
+
+ // Field: Socket struct address. Redacted due to the same reason as
+ // the 'Num' field in /proc/net/unix, see netUnix.ReadSeqFileData.
+ fmt.Fprintf(&buf, "%#016p ", (*socket.Socket)(nil))
+
+ // Field: drops; number of dropped packets. Unimplemented.
+ fmt.Fprintf(&buf, "%d", 0)
+
+ fmt.Fprintf(&buf, "\n")
+
+ s.DecRef()
+ }
+
+ data := []seqfile.SeqData{
+ {
+ Buf: []byte(" sl local_address rem_address st tx_queue rx_queue tr tm->when retrnsmt uid timeout inode ref pointer drops \n"),
+ Handle: n,
+ },
+ {
+ Buf: buf.Bytes(),
+ Handle: n,
+ },
+ }
+ return data, 0
+}
diff --git a/pkg/sentry/fs/proc/seqfile/BUILD b/pkg/sentry/fs/proc/seqfile/BUILD
index 20c3eefc8..76433c7d0 100644
--- a/pkg/sentry/fs/proc/seqfile/BUILD
+++ b/pkg/sentry/fs/proc/seqfile/BUILD
@@ -1,6 +1,8 @@
+load("@io_bazel_rules_go//go:def.bzl", "go_test")
+
package(licenses = ["notice"])
-load("//tools/go_stateify:defs.bzl", "go_library", "go_test")
+load("//tools/go_stateify:defs.bzl", "go_library")
go_library(
name = "seqfile",
diff --git a/pkg/sentry/fs/ramfs/BUILD b/pkg/sentry/fs/ramfs/BUILD
index 516efcc4c..d0f351e5a 100644
--- a/pkg/sentry/fs/ramfs/BUILD
+++ b/pkg/sentry/fs/ramfs/BUILD
@@ -1,6 +1,8 @@
+load("@io_bazel_rules_go//go:def.bzl", "go_test")
+
package(licenses = ["notice"])
-load("//tools/go_stateify:defs.bzl", "go_library", "go_test")
+load("//tools/go_stateify:defs.bzl", "go_library")
go_library(
name = "ramfs",
diff --git a/pkg/sentry/fs/splice.go b/pkg/sentry/fs/splice.go
index eed1c2854..311798811 100644
--- a/pkg/sentry/fs/splice.go
+++ b/pkg/sentry/fs/splice.go
@@ -18,7 +18,6 @@ import (
"io"
"sync/atomic"
- "gvisor.dev/gvisor/pkg/secio"
"gvisor.dev/gvisor/pkg/sentry/context"
"gvisor.dev/gvisor/pkg/syserror"
)
@@ -33,146 +32,131 @@ func Splice(ctx context.Context, dst *File, src *File, opts SpliceOpts) (int64,
}
// Check whether or not the objects being sliced are stream-oriented
- // (i.e. pipes or sockets). If yes, we elide checks and offset locks.
- srcPipe := IsPipe(src.Dirent.Inode.StableAttr) || IsSocket(src.Dirent.Inode.StableAttr)
- dstPipe := IsPipe(dst.Dirent.Inode.StableAttr) || IsSocket(dst.Dirent.Inode.StableAttr)
+ // (i.e. pipes or sockets). For all stream-oriented files and files
+ // where a specific offiset is not request, we acquire the file mutex.
+ // This has two important side effects. First, it provides the standard
+ // protection against concurrent writes that would mutate the offset.
+ // Second, it prevents Splice deadlocks. Only internal anonymous files
+ // implement the ReadFrom and WriteTo methods directly, and since such
+ // anonymous files are referred to by a unique fs.File object, we know
+ // that the file mutex takes strict precedence over internal locks.
+ // Since we enforce lock ordering here, we can't deadlock by using
+ // using a file in two different splice operations simultaneously.
+ srcPipe := !IsRegular(src.Dirent.Inode.StableAttr)
+ dstPipe := !IsRegular(dst.Dirent.Inode.StableAttr)
+ dstAppend := !dstPipe && dst.Flags().Append
+ srcLock := srcPipe || !opts.SrcOffset
+ dstLock := dstPipe || !opts.DstOffset || dstAppend
- if !dstPipe && !opts.DstOffset && !srcPipe && !opts.SrcOffset {
+ switch {
+ case srcLock && dstLock:
switch {
case dst.UniqueID < src.UniqueID:
// Acquire dst first.
if !dst.mu.Lock(ctx) {
return 0, syserror.ErrInterrupted
}
- defer dst.mu.Unlock()
if !src.mu.Lock(ctx) {
+ dst.mu.Unlock()
return 0, syserror.ErrInterrupted
}
- defer src.mu.Unlock()
case dst.UniqueID > src.UniqueID:
// Acquire src first.
if !src.mu.Lock(ctx) {
return 0, syserror.ErrInterrupted
}
- defer src.mu.Unlock()
if !dst.mu.Lock(ctx) {
+ src.mu.Unlock()
return 0, syserror.ErrInterrupted
}
- defer dst.mu.Unlock()
case dst.UniqueID == src.UniqueID:
// Acquire only one lock; it's the same file. This is a
// bit of a edge case, but presumably it's possible.
if !dst.mu.Lock(ctx) {
return 0, syserror.ErrInterrupted
}
- defer dst.mu.Unlock()
+ srcLock = false // Only need one unlock.
}
// Use both offsets (locked).
opts.DstStart = dst.offset
opts.SrcStart = src.offset
- } else if !dstPipe && !opts.DstOffset {
+ case dstLock:
// Acquire only dst.
if !dst.mu.Lock(ctx) {
return 0, syserror.ErrInterrupted
}
- defer dst.mu.Unlock()
opts.DstStart = dst.offset // Safe: locked.
- } else if !srcPipe && !opts.SrcOffset {
+ case srcLock:
// Acquire only src.
if !src.mu.Lock(ctx) {
return 0, syserror.ErrInterrupted
}
- defer src.mu.Unlock()
opts.SrcStart = src.offset // Safe: locked.
}
- // Check append-only mode and the limit.
- if !dstPipe {
+ var err error
+ if dstAppend {
unlock := dst.Dirent.Inode.lockAppendMu(dst.Flags().Append)
defer unlock()
- if dst.Flags().Append {
- if opts.DstOffset {
- // We need to acquire the lock.
- if !dst.mu.Lock(ctx) {
- return 0, syserror.ErrInterrupted
- }
- defer dst.mu.Unlock()
- }
- // Figure out the appropriate offset to use.
- if err := dst.offsetForAppend(ctx, &opts.DstStart); err != nil {
- return 0, err
- }
- }
+ // Figure out the appropriate offset to use.
+ err = dst.offsetForAppend(ctx, &opts.DstStart)
+ }
+ if err == nil && !dstPipe {
// Enforce file limits.
limit, ok := dst.checkLimit(ctx, opts.DstStart)
switch {
case ok && limit == 0:
- return 0, syserror.ErrExceedsFileSizeLimit
+ err = syserror.ErrExceedsFileSizeLimit
case ok && limit < opts.Length:
opts.Length = limit // Cap the write.
}
}
+ if err != nil {
+ if dstLock {
+ dst.mu.Unlock()
+ }
+ if srcLock {
+ src.mu.Unlock()
+ }
+ return 0, err
+ }
- // Attempt to do a WriteTo; this is likely the most efficient.
- //
- // The underlying implementation may be able to donate buffers.
- newOpts := SpliceOpts{
- Length: opts.Length,
- SrcStart: opts.SrcStart,
- SrcOffset: !srcPipe,
- Dup: opts.Dup,
- DstStart: opts.DstStart,
- DstOffset: !dstPipe,
+ // Construct readers and writers for the splice. This is used to
+ // provide a safer locking path for the WriteTo/ReadFrom operations
+ // (since they will otherwise go through public interface methods which
+ // conflict with locking done above), and simplifies the fallback path.
+ w := &lockedWriter{
+ Ctx: ctx,
+ File: dst,
+ Offset: opts.DstStart,
}
- n, err := src.FileOperations.WriteTo(ctx, src, dst, newOpts)
- if n == 0 && err != nil {
- // Attempt as a ReadFrom. If a WriteTo, a ReadFrom may also
- // be more efficient than a copy if buffers are cached or readily
- // available. (It's unlikely that they can actually be donate
- n, err = dst.FileOperations.ReadFrom(ctx, dst, src, newOpts)
+ r := &lockedReader{
+ Ctx: ctx,
+ File: src,
+ Offset: opts.SrcStart,
}
- if n == 0 && err != nil {
- // If we've failed up to here, and at least one of the sources
- // is a pipe or socket, then we can't properly support dup.
- // Return an error indicating that this operation is not
- // supported.
- if (srcPipe || dstPipe) && newOpts.Dup {
- return 0, syserror.EINVAL
- }
- // We failed to splice the files. But that's fine; we just fall
- // back to a slow path in this case. This copies without doing
- // any mode changes, so should still be more efficient.
- var (
- r io.Reader
- w io.Writer
- )
- fw := &lockedWriter{
- Ctx: ctx,
- File: dst,
- }
- if newOpts.DstOffset {
- // Use the provided offset.
- w = secio.NewOffsetWriter(fw, newOpts.DstStart)
- } else {
- // Writes will proceed with no offset.
- w = fw
- }
- fr := &lockedReader{
- Ctx: ctx,
- File: src,
- }
- if newOpts.SrcOffset {
- // Limit to the given offset and length.
- r = io.NewSectionReader(fr, opts.SrcStart, opts.Length)
- } else {
- // Limit just to the given length.
- r = &io.LimitedReader{fr, opts.Length}
- }
+ // Attempt to do a WriteTo; this is likely the most efficient.
+ n, err := src.FileOperations.WriteTo(ctx, src, w, opts.Length, opts.Dup)
+ if n == 0 && err == syserror.ENOSYS && !opts.Dup {
+ // Attempt as a ReadFrom. If a WriteTo, a ReadFrom may also be
+ // more efficient than a copy if buffers are cached or readily
+ // available. (It's unlikely that they can actually be donated).
+ n, err = dst.FileOperations.ReadFrom(ctx, dst, r, opts.Length)
+ }
- // Copy between the two.
- n, err = io.Copy(w, r)
+ // Support one last fallback option, but only if at least one of
+ // the source and destination are regular files. This is because
+ // if we block at some point, we could lose data. If the source is
+ // not a pipe then reading is not destructive; if the destination
+ // is a regular file, then it is guaranteed not to block writing.
+ if n == 0 && err == syserror.ENOSYS && !opts.Dup && (!dstPipe || !srcPipe) {
+ // Fallback to an in-kernel copy.
+ n, err = io.Copy(w, &io.LimitedReader{
+ R: r,
+ N: opts.Length,
+ })
}
// Update offsets, if required.
@@ -185,5 +169,13 @@ func Splice(ctx context.Context, dst *File, src *File, opts SpliceOpts) (int64,
}
}
+ // Drop locks.
+ if dstLock {
+ dst.mu.Unlock()
+ }
+ if srcLock {
+ src.mu.Unlock()
+ }
+
return n, err
}
diff --git a/pkg/sentry/fs/timerfd/timerfd.go b/pkg/sentry/fs/timerfd/timerfd.go
index 59403d9db..f8bf663bb 100644
--- a/pkg/sentry/fs/timerfd/timerfd.go
+++ b/pkg/sentry/fs/timerfd/timerfd.go
@@ -141,9 +141,10 @@ func (t *TimerOperations) Write(context.Context, *fs.File, usermem.IOSequence, i
}
// Notify implements ktime.TimerListener.Notify.
-func (t *TimerOperations) Notify(exp uint64) {
+func (t *TimerOperations) Notify(exp uint64, setting ktime.Setting) (ktime.Setting, bool) {
atomic.AddUint64(&t.val, exp)
t.events.Notify(waiter.EventIn)
+ return ktime.Setting{}, false
}
// Destroy implements ktime.TimerListener.Destroy.
diff --git a/pkg/sentry/fs/tmpfs/BUILD b/pkg/sentry/fs/tmpfs/BUILD
index 8f7eb5757..11b680929 100644
--- a/pkg/sentry/fs/tmpfs/BUILD
+++ b/pkg/sentry/fs/tmpfs/BUILD
@@ -1,6 +1,8 @@
+load("@io_bazel_rules_go//go:def.bzl", "go_test")
+
package(licenses = ["notice"])
-load("//tools/go_stateify:defs.bzl", "go_library", "go_test")
+load("//tools/go_stateify:defs.bzl", "go_library")
go_library(
name = "tmpfs",
diff --git a/pkg/sentry/fs/tty/BUILD b/pkg/sentry/fs/tty/BUILD
index 291164986..25811f668 100644
--- a/pkg/sentry/fs/tty/BUILD
+++ b/pkg/sentry/fs/tty/BUILD
@@ -1,6 +1,8 @@
+load("@io_bazel_rules_go//go:def.bzl", "go_test")
+
package(licenses = ["notice"])
-load("//tools/go_stateify:defs.bzl", "go_library", "go_test")
+load("//tools/go_stateify:defs.bzl", "go_library")
go_library(
name = "tty",
diff --git a/pkg/sentry/fsimpl/ext/BUILD b/pkg/sentry/fsimpl/ext/BUILD
index a41101339..b0c286b7a 100644
--- a/pkg/sentry/fsimpl/ext/BUILD
+++ b/pkg/sentry/fsimpl/ext/BUILD
@@ -1,6 +1,8 @@
+load("@io_bazel_rules_go//go:def.bzl", "go_test")
+
package(licenses = ["notice"])
-load("//tools/go_stateify:defs.bzl", "go_library", "go_test")
+load("//tools/go_stateify:defs.bzl", "go_library")
load("//tools/go_generics:defs.bzl", "go_template_instance")
go_template_instance(
@@ -79,7 +81,7 @@ go_test(
"//pkg/sentry/usermem",
"//pkg/sentry/vfs",
"//pkg/syserror",
- "//runsc/test/testutil",
+ "//runsc/testutil",
"@com_github_google_go-cmp//cmp:go_default_library",
"@com_github_google_go-cmp//cmp/cmpopts:go_default_library",
],
diff --git a/pkg/sentry/fsimpl/ext/benchmark/BUILD b/pkg/sentry/fsimpl/ext/benchmark/BUILD
index 9fddb4c4c..bfc46dfa6 100644
--- a/pkg/sentry/fsimpl/ext/benchmark/BUILD
+++ b/pkg/sentry/fsimpl/ext/benchmark/BUILD
@@ -1,4 +1,4 @@
-load("//tools/go_stateify:defs.bzl", "go_test")
+load("@io_bazel_rules_go//go:def.bzl", "go_test")
package(licenses = ["notice"])
diff --git a/pkg/sentry/fsimpl/ext/directory.go b/pkg/sentry/fsimpl/ext/directory.go
index b51f3e18d..0b471d121 100644
--- a/pkg/sentry/fsimpl/ext/directory.go
+++ b/pkg/sentry/fsimpl/ext/directory.go
@@ -190,10 +190,10 @@ func (fd *directoryFD) IterDirents(ctx context.Context, cb vfs.IterDirentsCallba
}
if !cb.Handle(vfs.Dirent{
- Name: child.diskDirent.FileName(),
- Type: fs.ToDirentType(childType),
- Ino: uint64(child.diskDirent.Inode()),
- Off: fd.off,
+ Name: child.diskDirent.FileName(),
+ Type: fs.ToDirentType(childType),
+ Ino: uint64(child.diskDirent.Inode()),
+ NextOff: fd.off + 1,
}) {
dir.childList.InsertBefore(child, fd.iter)
return nil
diff --git a/pkg/sentry/fsimpl/ext/disklayout/BUILD b/pkg/sentry/fsimpl/ext/disklayout/BUILD
index 907d35b7e..2d50e30aa 100644
--- a/pkg/sentry/fsimpl/ext/disklayout/BUILD
+++ b/pkg/sentry/fsimpl/ext/disklayout/BUILD
@@ -1,6 +1,8 @@
+load("@io_bazel_rules_go//go:def.bzl", "go_test")
+
package(licenses = ["notice"])
-load("//tools/go_stateify:defs.bzl", "go_library", "go_test")
+load("//tools/go_stateify:defs.bzl", "go_library")
go_library(
name = "disklayout",
diff --git a/pkg/sentry/fsimpl/ext/ext_test.go b/pkg/sentry/fsimpl/ext/ext_test.go
index 49b57a2d6..1aa2bd6a4 100644
--- a/pkg/sentry/fsimpl/ext/ext_test.go
+++ b/pkg/sentry/fsimpl/ext/ext_test.go
@@ -33,7 +33,7 @@ import (
"gvisor.dev/gvisor/pkg/sentry/vfs"
"gvisor.dev/gvisor/pkg/syserror"
- "gvisor.dev/gvisor/runsc/test/testutil"
+ "gvisor.dev/gvisor/runsc/testutil"
)
const (
@@ -584,7 +584,7 @@ func TestIterDirents(t *testing.T) {
// Ignore the inode number and offset of dirents because those are likely to
// change as the underlying image changes.
cmpIgnoreFields := cmp.FilterPath(func(p cmp.Path) bool {
- return p.String() == "Ino" || p.String() == "Off"
+ return p.String() == "Ino" || p.String() == "NextOff"
}, cmp.Ignore())
if diff := cmp.Diff(cb.dirents, test.want, cmpIgnoreFields); diff != "" {
t.Errorf("dirents mismatch (-want +got):\n%s", diff)
diff --git a/pkg/sentry/fsimpl/memfs/BUILD b/pkg/sentry/fsimpl/memfs/BUILD
index d2450e810..7e364c5fd 100644
--- a/pkg/sentry/fsimpl/memfs/BUILD
+++ b/pkg/sentry/fsimpl/memfs/BUILD
@@ -1,4 +1,5 @@
-load("//tools/go_stateify:defs.bzl", "go_library", "go_test")
+load("//tools/go_stateify:defs.bzl", "go_library")
+load("@io_bazel_rules_go//go:def.bzl", "go_test")
package(licenses = ["notice"])
diff --git a/pkg/sentry/fsimpl/memfs/directory.go b/pkg/sentry/fsimpl/memfs/directory.go
index c52dc781c..c620227c9 100644
--- a/pkg/sentry/fsimpl/memfs/directory.go
+++ b/pkg/sentry/fsimpl/memfs/directory.go
@@ -75,10 +75,10 @@ func (fd *directoryFD) IterDirents(ctx context.Context, cb vfs.IterDirentsCallba
if fd.off == 0 {
if !cb.Handle(vfs.Dirent{
- Name: ".",
- Type: linux.DT_DIR,
- Ino: vfsd.Impl().(*dentry).inode.ino,
- Off: 0,
+ Name: ".",
+ Type: linux.DT_DIR,
+ Ino: vfsd.Impl().(*dentry).inode.ino,
+ NextOff: 1,
}) {
return nil
}
@@ -87,10 +87,10 @@ func (fd *directoryFD) IterDirents(ctx context.Context, cb vfs.IterDirentsCallba
if fd.off == 1 {
parentInode := vfsd.ParentOrSelf().Impl().(*dentry).inode
if !cb.Handle(vfs.Dirent{
- Name: "..",
- Type: parentInode.direntType(),
- Ino: parentInode.ino,
- Off: 1,
+ Name: "..",
+ Type: parentInode.direntType(),
+ Ino: parentInode.ino,
+ NextOff: 2,
}) {
return nil
}
@@ -112,10 +112,10 @@ func (fd *directoryFD) IterDirents(ctx context.Context, cb vfs.IterDirentsCallba
// Skip other directoryFD iterators.
if child.inode != nil {
if !cb.Handle(vfs.Dirent{
- Name: child.vfsd.Name(),
- Type: child.inode.direntType(),
- Ino: child.inode.ino,
- Off: fd.off,
+ Name: child.vfsd.Name(),
+ Type: child.inode.direntType(),
+ Ino: child.inode.ino,
+ NextOff: fd.off + 1,
}) {
dir.childList.InsertBefore(child, fd.iter)
return nil
diff --git a/pkg/sentry/fsimpl/proc/BUILD b/pkg/sentry/fsimpl/proc/BUILD
index 3d8a4deaf..ade6ac946 100644
--- a/pkg/sentry/fsimpl/proc/BUILD
+++ b/pkg/sentry/fsimpl/proc/BUILD
@@ -1,4 +1,5 @@
-load("//tools/go_stateify:defs.bzl", "go_library", "go_test")
+load("//tools/go_stateify:defs.bzl", "go_library")
+load("@io_bazel_rules_go//go:def.bzl", "go_test")
package(licenses = ["notice"])
diff --git a/pkg/sentry/hostcpu/BUILD b/pkg/sentry/hostcpu/BUILD
index f989f2f8b..359468ccc 100644
--- a/pkg/sentry/hostcpu/BUILD
+++ b/pkg/sentry/hostcpu/BUILD
@@ -1,4 +1,5 @@
-load("//tools/go_stateify:defs.bzl", "go_library", "go_test")
+load("//tools/go_stateify:defs.bzl", "go_library")
+load("@io_bazel_rules_go//go:def.bzl", "go_test")
package(licenses = ["notice"])
@@ -6,6 +7,7 @@ go_library(
name = "hostcpu",
srcs = [
"getcpu_amd64.s",
+ "getcpu_arm64.s",
"hostcpu.go",
],
importpath = "gvisor.dev/gvisor/pkg/sentry/hostcpu",
diff --git a/pkg/sentry/hostcpu/getcpu_arm64.s b/pkg/sentry/hostcpu/getcpu_arm64.s
new file mode 100644
index 000000000..caf9abb89
--- /dev/null
+++ b/pkg/sentry/hostcpu/getcpu_arm64.s
@@ -0,0 +1,28 @@
+// Copyright 2019 The gVisor Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+#include "textflag.h"
+
+// GetCPU makes the getcpu(unsigned *cpu, unsigned *node, NULL) syscall for
+// the lack of an optimazed way of getting the current CPU number on arm64.
+
+// func GetCPU() (cpu uint32)
+TEXT ·GetCPU(SB), NOSPLIT, $0-4
+ MOVW ZR, cpu+0(FP)
+ MOVD $cpu+0(FP), R0
+ MOVD $0x0, R1 // unused
+ MOVD $0x0, R2 // unused
+ MOVD $0xA8, R8 // SYS_GETCPU
+ SVC
+ RET
diff --git a/pkg/sentry/kernel/BUILD b/pkg/sentry/kernel/BUILD
index 41bee9a22..aba2414d4 100644
--- a/pkg/sentry/kernel/BUILD
+++ b/pkg/sentry/kernel/BUILD
@@ -1,9 +1,11 @@
load("@io_bazel_rules_go//proto:def.bzl", "go_proto_library")
+load("@io_bazel_rules_go//go:def.bzl", "go_test")
+load("@rules_cc//cc:defs.bzl", "cc_proto_library")
package(licenses = ["notice"])
load("//tools/go_generics:defs.bzl", "go_template_instance")
-load("//tools/go_stateify:defs.bzl", "go_library", "go_test")
+load("//tools/go_stateify:defs.bzl", "go_library")
go_template_instance(
name = "pending_signals_list",
@@ -83,6 +85,12 @@ proto_library(
deps = ["//pkg/sentry/arch:registers_proto"],
)
+cc_proto_library(
+ name = "uncaught_signal_cc_proto",
+ visibility = ["//visibility:public"],
+ deps = [":uncaught_signal_proto"],
+)
+
go_proto_library(
name = "uncaught_signal_go_proto",
importpath = "gvisor.dev/gvisor/pkg/sentry/kernel/uncaught_signal_go_proto",
diff --git a/pkg/sentry/kernel/epoll/BUILD b/pkg/sentry/kernel/epoll/BUILD
index f46c43128..65427b112 100644
--- a/pkg/sentry/kernel/epoll/BUILD
+++ b/pkg/sentry/kernel/epoll/BUILD
@@ -1,7 +1,9 @@
+load("@io_bazel_rules_go//go:def.bzl", "go_test")
+
package(licenses = ["notice"])
load("//tools/go_generics:defs.bzl", "go_template_instance")
-load("//tools/go_stateify:defs.bzl", "go_library", "go_test")
+load("//tools/go_stateify:defs.bzl", "go_library")
go_template_instance(
name = "epoll_list",
diff --git a/pkg/sentry/kernel/eventfd/BUILD b/pkg/sentry/kernel/eventfd/BUILD
index 1c5f979d4..983ca67ed 100644
--- a/pkg/sentry/kernel/eventfd/BUILD
+++ b/pkg/sentry/kernel/eventfd/BUILD
@@ -1,6 +1,8 @@
+load("@io_bazel_rules_go//go:def.bzl", "go_test")
+
package(licenses = ["notice"])
-load("//tools/go_stateify:defs.bzl", "go_library", "go_test")
+load("//tools/go_stateify:defs.bzl", "go_library")
go_library(
name = "eventfd",
diff --git a/pkg/sentry/kernel/futex/BUILD b/pkg/sentry/kernel/futex/BUILD
index 6a31dc044..41f44999c 100644
--- a/pkg/sentry/kernel/futex/BUILD
+++ b/pkg/sentry/kernel/futex/BUILD
@@ -1,7 +1,9 @@
+load("@io_bazel_rules_go//go:def.bzl", "go_test")
+
package(licenses = ["notice"])
load("//tools/go_generics:defs.bzl", "go_template_instance")
-load("//tools/go_stateify:defs.bzl", "go_library", "go_test")
+load("//tools/go_stateify:defs.bzl", "go_library")
go_template_instance(
name = "atomicptr_bucket",
diff --git a/pkg/sentry/kernel/kernel.go b/pkg/sentry/kernel/kernel.go
index 8c1f79ab5..3cda03891 100644
--- a/pkg/sentry/kernel/kernel.go
+++ b/pkg/sentry/kernel/kernel.go
@@ -24,6 +24,7 @@
// TaskSet.mu
// SignalHandlers.mu
// Task.mu
+// runningTasksMu
//
// Locking SignalHandlers.mu in multiple SignalHandlers requires locking
// TaskSet.mu exclusively first. Locking Task.mu in multiple Tasks at the same
@@ -135,6 +136,22 @@ type Kernel struct {
// syslog is the kernel log.
syslog syslog
+ // runningTasksMu synchronizes disable/enable of cpuClockTicker when
+ // the kernel is idle (runningTasks == 0).
+ //
+ // runningTasksMu is used to exclude critical sections when the timer
+ // disables itself and when the first active task enables the timer,
+ // ensuring that tasks always see a valid cpuClock value.
+ runningTasksMu sync.Mutex `state:"nosave"`
+
+ // runningTasks is the total count of tasks currently in
+ // TaskGoroutineRunningSys or TaskGoroutineRunningApp. i.e., they are
+ // not blocked or stopped.
+ //
+ // runningTasks must be accessed atomically. Increments from 0 to 1 are
+ // further protected by runningTasksMu (see incRunningTasks).
+ runningTasks int64
+
// cpuClock is incremented every linux.ClockTick. cpuClock is used to
// measure task CPU usage, since sampling monotonicClock twice on every
// syscall turns out to be unreasonably expensive. This is similar to how
@@ -150,6 +167,22 @@ type Kernel struct {
// cpuClockTicker increments cpuClock.
cpuClockTicker *ktime.Timer `state:"nosave"`
+ // cpuClockTickerDisabled indicates that cpuClockTicker has been
+ // disabled because no tasks are running.
+ //
+ // cpuClockTickerDisabled is protected by runningTasksMu.
+ cpuClockTickerDisabled bool
+
+ // cpuClockTickerSetting is the ktime.Setting of cpuClockTicker at the
+ // point it was disabled. It is cached here to avoid a lock ordering
+ // violation with cpuClockTicker.mu when runningTaskMu is held.
+ //
+ // cpuClockTickerSetting is only valid when cpuClockTickerDisabled is
+ // true.
+ //
+ // cpuClockTickerSetting is protected by runningTasksMu.
+ cpuClockTickerSetting ktime.Setting
+
// fdMapUids is an ever-increasing counter for generating FDTable uids.
//
// fdMapUids is mutable, and is accessed using atomic memory operations.
@@ -912,6 +945,102 @@ func (k *Kernel) resumeTimeLocked() {
}
}
+func (k *Kernel) incRunningTasks() {
+ for {
+ tasks := atomic.LoadInt64(&k.runningTasks)
+ if tasks != 0 {
+ // Standard case. Simply increment.
+ if !atomic.CompareAndSwapInt64(&k.runningTasks, tasks, tasks+1) {
+ continue
+ }
+ return
+ }
+
+ // Transition from 0 -> 1. Synchronize with other transitions and timer.
+ k.runningTasksMu.Lock()
+ tasks = atomic.LoadInt64(&k.runningTasks)
+ if tasks != 0 {
+ // We're no longer the first task, no need to
+ // re-enable.
+ atomic.AddInt64(&k.runningTasks, 1)
+ k.runningTasksMu.Unlock()
+ return
+ }
+
+ if !k.cpuClockTickerDisabled {
+ // Timer was never disabled.
+ atomic.StoreInt64(&k.runningTasks, 1)
+ k.runningTasksMu.Unlock()
+ return
+ }
+
+ // We need to update cpuClock for all of the ticks missed while we
+ // slept, and then re-enable the timer.
+ //
+ // The Notify in Swap isn't sufficient. kernelCPUClockTicker.Notify
+ // always increments cpuClock by 1 regardless of the number of
+ // expirations as a heuristic to avoid over-accounting in cases of CPU
+ // throttling.
+ //
+ // We want to cover the normal case, when all time should be accounted,
+ // so we increment for all expirations. Throttling is less concerning
+ // here because the ticker is only disabled from Notify. This means
+ // that Notify must schedule and compensate for the throttled period
+ // before the timer is disabled. Throttling while the timer is disabled
+ // doesn't matter, as nothing is running or reading cpuClock anyways.
+ //
+ // S/R also adds complication, as there are two cases. Recall that
+ // monotonicClock will jump forward on restore.
+ //
+ // 1. If the ticker is enabled during save, then on Restore Notify is
+ // called with many expirations, covering the time jump, but cpuClock
+ // is only incremented by 1.
+ //
+ // 2. If the ticker is disabled during save, then after Restore the
+ // first wakeup will call this function and cpuClock will be
+ // incremented by the number of expirations across the S/R.
+ //
+ // These cause very different value of cpuClock. But again, since
+ // nothing was running while the ticker was disabled, those differences
+ // don't matter.
+ setting, exp := k.cpuClockTickerSetting.At(k.monotonicClock.Now())
+ if exp > 0 {
+ atomic.AddUint64(&k.cpuClock, exp)
+ }
+
+ // Now that cpuClock is updated it is safe to allow other tasks to
+ // transition to running.
+ atomic.StoreInt64(&k.runningTasks, 1)
+
+ // N.B. we must unlock before calling Swap to maintain lock ordering.
+ //
+ // cpuClockTickerDisabled need not wait until after Swap to become
+ // true. It is sufficient that the timer *will* be enabled.
+ k.cpuClockTickerDisabled = false
+ k.runningTasksMu.Unlock()
+
+ // This won't call Notify (unless it's been ClockTick since setting.At
+ // above). This means we skip the thread group work in Notify. However,
+ // since nothing was running while we were disabled, none of the timers
+ // could have expired.
+ k.cpuClockTicker.Swap(setting)
+
+ return
+ }
+}
+
+func (k *Kernel) decRunningTasks() {
+ tasks := atomic.AddInt64(&k.runningTasks, -1)
+ if tasks < 0 {
+ panic(fmt.Sprintf("Invalid running count %d", tasks))
+ }
+
+ // Nothing to do. The next CPU clock tick will disable the timer if
+ // there is still nothing running. This provides approximately one tick
+ // of slack in which we can switch back and forth between idle and
+ // active without an expensive transition.
+}
+
// WaitExited blocks until all tasks in k have exited.
func (k *Kernel) WaitExited() {
k.tasks.liveGoroutines.Wait()
diff --git a/pkg/sentry/kernel/memevent/BUILD b/pkg/sentry/kernel/memevent/BUILD
index ebcfaa619..d7a7d1169 100644
--- a/pkg/sentry/kernel/memevent/BUILD
+++ b/pkg/sentry/kernel/memevent/BUILD
@@ -1,5 +1,6 @@
load("//tools/go_stateify:defs.bzl", "go_library")
load("@io_bazel_rules_go//proto:def.bzl", "go_proto_library")
+load("@rules_cc//cc:defs.bzl", "cc_proto_library")
package(licenses = ["notice"])
@@ -24,6 +25,12 @@ proto_library(
visibility = ["//visibility:public"],
)
+cc_proto_library(
+ name = "memory_events_cc_proto",
+ visibility = ["//visibility:public"],
+ deps = [":memory_events_proto"],
+)
+
go_proto_library(
name = "memory_events_go_proto",
importpath = "gvisor.dev/gvisor/pkg/sentry/kernel/memevent/memory_events_go_proto",
diff --git a/pkg/sentry/kernel/pipe/BUILD b/pkg/sentry/kernel/pipe/BUILD
index 4d15cca85..2ce8952e2 100644
--- a/pkg/sentry/kernel/pipe/BUILD
+++ b/pkg/sentry/kernel/pipe/BUILD
@@ -1,7 +1,9 @@
+load("@io_bazel_rules_go//go:def.bzl", "go_test")
+
package(licenses = ["notice"])
load("//tools/go_generics:defs.bzl", "go_template_instance")
-load("//tools/go_stateify:defs.bzl", "go_library", "go_test")
+load("//tools/go_stateify:defs.bzl", "go_library")
go_template_instance(
name = "buffer_list",
diff --git a/pkg/sentry/kernel/pipe/buffer.go b/pkg/sentry/kernel/pipe/buffer.go
index 69ef2a720..95bee2d37 100644
--- a/pkg/sentry/kernel/pipe/buffer.go
+++ b/pkg/sentry/kernel/pipe/buffer.go
@@ -15,6 +15,7 @@
package pipe
import (
+ "io"
"sync"
"gvisor.dev/gvisor/pkg/sentry/safemem"
@@ -67,6 +68,17 @@ func (b *buffer) WriteFromBlocks(srcs safemem.BlockSeq) (uint64, error) {
return n, err
}
+// WriteFromReader writes to the buffer from an io.Reader.
+func (b *buffer) WriteFromReader(r io.Reader, count int64) (int64, error) {
+ dst := b.data[b.write:]
+ if count < int64(len(dst)) {
+ dst = b.data[b.write:][:count]
+ }
+ n, err := r.Read(dst)
+ b.write += n
+ return int64(n), err
+}
+
// ReadToBlocks implements safemem.Reader.ReadToBlocks.
func (b *buffer) ReadToBlocks(dsts safemem.BlockSeq) (uint64, error) {
src := safemem.BlockSeqOf(safemem.BlockFromSafeSlice(b.data[b.read:b.write]))
@@ -75,6 +87,19 @@ func (b *buffer) ReadToBlocks(dsts safemem.BlockSeq) (uint64, error) {
return n, err
}
+// ReadToWriter reads from the buffer into an io.Writer.
+func (b *buffer) ReadToWriter(w io.Writer, count int64, dup bool) (int64, error) {
+ src := b.data[b.read:b.write]
+ if count < int64(len(src)) {
+ src = b.data[b.read:][:count]
+ }
+ n, err := w.Write(src)
+ if !dup {
+ b.read += n
+ }
+ return int64(n), err
+}
+
// bufferPool is a pool for buffers.
var bufferPool = sync.Pool{
New: func() interface{} {
diff --git a/pkg/sentry/kernel/pipe/pipe.go b/pkg/sentry/kernel/pipe/pipe.go
index 247e2928e..93b50669f 100644
--- a/pkg/sentry/kernel/pipe/pipe.go
+++ b/pkg/sentry/kernel/pipe/pipe.go
@@ -23,7 +23,6 @@ import (
"gvisor.dev/gvisor/pkg/sentry/context"
"gvisor.dev/gvisor/pkg/sentry/fs"
- "gvisor.dev/gvisor/pkg/sentry/usermem"
"gvisor.dev/gvisor/pkg/syserror"
"gvisor.dev/gvisor/pkg/waiter"
)
@@ -173,13 +172,24 @@ func (p *Pipe) Open(ctx context.Context, d *fs.Dirent, flags fs.FileFlags) *fs.F
}
}
+type readOps struct {
+ // left returns the bytes remaining.
+ left func() int64
+
+ // limit limits subsequence reads.
+ limit func(int64)
+
+ // read performs the actual read operation.
+ read func(*buffer) (int64, error)
+}
+
// read reads data from the pipe into dst and returns the number of bytes
// read, or returns ErrWouldBlock if the pipe is empty.
//
// Precondition: this pipe must have readers.
-func (p *Pipe) read(ctx context.Context, dst usermem.IOSequence) (int64, error) {
+func (p *Pipe) read(ctx context.Context, ops readOps) (int64, error) {
// Don't block for a zero-length read even if the pipe is empty.
- if dst.NumBytes() == 0 {
+ if ops.left() == 0 {
return 0, nil
}
@@ -196,12 +206,12 @@ func (p *Pipe) read(ctx context.Context, dst usermem.IOSequence) (int64, error)
}
// Limit how much we consume.
- if dst.NumBytes() > p.size {
- dst = dst.TakeFirst64(p.size)
+ if ops.left() > p.size {
+ ops.limit(p.size)
}
done := int64(0)
- for dst.NumBytes() > 0 {
+ for ops.left() > 0 {
// Pop the first buffer.
first := p.data.Front()
if first == nil {
@@ -209,10 +219,9 @@ func (p *Pipe) read(ctx context.Context, dst usermem.IOSequence) (int64, error)
}
// Copy user data.
- n, err := dst.CopyOutFrom(ctx, first)
+ n, err := ops.read(first)
done += int64(n)
p.size -= n
- dst = dst.DropFirst64(n)
// Empty buffer?
if first.Empty() {
@@ -230,12 +239,57 @@ func (p *Pipe) read(ctx context.Context, dst usermem.IOSequence) (int64, error)
return done, nil
}
+// dup duplicates all data from this pipe into the given writer.
+//
+// There is no blocking behavior implemented here. The writer may propagate
+// some blocking error. All the writes must be complete writes.
+func (p *Pipe) dup(ctx context.Context, ops readOps) (int64, error) {
+ p.mu.Lock()
+ defer p.mu.Unlock()
+
+ // Is the pipe empty?
+ if p.size == 0 {
+ if !p.HasWriters() {
+ // See above.
+ return 0, nil
+ }
+ return 0, syserror.ErrWouldBlock
+ }
+
+ // Limit how much we consume.
+ if ops.left() > p.size {
+ ops.limit(p.size)
+ }
+
+ done := int64(0)
+ for buf := p.data.Front(); buf != nil; buf = buf.Next() {
+ n, err := ops.read(buf)
+ done += n
+ if err != nil {
+ return done, err
+ }
+ }
+
+ return done, nil
+}
+
+type writeOps struct {
+ // left returns the bytes remaining.
+ left func() int64
+
+ // limit should limit subsequent writes.
+ limit func(int64)
+
+ // write should write to the provided buffer.
+ write func(*buffer) (int64, error)
+}
+
// write writes data from sv into the pipe and returns the number of bytes
// written. If no bytes are written because the pipe is full (or has less than
// atomicIOBytes free capacity), write returns ErrWouldBlock.
//
// Precondition: this pipe must have writers.
-func (p *Pipe) write(ctx context.Context, src usermem.IOSequence) (int64, error) {
+func (p *Pipe) write(ctx context.Context, ops writeOps) (int64, error) {
p.mu.Lock()
defer p.mu.Unlock()
@@ -246,17 +300,16 @@ func (p *Pipe) write(ctx context.Context, src usermem.IOSequence) (int64, error)
// POSIX requires that a write smaller than atomicIOBytes (PIPE_BUF) be
// atomic, but requires no atomicity for writes larger than this.
- wanted := src.NumBytes()
+ wanted := ops.left()
if avail := p.max - p.size; wanted > avail {
if wanted <= p.atomicIOBytes {
return 0, syserror.ErrWouldBlock
}
- // Limit to the available capacity.
- src = src.TakeFirst64(avail)
+ ops.limit(avail)
}
done := int64(0)
- for src.NumBytes() > 0 {
+ for ops.left() > 0 {
// Need a new buffer?
last := p.data.Back()
if last == nil || last.Full() {
@@ -266,10 +319,9 @@ func (p *Pipe) write(ctx context.Context, src usermem.IOSequence) (int64, error)
}
// Copy user data.
- n, err := src.CopyInTo(ctx, last)
+ n, err := ops.write(last)
done += int64(n)
p.size += n
- src = src.DropFirst64(n)
// Handle errors.
if err != nil {
diff --git a/pkg/sentry/kernel/pipe/reader_writer.go b/pkg/sentry/kernel/pipe/reader_writer.go
index f69dbf27b..7c307f013 100644
--- a/pkg/sentry/kernel/pipe/reader_writer.go
+++ b/pkg/sentry/kernel/pipe/reader_writer.go
@@ -15,6 +15,7 @@
package pipe
import (
+ "io"
"math"
"syscall"
@@ -55,7 +56,45 @@ func (rw *ReaderWriter) Release() {
// Read implements fs.FileOperations.Read.
func (rw *ReaderWriter) Read(ctx context.Context, _ *fs.File, dst usermem.IOSequence, _ int64) (int64, error) {
- n, err := rw.Pipe.read(ctx, dst)
+ n, err := rw.Pipe.read(ctx, readOps{
+ left: func() int64 {
+ return dst.NumBytes()
+ },
+ limit: func(l int64) {
+ dst = dst.TakeFirst64(l)
+ },
+ read: func(buf *buffer) (int64, error) {
+ n, err := dst.CopyOutFrom(ctx, buf)
+ dst = dst.DropFirst64(n)
+ return n, err
+ },
+ })
+ if n > 0 {
+ rw.Pipe.Notify(waiter.EventOut)
+ }
+ return n, err
+}
+
+// WriteTo implements fs.FileOperations.WriteTo.
+func (rw *ReaderWriter) WriteTo(ctx context.Context, _ *fs.File, w io.Writer, count int64, dup bool) (int64, error) {
+ ops := readOps{
+ left: func() int64 {
+ return count
+ },
+ limit: func(l int64) {
+ count = l
+ },
+ read: func(buf *buffer) (int64, error) {
+ n, err := buf.ReadToWriter(w, count, dup)
+ count -= n
+ return n, err
+ },
+ }
+ if dup {
+ // There is no notification for dup operations.
+ return rw.Pipe.dup(ctx, ops)
+ }
+ n, err := rw.Pipe.read(ctx, ops)
if n > 0 {
rw.Pipe.Notify(waiter.EventOut)
}
@@ -64,7 +103,40 @@ func (rw *ReaderWriter) Read(ctx context.Context, _ *fs.File, dst usermem.IOSequ
// Write implements fs.FileOperations.Write.
func (rw *ReaderWriter) Write(ctx context.Context, _ *fs.File, src usermem.IOSequence, _ int64) (int64, error) {
- n, err := rw.Pipe.write(ctx, src)
+ n, err := rw.Pipe.write(ctx, writeOps{
+ left: func() int64 {
+ return src.NumBytes()
+ },
+ limit: func(l int64) {
+ src = src.TakeFirst64(l)
+ },
+ write: func(buf *buffer) (int64, error) {
+ n, err := src.CopyInTo(ctx, buf)
+ src = src.DropFirst64(n)
+ return n, err
+ },
+ })
+ if n > 0 {
+ rw.Pipe.Notify(waiter.EventIn)
+ }
+ return n, err
+}
+
+// ReadFrom implements fs.FileOperations.WriteTo.
+func (rw *ReaderWriter) ReadFrom(ctx context.Context, _ *fs.File, r io.Reader, count int64) (int64, error) {
+ n, err := rw.Pipe.write(ctx, writeOps{
+ left: func() int64 {
+ return count
+ },
+ limit: func(l int64) {
+ count = l
+ },
+ write: func(buf *buffer) (int64, error) {
+ n, err := buf.WriteFromReader(r, count)
+ count -= n
+ return n, err
+ },
+ })
if n > 0 {
rw.Pipe.Notify(waiter.EventIn)
}
diff --git a/pkg/sentry/kernel/posixtimer.go b/pkg/sentry/kernel/posixtimer.go
index c5d095af7..2e861a5a8 100644
--- a/pkg/sentry/kernel/posixtimer.go
+++ b/pkg/sentry/kernel/posixtimer.go
@@ -117,9 +117,9 @@ func (it *IntervalTimer) signalRejectedLocked() {
}
// Notify implements ktime.TimerListener.Notify.
-func (it *IntervalTimer) Notify(exp uint64) {
+func (it *IntervalTimer) Notify(exp uint64, setting ktime.Setting) (ktime.Setting, bool) {
if it.target == nil {
- return
+ return ktime.Setting{}, false
}
it.target.tg.pidns.owner.mu.RLock()
@@ -129,7 +129,7 @@ func (it *IntervalTimer) Notify(exp uint64) {
if it.sigpending {
it.overrunCur += exp
- return
+ return ktime.Setting{}, false
}
// sigpending must be set before sendSignalTimerLocked() so that it can be
@@ -148,6 +148,8 @@ func (it *IntervalTimer) Notify(exp uint64) {
if err := it.target.sendSignalTimerLocked(si, it.group, it); err != nil {
it.signalRejectedLocked()
}
+
+ return ktime.Setting{}, false
}
// Destroy implements ktime.TimerListener.Destroy. Users of Timer should call
diff --git a/pkg/sentry/kernel/sched/BUILD b/pkg/sentry/kernel/sched/BUILD
index 1725b8562..98ea7a0d8 100644
--- a/pkg/sentry/kernel/sched/BUILD
+++ b/pkg/sentry/kernel/sched/BUILD
@@ -1,4 +1,5 @@
-load("//tools/go_stateify:defs.bzl", "go_library", "go_test")
+load("//tools/go_stateify:defs.bzl", "go_library")
+load("@io_bazel_rules_go//go:def.bzl", "go_test")
package(licenses = ["notice"])
diff --git a/pkg/sentry/kernel/semaphore/BUILD b/pkg/sentry/kernel/semaphore/BUILD
index 36edf10f3..80e5e5da3 100644
--- a/pkg/sentry/kernel/semaphore/BUILD
+++ b/pkg/sentry/kernel/semaphore/BUILD
@@ -1,7 +1,9 @@
+load("@io_bazel_rules_go//go:def.bzl", "go_test")
+
package(licenses = ["notice"])
load("//tools/go_generics:defs.bzl", "go_template_instance")
-load("//tools/go_stateify:defs.bzl", "go_library", "go_test")
+load("//tools/go_stateify:defs.bzl", "go_library")
go_template_instance(
name = "waiter_list",
diff --git a/pkg/sentry/kernel/sessions.go b/pkg/sentry/kernel/sessions.go
index e5f297478..047b5214d 100644
--- a/pkg/sentry/kernel/sessions.go
+++ b/pkg/sentry/kernel/sessions.go
@@ -328,8 +328,14 @@ func (tg *ThreadGroup) createSession() error {
childTG.processGroup.incRefWithParent(pg)
childTG.processGroup.decRefWithParent(oldParentPG)
})
- tg.processGroup.decRefWithParent(oldParentPG)
+ // If tg.processGroup is an orphan, decRefWithParent will lock
+ // the signal mutex of each thread group in tg.processGroup.
+ // However, tg's signal mutex may already be locked at this
+ // point. We change tg's process group before calling
+ // decRefWithParent to avoid locking tg's signal mutex twice.
+ oldPG := tg.processGroup
tg.processGroup = pg
+ oldPG.decRefWithParent(oldParentPG)
} else {
// The current process group may be nil only in the case of an
// unparented thread group (i.e. the init process). This would
diff --git a/pkg/sentry/kernel/signalfd/BUILD b/pkg/sentry/kernel/signalfd/BUILD
new file mode 100644
index 000000000..50b69d154
--- /dev/null
+++ b/pkg/sentry/kernel/signalfd/BUILD
@@ -0,0 +1,22 @@
+package(licenses = ["notice"])
+
+load("//tools/go_stateify:defs.bzl", "go_library")
+
+go_library(
+ name = "signalfd",
+ srcs = ["signalfd.go"],
+ importpath = "gvisor.dev/gvisor/pkg/sentry/kernel/signalfd",
+ visibility = ["//pkg/sentry:internal"],
+ deps = [
+ "//pkg/abi/linux",
+ "//pkg/binary",
+ "//pkg/sentry/context",
+ "//pkg/sentry/fs",
+ "//pkg/sentry/fs/anon",
+ "//pkg/sentry/fs/fsutil",
+ "//pkg/sentry/kernel",
+ "//pkg/sentry/usermem",
+ "//pkg/syserror",
+ "//pkg/waiter",
+ ],
+)
diff --git a/pkg/sentry/kernel/signalfd/signalfd.go b/pkg/sentry/kernel/signalfd/signalfd.go
new file mode 100644
index 000000000..06fd5ec88
--- /dev/null
+++ b/pkg/sentry/kernel/signalfd/signalfd.go
@@ -0,0 +1,137 @@
+// Copyright 2019 The gVisor Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+// Package signalfd provides an implementation of signal file descriptors.
+package signalfd
+
+import (
+ "sync"
+
+ "gvisor.dev/gvisor/pkg/abi/linux"
+ "gvisor.dev/gvisor/pkg/binary"
+ "gvisor.dev/gvisor/pkg/sentry/context"
+ "gvisor.dev/gvisor/pkg/sentry/fs"
+ "gvisor.dev/gvisor/pkg/sentry/fs/anon"
+ "gvisor.dev/gvisor/pkg/sentry/fs/fsutil"
+ "gvisor.dev/gvisor/pkg/sentry/kernel"
+ "gvisor.dev/gvisor/pkg/sentry/usermem"
+ "gvisor.dev/gvisor/pkg/syserror"
+ "gvisor.dev/gvisor/pkg/waiter"
+)
+
+// SignalOperations represent a file with signalfd semantics.
+//
+// +stateify savable
+type SignalOperations struct {
+ fsutil.FileNoopRelease `state:"nosave"`
+ fsutil.FilePipeSeek `state:"nosave"`
+ fsutil.FileNotDirReaddir `state:"nosave"`
+ fsutil.FileNoIoctl `state:"nosave"`
+ fsutil.FileNoFsync `state:"nosave"`
+ fsutil.FileNoMMap `state:"nosave"`
+ fsutil.FileNoSplice `state:"nosave"`
+ fsutil.FileNoWrite `state:"nosave"`
+ fsutil.FileNoopFlush `state:"nosave"`
+ fsutil.FileUseInodeUnstableAttr `state:"nosave"`
+
+ // target is the original task target.
+ //
+ // The semantics here are a bit broken. Linux will always use current
+ // for all reads, regardless of where the signalfd originated. We can't
+ // do exactly that because we need to plumb the context through
+ // EventRegister in order to support proper blocking behavior. This
+ // will undoubtedly become very complicated quickly.
+ target *kernel.Task
+
+ // mu protects below.
+ mu sync.Mutex `state:"nosave"`
+
+ // mask is the signal mask. Protected by mu.
+ mask linux.SignalSet
+}
+
+// New creates a new signalfd object with the supplied mask.
+func New(ctx context.Context, mask linux.SignalSet) (*fs.File, error) {
+ t := kernel.TaskFromContext(ctx)
+ if t == nil {
+ // No task context? Not valid.
+ return nil, syserror.EINVAL
+ }
+ // name matches fs/signalfd.c:signalfd4.
+ dirent := fs.NewDirent(ctx, anon.NewInode(ctx), "anon_inode:[signalfd]")
+ return fs.NewFile(ctx, dirent, fs.FileFlags{Read: true, Write: true}, &SignalOperations{
+ target: t,
+ mask: mask,
+ }), nil
+}
+
+// Release implements fs.FileOperations.Release.
+func (s *SignalOperations) Release() {}
+
+// Mask returns the signal mask.
+func (s *SignalOperations) Mask() linux.SignalSet {
+ s.mu.Lock()
+ mask := s.mask
+ s.mu.Unlock()
+ return mask
+}
+
+// SetMask sets the signal mask.
+func (s *SignalOperations) SetMask(mask linux.SignalSet) {
+ s.mu.Lock()
+ s.mask = mask
+ s.mu.Unlock()
+}
+
+// Read implements fs.FileOperations.Read.
+func (s *SignalOperations) Read(ctx context.Context, _ *fs.File, dst usermem.IOSequence, _ int64) (int64, error) {
+ // Attempt to dequeue relevant signals.
+ info, err := s.target.Sigtimedwait(s.Mask(), 0)
+ if err != nil {
+ // There must be no signal available.
+ return 0, syserror.ErrWouldBlock
+ }
+
+ // Copy out the signal info using the specified format.
+ var buf [128]byte
+ binary.Marshal(buf[:0], usermem.ByteOrder, &linux.SignalfdSiginfo{
+ Signo: uint32(info.Signo),
+ Errno: info.Errno,
+ Code: info.Code,
+ PID: uint32(info.Pid()),
+ UID: uint32(info.Uid()),
+ Status: info.Status(),
+ Overrun: uint32(info.Overrun()),
+ Addr: info.Addr(),
+ })
+ n, err := dst.CopyOut(ctx, buf[:])
+ return int64(n), err
+}
+
+// Readiness implements waiter.Waitable.Readiness.
+func (s *SignalOperations) Readiness(mask waiter.EventMask) waiter.EventMask {
+ return mask & waiter.EventIn
+}
+
+// EventRegister implements waiter.Waitable.EventRegister.
+func (s *SignalOperations) EventRegister(entry *waiter.Entry, _ waiter.EventMask) {
+ // Register for the signal set; ignore the passed events.
+ s.target.SignalRegister(entry, waiter.EventMask(s.Mask()))
+}
+
+// EventUnregister implements waiter.Waitable.EventUnregister.
+func (s *SignalOperations) EventUnregister(entry *waiter.Entry) {
+ // Unregister the original entry.
+ s.target.SignalUnregister(entry)
+}
diff --git a/pkg/sentry/kernel/task.go b/pkg/sentry/kernel/task.go
index e91f82bb3..c82ef5486 100644
--- a/pkg/sentry/kernel/task.go
+++ b/pkg/sentry/kernel/task.go
@@ -35,6 +35,7 @@ import (
"gvisor.dev/gvisor/pkg/sentry/uniqueid"
"gvisor.dev/gvisor/pkg/sentry/usage"
"gvisor.dev/gvisor/pkg/sentry/usermem"
+ "gvisor.dev/gvisor/pkg/waiter"
"gvisor.dev/gvisor/third_party/gvsync"
)
@@ -133,6 +134,13 @@ type Task struct {
// signalStack is exclusive to the task goroutine.
signalStack arch.SignalStack
+ // signalQueue is a set of registered waiters for signal-related events.
+ //
+ // signalQueue is protected by the signalMutex. Note that the task does
+ // not implement all queue methods, specifically the readiness checks.
+ // The task only broadcast a notification on signal delivery.
+ signalQueue waiter.Queue `state:"zerovalue"`
+
// If groupStopPending is true, the task should participate in a group
// stop in the interrupt path.
//
diff --git a/pkg/sentry/kernel/task_identity.go b/pkg/sentry/kernel/task_identity.go
index 78ff14b20..ce3e6ef28 100644
--- a/pkg/sentry/kernel/task_identity.go
+++ b/pkg/sentry/kernel/task_identity.go
@@ -465,8 +465,8 @@ func (t *Task) SetKeepCaps(k bool) {
// disables the features we don't support anyway, is always set. This
// drastically simplifies this function.
//
-// - We don't implement AT_SECURE, because no_new_privs always being set means
-// that the conditions that require AT_SECURE never arise. (Compare Linux's
+// - We don't set AT_SECURE = 1, because no_new_privs always being set means
+// that the conditions that require AT_SECURE = 1 never arise. (Compare Linux's
// security/commoncap.c:cap_bprm_set_creds() and cap_bprm_secureexec().)
//
// - We don't check for CAP_SYS_ADMIN in prctl(PR_SET_SECCOMP), since
diff --git a/pkg/sentry/kernel/task_sched.go b/pkg/sentry/kernel/task_sched.go
index e76c069b0..8b148db35 100644
--- a/pkg/sentry/kernel/task_sched.go
+++ b/pkg/sentry/kernel/task_sched.go
@@ -126,12 +126,22 @@ func (t *Task) accountTaskGoroutineEnter(state TaskGoroutineState) {
t.gosched.Timestamp = now
t.gosched.State = state
t.goschedSeq.EndWrite()
+
+ if state != TaskGoroutineRunningApp {
+ // Task is blocking/stopping.
+ t.k.decRunningTasks()
+ }
}
// Preconditions: The caller must be running on the task goroutine, and leaving
// a state indicated by a previous call to
// t.accountTaskGoroutineEnter(state).
func (t *Task) accountTaskGoroutineLeave(state TaskGoroutineState) {
+ if state != TaskGoroutineRunningApp {
+ // Task is unblocking/continuing.
+ t.k.incRunningTasks()
+ }
+
now := t.k.CPUClockNow()
if t.gosched.State != state {
panic(fmt.Sprintf("Task goroutine switching from state %v (expected %v) to %v", t.gosched.State, state, TaskGoroutineRunningSys))
@@ -330,7 +340,7 @@ func newKernelCPUClockTicker(k *Kernel) *kernelCPUClockTicker {
}
// Notify implements ktime.TimerListener.Notify.
-func (ticker *kernelCPUClockTicker) Notify(exp uint64) {
+func (ticker *kernelCPUClockTicker) Notify(exp uint64, setting ktime.Setting) (ktime.Setting, bool) {
// Only increment cpuClock by 1 regardless of the number of expirations.
// This approximately compensates for cases where thread throttling or bad
// Go runtime scheduling prevents the kernelCPUClockTicker goroutine, and
@@ -426,6 +436,27 @@ func (ticker *kernelCPUClockTicker) Notify(exp uint64) {
tgs[i] = nil
}
ticker.tgs = tgs[:0]
+
+ // If nothing is running, we can disable the timer.
+ tasks := atomic.LoadInt64(&ticker.k.runningTasks)
+ if tasks == 0 {
+ ticker.k.runningTasksMu.Lock()
+ defer ticker.k.runningTasksMu.Unlock()
+ tasks := atomic.LoadInt64(&ticker.k.runningTasks)
+ if tasks != 0 {
+ // Raced with a 0 -> 1 transition.
+ return setting, false
+ }
+
+ // Stop the timer. We must cache the current setting so the
+ // kernel can access it without violating the lock order.
+ ticker.k.cpuClockTickerSetting = setting
+ ticker.k.cpuClockTickerDisabled = true
+ setting.Enabled = false
+ return setting, true
+ }
+
+ return setting, false
}
// Destroy implements ktime.TimerListener.Destroy.
diff --git a/pkg/sentry/kernel/task_signals.go b/pkg/sentry/kernel/task_signals.go
index 266959a07..39cd1340d 100644
--- a/pkg/sentry/kernel/task_signals.go
+++ b/pkg/sentry/kernel/task_signals.go
@@ -28,6 +28,7 @@ import (
ucspb "gvisor.dev/gvisor/pkg/sentry/kernel/uncaught_signal_go_proto"
"gvisor.dev/gvisor/pkg/sentry/usermem"
"gvisor.dev/gvisor/pkg/syserror"
+ "gvisor.dev/gvisor/pkg/waiter"
)
// SignalAction is an internal signal action.
@@ -497,6 +498,9 @@ func (tg *ThreadGroup) applySignalSideEffectsLocked(sig linux.Signal) {
//
// Preconditions: The signal mutex must be locked.
func (t *Task) canReceiveSignalLocked(sig linux.Signal) bool {
+ // Notify that the signal is queued.
+ t.signalQueue.Notify(waiter.EventMask(linux.MakeSignalSet(sig)))
+
// - Do not choose tasks that are blocking the signal.
if linux.SignalSetOf(sig)&t.signalMask != 0 {
return false
@@ -1108,3 +1112,17 @@ func (*runInterruptAfterSignalDeliveryStop) execute(t *Task) taskRunState {
t.tg.signalHandlers.mu.Unlock()
return t.deliverSignal(info, act)
}
+
+// SignalRegister registers a waiter for pending signals.
+func (t *Task) SignalRegister(e *waiter.Entry, mask waiter.EventMask) {
+ t.tg.signalHandlers.mu.Lock()
+ t.signalQueue.EventRegister(e, mask)
+ t.tg.signalHandlers.mu.Unlock()
+}
+
+// SignalUnregister unregisters a waiter for pending signals.
+func (t *Task) SignalUnregister(e *waiter.Entry) {
+ t.tg.signalHandlers.mu.Lock()
+ t.signalQueue.EventUnregister(e)
+ t.tg.signalHandlers.mu.Unlock()
+}
diff --git a/pkg/sentry/kernel/thread_group.go b/pkg/sentry/kernel/thread_group.go
index 0eef24bfb..72568d296 100644
--- a/pkg/sentry/kernel/thread_group.go
+++ b/pkg/sentry/kernel/thread_group.go
@@ -511,8 +511,9 @@ type itimerRealListener struct {
}
// Notify implements ktime.TimerListener.Notify.
-func (l *itimerRealListener) Notify(exp uint64) {
+func (l *itimerRealListener) Notify(exp uint64, setting ktime.Setting) (ktime.Setting, bool) {
l.tg.SendSignal(SignalInfoPriv(linux.SIGALRM))
+ return ktime.Setting{}, false
}
// Destroy implements ktime.TimerListener.Destroy.
diff --git a/pkg/sentry/kernel/time/time.go b/pkg/sentry/kernel/time/time.go
index aa6c75d25..107394183 100644
--- a/pkg/sentry/kernel/time/time.go
+++ b/pkg/sentry/kernel/time/time.go
@@ -280,13 +280,16 @@ func (ClockEventsQueue) Readiness(mask waiter.EventMask) waiter.EventMask {
// A TimerListener receives expirations from a Timer.
type TimerListener interface {
// Notify is called when its associated Timer expires. exp is the number of
- // expirations.
+ // expirations. setting is the next timer Setting.
//
// Notify is called with the associated Timer's mutex locked, so Notify
// must not take any locks that precede Timer.mu in lock order.
//
+ // If Notify returns true, the timer will use the returned setting
+ // rather than the passed one.
+ //
// Preconditions: exp > 0.
- Notify(exp uint64)
+ Notify(exp uint64, setting Setting) (newSetting Setting, update bool)
// Destroy is called when the timer is destroyed.
Destroy()
@@ -533,7 +536,9 @@ func (t *Timer) Tick() {
s, exp := t.setting.At(now)
t.setting = s
if exp > 0 {
- t.listener.Notify(exp)
+ if newS, ok := t.listener.Notify(exp, t.setting); ok {
+ t.setting = newS
+ }
}
t.resetKickerLocked(now)
}
@@ -588,7 +593,9 @@ func (t *Timer) Get() (Time, Setting) {
s, exp := t.setting.At(now)
t.setting = s
if exp > 0 {
- t.listener.Notify(exp)
+ if newS, ok := t.listener.Notify(exp, t.setting); ok {
+ t.setting = newS
+ }
}
t.resetKickerLocked(now)
return now, s
@@ -620,7 +627,9 @@ func (t *Timer) SwapAnd(s Setting, f func()) (Time, Setting) {
}
oldS, oldExp := t.setting.At(now)
if oldExp > 0 {
- t.listener.Notify(oldExp)
+ t.listener.Notify(oldExp, oldS)
+ // N.B. The returned Setting doesn't matter because we're about
+ // to overwrite.
}
if f != nil {
f()
@@ -628,7 +637,9 @@ func (t *Timer) SwapAnd(s Setting, f func()) (Time, Setting) {
newS, newExp := s.At(now)
t.setting = newS
if newExp > 0 {
- t.listener.Notify(newExp)
+ if newS, ok := t.listener.Notify(newExp, t.setting); ok {
+ t.setting = newS
+ }
}
t.resetKickerLocked(now)
return now, oldS
@@ -683,11 +694,13 @@ func NewChannelNotifier() (TimerListener, <-chan struct{}) {
}
// Notify implements ktime.TimerListener.Notify.
-func (c *ChannelNotifier) Notify(uint64) {
+func (c *ChannelNotifier) Notify(uint64, Setting) (Setting, bool) {
select {
case c.tchan <- struct{}{}:
default:
}
+
+ return Setting{}, false
}
// Destroy implements ktime.TimerListener.Destroy and will close the channel.
diff --git a/pkg/sentry/limits/BUILD b/pkg/sentry/limits/BUILD
index 40025d62d..59649c770 100644
--- a/pkg/sentry/limits/BUILD
+++ b/pkg/sentry/limits/BUILD
@@ -1,6 +1,8 @@
+load("@io_bazel_rules_go//go:def.bzl", "go_test")
+
package(licenses = ["notice"])
-load("//tools/go_stateify:defs.bzl", "go_library", "go_test")
+load("//tools/go_stateify:defs.bzl", "go_library")
go_library(
name = "limits",
diff --git a/pkg/sentry/loader/elf.go b/pkg/sentry/loader/elf.go
index bc5b841fb..2d9251e92 100644
--- a/pkg/sentry/loader/elf.go
+++ b/pkg/sentry/loader/elf.go
@@ -323,18 +323,22 @@ func mapSegment(ctx context.Context, m *mm.MemoryManager, f *fs.File, phdr *elf.
return syserror.ENOEXEC
}
+ // N.B. Linux uses vm_brk_flags to map these pages, which only
+ // honors the X bit, always mapping at least RW. ignoring These
+ // pages are not included in the final brk region.
+ prot := usermem.ReadWrite
+ if phdr.Flags&elf.PF_X == elf.PF_X {
+ prot.Execute = true
+ }
+
if _, err := m.MMap(ctx, memmap.MMapOpts{
Length: uint64(anonSize),
Addr: anonAddr,
// Fixed without Unmap will fail the mmap if something is
// already at addr.
- Fixed: true,
- Private: true,
- // N.B. Linux uses vm_brk to map these pages, ignoring
- // the segment protections, instead always mapping RW.
- // These pages are not included in the final brk
- // region.
- Perms: usermem.ReadWrite,
+ Fixed: true,
+ Private: true,
+ Perms: prot,
MaxPerms: usermem.AnyAccess,
}); err != nil {
ctx.Infof("Error mapping PT_LOAD segment %v anonymous memory: %v", phdr, err)
@@ -464,7 +468,7 @@ func loadParsedELF(ctx context.Context, m *mm.MemoryManager, f *fs.File, info el
// base address big enough to fit all segments, so we first create a
// mapping for the total size just to find a region that is big enough.
//
- // It is safe to unmap it immediately with racing with another mapping
+ // It is safe to unmap it immediately without racing with another mapping
// because we are the only one in control of the MemoryManager.
//
// Note that the vaddr of the first PT_LOAD segment is ignored when
diff --git a/pkg/sentry/loader/loader.go b/pkg/sentry/loader/loader.go
index f6f1ae762..089d1635b 100644
--- a/pkg/sentry/loader/loader.go
+++ b/pkg/sentry/loader/loader.go
@@ -308,6 +308,9 @@ func Load(ctx context.Context, m *mm.MemoryManager, mounts *fs.MountNamespace, r
arch.AuxEntry{linux.AT_EUID, usermem.Addr(c.EffectiveKUID.In(c.UserNamespace).OrOverflow())},
arch.AuxEntry{linux.AT_GID, usermem.Addr(c.RealKGID.In(c.UserNamespace).OrOverflow())},
arch.AuxEntry{linux.AT_EGID, usermem.Addr(c.EffectiveKGID.In(c.UserNamespace).OrOverflow())},
+ // The conditions that require AT_SECURE = 1 never arise. See
+ // kernel.Task.updateCredsForExecLocked.
+ arch.AuxEntry{linux.AT_SECURE, 0},
arch.AuxEntry{linux.AT_CLKTCK, linux.CLOCKS_PER_SEC},
arch.AuxEntry{linux.AT_EXECFN, execfn},
arch.AuxEntry{linux.AT_RANDOM, random},
diff --git a/pkg/sentry/memmap/BUILD b/pkg/sentry/memmap/BUILD
index 29c14ec56..9687e7e76 100644
--- a/pkg/sentry/memmap/BUILD
+++ b/pkg/sentry/memmap/BUILD
@@ -1,7 +1,9 @@
+load("@io_bazel_rules_go//go:def.bzl", "go_test")
+
package(licenses = ["notice"])
load("//tools/go_generics:defs.bzl", "go_template_instance")
-load("//tools/go_stateify:defs.bzl", "go_library", "go_test")
+load("//tools/go_stateify:defs.bzl", "go_library")
go_template_instance(
name = "mappable_range",
diff --git a/pkg/sentry/mm/BUILD b/pkg/sentry/mm/BUILD
index 072745a08..b35c8c673 100644
--- a/pkg/sentry/mm/BUILD
+++ b/pkg/sentry/mm/BUILD
@@ -1,7 +1,9 @@
+load("@io_bazel_rules_go//go:def.bzl", "go_test")
+
package(licenses = ["notice"])
load("//tools/go_generics:defs.bzl", "go_template_instance")
-load("//tools/go_stateify:defs.bzl", "go_library", "go_test")
+load("//tools/go_stateify:defs.bzl", "go_library")
go_template_instance(
name = "file_refcount_set",
diff --git a/pkg/sentry/pgalloc/BUILD b/pkg/sentry/pgalloc/BUILD
index 858f895f2..3fd904c67 100644
--- a/pkg/sentry/pgalloc/BUILD
+++ b/pkg/sentry/pgalloc/BUILD
@@ -1,7 +1,9 @@
+load("@io_bazel_rules_go//go:def.bzl", "go_test")
+
package(licenses = ["notice"])
load("//tools/go_generics:defs.bzl", "go_template_instance")
-load("//tools/go_stateify:defs.bzl", "go_library", "go_test")
+load("//tools/go_stateify:defs.bzl", "go_library")
go_template_instance(
name = "evictable_range",
diff --git a/pkg/sentry/platform/interrupt/BUILD b/pkg/sentry/platform/interrupt/BUILD
index eeb634644..b6d008dbe 100644
--- a/pkg/sentry/platform/interrupt/BUILD
+++ b/pkg/sentry/platform/interrupt/BUILD
@@ -1,4 +1,5 @@
-load("//tools/go_stateify:defs.bzl", "go_library", "go_test")
+load("//tools/go_stateify:defs.bzl", "go_library")
+load("@io_bazel_rules_go//go:def.bzl", "go_test")
package(licenses = ["notice"])
diff --git a/pkg/sentry/platform/kvm/BUILD b/pkg/sentry/platform/kvm/BUILD
index ad8b95744..31fa48ec5 100644
--- a/pkg/sentry/platform/kvm/BUILD
+++ b/pkg/sentry/platform/kvm/BUILD
@@ -1,4 +1,5 @@
-load("//tools/go_stateify:defs.bzl", "go_library", "go_test")
+load("//tools/go_stateify:defs.bzl", "go_library")
+load("@io_bazel_rules_go//go:def.bzl", "go_test")
package(licenses = ["notice"])
@@ -54,6 +55,7 @@ go_test(
],
embed = [":kvm"],
tags = [
+ "manual",
"nogotsan",
"requires-kvm",
],
diff --git a/pkg/sentry/platform/ptrace/ptrace_unsafe.go b/pkg/sentry/platform/ptrace/ptrace_unsafe.go
index 47957bb3b..72c7ec564 100644
--- a/pkg/sentry/platform/ptrace/ptrace_unsafe.go
+++ b/pkg/sentry/platform/ptrace/ptrace_unsafe.go
@@ -154,3 +154,19 @@ func (t *thread) clone() (*thread, error) {
cpu: ^uint32(0),
}, nil
}
+
+// getEventMessage retrieves a message about the ptrace event that just happened.
+func (t *thread) getEventMessage() (uintptr, error) {
+ var msg uintptr
+ _, _, errno := syscall.RawSyscall6(
+ syscall.SYS_PTRACE,
+ syscall.PTRACE_GETEVENTMSG,
+ uintptr(t.tid),
+ 0,
+ uintptr(unsafe.Pointer(&msg)),
+ 0, 0)
+ if errno != 0 {
+ return msg, errno
+ }
+ return msg, nil
+}
diff --git a/pkg/sentry/platform/ptrace/subprocess.go b/pkg/sentry/platform/ptrace/subprocess.go
index 6bf7cd097..9f0ecfbe4 100644
--- a/pkg/sentry/platform/ptrace/subprocess.go
+++ b/pkg/sentry/platform/ptrace/subprocess.go
@@ -267,7 +267,7 @@ func (s *subprocess) newThread() *thread {
// attach attaches to the thread.
func (t *thread) attach() {
- if _, _, errno := syscall.RawSyscall(syscall.SYS_PTRACE, syscall.PTRACE_ATTACH, uintptr(t.tid), 0); errno != 0 {
+ if _, _, errno := syscall.RawSyscall6(syscall.SYS_PTRACE, syscall.PTRACE_ATTACH, uintptr(t.tid), 0, 0, 0, 0); errno != 0 {
panic(fmt.Sprintf("unable to attach: %v", errno))
}
@@ -355,7 +355,8 @@ func (t *thread) wait(outcome waitOutcome) syscall.Signal {
}
if stopSig == syscall.SIGTRAP {
if status.TrapCause() == syscall.PTRACE_EVENT_EXIT {
- t.dumpAndPanic("wait failed: the process exited")
+ msg, err := t.getEventMessage()
+ t.dumpAndPanic(fmt.Sprintf("wait failed: the process %d:%d exited: %x (err %v)", t.tgid, t.tid, msg, err))
}
// Re-encode the trap cause the way it's expected.
return stopSig | syscall.Signal(status.TrapCause()<<8)
@@ -416,7 +417,7 @@ func (t *thread) syscall(regs *syscall.PtraceRegs) (uintptr, error) {
for {
// Execute the syscall instruction.
- if _, _, errno := syscall.RawSyscall(syscall.SYS_PTRACE, syscall.PTRACE_SYSCALL, uintptr(t.tid), 0); errno != 0 {
+ if _, _, errno := syscall.RawSyscall6(syscall.SYS_PTRACE, syscall.PTRACE_SYSCALL, uintptr(t.tid), 0, 0, 0, 0); errno != 0 {
panic(fmt.Sprintf("ptrace syscall-enter failed: %v", errno))
}
@@ -426,12 +427,15 @@ func (t *thread) syscall(regs *syscall.PtraceRegs) (uintptr, error) {
break
} else {
// Some other signal caused a thread stop; ignore.
+ if sig != syscall.SIGSTOP && sig != syscall.SIGCHLD {
+ log.Warningf("The thread %d:%d has been interrupted by %d", t.tgid, t.tid, sig)
+ }
continue
}
}
// Complete the actual system call.
- if _, _, errno := syscall.RawSyscall(syscall.SYS_PTRACE, syscall.PTRACE_SYSCALL, uintptr(t.tid), 0); errno != 0 {
+ if _, _, errno := syscall.RawSyscall6(syscall.SYS_PTRACE, syscall.PTRACE_SYSCALL, uintptr(t.tid), 0, 0, 0, 0); errno != 0 {
panic(fmt.Sprintf("ptrace syscall-enter failed: %v", errno))
}
@@ -522,17 +526,17 @@ func (s *subprocess) switchToApp(c *context, ac arch.Context) bool {
for {
// Start running until the next system call.
if isSingleStepping(regs) {
- if _, _, errno := syscall.RawSyscall(
+ if _, _, errno := syscall.RawSyscall6(
syscall.SYS_PTRACE,
syscall.PTRACE_SYSEMU_SINGLESTEP,
- uintptr(t.tid), 0); errno != 0 {
+ uintptr(t.tid), 0, 0, 0, 0); errno != 0 {
panic(fmt.Sprintf("ptrace sysemu failed: %v", errno))
}
} else {
- if _, _, errno := syscall.RawSyscall(
+ if _, _, errno := syscall.RawSyscall6(
syscall.SYS_PTRACE,
syscall.PTRACE_SYSEMU,
- uintptr(t.tid), 0); errno != 0 {
+ uintptr(t.tid), 0, 0, 0, 0); errno != 0 {
panic(fmt.Sprintf("ptrace sysemu failed: %v", errno))
}
}
diff --git a/pkg/sentry/platform/ptrace/subprocess_linux.go b/pkg/sentry/platform/ptrace/subprocess_linux.go
index f09b0b3d0..c075b5f91 100644
--- a/pkg/sentry/platform/ptrace/subprocess_linux.go
+++ b/pkg/sentry/platform/ptrace/subprocess_linux.go
@@ -53,7 +53,7 @@ func probeSeccomp() bool {
for {
// Attempt an emulation.
- if _, _, errno := syscall.RawSyscall(syscall.SYS_PTRACE, syscall.PTRACE_SYSEMU, uintptr(t.tid), 0); errno != 0 {
+ if _, _, errno := syscall.RawSyscall6(syscall.SYS_PTRACE, syscall.PTRACE_SYSEMU, uintptr(t.tid), 0, 0, 0, 0); errno != 0 {
panic(fmt.Sprintf("ptrace syscall-enter failed: %v", errno))
}
@@ -266,7 +266,7 @@ func attachedThread(flags uintptr, defaultAction linux.BPFAction) (*thread, erro
// Enable cpuid-faulting; this may fail on older kernels or hardware,
// so we just disregard the result. Host CPUID will be enabled.
- syscall.RawSyscall(syscall.SYS_ARCH_PRCTL, linux.ARCH_SET_CPUID, 0, 0)
+ syscall.RawSyscall6(syscall.SYS_ARCH_PRCTL, linux.ARCH_SET_CPUID, 0, 0, 0, 0, 0)
// Call the stub; should not return.
stubCall(stubStart, ppid)
diff --git a/pkg/sentry/platform/ring0/pagetables/BUILD b/pkg/sentry/platform/ring0/pagetables/BUILD
index 3b95af617..ea090b686 100644
--- a/pkg/sentry/platform/ring0/pagetables/BUILD
+++ b/pkg/sentry/platform/ring0/pagetables/BUILD
@@ -1,4 +1,5 @@
-load("//tools/go_stateify:defs.bzl", "go_library", "go_test")
+load("//tools/go_stateify:defs.bzl", "go_library")
+load("@io_bazel_rules_go//go:def.bzl", "go_test")
package(licenses = ["notice"])
diff --git a/pkg/sentry/platform/safecopy/BUILD b/pkg/sentry/platform/safecopy/BUILD
index 924d8a6d6..6769cd0a5 100644
--- a/pkg/sentry/platform/safecopy/BUILD
+++ b/pkg/sentry/platform/safecopy/BUILD
@@ -1,4 +1,5 @@
-load("//tools/go_stateify:defs.bzl", "go_library", "go_test")
+load("//tools/go_stateify:defs.bzl", "go_library")
+load("@io_bazel_rules_go//go:def.bzl", "go_test")
package(licenses = ["notice"])
diff --git a/pkg/sentry/safemem/BUILD b/pkg/sentry/safemem/BUILD
index fd6dc8e6e..884020f7b 100644
--- a/pkg/sentry/safemem/BUILD
+++ b/pkg/sentry/safemem/BUILD
@@ -1,4 +1,5 @@
-load("//tools/go_stateify:defs.bzl", "go_library", "go_test")
+load("//tools/go_stateify:defs.bzl", "go_library")
+load("@io_bazel_rules_go//go:def.bzl", "go_test")
package(licenses = ["notice"])
diff --git a/pkg/sentry/sighandling/sighandling_unsafe.go b/pkg/sentry/sighandling/sighandling_unsafe.go
index eace3766d..c303435d5 100644
--- a/pkg/sentry/sighandling/sighandling_unsafe.go
+++ b/pkg/sentry/sighandling/sighandling_unsafe.go
@@ -23,7 +23,7 @@ import (
"gvisor.dev/gvisor/pkg/abi/linux"
)
-// TODO(b/34161764): Move to pkg/abi/linux along with definitions in
+// FIXME(gvisor.dev/issue/214): Move to pkg/abi/linux along with definitions in
// pkg/sentry/arch.
type sigaction struct {
handler uintptr
diff --git a/pkg/sentry/socket/epsocket/epsocket.go b/pkg/sentry/socket/epsocket/epsocket.go
index 635042263..5812085fa 100644
--- a/pkg/sentry/socket/epsocket/epsocket.go
+++ b/pkg/sentry/socket/epsocket/epsocket.go
@@ -26,13 +26,16 @@ package epsocket
import (
"bytes"
+ "io"
"math"
+ "reflect"
"sync"
"syscall"
"time"
"gvisor.dev/gvisor/pkg/abi/linux"
"gvisor.dev/gvisor/pkg/binary"
+ "gvisor.dev/gvisor/pkg/log"
"gvisor.dev/gvisor/pkg/metric"
"gvisor.dev/gvisor/pkg/sentry/arch"
"gvisor.dev/gvisor/pkg/sentry/context"
@@ -52,6 +55,7 @@ import (
"gvisor.dev/gvisor/pkg/tcpip/buffer"
"gvisor.dev/gvisor/pkg/tcpip/stack"
"gvisor.dev/gvisor/pkg/tcpip/transport/tcp"
+ "gvisor.dev/gvisor/pkg/tcpip/transport/udp"
"gvisor.dev/gvisor/pkg/waiter"
)
@@ -205,6 +209,10 @@ type commonEndpoint interface {
// transport.Endpoint.SetSockOpt.
SetSockOpt(interface{}) *tcpip.Error
+ // SetSockOptInt implements tcpip.Endpoint.SetSockOptInt and
+ // transport.Endpoint.SetSockOptInt.
+ SetSockOptInt(opt tcpip.SockOpt, v int) *tcpip.Error
+
// GetSockOpt implements tcpip.Endpoint.GetSockOpt and
// transport.Endpoint.GetSockOpt.
GetSockOpt(interface{}) *tcpip.Error
@@ -224,7 +232,6 @@ type SocketOperations struct {
fsutil.FileNoopFlush `state:"nosave"`
fsutil.FileNoFsync `state:"nosave"`
fsutil.FileNoMMap `state:"nosave"`
- fsutil.FileNoSplice `state:"nosave"`
fsutil.FileUseInodeUnstableAttr `state:"nosave"`
socket.SendReceiveTimeout
*waiter.Queue
@@ -409,17 +416,60 @@ func (s *SocketOperations) Read(ctx context.Context, _ *fs.File, dst usermem.IOS
return int64(n), nil
}
-// ioSequencePayload implements tcpip.Payload. It copies user memory bytes on demand
-// based on the requested size.
+// WriteTo implements fs.FileOperations.WriteTo.
+func (s *SocketOperations) WriteTo(ctx context.Context, _ *fs.File, dst io.Writer, count int64, dup bool) (int64, error) {
+ s.readMu.Lock()
+
+ // Copy as much data as possible.
+ done := int64(0)
+ for count > 0 {
+ // This may return a blocking error.
+ if err := s.fetchReadView(); err != nil {
+ s.readMu.Unlock()
+ return done, err.ToError()
+ }
+
+ // Write to the underlying file.
+ n, err := dst.Write(s.readView)
+ done += int64(n)
+ count -= int64(n)
+ if dup {
+ // That's all we support for dup. This is generally
+ // supported by any Linux system calls, but the
+ // expectation is that now a caller will call read to
+ // actually remove these bytes from the socket.
+ break
+ }
+
+ // Drop that part of the view.
+ s.readView.TrimFront(n)
+ if err != nil {
+ s.readMu.Unlock()
+ return done, err
+ }
+ }
+
+ s.readMu.Unlock()
+ return done, nil
+}
+
+// ioSequencePayload implements tcpip.Payload.
+//
+// t copies user memory bytes on demand based on the requested size.
type ioSequencePayload struct {
ctx context.Context
src usermem.IOSequence
}
-// Get implements tcpip.Payload.
-func (i *ioSequencePayload) Get(size int) ([]byte, *tcpip.Error) {
- if size > i.Size() {
- size = i.Size()
+// FullPayload implements tcpip.Payloader.FullPayload
+func (i *ioSequencePayload) FullPayload() ([]byte, *tcpip.Error) {
+ return i.Payload(int(i.src.NumBytes()))
+}
+
+// Payload implements tcpip.Payloader.Payload.
+func (i *ioSequencePayload) Payload(size int) ([]byte, *tcpip.Error) {
+ if max := int(i.src.NumBytes()); size > max {
+ size = max
}
v := buffer.NewView(size)
if _, err := i.src.CopyIn(i.ctx, v); err != nil {
@@ -428,11 +478,6 @@ func (i *ioSequencePayload) Get(size int) ([]byte, *tcpip.Error) {
return v, nil
}
-// Size implements tcpip.Payload.
-func (i *ioSequencePayload) Size() int {
- return int(i.src.NumBytes())
-}
-
// DropFirst drops the first n bytes from underlying src.
func (i *ioSequencePayload) DropFirst(n int) {
i.src = i.src.DropFirst(int(n))
@@ -466,6 +511,78 @@ func (s *SocketOperations) Write(ctx context.Context, _ *fs.File, src usermem.IO
return int64(n), nil
}
+// readerPayload implements tcpip.Payloader.
+//
+// It allocates a view and reads from a reader on-demand, based on available
+// capacity in the endpoint.
+type readerPayload struct {
+ ctx context.Context
+ r io.Reader
+ count int64
+ err error
+}
+
+// FullPayload implements tcpip.Payloader.FullPayload.
+func (r *readerPayload) FullPayload() ([]byte, *tcpip.Error) {
+ return r.Payload(int(r.count))
+}
+
+// Payload implements tcpip.Payloader.Payload.
+func (r *readerPayload) Payload(size int) ([]byte, *tcpip.Error) {
+ if size > int(r.count) {
+ size = int(r.count)
+ }
+ v := buffer.NewView(size)
+ n, err := r.r.Read(v)
+ if n > 0 {
+ // We ignore the error here. It may re-occur on subsequent
+ // reads, but for now we can enqueue some amount of data.
+ r.count -= int64(n)
+ return v[:n], nil
+ }
+ if err == syserror.ErrWouldBlock {
+ return nil, tcpip.ErrWouldBlock
+ } else if err != nil {
+ r.err = err // Save for propation.
+ return nil, tcpip.ErrBadAddress
+ }
+
+ // There is no data and no error. Return an error, which will propagate
+ // r.err, which will be nil. This is the desired result: (0, nil).
+ return nil, tcpip.ErrBadAddress
+}
+
+// ReadFrom implements fs.FileOperations.ReadFrom.
+func (s *SocketOperations) ReadFrom(ctx context.Context, _ *fs.File, r io.Reader, count int64) (int64, error) {
+ f := &readerPayload{ctx: ctx, r: r, count: count}
+ n, resCh, err := s.Endpoint.Write(f, tcpip.WriteOptions{
+ // Reads may be destructive but should be very fast,
+ // so we can't release the lock while copying data.
+ Atomic: true,
+ })
+ if err == tcpip.ErrWouldBlock {
+ return 0, syserror.ErrWouldBlock
+ }
+
+ if resCh != nil {
+ t := ctx.(*kernel.Task)
+ if err := t.Block(resCh); err != nil {
+ return 0, syserr.FromError(err).ToError()
+ }
+
+ n, _, err = s.Endpoint.Write(f, tcpip.WriteOptions{
+ Atomic: true, // See above.
+ })
+ }
+ if err == tcpip.ErrWouldBlock {
+ return n, syserror.ErrWouldBlock
+ } else if err != nil {
+ return int64(n), f.err // Propagate error.
+ }
+
+ return int64(n), nil
+}
+
// Readiness returns a mask of ready events for socket s.
func (s *SocketOperations) Readiness(mask waiter.EventMask) waiter.EventMask {
r := s.Endpoint.Readiness(mask)
@@ -774,8 +891,8 @@ func getSockOptSocket(t *kernel.Task, s socket.Socket, ep commonEndpoint, family
return nil, syserr.ErrInvalidArgument
}
- var size tcpip.SendBufferSizeOption
- if err := ep.GetSockOpt(&size); err != nil {
+ size, err := ep.GetSockOptInt(tcpip.SendBufferSizeOption)
+ if err != nil {
return nil, syserr.TranslateNetstackError(err)
}
@@ -790,8 +907,8 @@ func getSockOptSocket(t *kernel.Task, s socket.Socket, ep commonEndpoint, family
return nil, syserr.ErrInvalidArgument
}
- var size tcpip.ReceiveBufferSizeOption
- if err := ep.GetSockOpt(&size); err != nil {
+ size, err := ep.GetSockOptInt(tcpip.ReceiveBufferSizeOption)
+ if err != nil {
return nil, syserr.TranslateNetstackError(err)
}
@@ -825,6 +942,19 @@ func getSockOptSocket(t *kernel.Task, s socket.Socket, ep commonEndpoint, family
return int32(v), nil
+ case linux.SO_BINDTODEVICE:
+ var v tcpip.BindToDeviceOption
+ if err := ep.GetSockOpt(&v); err != nil {
+ return nil, syserr.TranslateNetstackError(err)
+ }
+ if len(v) == 0 {
+ return []byte{}, nil
+ }
+ if outLen < linux.IFNAMSIZ {
+ return nil, syserr.ErrInvalidArgument
+ }
+ return append([]byte(v), 0), nil
+
case linux.SO_BROADCAST:
if outLen < sizeOfInt32 {
return nil, syserr.ErrInvalidArgument
@@ -1162,7 +1292,7 @@ func setSockOptSocket(t *kernel.Task, s socket.Socket, ep commonEndpoint, name i
}
v := usermem.ByteOrder.Uint32(optVal)
- return syserr.TranslateNetstackError(ep.SetSockOpt(tcpip.SendBufferSizeOption(v)))
+ return syserr.TranslateNetstackError(ep.SetSockOptInt(tcpip.SendBufferSizeOption, int(v)))
case linux.SO_RCVBUF:
if len(optVal) < sizeOfInt32 {
@@ -1170,7 +1300,7 @@ func setSockOptSocket(t *kernel.Task, s socket.Socket, ep commonEndpoint, name i
}
v := usermem.ByteOrder.Uint32(optVal)
- return syserr.TranslateNetstackError(ep.SetSockOpt(tcpip.ReceiveBufferSizeOption(v)))
+ return syserr.TranslateNetstackError(ep.SetSockOptInt(tcpip.ReceiveBufferSizeOption, int(v)))
case linux.SO_REUSEADDR:
if len(optVal) < sizeOfInt32 {
@@ -1188,6 +1318,13 @@ func setSockOptSocket(t *kernel.Task, s socket.Socket, ep commonEndpoint, name i
v := usermem.ByteOrder.Uint32(optVal)
return syserr.TranslateNetstackError(ep.SetSockOpt(tcpip.ReusePortOption(v)))
+ case linux.SO_BINDTODEVICE:
+ n := bytes.IndexByte(optVal, 0)
+ if n == -1 {
+ n = len(optVal)
+ }
+ return syserr.TranslateNetstackError(ep.SetSockOpt(tcpip.BindToDeviceOption(optVal[:n])))
+
case linux.SO_BROADCAST:
if len(optVal) < sizeOfInt32 {
return syserr.ErrInvalidArgument
@@ -2057,7 +2194,7 @@ func (s *SocketOperations) SendMsg(t *kernel.Task, src usermem.IOSequence, to []
n, _, err = s.Endpoint.Write(v, opts)
}
dontWait := flags&linux.MSG_DONTWAIT != 0
- if err == nil && (n >= int64(v.Size()) || dontWait) {
+ if err == nil && (n >= v.src.NumBytes() || dontWait) {
// Complete write.
return int(n), nil
}
@@ -2082,7 +2219,7 @@ func (s *SocketOperations) SendMsg(t *kernel.Task, src usermem.IOSequence, to []
return 0, syserr.TranslateNetstackError(err)
}
- if err == nil && v.Size() == 0 || err != nil && err != tcpip.ErrWouldBlock {
+ if err == nil && v.src.NumBytes() == 0 || err != nil && err != tcpip.ErrWouldBlock {
return int(total), nil
}
@@ -2101,7 +2238,8 @@ func (s *SocketOperations) Ioctl(ctx context.Context, _ *fs.File, io usermem.IO,
// SIOCGSTAMP is implemented by epsocket rather than all commonEndpoint
// sockets.
// TODO(b/78348848): Add a commonEndpoint method to support SIOCGSTAMP.
- if int(args[1].Int()) == syscall.SIOCGSTAMP {
+ switch args[1].Int() {
+ case syscall.SIOCGSTAMP:
s.readMu.Lock()
defer s.readMu.Unlock()
if !s.timestampValid {
@@ -2113,6 +2251,25 @@ func (s *SocketOperations) Ioctl(ctx context.Context, _ *fs.File, io usermem.IO,
AddressSpaceActive: true,
})
return 0, err
+
+ case linux.TIOCINQ:
+ v, terr := s.Endpoint.GetSockOptInt(tcpip.ReceiveQueueSizeOption)
+ if terr != nil {
+ return 0, syserr.TranslateNetstackError(terr).ToError()
+ }
+
+ // Add bytes removed from the endpoint but not yet sent to the caller.
+ v += len(s.readView)
+
+ if v > math.MaxInt32 {
+ v = math.MaxInt32
+ }
+
+ // Copy result to user-space.
+ _, err := usermem.CopyObjectOut(ctx, io, args[2].Pointer(), int32(v), usermem.IOOpts{
+ AddressSpaceActive: true,
+ })
+ return 0, err
}
return Ioctl(ctx, s.Endpoint, io, args)
@@ -2184,9 +2341,9 @@ func Ioctl(ctx context.Context, ep commonEndpoint, io usermem.IO, args arch.Sysc
return 0, err
case linux.TIOCOUTQ:
- var v tcpip.SendQueueSizeOption
- if err := ep.GetSockOpt(&v); err != nil {
- return 0, syserr.TranslateNetstackError(err).ToError()
+ v, terr := ep.GetSockOptInt(tcpip.SendQueueSizeOption)
+ if terr != nil {
+ return 0, syserr.TranslateNetstackError(terr).ToError()
}
if v > math.MaxInt32 {
@@ -2421,7 +2578,8 @@ func (s *SocketOperations) State() uint32 {
return 0
}
- if !s.isPacketBased() {
+ switch {
+ case s.skType == linux.SOCK_STREAM && s.protocol == 0 || s.protocol == syscall.IPPROTO_TCP:
// TCP socket.
switch tcp.EndpointState(s.Endpoint.State()) {
case tcp.StateEstablished:
@@ -2450,9 +2608,26 @@ func (s *SocketOperations) State() uint32 {
// Internal or unknown state.
return 0
}
+ case s.skType == linux.SOCK_DGRAM && s.protocol == 0 || s.protocol == syscall.IPPROTO_UDP:
+ // UDP socket.
+ switch udp.EndpointState(s.Endpoint.State()) {
+ case udp.StateInitial, udp.StateBound, udp.StateClosed:
+ return linux.TCP_CLOSE
+ case udp.StateConnected:
+ return linux.TCP_ESTABLISHED
+ default:
+ return 0
+ }
+ case s.skType == linux.SOCK_DGRAM && s.protocol == syscall.IPPROTO_ICMP || s.protocol == syscall.IPPROTO_ICMPV6:
+ // TODO(b/112063468): Export states for ICMP sockets.
+ case s.skType == linux.SOCK_RAW:
+ // TODO(b/112063468): Export states for raw sockets.
+ default:
+ // Unknown transport protocol, how did we make this socket?
+ log.Warningf("Unknown transport protocol for an existing socket: family=%v, type=%v, protocol=%v, internal type %v", s.family, s.skType, s.protocol, reflect.TypeOf(s.Endpoint).Elem())
+ return 0
}
- // TODO(b/112063468): Export states for UDP, ICMP, and raw sockets.
return 0
}
diff --git a/pkg/sentry/socket/epsocket/provider.go b/pkg/sentry/socket/epsocket/provider.go
index 421f93dc4..0a9dfa6c3 100644
--- a/pkg/sentry/socket/epsocket/provider.go
+++ b/pkg/sentry/socket/epsocket/provider.go
@@ -65,7 +65,7 @@ func getTransportProtocol(ctx context.Context, stype linux.SockType, protocol in
// Raw sockets require CAP_NET_RAW.
creds := auth.CredentialsFromContext(ctx)
if !creds.HasCapability(linux.CAP_NET_RAW) {
- return 0, true, syserr.ErrPermissionDenied
+ return 0, true, syserr.ErrNotPermitted
}
switch protocol {
diff --git a/pkg/sentry/socket/netlink/port/BUILD b/pkg/sentry/socket/netlink/port/BUILD
index 9e2e12799..445080aa4 100644
--- a/pkg/sentry/socket/netlink/port/BUILD
+++ b/pkg/sentry/socket/netlink/port/BUILD
@@ -1,6 +1,8 @@
+load("@io_bazel_rules_go//go:def.bzl", "go_test")
+
package(licenses = ["notice"])
-load("//tools/go_stateify:defs.bzl", "go_library", "go_test")
+load("//tools/go_stateify:defs.bzl", "go_library")
go_library(
name = "port",
diff --git a/pkg/sentry/socket/rpcinet/BUILD b/pkg/sentry/socket/rpcinet/BUILD
index 5061dcbde..3a6baa308 100644
--- a/pkg/sentry/socket/rpcinet/BUILD
+++ b/pkg/sentry/socket/rpcinet/BUILD
@@ -1,5 +1,6 @@
load("//tools/go_stateify:defs.bzl", "go_library")
load("@io_bazel_rules_go//proto:def.bzl", "go_proto_library")
+load("@rules_cc//cc:defs.bzl", "cc_proto_library")
package(licenses = ["notice"])
@@ -49,6 +50,14 @@ proto_library(
],
)
+cc_proto_library(
+ name = "syscall_rpc_cc_proto",
+ visibility = [
+ "//visibility:public",
+ ],
+ deps = [":syscall_rpc_proto"],
+)
+
go_proto_library(
name = "syscall_rpc_go_proto",
importpath = "gvisor.dev/gvisor/pkg/sentry/socket/rpcinet/syscall_rpc_go_proto",
diff --git a/pkg/sentry/socket/unix/transport/unix.go b/pkg/sentry/socket/unix/transport/unix.go
index 2b0ad6395..1867b3a5c 100644
--- a/pkg/sentry/socket/unix/transport/unix.go
+++ b/pkg/sentry/socket/unix/transport/unix.go
@@ -175,6 +175,10 @@ type Endpoint interface {
// types.
SetSockOpt(opt interface{}) *tcpip.Error
+ // SetSockOptInt sets a socket option for simple cases when a value has
+ // the int type.
+ SetSockOptInt(opt tcpip.SockOpt, v int) *tcpip.Error
+
// GetSockOpt gets a socket option. opt should be a pointer to one of the
// tcpip.*Option types.
GetSockOpt(opt interface{}) *tcpip.Error
@@ -838,6 +842,10 @@ func (e *baseEndpoint) SetSockOpt(opt interface{}) *tcpip.Error {
return nil
}
+func (e *baseEndpoint) SetSockOptInt(opt tcpip.SockOpt, v int) *tcpip.Error {
+ return nil
+}
+
func (e *baseEndpoint) GetSockOptInt(opt tcpip.SockOpt) (int, *tcpip.Error) {
switch opt {
case tcpip.ReceiveQueueSizeOption:
@@ -853,65 +861,63 @@ func (e *baseEndpoint) GetSockOptInt(opt tcpip.SockOpt) (int, *tcpip.Error) {
return -1, tcpip.ErrQueueSizeNotSupported
}
return v, nil
- default:
- return -1, tcpip.ErrUnknownProtocolOption
- }
-}
-
-// GetSockOpt implements tcpip.Endpoint.GetSockOpt.
-func (e *baseEndpoint) GetSockOpt(opt interface{}) *tcpip.Error {
- switch o := opt.(type) {
- case tcpip.ErrorOption:
- return nil
- case *tcpip.SendQueueSizeOption:
+ case tcpip.SendQueueSizeOption:
e.Lock()
if !e.Connected() {
e.Unlock()
- return tcpip.ErrNotConnected
+ return -1, tcpip.ErrNotConnected
}
- qs := tcpip.SendQueueSizeOption(e.connected.SendQueuedSize())
+ v := e.connected.SendQueuedSize()
e.Unlock()
- if qs < 0 {
- return tcpip.ErrQueueSizeNotSupported
- }
- *o = qs
- return nil
-
- case *tcpip.PasscredOption:
- if e.Passcred() {
- *o = tcpip.PasscredOption(1)
- } else {
- *o = tcpip.PasscredOption(0)
+ if v < 0 {
+ return -1, tcpip.ErrQueueSizeNotSupported
}
- return nil
+ return int(v), nil
- case *tcpip.SendBufferSizeOption:
+ case tcpip.SendBufferSizeOption:
e.Lock()
if !e.Connected() {
e.Unlock()
- return tcpip.ErrNotConnected
+ return -1, tcpip.ErrNotConnected
}
- qs := tcpip.SendBufferSizeOption(e.connected.SendMaxQueueSize())
+ v := e.connected.SendMaxQueueSize()
e.Unlock()
- if qs < 0 {
- return tcpip.ErrQueueSizeNotSupported
+ if v < 0 {
+ return -1, tcpip.ErrQueueSizeNotSupported
}
- *o = qs
- return nil
+ return int(v), nil
- case *tcpip.ReceiveBufferSizeOption:
+ case tcpip.ReceiveBufferSizeOption:
e.Lock()
if e.receiver == nil {
e.Unlock()
- return tcpip.ErrNotConnected
+ return -1, tcpip.ErrNotConnected
}
- qs := tcpip.ReceiveBufferSizeOption(e.receiver.RecvMaxQueueSize())
+ v := e.receiver.RecvMaxQueueSize()
e.Unlock()
- if qs < 0 {
- return tcpip.ErrQueueSizeNotSupported
+ if v < 0 {
+ return -1, tcpip.ErrQueueSizeNotSupported
+ }
+ return int(v), nil
+
+ default:
+ return -1, tcpip.ErrUnknownProtocolOption
+ }
+}
+
+// GetSockOpt implements tcpip.Endpoint.GetSockOpt.
+func (e *baseEndpoint) GetSockOpt(opt interface{}) *tcpip.Error {
+ switch o := opt.(type) {
+ case tcpip.ErrorOption:
+ return nil
+
+ case *tcpip.PasscredOption:
+ if e.Passcred() {
+ *o = tcpip.PasscredOption(1)
+ } else {
+ *o = tcpip.PasscredOption(0)
}
- *o = qs
return nil
case *tcpip.KeepaliveEnabledOption:
diff --git a/pkg/sentry/strace/BUILD b/pkg/sentry/strace/BUILD
index 445d25010..7d7b42eba 100644
--- a/pkg/sentry/strace/BUILD
+++ b/pkg/sentry/strace/BUILD
@@ -1,5 +1,6 @@
load("//tools/go_stateify:defs.bzl", "go_library")
load("@io_bazel_rules_go//proto:def.bzl", "go_proto_library")
+load("@rules_cc//cc:defs.bzl", "cc_proto_library")
package(licenses = ["notice"])
@@ -44,6 +45,12 @@ proto_library(
visibility = ["//visibility:public"],
)
+cc_proto_library(
+ name = "strace_cc_proto",
+ visibility = ["//visibility:public"],
+ deps = [":strace_proto"],
+)
+
go_proto_library(
name = "strace_go_proto",
importpath = "gvisor.dev/gvisor/pkg/sentry/strace/strace_go_proto",
diff --git a/pkg/sentry/strace/linux64.go b/pkg/sentry/strace/linux64.go
index 3650fd6e1..5d57b75af 100644
--- a/pkg/sentry/strace/linux64.go
+++ b/pkg/sentry/strace/linux64.go
@@ -335,4 +335,5 @@ var linuxAMD64 = SyscallMap{
315: makeSyscallInfo("sched_getattr", Hex, Hex, Hex),
316: makeSyscallInfo("renameat2", FD, Path, Hex, Path, Hex),
317: makeSyscallInfo("seccomp", Hex, Hex, Hex),
+ 332: makeSyscallInfo("statx", FD, Path, Hex, Hex, Hex),
}
diff --git a/pkg/sentry/syscalls/linux/BUILD b/pkg/sentry/syscalls/linux/BUILD
index 33a40b9c6..e76ee27d2 100644
--- a/pkg/sentry/syscalls/linux/BUILD
+++ b/pkg/sentry/syscalls/linux/BUILD
@@ -74,6 +74,7 @@ go_library(
"//pkg/sentry/kernel/pipe",
"//pkg/sentry/kernel/sched",
"//pkg/sentry/kernel/shm",
+ "//pkg/sentry/kernel/signalfd",
"//pkg/sentry/kernel/time",
"//pkg/sentry/limits",
"//pkg/sentry/memmap",
diff --git a/pkg/sentry/syscalls/linux/linux64.go b/pkg/sentry/syscalls/linux/linux64.go
index 2f77d587b..72c383537 100644
--- a/pkg/sentry/syscalls/linux/linux64.go
+++ b/pkg/sentry/syscalls/linux/linux64.go
@@ -19,4 +19,4 @@ const (
_LINUX_SYSNAME = "Linux"
_LINUX_RELEASE = "4.4"
_LINUX_VERSION = "#1 SMP Sun Jan 10 15:06:54 PST 2016"
-)
+) \ No newline at end of file
diff --git a/pkg/sentry/syscalls/linux/sys_file.go b/pkg/sentry/syscalls/linux/sys_file.go
index 2e00a91ce..b9a8e3e21 100644
--- a/pkg/sentry/syscalls/linux/sys_file.go
+++ b/pkg/sentry/syscalls/linux/sys_file.go
@@ -1423,9 +1423,6 @@ func unlinkAt(t *kernel.Task, dirFD int32, addr usermem.Addr) error {
if err != nil {
return err
}
- if dirPath {
- return syserror.ENOENT
- }
return fileOpAt(t, dirFD, path, func(root *fs.Dirent, d *fs.Dirent, name string, _ uint) error {
if !fs.IsDir(d.Inode.StableAttr) {
@@ -1436,7 +1433,7 @@ func unlinkAt(t *kernel.Task, dirFD int32, addr usermem.Addr) error {
return err
}
- return d.Remove(t, root, name)
+ return d.Remove(t, root, name, dirPath)
})
}
diff --git a/pkg/sentry/syscalls/linux/sys_read.go b/pkg/sentry/syscalls/linux/sys_read.go
index 3ab54271c..cd31e0649 100644
--- a/pkg/sentry/syscalls/linux/sys_read.go
+++ b/pkg/sentry/syscalls/linux/sys_read.go
@@ -72,6 +72,39 @@ func Read(t *kernel.Task, args arch.SyscallArguments) (uintptr, *kernel.SyscallC
return uintptr(n), nil, handleIOError(t, n != 0, err, kernel.ERESTARTSYS, "read", file)
}
+// Readahead implements readahead(2).
+func Readahead(t *kernel.Task, args arch.SyscallArguments) (uintptr, *kernel.SyscallControl, error) {
+ fd := args[0].Int()
+ offset := args[1].Int64()
+ size := args[2].SizeT()
+
+ file := t.GetFile(fd)
+ if file == nil {
+ return 0, nil, syserror.EBADF
+ }
+ defer file.DecRef()
+
+ // Check that the file is readable.
+ if !file.Flags().Read {
+ return 0, nil, syserror.EBADF
+ }
+
+ // Check that the size is valid.
+ if int(size) < 0 {
+ return 0, nil, syserror.EINVAL
+ }
+
+ // Check that the offset is legitimate.
+ if offset < 0 {
+ return 0, nil, syserror.EINVAL
+ }
+
+ // Return EINVAL; if the underlying file type does not support readahead,
+ // then Linux will return EINVAL to indicate as much. In the future, we
+ // may extend this function to actually support readahead hints.
+ return 0, nil, syserror.EINVAL
+}
+
// Pread64 implements linux syscall pread64(2).
func Pread64(t *kernel.Task, args arch.SyscallArguments) (uintptr, *kernel.SyscallControl, error) {
fd := args[0].Int()
diff --git a/pkg/sentry/syscalls/linux/sys_signal.go b/pkg/sentry/syscalls/linux/sys_signal.go
index 0104a94c0..fb6efd5d8 100644
--- a/pkg/sentry/syscalls/linux/sys_signal.go
+++ b/pkg/sentry/syscalls/linux/sys_signal.go
@@ -20,7 +20,10 @@ import (
"gvisor.dev/gvisor/pkg/abi/linux"
"gvisor.dev/gvisor/pkg/sentry/arch"
+ "gvisor.dev/gvisor/pkg/sentry/fs"
"gvisor.dev/gvisor/pkg/sentry/kernel"
+ "gvisor.dev/gvisor/pkg/sentry/kernel/signalfd"
+ "gvisor.dev/gvisor/pkg/sentry/usermem"
"gvisor.dev/gvisor/pkg/syserror"
)
@@ -506,3 +509,77 @@ func RestartSyscall(t *kernel.Task, args arch.SyscallArguments) (uintptr, *kerne
t.Debugf("Restart block missing in restart_syscall(2). Did ptrace inject a return value of ERESTART_RESTARTBLOCK?")
return 0, nil, syserror.EINTR
}
+
+// sharedSignalfd is shared between the two calls.
+func sharedSignalfd(t *kernel.Task, fd int32, sigset usermem.Addr, sigsetsize uint, flags int32) (uintptr, *kernel.SyscallControl, error) {
+ // Copy in the signal mask.
+ mask, err := copyInSigSet(t, sigset, sigsetsize)
+ if err != nil {
+ return 0, nil, err
+ }
+
+ // Always check for valid flags, even if not creating.
+ if flags&^(linux.SFD_NONBLOCK|linux.SFD_CLOEXEC) != 0 {
+ return 0, nil, syserror.EINVAL
+ }
+
+ // Is this a change to an existing signalfd?
+ //
+ // The spec indicates that this should adjust the mask.
+ if fd != -1 {
+ file := t.GetFile(fd)
+ if file == nil {
+ return 0, nil, syserror.EBADF
+ }
+ defer file.DecRef()
+
+ // Is this a signalfd?
+ if s, ok := file.FileOperations.(*signalfd.SignalOperations); ok {
+ s.SetMask(mask)
+ return 0, nil, nil
+ }
+
+ // Not a signalfd.
+ return 0, nil, syserror.EINVAL
+ }
+
+ // Create a new file.
+ file, err := signalfd.New(t, mask)
+ if err != nil {
+ return 0, nil, err
+ }
+ defer file.DecRef()
+
+ // Set appropriate flags.
+ file.SetFlags(fs.SettableFileFlags{
+ NonBlocking: flags&linux.SFD_NONBLOCK != 0,
+ })
+
+ // Create a new descriptor.
+ fd, err = t.NewFDFrom(0, file, kernel.FDFlags{
+ CloseOnExec: flags&linux.SFD_CLOEXEC != 0,
+ })
+ if err != nil {
+ return 0, nil, err
+ }
+
+ // Done.
+ return uintptr(fd), nil, nil
+}
+
+// Signalfd implements the linux syscall signalfd(2).
+func Signalfd(t *kernel.Task, args arch.SyscallArguments) (uintptr, *kernel.SyscallControl, error) {
+ fd := args[0].Int()
+ sigset := args[1].Pointer()
+ sigsetsize := args[2].SizeT()
+ return sharedSignalfd(t, fd, sigset, sigsetsize, 0)
+}
+
+// Signalfd4 implements the linux syscall signalfd4(2).
+func Signalfd4(t *kernel.Task, args arch.SyscallArguments) (uintptr, *kernel.SyscallControl, error) {
+ fd := args[0].Int()
+ sigset := args[1].Pointer()
+ sigsetsize := args[2].SizeT()
+ flags := args[3].Int()
+ return sharedSignalfd(t, fd, sigset, sigsetsize, flags)
+}
diff --git a/pkg/sentry/syscalls/linux/sys_socket.go b/pkg/sentry/syscalls/linux/sys_socket.go
index 3bac4d90d..b5a72ce63 100644
--- a/pkg/sentry/syscalls/linux/sys_socket.go
+++ b/pkg/sentry/syscalls/linux/sys_socket.go
@@ -531,7 +531,7 @@ func SetSockOpt(t *kernel.Task, args arch.SyscallArguments) (uintptr, *kernel.Sy
return 0, nil, syserror.ENOTSOCK
}
- if optLen <= 0 {
+ if optLen < 0 {
return 0, nil, syserror.EINVAL
}
if optLen > maxOptLen {
diff --git a/pkg/sentry/syscalls/linux/sys_splice.go b/pkg/sentry/syscalls/linux/sys_splice.go
index 17e3dde1f..9f705ebca 100644
--- a/pkg/sentry/syscalls/linux/sys_splice.go
+++ b/pkg/sentry/syscalls/linux/sys_splice.go
@@ -29,9 +29,8 @@ func doSplice(t *kernel.Task, outFile, inFile *fs.File, opts fs.SpliceOpts, nonB
total int64
n int64
err error
- ch chan struct{}
- inW bool
- outW bool
+ inCh chan struct{}
+ outCh chan struct{}
)
for opts.Length > 0 {
n, err = fs.Splice(t, outFile, inFile, opts)
@@ -43,35 +42,33 @@ func doSplice(t *kernel.Task, outFile, inFile *fs.File, opts fs.SpliceOpts, nonB
break
}
- // Are we a registered waiter?
- if ch == nil {
- ch = make(chan struct{}, 1)
- }
- if !inW && !inFile.Flags().NonBlocking {
- w, _ := waiter.NewChannelEntry(ch)
- inFile.EventRegister(&w, EventMaskRead)
- defer inFile.EventUnregister(&w)
- inW = true // Registered.
- } else if !outW && !outFile.Flags().NonBlocking {
- w, _ := waiter.NewChannelEntry(ch)
- outFile.EventRegister(&w, EventMaskWrite)
- defer outFile.EventUnregister(&w)
- outW = true // Registered.
- }
-
- // Was anything registered? If no, everything is non-blocking.
- if !inW && !outW {
- break
- }
-
- if (!inW || inFile.Readiness(EventMaskRead) != 0) && (!outW || outFile.Readiness(EventMaskWrite) != 0) {
- // Something became ready, try again without blocking.
- continue
+ // Note that the blocking behavior here is a bit different than the
+ // normal pattern. Because we need to have both data to read and data
+ // to write simultaneously, we actually explicitly block on both of
+ // these cases in turn before returning to the splice operation.
+ if inFile.Readiness(EventMaskRead) == 0 {
+ if inCh == nil {
+ inCh = make(chan struct{}, 1)
+ inW, _ := waiter.NewChannelEntry(inCh)
+ inFile.EventRegister(&inW, EventMaskRead)
+ defer inFile.EventUnregister(&inW)
+ continue // Need to refresh readiness.
+ }
+ if err = t.Block(inCh); err != nil {
+ break
+ }
}
-
- // Block until there's data.
- if err = t.Block(ch); err != nil {
- break
+ if outFile.Readiness(EventMaskWrite) == 0 {
+ if outCh == nil {
+ outCh = make(chan struct{}, 1)
+ outW, _ := waiter.NewChannelEntry(outCh)
+ outFile.EventRegister(&outW, EventMaskWrite)
+ defer outFile.EventUnregister(&outW)
+ continue // Need to refresh readiness.
+ }
+ if err = t.Block(outCh); err != nil {
+ break
+ }
}
}
@@ -91,22 +88,29 @@ func Sendfile(t *kernel.Task, args arch.SyscallArguments) (uintptr, *kernel.Sysc
}
// Get files.
+ inFile := t.GetFile(inFD)
+ if inFile == nil {
+ return 0, nil, syserror.EBADF
+ }
+ defer inFile.DecRef()
+
+ if !inFile.Flags().Read {
+ return 0, nil, syserror.EBADF
+ }
+
outFile := t.GetFile(outFD)
if outFile == nil {
return 0, nil, syserror.EBADF
}
defer outFile.DecRef()
- inFile := t.GetFile(inFD)
- if inFile == nil {
+ if !outFile.Flags().Write {
return 0, nil, syserror.EBADF
}
- defer inFile.DecRef()
- // Verify that the outfile Append flag is not set. Note that fs.Splice
- // itself validates that the output file is writable.
+ // Verify that the outfile Append flag is not set.
if outFile.Flags().Append {
- return 0, nil, syserror.EBADF
+ return 0, nil, syserror.EINVAL
}
// Verify that we have a regular infile. This is a requirement; the
@@ -142,7 +146,7 @@ func Sendfile(t *kernel.Task, args arch.SyscallArguments) (uintptr, *kernel.Sysc
Length: count,
SrcOffset: true,
SrcStart: offset,
- }, false)
+ }, outFile.Flags().NonBlocking)
// Copy out the new offset.
if _, err := t.CopyOut(offsetAddr, n+offset); err != nil {
@@ -152,7 +156,7 @@ func Sendfile(t *kernel.Task, args arch.SyscallArguments) (uintptr, *kernel.Sysc
// Send data using splice.
n, err = doSplice(t, outFile, inFile, fs.SpliceOpts{
Length: count,
- }, false)
+ }, outFile.Flags().NonBlocking)
}
// We can only pass a single file to handleIOError, so pick inFile
@@ -174,12 +178,6 @@ func Splice(t *kernel.Task, args arch.SyscallArguments) (uintptr, *kernel.Syscal
return 0, nil, syserror.EINVAL
}
- // Only non-blocking is meaningful. Note that unlike in Linux, this
- // flag is applied consistently. We will have either fully blocking or
- // non-blocking behavior below, regardless of the underlying files
- // being spliced to. It's unclear if this is a bug or not yet.
- nonBlocking := (flags & linux.SPLICE_F_NONBLOCK) != 0
-
// Get files.
outFile := t.GetFile(outFD)
if outFile == nil {
@@ -193,6 +191,13 @@ func Splice(t *kernel.Task, args arch.SyscallArguments) (uintptr, *kernel.Syscal
}
defer inFile.DecRef()
+ // The operation is non-blocking if anything is non-blocking.
+ //
+ // N.B. This is a rather simplistic heuristic that avoids some
+ // poor edge case behavior since the exact semantics here are
+ // underspecified and vary between versions of Linux itself.
+ nonBlock := inFile.Flags().NonBlocking || outFile.Flags().NonBlocking || (flags&linux.SPLICE_F_NONBLOCK != 0)
+
// Construct our options.
//
// Note that exactly one of the underlying buffers must be a pipe. We
@@ -240,17 +245,17 @@ func Splice(t *kernel.Task, args arch.SyscallArguments) (uintptr, *kernel.Syscal
if inOffset != 0 || outOffset != 0 {
return 0, nil, syserror.ESPIPE
}
- default:
- return 0, nil, syserror.EINVAL
- }
- // We may not refer to the same pipe; otherwise it's a continuous loop.
- if inFile.Dirent.Inode.StableAttr.InodeID == outFile.Dirent.Inode.StableAttr.InodeID {
+ // We may not refer to the same pipe; otherwise it's a continuous loop.
+ if inFile.Dirent.Inode.StableAttr.InodeID == outFile.Dirent.Inode.StableAttr.InodeID {
+ return 0, nil, syserror.EINVAL
+ }
+ default:
return 0, nil, syserror.EINVAL
}
// Splice data.
- n, err := doSplice(t, outFile, inFile, opts, nonBlocking)
+ n, err := doSplice(t, outFile, inFile, opts, nonBlock)
// See above; inFile is chosen arbitrarily here.
return uintptr(n), nil, handleIOError(t, n != 0, err, kernel.ERESTARTSYS, "splice", inFile)
@@ -268,9 +273,6 @@ func Tee(t *kernel.Task, args arch.SyscallArguments) (uintptr, *kernel.SyscallCo
return 0, nil, syserror.EINVAL
}
- // Only non-blocking is meaningful.
- nonBlocking := (flags & linux.SPLICE_F_NONBLOCK) != 0
-
// Get files.
outFile := t.GetFile(outFD)
if outFile == nil {
@@ -294,11 +296,14 @@ func Tee(t *kernel.Task, args arch.SyscallArguments) (uintptr, *kernel.SyscallCo
return 0, nil, syserror.EINVAL
}
+ // The operation is non-blocking if anything is non-blocking.
+ nonBlock := inFile.Flags().NonBlocking || outFile.Flags().NonBlocking || (flags&linux.SPLICE_F_NONBLOCK != 0)
+
// Splice data.
n, err := doSplice(t, outFile, inFile, fs.SpliceOpts{
Length: count,
Dup: true,
- }, nonBlocking)
+ }, nonBlock)
// See above; inFile is chosen arbitrarily here.
return uintptr(n), nil, handleIOError(t, n != 0, err, kernel.ERESTARTSYS, "tee", inFile)
diff --git a/pkg/sentry/syscalls/linux/sys_time.go b/pkg/sentry/syscalls/linux/sys_time.go
index 4b3f043a2..b887fa9d7 100644
--- a/pkg/sentry/syscalls/linux/sys_time.go
+++ b/pkg/sentry/syscalls/linux/sys_time.go
@@ -15,6 +15,7 @@
package linux
import (
+ "fmt"
"time"
"gvisor.dev/gvisor/pkg/abi/linux"
@@ -228,41 +229,35 @@ func clockNanosleepFor(t *kernel.Task, c ktime.Clock, dur time.Duration, rem use
timer.Destroy()
- var remaining time.Duration
- // Did we just block for the entire duration?
- if err == syserror.ETIMEDOUT {
- remaining = 0
- } else {
- remaining = dur - after.Sub(start)
+ switch err {
+ case syserror.ETIMEDOUT:
+ // Slept for entire timeout.
+ return nil
+ case syserror.ErrInterrupted:
+ // Interrupted.
+ remaining := dur - after.Sub(start)
if remaining < 0 {
remaining = time.Duration(0)
}
- }
- // Copy out remaining time.
- if err != nil && rem != usermem.Addr(0) {
- timeleft := linux.NsecToTimespec(remaining.Nanoseconds())
- if err := copyTimespecOut(t, rem, &timeleft); err != nil {
- return err
+ // Copy out remaining time.
+ if rem != 0 {
+ timeleft := linux.NsecToTimespec(remaining.Nanoseconds())
+ if err := copyTimespecOut(t, rem, &timeleft); err != nil {
+ return err
+ }
}
- }
-
- // Did we just block for the entire duration?
- if err == syserror.ETIMEDOUT {
- return nil
- }
- // If interrupted, arrange for a restart with the remaining duration.
- if err == syserror.ErrInterrupted {
+ // Arrange for a restart with the remaining duration.
t.SetSyscallRestartBlock(&clockNanosleepRestartBlock{
c: c,
duration: remaining,
rem: rem,
})
return kernel.ERESTART_RESTARTBLOCK
+ default:
+ panic(fmt.Sprintf("Impossible BlockWithTimer error %v", err))
}
-
- return err
}
// Nanosleep implements linux syscall Nanosleep(2).
diff --git a/pkg/sentry/syscalls/linux/sys_utsname.go b/pkg/sentry/syscalls/linux/sys_utsname.go
index 271ace08e..748e8dd8d 100644
--- a/pkg/sentry/syscalls/linux/sys_utsname.go
+++ b/pkg/sentry/syscalls/linux/sys_utsname.go
@@ -79,11 +79,11 @@ func Sethostname(t *kernel.Task, args arch.SyscallArguments) (uintptr, *kernel.S
return 0, nil, syserror.EINVAL
}
- name, err := t.CopyInString(nameAddr, int(size))
- if err != nil {
+ name := make([]byte, size)
+ if _, err := t.CopyInBytes(nameAddr, name); err != nil {
return 0, nil, err
}
- utsns.SetHostName(name)
+ utsns.SetHostName(string(name))
return 0, nil, nil
}
diff --git a/pkg/sentry/time/BUILD b/pkg/sentry/time/BUILD
index 8aa6a3017..beb43ba13 100644
--- a/pkg/sentry/time/BUILD
+++ b/pkg/sentry/time/BUILD
@@ -1,4 +1,5 @@
-load("//tools/go_stateify:defs.bzl", "go_library", "go_test")
+load("//tools/go_stateify:defs.bzl", "go_library")
+load("@io_bazel_rules_go//go:def.bzl", "go_test")
package(licenses = ["notice"])
diff --git a/pkg/sentry/unimpl/BUILD b/pkg/sentry/unimpl/BUILD
index b69603da3..fc7614fff 100644
--- a/pkg/sentry/unimpl/BUILD
+++ b/pkg/sentry/unimpl/BUILD
@@ -1,5 +1,6 @@
load("//tools/go_stateify:defs.bzl", "go_library")
load("@io_bazel_rules_go//proto:def.bzl", "go_proto_library")
+load("@rules_cc//cc:defs.bzl", "cc_proto_library")
package(licenses = ["notice"])
@@ -10,6 +11,12 @@ proto_library(
deps = ["//pkg/sentry/arch:registers_proto"],
)
+cc_proto_library(
+ name = "unimplemented_syscall_cc_proto",
+ visibility = ["//visibility:public"],
+ deps = [":unimplemented_syscall_proto"],
+)
+
go_proto_library(
name = "unimplemented_syscall_go_proto",
importpath = "gvisor.dev/gvisor/pkg/sentry/unimpl/unimplemented_syscall_go_proto",
diff --git a/pkg/sentry/usage/memory.go b/pkg/sentry/usage/memory.go
index f4326706a..d6ef644d8 100644
--- a/pkg/sentry/usage/memory.go
+++ b/pkg/sentry/usage/memory.go
@@ -277,8 +277,3 @@ func TotalMemory(memSize, used uint64) uint64 {
}
return memSize
}
-
-// IncrementalMappedAccounting controls whether host mapped memory is accounted
-// incrementally during map translation. This may be modified during early
-// initialization, and is read-only afterward.
-var IncrementalMappedAccounting = false
diff --git a/pkg/sentry/usermem/BUILD b/pkg/sentry/usermem/BUILD
index a5b4206bb..cc5d25762 100644
--- a/pkg/sentry/usermem/BUILD
+++ b/pkg/sentry/usermem/BUILD
@@ -1,7 +1,9 @@
+load("@io_bazel_rules_go//go:def.bzl", "go_test")
+
package(licenses = ["notice"])
load("//tools/go_generics:defs.bzl", "go_template_instance")
-load("//tools/go_stateify:defs.bzl", "go_library", "go_test")
+load("//tools/go_stateify:defs.bzl", "go_library")
go_template_instance(
name = "addr_range",
diff --git a/pkg/sentry/usermem/usermem.go b/pkg/sentry/usermem/usermem.go
index 6eced660a..7b1f312b1 100644
--- a/pkg/sentry/usermem/usermem.go
+++ b/pkg/sentry/usermem/usermem.go
@@ -16,6 +16,7 @@
package usermem
import (
+ "bytes"
"errors"
"io"
"strconv"
@@ -270,11 +271,10 @@ func CopyStringIn(ctx context.Context, uio IO, addr Addr, maxlen int, opts IOOpt
n, err := uio.CopyIn(ctx, addr, buf[done:done+readlen], opts)
// Look for the terminating zero byte, which may have occurred before
// hitting err.
- for i, c := range buf[done : done+n] {
- if c == 0 {
- return stringFromImmutableBytes(buf[:done+i]), nil
- }
+ if i := bytes.IndexByte(buf[done:done+n], byte(0)); i >= 0 {
+ return stringFromImmutableBytes(buf[:done+i]), nil
}
+
done += n
if err != nil {
return stringFromImmutableBytes(buf[:done]), err
diff --git a/pkg/sentry/vfs/BUILD b/pkg/sentry/vfs/BUILD
index 0f247bf77..eff4b44f6 100644
--- a/pkg/sentry/vfs/BUILD
+++ b/pkg/sentry/vfs/BUILD
@@ -1,4 +1,5 @@
-load("//tools/go_stateify:defs.bzl", "go_library", "go_test")
+load("//tools/go_stateify:defs.bzl", "go_library")
+load("@io_bazel_rules_go//go:def.bzl", "go_test")
package(licenses = ["notice"])
diff --git a/pkg/sentry/vfs/file_description.go b/pkg/sentry/vfs/file_description.go
index 86bde7fb3..7eb2b2821 100644
--- a/pkg/sentry/vfs/file_description.go
+++ b/pkg/sentry/vfs/file_description.go
@@ -199,8 +199,11 @@ type Dirent struct {
// Ino is the inode number.
Ino uint64
- // Off is this Dirent's offset.
- Off int64
+ // NextOff is the offset of the *next* Dirent in the directory; that is,
+ // FileDescription.Seek(NextOff, SEEK_SET) (as called by seekdir(3)) will
+ // cause the next call to FileDescription.IterDirents() to yield the next
+ // Dirent. (The offset of the first Dirent in a directory is always 0.)
+ NextOff int64
}
// IterDirentsCallback receives Dirents from FileDescriptionImpl.IterDirents.
diff --git a/pkg/sleep/BUILD b/pkg/sleep/BUILD
index 00665c939..bdca80d37 100644
--- a/pkg/sleep/BUILD
+++ b/pkg/sleep/BUILD
@@ -1,4 +1,5 @@
-load("//tools/go_stateify:defs.bzl", "go_library", "go_test")
+load("//tools/go_stateify:defs.bzl", "go_library")
+load("@io_bazel_rules_go//go:def.bzl", "go_test")
package(licenses = ["notice"])
diff --git a/pkg/state/BUILD b/pkg/state/BUILD
index c0f3c658d..329904457 100644
--- a/pkg/state/BUILD
+++ b/pkg/state/BUILD
@@ -1,5 +1,6 @@
-load("//tools/go_stateify:defs.bzl", "go_library", "go_test")
+load("//tools/go_stateify:defs.bzl", "go_library")
load("@io_bazel_rules_go//proto:def.bzl", "go_proto_library")
+load("@io_bazel_rules_go//go:def.bzl", "go_test")
package(licenses = ["notice"])
diff --git a/pkg/state/statefile/BUILD b/pkg/state/statefile/BUILD
index e70f4a79f..8a865d229 100644
--- a/pkg/state/statefile/BUILD
+++ b/pkg/state/statefile/BUILD
@@ -1,4 +1,5 @@
-load("//tools/go_stateify:defs.bzl", "go_library", "go_test")
+load("//tools/go_stateify:defs.bzl", "go_library")
+load("@io_bazel_rules_go//go:def.bzl", "go_test")
package(licenses = ["notice"])
diff --git a/pkg/syserror/BUILD b/pkg/syserror/BUILD
index b149f9e02..bd3f9fd28 100644
--- a/pkg/syserror/BUILD
+++ b/pkg/syserror/BUILD
@@ -1,4 +1,5 @@
-load("//tools/go_stateify:defs.bzl", "go_library", "go_test")
+load("//tools/go_stateify:defs.bzl", "go_library")
+load("@io_bazel_rules_go//go:def.bzl", "go_test")
package(licenses = ["notice"])
diff --git a/pkg/tcpip/BUILD b/pkg/tcpip/BUILD
index df37c7d5a..3fd9e3134 100644
--- a/pkg/tcpip/BUILD
+++ b/pkg/tcpip/BUILD
@@ -1,6 +1,8 @@
+load("@io_bazel_rules_go//go:def.bzl", "go_test")
+
package(licenses = ["notice"])
-load("//tools/go_stateify:defs.bzl", "go_library", "go_test")
+load("//tools/go_stateify:defs.bzl", "go_library")
go_library(
name = "tcpip",
diff --git a/pkg/tcpip/adapters/gonet/BUILD b/pkg/tcpip/adapters/gonet/BUILD
index 0d2637ee4..78df5a0b1 100644
--- a/pkg/tcpip/adapters/gonet/BUILD
+++ b/pkg/tcpip/adapters/gonet/BUILD
@@ -1,4 +1,5 @@
-load("//tools/go_stateify:defs.bzl", "go_library", "go_test")
+load("//tools/go_stateify:defs.bzl", "go_library")
+load("@io_bazel_rules_go//go:def.bzl", "go_test")
package(licenses = ["notice"])
diff --git a/pkg/tcpip/adapters/gonet/gonet_test.go b/pkg/tcpip/adapters/gonet/gonet_test.go
index 672f026b2..8ced960bb 100644
--- a/pkg/tcpip/adapters/gonet/gonet_test.go
+++ b/pkg/tcpip/adapters/gonet/gonet_test.go
@@ -60,7 +60,10 @@ func TestTimeouts(t *testing.T) {
func newLoopbackStack() (*stack.Stack, *tcpip.Error) {
// Create the stack and add a NIC.
- s := stack.New([]string{ipv4.ProtocolName, ipv6.ProtocolName}, []string{tcp.ProtocolName, udp.ProtocolName}, stack.Options{})
+ s := stack.New(stack.Options{
+ NetworkProtocols: []stack.NetworkProtocol{ipv4.NewProtocol(), ipv6.NewProtocol()},
+ TransportProtocols: []stack.TransportProtocol{tcp.NewProtocol(), udp.NewProtocol()},
+ })
if err := s.CreateNIC(NICID, loopback.New()); err != nil {
return nil, err
diff --git a/pkg/tcpip/buffer/BUILD b/pkg/tcpip/buffer/BUILD
index 3301967fb..b4e8d6810 100644
--- a/pkg/tcpip/buffer/BUILD
+++ b/pkg/tcpip/buffer/BUILD
@@ -1,6 +1,8 @@
+load("@io_bazel_rules_go//go:def.bzl", "go_test")
+
package(licenses = ["notice"])
-load("//tools/go_stateify:defs.bzl", "go_library", "go_test")
+load("//tools/go_stateify:defs.bzl", "go_library")
go_library(
name = "buffer",
diff --git a/pkg/tcpip/checker/checker.go b/pkg/tcpip/checker/checker.go
index afcabd51d..096ad71ab 100644
--- a/pkg/tcpip/checker/checker.go
+++ b/pkg/tcpip/checker/checker.go
@@ -586,3 +586,103 @@ func Payload(want []byte) TransportChecker {
}
}
}
+
+// ICMPv4 creates a checker that checks that the transport protocol is ICMPv4 and
+// potentially additional ICMPv4 header fields.
+func ICMPv4(checkers ...TransportChecker) NetworkChecker {
+ return func(t *testing.T, h []header.Network) {
+ t.Helper()
+
+ last := h[len(h)-1]
+
+ if p := last.TransportProtocol(); p != header.ICMPv4ProtocolNumber {
+ t.Fatalf("Bad protocol, got %d, want %d", p, header.ICMPv4ProtocolNumber)
+ }
+
+ icmp := header.ICMPv4(last.Payload())
+ for _, f := range checkers {
+ f(t, icmp)
+ }
+ if t.Failed() {
+ t.FailNow()
+ }
+ }
+}
+
+// ICMPv4Type creates a checker that checks the ICMPv4 Type field.
+func ICMPv4Type(want header.ICMPv4Type) TransportChecker {
+ return func(t *testing.T, h header.Transport) {
+ t.Helper()
+ icmpv4, ok := h.(header.ICMPv4)
+ if !ok {
+ t.Fatalf("unexpected transport header passed to checker got: %+v, want: header.ICMPv4", h)
+ }
+ if got := icmpv4.Type(); got != want {
+ t.Fatalf("unexpected icmp type got: %d, want: %d", got, want)
+ }
+ }
+}
+
+// ICMPv4Code creates a checker that checks the ICMPv4 Code field.
+func ICMPv4Code(want byte) TransportChecker {
+ return func(t *testing.T, h header.Transport) {
+ t.Helper()
+ icmpv4, ok := h.(header.ICMPv4)
+ if !ok {
+ t.Fatalf("unexpected transport header passed to checker got: %+v, want: header.ICMPv4", h)
+ }
+ if got := icmpv4.Code(); got != want {
+ t.Fatalf("unexpected ICMP code got: %d, want: %d", got, want)
+ }
+ }
+}
+
+// ICMPv6 creates a checker that checks that the transport protocol is ICMPv6 and
+// potentially additional ICMPv6 header fields.
+func ICMPv6(checkers ...TransportChecker) NetworkChecker {
+ return func(t *testing.T, h []header.Network) {
+ t.Helper()
+
+ last := h[len(h)-1]
+
+ if p := last.TransportProtocol(); p != header.ICMPv6ProtocolNumber {
+ t.Fatalf("Bad protocol, got %d, want %d", p, header.ICMPv6ProtocolNumber)
+ }
+
+ icmp := header.ICMPv6(last.Payload())
+ for _, f := range checkers {
+ f(t, icmp)
+ }
+ if t.Failed() {
+ t.FailNow()
+ }
+ }
+}
+
+// ICMPv6Type creates a checker that checks the ICMPv6 Type field.
+func ICMPv6Type(want header.ICMPv6Type) TransportChecker {
+ return func(t *testing.T, h header.Transport) {
+ t.Helper()
+ icmpv6, ok := h.(header.ICMPv6)
+ if !ok {
+ t.Fatalf("unexpected transport header passed to checker got: %+v, want: header.ICMPv6", h)
+ }
+ if got := icmpv6.Type(); got != want {
+ t.Fatalf("unexpected icmp type got: %d, want: %d", got, want)
+ }
+ }
+}
+
+// ICMPv6Code creates a checker that checks the ICMPv6 Code field.
+func ICMPv6Code(want byte) TransportChecker {
+ return func(t *testing.T, h header.Transport) {
+ t.Helper()
+ icmpv6, ok := h.(header.ICMPv6)
+ if !ok {
+ t.Fatalf("unexpected transport header passed to checker got: %+v, want: header.ICMPv6", h)
+ }
+ if got := icmpv6.Code(); got != want {
+ t.Fatalf("unexpected ICMP code got: %d, want: %d", got, want)
+ }
+ }
+}
diff --git a/pkg/tcpip/hash/jenkins/BUILD b/pkg/tcpip/hash/jenkins/BUILD
index 29b30be9c..0c5c20cea 100644
--- a/pkg/tcpip/hash/jenkins/BUILD
+++ b/pkg/tcpip/hash/jenkins/BUILD
@@ -1,4 +1,5 @@
-load("//tools/go_stateify:defs.bzl", "go_library", "go_test")
+load("//tools/go_stateify:defs.bzl", "go_library")
+load("@io_bazel_rules_go//go:def.bzl", "go_test")
package(licenses = ["notice"])
diff --git a/pkg/tcpip/header/BUILD b/pkg/tcpip/header/BUILD
index 76ef02f13..b558350c3 100644
--- a/pkg/tcpip/header/BUILD
+++ b/pkg/tcpip/header/BUILD
@@ -1,6 +1,8 @@
+load("@io_bazel_rules_go//go:def.bzl", "go_test")
+
package(licenses = ["notice"])
-load("//tools/go_stateify:defs.bzl", "go_library", "go_test")
+load("//tools/go_stateify:defs.bzl", "go_library")
go_library(
name = "header",
diff --git a/pkg/tcpip/header/icmpv4.go b/pkg/tcpip/header/icmpv4.go
index c52c0d851..0cac6c0a5 100644
--- a/pkg/tcpip/header/icmpv4.go
+++ b/pkg/tcpip/header/icmpv4.go
@@ -18,6 +18,7 @@ import (
"encoding/binary"
"gvisor.dev/gvisor/pkg/tcpip"
+ "gvisor.dev/gvisor/pkg/tcpip/buffer"
)
// ICMPv4 represents an ICMPv4 header stored in a byte array.
@@ -25,13 +26,29 @@ type ICMPv4 []byte
const (
// ICMPv4PayloadOffset defines the start of ICMP payload.
- ICMPv4PayloadOffset = 4
+ ICMPv4PayloadOffset = 8
// ICMPv4MinimumSize is the minimum size of a valid ICMP packet.
ICMPv4MinimumSize = 8
// ICMPv4ProtocolNumber is the ICMP transport protocol number.
ICMPv4ProtocolNumber tcpip.TransportProtocolNumber = 1
+
+ // icmpv4ChecksumOffset is the offset of the checksum field
+ // in an ICMPv4 message.
+ icmpv4ChecksumOffset = 2
+
+ // icmpv4MTUOffset is the offset of the MTU field
+ // in a ICMPv4FragmentationNeeded message.
+ icmpv4MTUOffset = 6
+
+ // icmpv4IdentOffset is the offset of the ident field
+ // in a ICMPv4EchoRequest/Reply message.
+ icmpv4IdentOffset = 4
+
+ // icmpv4SequenceOffset is the offset of the sequence field
+ // in a ICMPv4EchoRequest/Reply message.
+ icmpv4SequenceOffset = 6
)
// ICMPv4Type is the ICMP type field described in RFC 792.
@@ -72,12 +89,12 @@ func (b ICMPv4) SetCode(c byte) { b[1] = c }
// Checksum is the ICMP checksum field.
func (b ICMPv4) Checksum() uint16 {
- return binary.BigEndian.Uint16(b[2:])
+ return binary.BigEndian.Uint16(b[icmpv4ChecksumOffset:])
}
// SetChecksum sets the ICMP checksum field.
func (b ICMPv4) SetChecksum(checksum uint16) {
- binary.BigEndian.PutUint16(b[2:], checksum)
+ binary.BigEndian.PutUint16(b[icmpv4ChecksumOffset:], checksum)
}
// SourcePort implements Transport.SourcePort.
@@ -102,3 +119,51 @@ func (ICMPv4) SetDestinationPort(uint16) {
func (b ICMPv4) Payload() []byte {
return b[ICMPv4PayloadOffset:]
}
+
+// MTU retrieves the MTU field from an ICMPv4 message.
+func (b ICMPv4) MTU() uint16 {
+ return binary.BigEndian.Uint16(b[icmpv4MTUOffset:])
+}
+
+// SetMTU sets the MTU field from an ICMPv4 message.
+func (b ICMPv4) SetMTU(mtu uint16) {
+ binary.BigEndian.PutUint16(b[icmpv4MTUOffset:], mtu)
+}
+
+// Ident retrieves the Ident field from an ICMPv4 message.
+func (b ICMPv4) Ident() uint16 {
+ return binary.BigEndian.Uint16(b[icmpv4IdentOffset:])
+}
+
+// SetIdent sets the Ident field from an ICMPv4 message.
+func (b ICMPv4) SetIdent(ident uint16) {
+ binary.BigEndian.PutUint16(b[icmpv4IdentOffset:], ident)
+}
+
+// Sequence retrieves the Sequence field from an ICMPv4 message.
+func (b ICMPv4) Sequence() uint16 {
+ return binary.BigEndian.Uint16(b[icmpv4SequenceOffset:])
+}
+
+// SetSequence sets the Sequence field from an ICMPv4 message.
+func (b ICMPv4) SetSequence(sequence uint16) {
+ binary.BigEndian.PutUint16(b[icmpv4SequenceOffset:], sequence)
+}
+
+// ICMPv4Checksum calculates the ICMP checksum over the provided ICMP header,
+// and payload.
+func ICMPv4Checksum(h ICMPv4, vv buffer.VectorisedView) uint16 {
+ // Calculate the IPv6 pseudo-header upper-layer checksum.
+ xsum := uint16(0)
+ for _, v := range vv.Views() {
+ xsum = Checksum(v, xsum)
+ }
+
+ // h[2:4] is the checksum itself, set it aside to avoid checksumming the checksum.
+ h2, h3 := h[2], h[3]
+ h[2], h[3] = 0, 0
+ xsum = ^Checksum(h, xsum)
+ h[2], h[3] = h2, h3
+
+ return xsum
+}
diff --git a/pkg/tcpip/header/icmpv6.go b/pkg/tcpip/header/icmpv6.go
index 3cc57e234..1125a7d14 100644
--- a/pkg/tcpip/header/icmpv6.go
+++ b/pkg/tcpip/header/icmpv6.go
@@ -18,6 +18,7 @@ import (
"encoding/binary"
"gvisor.dev/gvisor/pkg/tcpip"
+ "gvisor.dev/gvisor/pkg/tcpip/buffer"
)
// ICMPv6 represents an ICMPv6 header stored in a byte array.
@@ -25,14 +26,18 @@ type ICMPv6 []byte
const (
// ICMPv6MinimumSize is the minimum size of a valid ICMP packet.
- ICMPv6MinimumSize = 4
+ ICMPv6MinimumSize = 8
+
+ // ICMPv6PayloadOffset is the offset of the payload in an
+ // ICMP packet.
+ ICMPv6PayloadOffset = 8
// ICMPv6ProtocolNumber is the ICMP transport protocol number.
ICMPv6ProtocolNumber tcpip.TransportProtocolNumber = 58
// ICMPv6NeighborSolicitMinimumSize is the minimum size of a
// neighbor solicitation packet.
- ICMPv6NeighborSolicitMinimumSize = ICMPv6MinimumSize + 4 + 16
+ ICMPv6NeighborSolicitMinimumSize = ICMPv6MinimumSize + 16
// ICMPv6NeighborAdvertSize is size of a neighbor advertisement.
ICMPv6NeighborAdvertSize = 32
@@ -42,11 +47,27 @@ const (
// ICMPv6DstUnreachableMinimumSize is the minimum size of a valid ICMP
// destination unreachable packet.
- ICMPv6DstUnreachableMinimumSize = ICMPv6MinimumSize + 4
+ ICMPv6DstUnreachableMinimumSize = ICMPv6MinimumSize
// ICMPv6PacketTooBigMinimumSize is the minimum size of a valid ICMP
// packet-too-big packet.
- ICMPv6PacketTooBigMinimumSize = ICMPv6MinimumSize + 4
+ ICMPv6PacketTooBigMinimumSize = ICMPv6MinimumSize
+
+ // icmpv6ChecksumOffset is the offset of the checksum field
+ // in an ICMPv6 message.
+ icmpv6ChecksumOffset = 2
+
+ // icmpv6MTUOffset is the offset of the MTU field in an ICMPv6
+ // PacketTooBig message.
+ icmpv6MTUOffset = 4
+
+ // icmpv6IdentOffset is the offset of the ident field
+ // in a ICMPv6 Echo Request/Reply message.
+ icmpv6IdentOffset = 4
+
+ // icmpv6SequenceOffset is the offset of the sequence field
+ // in a ICMPv6 Echo Request/Reply message.
+ icmpv6SequenceOffset = 6
)
// ICMPv6Type is the ICMP type field described in RFC 4443 and friends.
@@ -89,12 +110,12 @@ func (b ICMPv6) SetCode(c byte) { b[1] = c }
// Checksum is the ICMP checksum field.
func (b ICMPv6) Checksum() uint16 {
- return binary.BigEndian.Uint16(b[2:])
+ return binary.BigEndian.Uint16(b[icmpv6ChecksumOffset:])
}
// SetChecksum calculates and sets the ICMP checksum field.
func (b ICMPv6) SetChecksum(checksum uint16) {
- binary.BigEndian.PutUint16(b[2:], checksum)
+ binary.BigEndian.PutUint16(b[icmpv6ChecksumOffset:], checksum)
}
// SourcePort implements Transport.SourcePort.
@@ -115,7 +136,60 @@ func (ICMPv6) SetSourcePort(uint16) {
func (ICMPv6) SetDestinationPort(uint16) {
}
+// MTU retrieves the MTU field from an ICMPv6 message.
+func (b ICMPv6) MTU() uint32 {
+ return binary.BigEndian.Uint32(b[icmpv6MTUOffset:])
+}
+
+// SetMTU sets the MTU field from an ICMPv6 message.
+func (b ICMPv6) SetMTU(mtu uint32) {
+ binary.BigEndian.PutUint32(b[icmpv6MTUOffset:], mtu)
+}
+
+// Ident retrieves the Ident field from an ICMPv6 message.
+func (b ICMPv6) Ident() uint16 {
+ return binary.BigEndian.Uint16(b[icmpv6IdentOffset:])
+}
+
+// SetIdent sets the Ident field from an ICMPv6 message.
+func (b ICMPv6) SetIdent(ident uint16) {
+ binary.BigEndian.PutUint16(b[icmpv6IdentOffset:], ident)
+}
+
+// Sequence retrieves the Sequence field from an ICMPv6 message.
+func (b ICMPv6) Sequence() uint16 {
+ return binary.BigEndian.Uint16(b[icmpv6SequenceOffset:])
+}
+
+// SetSequence sets the Sequence field from an ICMPv6 message.
+func (b ICMPv6) SetSequence(sequence uint16) {
+ binary.BigEndian.PutUint16(b[icmpv6SequenceOffset:], sequence)
+}
+
// Payload implements Transport.Payload.
func (b ICMPv6) Payload() []byte {
- return b[ICMPv6MinimumSize:]
+ return b[ICMPv6PayloadOffset:]
+}
+
+// ICMPv6Checksum calculates the ICMP checksum over the provided ICMP header,
+// IPv6 src/dst addresses and the payload.
+func ICMPv6Checksum(h ICMPv6, src, dst tcpip.Address, vv buffer.VectorisedView) uint16 {
+ // Calculate the IPv6 pseudo-header upper-layer checksum.
+ xsum := Checksum([]byte(src), 0)
+ xsum = Checksum([]byte(dst), xsum)
+ var upperLayerLength [4]byte
+ binary.BigEndian.PutUint32(upperLayerLength[:], uint32(len(h)+vv.Size()))
+ xsum = Checksum(upperLayerLength[:], xsum)
+ xsum = Checksum([]byte{0, 0, 0, uint8(ICMPv6ProtocolNumber)}, xsum)
+ for _, v := range vv.Views() {
+ xsum = Checksum(v, xsum)
+ }
+
+ // h[2:4] is the checksum itself, set it aside to avoid checksumming the checksum.
+ h2, h3 := h[2], h[3]
+ h[2], h[3] = 0, 0
+ xsum = ^Checksum(h, xsum)
+ h[2], h[3] = h2, h3
+
+ return xsum
}
diff --git a/pkg/tcpip/header/ipv4.go b/pkg/tcpip/header/ipv4.go
index 17fc9c68e..554632a64 100644
--- a/pkg/tcpip/header/ipv4.go
+++ b/pkg/tcpip/header/ipv4.go
@@ -21,16 +21,18 @@ import (
)
const (
- versIHL = 0
- tos = 1
- totalLen = 2
- id = 4
- flagsFO = 6
- ttl = 8
- protocol = 9
- checksum = 10
- srcAddr = 12
- dstAddr = 16
+ versIHL = 0
+ tos = 1
+ // IPv4TotalLenOffset is the offset of the total length field in the
+ // IPv4 header.
+ IPv4TotalLenOffset = 2
+ id = 4
+ flagsFO = 6
+ ttl = 8
+ protocol = 9
+ checksum = 10
+ srcAddr = 12
+ dstAddr = 16
)
// IPv4Fields contains the fields of an IPv4 packet. It is used to describe the
@@ -103,6 +105,11 @@ const (
// IPv4Any is the non-routable IPv4 "any" meta address.
IPv4Any tcpip.Address = "\x00\x00\x00\x00"
+
+ // IPv4MinimumProcessableDatagramSize is the minimum size of an IP
+ // packet that every IPv4 capable host must be able to
+ // process/reassemble.
+ IPv4MinimumProcessableDatagramSize = 576
)
// Flags that may be set in an IPv4 packet.
@@ -163,7 +170,7 @@ func (b IPv4) FragmentOffset() uint16 {
// TotalLength returns the "total length" field of the ipv4 header.
func (b IPv4) TotalLength() uint16 {
- return binary.BigEndian.Uint16(b[totalLen:])
+ return binary.BigEndian.Uint16(b[IPv4TotalLenOffset:])
}
// Checksum returns the checksum field of the ipv4 header.
@@ -209,7 +216,7 @@ func (b IPv4) SetTOS(v uint8, _ uint32) {
// SetTotalLength sets the "total length" field of the ipv4 header.
func (b IPv4) SetTotalLength(totalLength uint16) {
- binary.BigEndian.PutUint16(b[totalLen:], totalLength)
+ binary.BigEndian.PutUint16(b[IPv4TotalLenOffset:], totalLength)
}
// SetChecksum sets the checksum field of the ipv4 header.
@@ -265,7 +272,7 @@ func (b IPv4) Encode(i *IPv4Fields) {
// packets are produced.
func (b IPv4) EncodePartial(partialChecksum, totalLength uint16) {
b.SetTotalLength(totalLength)
- checksum := Checksum(b[totalLen:totalLen+2], partialChecksum)
+ checksum := Checksum(b[IPv4TotalLenOffset:IPv4TotalLenOffset+2], partialChecksum)
b.SetChecksum(^checksum)
}
diff --git a/pkg/tcpip/header/ipv6.go b/pkg/tcpip/header/ipv6.go
index 31be42ce0..9d3abc0e4 100644
--- a/pkg/tcpip/header/ipv6.go
+++ b/pkg/tcpip/header/ipv6.go
@@ -22,12 +22,14 @@ import (
)
const (
- versTCFL = 0
- payloadLen = 4
- nextHdr = 6
- hopLimit = 7
- v6SrcAddr = 8
- v6DstAddr = 24
+ versTCFL = 0
+ // IPv6PayloadLenOffset is the offset of the PayloadLength field in
+ // IPv6 header.
+ IPv6PayloadLenOffset = 4
+ nextHdr = 6
+ hopLimit = 7
+ v6SrcAddr = 8
+ v6DstAddr = v6SrcAddr + IPv6AddressSize
)
// IPv6Fields contains the fields of an IPv6 packet. It is used to describe the
@@ -74,6 +76,13 @@ const (
// IPv6Version is the version of the ipv6 protocol.
IPv6Version = 6
+ // IPv6AllNodesMulticastAddress is a link-local multicast group that
+ // all IPv6 nodes MUST join, as per RFC 4291, section 2.8. Packets
+ // destined to this address will reach all nodes on a link.
+ //
+ // The address is ff02::1.
+ IPv6AllNodesMulticastAddress tcpip.Address = "\xff\x02\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x01"
+
// IPv6MinimumMTU is the minimum MTU required by IPv6, per RFC 2460,
// section 5.
IPv6MinimumMTU = 1280
@@ -94,7 +103,7 @@ var IPv6EmptySubnet = func() tcpip.Subnet {
// PayloadLength returns the value of the "payload length" field of the ipv6
// header.
func (b IPv6) PayloadLength() uint16 {
- return binary.BigEndian.Uint16(b[payloadLen:])
+ return binary.BigEndian.Uint16(b[IPv6PayloadLenOffset:])
}
// HopLimit returns the value of the "hop limit" field of the ipv6 header.
@@ -119,13 +128,13 @@ func (b IPv6) Payload() []byte {
// SourceAddress returns the "source address" field of the ipv6 header.
func (b IPv6) SourceAddress() tcpip.Address {
- return tcpip.Address(b[v6SrcAddr : v6SrcAddr+IPv6AddressSize])
+ return tcpip.Address(b[v6SrcAddr:][:IPv6AddressSize])
}
// DestinationAddress returns the "destination address" field of the ipv6
// header.
func (b IPv6) DestinationAddress() tcpip.Address {
- return tcpip.Address(b[v6DstAddr : v6DstAddr+IPv6AddressSize])
+ return tcpip.Address(b[v6DstAddr:][:IPv6AddressSize])
}
// Checksum implements Network.Checksum. Given that IPv6 doesn't have a
@@ -148,18 +157,18 @@ func (b IPv6) SetTOS(t uint8, l uint32) {
// SetPayloadLength sets the "payload length" field of the ipv6 header.
func (b IPv6) SetPayloadLength(payloadLength uint16) {
- binary.BigEndian.PutUint16(b[payloadLen:], payloadLength)
+ binary.BigEndian.PutUint16(b[IPv6PayloadLenOffset:], payloadLength)
}
// SetSourceAddress sets the "source address" field of the ipv6 header.
func (b IPv6) SetSourceAddress(addr tcpip.Address) {
- copy(b[v6SrcAddr:v6SrcAddr+IPv6AddressSize], addr)
+ copy(b[v6SrcAddr:][:IPv6AddressSize], addr)
}
// SetDestinationAddress sets the "destination address" field of the ipv6
// header.
func (b IPv6) SetDestinationAddress(addr tcpip.Address) {
- copy(b[v6DstAddr:v6DstAddr+IPv6AddressSize], addr)
+ copy(b[v6DstAddr:][:IPv6AddressSize], addr)
}
// SetNextHeader sets the value of the "next header" field of the ipv6 header.
@@ -178,8 +187,8 @@ func (b IPv6) Encode(i *IPv6Fields) {
b.SetPayloadLength(i.PayloadLength)
b[nextHdr] = i.NextHeader
b[hopLimit] = i.HopLimit
- copy(b[v6SrcAddr:v6SrcAddr+IPv6AddressSize], i.SrcAddr)
- copy(b[v6DstAddr:v6DstAddr+IPv6AddressSize], i.DstAddr)
+ b.SetSourceAddress(i.SrcAddr)
+ b.SetDestinationAddress(i.DstAddr)
}
// IsValid performs basic validation on the packet.
@@ -219,6 +228,24 @@ func IsV6MulticastAddress(addr tcpip.Address) bool {
return addr[0] == 0xff
}
+// IsV6UnicastAddress determines if the provided address is a valid IPv6
+// unicast (and specified) address. That is, IsV6UnicastAddress returns
+// true if addr contains IPv6AddressSize bytes, is not the unspecified
+// address and is not a multicast address.
+func IsV6UnicastAddress(addr tcpip.Address) bool {
+ if len(addr) != IPv6AddressSize {
+ return false
+ }
+
+ // Must not be unspecified
+ if addr == IPv6Any {
+ return false
+ }
+
+ // Return if not a multicast.
+ return addr[0] != 0xff
+}
+
// SolicitedNodeAddr computes the solicited-node multicast address. This is
// used for NDP. Described in RFC 4291. The argument must be a full-length IPv6
// address.
diff --git a/pkg/tcpip/header/udp.go b/pkg/tcpip/header/udp.go
index c1f454805..74412c894 100644
--- a/pkg/tcpip/header/udp.go
+++ b/pkg/tcpip/header/udp.go
@@ -27,6 +27,11 @@ const (
udpChecksum = 6
)
+const (
+ // UDPMaximumPacketSize is the largest possible UDP packet.
+ UDPMaximumPacketSize = 0xffff
+)
+
// UDPFields contains the fields of a UDP packet. It is used to describe the
// fields of a packet that needs to be encoded.
type UDPFields struct {
diff --git a/pkg/tcpip/link/channel/channel.go b/pkg/tcpip/link/channel/channel.go
index c40744b8e..18adb2085 100644
--- a/pkg/tcpip/link/channel/channel.go
+++ b/pkg/tcpip/link/channel/channel.go
@@ -44,14 +44,12 @@ type Endpoint struct {
}
// New creates a new channel endpoint.
-func New(size int, mtu uint32, linkAddr tcpip.LinkAddress) (tcpip.LinkEndpointID, *Endpoint) {
- e := &Endpoint{
+func New(size int, mtu uint32, linkAddr tcpip.LinkAddress) *Endpoint {
+ return &Endpoint{
C: make(chan PacketInfo, size),
mtu: mtu,
linkAddr: linkAddr,
}
-
- return stack.RegisterLinkEndpoint(e), e
}
// Drain removes all outbound packets from the channel and counts them.
@@ -135,3 +133,6 @@ func (e *Endpoint) WritePacket(_ *stack.Route, gso *stack.GSO, hdr buffer.Prepen
return nil
}
+
+// Wait implements stack.LinkEndpoint.Wait.
+func (*Endpoint) Wait() {}
diff --git a/pkg/tcpip/link/fdbased/BUILD b/pkg/tcpip/link/fdbased/BUILD
index d786d8fdf..8fa9e3984 100644
--- a/pkg/tcpip/link/fdbased/BUILD
+++ b/pkg/tcpip/link/fdbased/BUILD
@@ -1,4 +1,5 @@
-load("//tools/go_stateify:defs.bzl", "go_library", "go_test")
+load("//tools/go_stateify:defs.bzl", "go_library")
+load("@io_bazel_rules_go//go:def.bzl", "go_test")
package(licenses = ["notice"])
@@ -8,8 +9,8 @@ go_library(
"endpoint.go",
"endpoint_unsafe.go",
"mmap.go",
- "mmap_amd64.go",
- "mmap_amd64_unsafe.go",
+ "mmap_stub.go",
+ "mmap_unsafe.go",
"packet_dispatchers.go",
],
importpath = "gvisor.dev/gvisor/pkg/tcpip/link/fdbased",
diff --git a/pkg/tcpip/link/fdbased/endpoint.go b/pkg/tcpip/link/fdbased/endpoint.go
index 77f988b9f..7636418b1 100644
--- a/pkg/tcpip/link/fdbased/endpoint.go
+++ b/pkg/tcpip/link/fdbased/endpoint.go
@@ -41,6 +41,7 @@ package fdbased
import (
"fmt"
+ "sync"
"syscall"
"golang.org/x/sys/unix"
@@ -81,6 +82,19 @@ const (
PacketMMap
)
+func (p PacketDispatchMode) String() string {
+ switch p {
+ case Readv:
+ return "Readv"
+ case RecvMMsg:
+ return "RecvMMsg"
+ case PacketMMap:
+ return "PacketMMap"
+ default:
+ return fmt.Sprintf("unknown packet dispatch mode %v", p)
+ }
+}
+
type endpoint struct {
// fds is the set of file descriptors each identifying one inbound/outbound
// channel. The endpoint will dispatch from all inbound channels as well as
@@ -114,6 +128,9 @@ type endpoint struct {
// gsoMaxSize is the maximum GSO packet size. It is zero if GSO is
// disabled.
gsoMaxSize uint32
+
+ // wg keeps track of running goroutines.
+ wg sync.WaitGroup
}
// Options specify the details about the fd-based endpoint to be created.
@@ -164,8 +181,9 @@ type Options struct {
// New creates a new fd-based endpoint.
//
// Makes fd non-blocking, but does not take ownership of fd, which must remain
-// open for the lifetime of the returned endpoint.
-func New(opts *Options) (tcpip.LinkEndpointID, error) {
+// open for the lifetime of the returned endpoint (until after the endpoint has
+// stopped being using and Wait returns).
+func New(opts *Options) (stack.LinkEndpoint, error) {
caps := stack.LinkEndpointCapabilities(0)
if opts.RXChecksumOffload {
caps |= stack.CapabilityRXChecksumOffload
@@ -190,7 +208,7 @@ func New(opts *Options) (tcpip.LinkEndpointID, error) {
}
if len(opts.FDs) == 0 {
- return 0, fmt.Errorf("opts.FD is empty, at least one FD must be specified")
+ return nil, fmt.Errorf("opts.FD is empty, at least one FD must be specified")
}
e := &endpoint{
@@ -207,12 +225,12 @@ func New(opts *Options) (tcpip.LinkEndpointID, error) {
for i := 0; i < len(e.fds); i++ {
fd := e.fds[i]
if err := syscall.SetNonblock(fd, true); err != nil {
- return 0, fmt.Errorf("syscall.SetNonblock(%v) failed: %v", fd, err)
+ return nil, fmt.Errorf("syscall.SetNonblock(%v) failed: %v", fd, err)
}
isSocket, err := isSocketFD(fd)
if err != nil {
- return 0, err
+ return nil, err
}
if isSocket {
if opts.GSOMaxSize != 0 {
@@ -222,12 +240,12 @@ func New(opts *Options) (tcpip.LinkEndpointID, error) {
}
inboundDispatcher, err := createInboundDispatcher(e, fd, isSocket)
if err != nil {
- return 0, fmt.Errorf("createInboundDispatcher(...) = %v", err)
+ return nil, fmt.Errorf("createInboundDispatcher(...) = %v", err)
}
e.inboundDispatchers = append(e.inboundDispatchers, inboundDispatcher)
}
- return stack.RegisterLinkEndpoint(e), nil
+ return e, nil
}
func createInboundDispatcher(e *endpoint, fd int, isSocket bool) (linkDispatcher, error) {
@@ -290,7 +308,11 @@ func (e *endpoint) Attach(dispatcher stack.NetworkDispatcher) {
// saved, they stop sending outgoing packets and all incoming packets
// are rejected.
for i := range e.inboundDispatchers {
- go e.dispatchLoop(e.inboundDispatchers[i]) // S/R-SAFE: See above.
+ e.wg.Add(1)
+ go func(i int) { // S/R-SAFE: See above.
+ e.dispatchLoop(e.inboundDispatchers[i])
+ e.wg.Done()
+ }(i)
}
}
@@ -320,6 +342,12 @@ func (e *endpoint) LinkAddress() tcpip.LinkAddress {
return e.addr
}
+// Wait implements stack.LinkEndpoint.Wait. It waits for the endpoint to stop
+// reading from its FD.
+func (e *endpoint) Wait() {
+ e.wg.Wait()
+}
+
// virtioNetHdr is declared in linux/virtio_net.h.
type virtioNetHdr struct {
flags uint8
@@ -435,14 +463,12 @@ func (e *InjectableEndpoint) Inject(protocol tcpip.NetworkProtocolNumber, vv buf
}
// NewInjectable creates a new fd-based InjectableEndpoint.
-func NewInjectable(fd int, mtu uint32, capabilities stack.LinkEndpointCapabilities) (tcpip.LinkEndpointID, *InjectableEndpoint) {
+func NewInjectable(fd int, mtu uint32, capabilities stack.LinkEndpointCapabilities) *InjectableEndpoint {
syscall.SetNonblock(fd, true)
- e := &InjectableEndpoint{endpoint: endpoint{
+ return &InjectableEndpoint{endpoint: endpoint{
fds: []int{fd},
mtu: mtu,
caps: capabilities,
}}
-
- return stack.RegisterLinkEndpoint(e), e
}
diff --git a/pkg/tcpip/link/fdbased/endpoint_test.go b/pkg/tcpip/link/fdbased/endpoint_test.go
index e305252d6..04406bc9a 100644
--- a/pkg/tcpip/link/fdbased/endpoint_test.go
+++ b/pkg/tcpip/link/fdbased/endpoint_test.go
@@ -68,11 +68,10 @@ func newContext(t *testing.T, opt *Options) *context {
}
opt.FDs = []int{fds[1]}
- epID, err := New(opt)
+ ep, err := New(opt)
if err != nil {
t.Fatalf("Failed to create FD endpoint: %v", err)
}
- ep := stack.FindLinkEndpoint(epID).(*endpoint)
c := &context{
t: t,
diff --git a/pkg/tcpip/link/fdbased/mmap.go b/pkg/tcpip/link/fdbased/mmap.go
index 2dca173c2..8bfeb97e4 100644
--- a/pkg/tcpip/link/fdbased/mmap.go
+++ b/pkg/tcpip/link/fdbased/mmap.go
@@ -12,12 +12,183 @@
// See the License for the specific language governing permissions and
// limitations under the License.
-// +build !linux !amd64
+// +build linux,amd64 linux,arm64
package fdbased
-// Stubbed out version for non-linux/non-amd64 platforms.
+import (
+ "encoding/binary"
+ "syscall"
-func newPacketMMapDispatcher(fd int, e *endpoint) (linkDispatcher, error) {
- return nil, nil
+ "golang.org/x/sys/unix"
+ "gvisor.dev/gvisor/pkg/tcpip"
+ "gvisor.dev/gvisor/pkg/tcpip/buffer"
+ "gvisor.dev/gvisor/pkg/tcpip/header"
+ "gvisor.dev/gvisor/pkg/tcpip/link/rawfile"
+)
+
+const (
+ tPacketAlignment = uintptr(16)
+ tpStatusKernel = 0
+ tpStatusUser = 1
+ tpStatusCopy = 2
+ tpStatusLosing = 4
+)
+
+// We overallocate the frame size to accommodate space for the
+// TPacketHdr+RawSockAddrLinkLayer+MAC header and any padding.
+//
+// Memory allocated for the ring buffer: tpBlockSize * tpBlockNR = 2 MiB
+//
+// NOTE:
+// Frames need to be aligned at 16 byte boundaries.
+// BlockSize needs to be page aligned.
+//
+// For details see PACKET_MMAP setting constraints in
+// https://www.kernel.org/doc/Documentation/networking/packet_mmap.txt
+const (
+ tpFrameSize = 65536 + 128
+ tpBlockSize = tpFrameSize * 32
+ tpBlockNR = 1
+ tpFrameNR = (tpBlockSize * tpBlockNR) / tpFrameSize
+)
+
+// tPacketAlign aligns the pointer v at a tPacketAlignment boundary. Direct
+// translation of the TPACKET_ALIGN macro in <linux/if_packet.h>.
+func tPacketAlign(v uintptr) uintptr {
+ return (v + tPacketAlignment - 1) & uintptr(^(tPacketAlignment - 1))
+}
+
+// tPacketReq is the tpacket_req structure as described in
+// https://www.kernel.org/doc/Documentation/networking/packet_mmap.txt
+type tPacketReq struct {
+ tpBlockSize uint32
+ tpBlockNR uint32
+ tpFrameSize uint32
+ tpFrameNR uint32
+}
+
+// tPacketHdr is tpacket_hdr structure as described in <linux/if_packet.h>
+type tPacketHdr []byte
+
+const (
+ tpStatusOffset = 0
+ tpLenOffset = 8
+ tpSnapLenOffset = 12
+ tpMacOffset = 16
+ tpNetOffset = 18
+ tpSecOffset = 20
+ tpUSecOffset = 24
+)
+
+func (t tPacketHdr) tpLen() uint32 {
+ return binary.LittleEndian.Uint32(t[tpLenOffset:])
+}
+
+func (t tPacketHdr) tpSnapLen() uint32 {
+ return binary.LittleEndian.Uint32(t[tpSnapLenOffset:])
+}
+
+func (t tPacketHdr) tpMac() uint16 {
+ return binary.LittleEndian.Uint16(t[tpMacOffset:])
+}
+
+func (t tPacketHdr) tpNet() uint16 {
+ return binary.LittleEndian.Uint16(t[tpNetOffset:])
+}
+
+func (t tPacketHdr) tpSec() uint32 {
+ return binary.LittleEndian.Uint32(t[tpSecOffset:])
+}
+
+func (t tPacketHdr) tpUSec() uint32 {
+ return binary.LittleEndian.Uint32(t[tpUSecOffset:])
+}
+
+func (t tPacketHdr) Payload() []byte {
+ return t[uint32(t.tpMac()) : uint32(t.tpMac())+t.tpSnapLen()]
+}
+
+// packetMMapDispatcher uses PACKET_RX_RING's to read/dispatch inbound packets.
+// See: mmap_amd64_unsafe.go for implementation details.
+type packetMMapDispatcher struct {
+ // fd is the file descriptor used to send and receive packets.
+ fd int
+
+ // e is the endpoint this dispatcher is attached to.
+ e *endpoint
+
+ // ringBuffer is only used when PacketMMap dispatcher is used and points
+ // to the start of the mmapped PACKET_RX_RING buffer.
+ ringBuffer []byte
+
+ // ringOffset is the current offset into the ring buffer where the next
+ // inbound packet will be placed by the kernel.
+ ringOffset int
+}
+
+func (d *packetMMapDispatcher) readMMappedPacket() ([]byte, *tcpip.Error) {
+ hdr := tPacketHdr(d.ringBuffer[d.ringOffset*tpFrameSize:])
+ for hdr.tpStatus()&tpStatusUser == 0 {
+ event := rawfile.PollEvent{
+ FD: int32(d.fd),
+ Events: unix.POLLIN | unix.POLLERR,
+ }
+ if _, errno := rawfile.BlockingPoll(&event, 1, nil); errno != 0 {
+ if errno == syscall.EINTR {
+ continue
+ }
+ return nil, rawfile.TranslateErrno(errno)
+ }
+ if hdr.tpStatus()&tpStatusCopy != 0 {
+ // This frame is truncated so skip it after flipping the
+ // buffer to the kernel.
+ hdr.setTPStatus(tpStatusKernel)
+ d.ringOffset = (d.ringOffset + 1) % tpFrameNR
+ hdr = (tPacketHdr)(d.ringBuffer[d.ringOffset*tpFrameSize:])
+ continue
+ }
+ }
+
+ // Copy out the packet from the mmapped frame to a locally owned buffer.
+ pkt := make([]byte, hdr.tpSnapLen())
+ copy(pkt, hdr.Payload())
+ // Release packet to kernel.
+ hdr.setTPStatus(tpStatusKernel)
+ d.ringOffset = (d.ringOffset + 1) % tpFrameNR
+ return pkt, nil
+}
+
+// dispatch reads packets from an mmaped ring buffer and dispatches them to the
+// network stack.
+func (d *packetMMapDispatcher) dispatch() (bool, *tcpip.Error) {
+ pkt, err := d.readMMappedPacket()
+ if err != nil {
+ return false, err
+ }
+ var (
+ p tcpip.NetworkProtocolNumber
+ remote, local tcpip.LinkAddress
+ )
+ if d.e.hdrSize > 0 {
+ eth := header.Ethernet(pkt)
+ p = eth.Type()
+ remote = eth.SourceAddress()
+ local = eth.DestinationAddress()
+ } else {
+ // We don't get any indication of what the packet is, so try to guess
+ // if it's an IPv4 or IPv6 packet.
+ switch header.IPVersion(pkt) {
+ case header.IPv4Version:
+ p = header.IPv4ProtocolNumber
+ case header.IPv6Version:
+ p = header.IPv6ProtocolNumber
+ default:
+ return true, nil
+ }
+ }
+
+ pkt = pkt[d.e.hdrSize:]
+ d.e.dispatcher.DeliverNetworkPacket(d.e, remote, local, p, buffer.NewVectorisedView(len(pkt), []buffer.View{buffer.View(pkt)}))
+ return true, nil
}
diff --git a/pkg/tcpip/link/fdbased/mmap_amd64.go b/pkg/tcpip/link/fdbased/mmap_amd64.go
deleted file mode 100644
index 029f86a18..000000000
--- a/pkg/tcpip/link/fdbased/mmap_amd64.go
+++ /dev/null
@@ -1,194 +0,0 @@
-// Copyright 2019 The gVisor Authors.
-//
-// Licensed under the Apache License, Version 2.0 (the "License");
-// you may not use this file except in compliance with the License.
-// You may obtain a copy of the License at
-//
-// http://www.apache.org/licenses/LICENSE-2.0
-//
-// Unless required by applicable law or agreed to in writing, software
-// distributed under the License is distributed on an "AS IS" BASIS,
-// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-// See the License for the specific language governing permissions and
-// limitations under the License.
-
-// +build linux,amd64
-
-package fdbased
-
-import (
- "encoding/binary"
- "syscall"
-
- "golang.org/x/sys/unix"
- "gvisor.dev/gvisor/pkg/tcpip"
- "gvisor.dev/gvisor/pkg/tcpip/buffer"
- "gvisor.dev/gvisor/pkg/tcpip/header"
- "gvisor.dev/gvisor/pkg/tcpip/link/rawfile"
-)
-
-const (
- tPacketAlignment = uintptr(16)
- tpStatusKernel = 0
- tpStatusUser = 1
- tpStatusCopy = 2
- tpStatusLosing = 4
-)
-
-// We overallocate the frame size to accommodate space for the
-// TPacketHdr+RawSockAddrLinkLayer+MAC header and any padding.
-//
-// Memory allocated for the ring buffer: tpBlockSize * tpBlockNR = 2 MiB
-//
-// NOTE:
-// Frames need to be aligned at 16 byte boundaries.
-// BlockSize needs to be page aligned.
-//
-// For details see PACKET_MMAP setting constraints in
-// https://www.kernel.org/doc/Documentation/networking/packet_mmap.txt
-const (
- tpFrameSize = 65536 + 128
- tpBlockSize = tpFrameSize * 32
- tpBlockNR = 1
- tpFrameNR = (tpBlockSize * tpBlockNR) / tpFrameSize
-)
-
-// tPacketAlign aligns the pointer v at a tPacketAlignment boundary. Direct
-// translation of the TPACKET_ALIGN macro in <linux/if_packet.h>.
-func tPacketAlign(v uintptr) uintptr {
- return (v + tPacketAlignment - 1) & uintptr(^(tPacketAlignment - 1))
-}
-
-// tPacketReq is the tpacket_req structure as described in
-// https://www.kernel.org/doc/Documentation/networking/packet_mmap.txt
-type tPacketReq struct {
- tpBlockSize uint32
- tpBlockNR uint32
- tpFrameSize uint32
- tpFrameNR uint32
-}
-
-// tPacketHdr is tpacket_hdr structure as described in <linux/if_packet.h>
-type tPacketHdr []byte
-
-const (
- tpStatusOffset = 0
- tpLenOffset = 8
- tpSnapLenOffset = 12
- tpMacOffset = 16
- tpNetOffset = 18
- tpSecOffset = 20
- tpUSecOffset = 24
-)
-
-func (t tPacketHdr) tpLen() uint32 {
- return binary.LittleEndian.Uint32(t[tpLenOffset:])
-}
-
-func (t tPacketHdr) tpSnapLen() uint32 {
- return binary.LittleEndian.Uint32(t[tpSnapLenOffset:])
-}
-
-func (t tPacketHdr) tpMac() uint16 {
- return binary.LittleEndian.Uint16(t[tpMacOffset:])
-}
-
-func (t tPacketHdr) tpNet() uint16 {
- return binary.LittleEndian.Uint16(t[tpNetOffset:])
-}
-
-func (t tPacketHdr) tpSec() uint32 {
- return binary.LittleEndian.Uint32(t[tpSecOffset:])
-}
-
-func (t tPacketHdr) tpUSec() uint32 {
- return binary.LittleEndian.Uint32(t[tpUSecOffset:])
-}
-
-func (t tPacketHdr) Payload() []byte {
- return t[uint32(t.tpMac()) : uint32(t.tpMac())+t.tpSnapLen()]
-}
-
-// packetMMapDispatcher uses PACKET_RX_RING's to read/dispatch inbound packets.
-// See: mmap_amd64_unsafe.go for implementation details.
-type packetMMapDispatcher struct {
- // fd is the file descriptor used to send and receive packets.
- fd int
-
- // e is the endpoint this dispatcher is attached to.
- e *endpoint
-
- // ringBuffer is only used when PacketMMap dispatcher is used and points
- // to the start of the mmapped PACKET_RX_RING buffer.
- ringBuffer []byte
-
- // ringOffset is the current offset into the ring buffer where the next
- // inbound packet will be placed by the kernel.
- ringOffset int
-}
-
-func (d *packetMMapDispatcher) readMMappedPacket() ([]byte, *tcpip.Error) {
- hdr := tPacketHdr(d.ringBuffer[d.ringOffset*tpFrameSize:])
- for hdr.tpStatus()&tpStatusUser == 0 {
- event := rawfile.PollEvent{
- FD: int32(d.fd),
- Events: unix.POLLIN | unix.POLLERR,
- }
- if _, errno := rawfile.BlockingPoll(&event, 1, nil); errno != 0 {
- if errno == syscall.EINTR {
- continue
- }
- return nil, rawfile.TranslateErrno(errno)
- }
- if hdr.tpStatus()&tpStatusCopy != 0 {
- // This frame is truncated so skip it after flipping the
- // buffer to the kernel.
- hdr.setTPStatus(tpStatusKernel)
- d.ringOffset = (d.ringOffset + 1) % tpFrameNR
- hdr = (tPacketHdr)(d.ringBuffer[d.ringOffset*tpFrameSize:])
- continue
- }
- }
-
- // Copy out the packet from the mmapped frame to a locally owned buffer.
- pkt := make([]byte, hdr.tpSnapLen())
- copy(pkt, hdr.Payload())
- // Release packet to kernel.
- hdr.setTPStatus(tpStatusKernel)
- d.ringOffset = (d.ringOffset + 1) % tpFrameNR
- return pkt, nil
-}
-
-// dispatch reads packets from an mmaped ring buffer and dispatches them to the
-// network stack.
-func (d *packetMMapDispatcher) dispatch() (bool, *tcpip.Error) {
- pkt, err := d.readMMappedPacket()
- if err != nil {
- return false, err
- }
- var (
- p tcpip.NetworkProtocolNumber
- remote, local tcpip.LinkAddress
- )
- if d.e.hdrSize > 0 {
- eth := header.Ethernet(pkt)
- p = eth.Type()
- remote = eth.SourceAddress()
- local = eth.DestinationAddress()
- } else {
- // We don't get any indication of what the packet is, so try to guess
- // if it's an IPv4 or IPv6 packet.
- switch header.IPVersion(pkt) {
- case header.IPv4Version:
- p = header.IPv4ProtocolNumber
- case header.IPv6Version:
- p = header.IPv6ProtocolNumber
- default:
- return true, nil
- }
- }
-
- pkt = pkt[d.e.hdrSize:]
- d.e.dispatcher.DeliverNetworkPacket(d.e, remote, local, p, buffer.NewVectorisedView(len(pkt), []buffer.View{buffer.View(pkt)}))
- return true, nil
-}
diff --git a/test/runtimes/runtimes.go b/pkg/tcpip/link/fdbased/mmap_stub.go
index 2568e07fe..67be52d67 100644
--- a/test/runtimes/runtimes.go
+++ b/pkg/tcpip/link/fdbased/mmap_stub.go
@@ -12,9 +12,12 @@
// See the License for the specific language governing permissions and
// limitations under the License.
-// Package runtimes provides language tests for runsc runtimes.
-// Each test calls docker commands to start up a container for each supported runtime,
-// and tests that its respective language tests are behaving as expected, like
-// connecting to a port or looking at the output. The container is killed and deleted
-// at the end.
-package runtimes
+// +build !linux !amd64,!arm64
+
+package fdbased
+
+// Stubbed out version for non-linux/non-amd64/non-arm64 platforms.
+
+func newPacketMMapDispatcher(fd int, e *endpoint) (linkDispatcher, error) {
+ return nil, nil
+}
diff --git a/pkg/tcpip/link/fdbased/mmap_amd64_unsafe.go b/pkg/tcpip/link/fdbased/mmap_unsafe.go
index 47cb1d1cc..3894185ae 100644
--- a/pkg/tcpip/link/fdbased/mmap_amd64_unsafe.go
+++ b/pkg/tcpip/link/fdbased/mmap_unsafe.go
@@ -12,7 +12,7 @@
// See the License for the specific language governing permissions and
// limitations under the License.
-// +build linux,amd64
+// +build linux,amd64 linux,arm64
package fdbased
diff --git a/pkg/tcpip/link/loopback/loopback.go b/pkg/tcpip/link/loopback/loopback.go
index ab6a53988..b36629d2c 100644
--- a/pkg/tcpip/link/loopback/loopback.go
+++ b/pkg/tcpip/link/loopback/loopback.go
@@ -32,8 +32,8 @@ type endpoint struct {
// New creates a new loopback endpoint. This link-layer endpoint just turns
// outbound packets into inbound packets.
-func New() tcpip.LinkEndpointID {
- return stack.RegisterLinkEndpoint(&endpoint{})
+func New() stack.LinkEndpoint {
+ return &endpoint{}
}
// Attach implements stack.LinkEndpoint.Attach. It just saves the stack network-
@@ -85,3 +85,6 @@ func (e *endpoint) WritePacket(_ *stack.Route, _ *stack.GSO, hdr buffer.Prependa
return nil
}
+
+// Wait implements stack.LinkEndpoint.Wait.
+func (*endpoint) Wait() {}
diff --git a/pkg/tcpip/link/muxed/BUILD b/pkg/tcpip/link/muxed/BUILD
index ea12ef1ac..1bab380b0 100644
--- a/pkg/tcpip/link/muxed/BUILD
+++ b/pkg/tcpip/link/muxed/BUILD
@@ -1,4 +1,5 @@
-load("//tools/go_stateify:defs.bzl", "go_library", "go_test")
+load("//tools/go_stateify:defs.bzl", "go_library")
+load("@io_bazel_rules_go//go:def.bzl", "go_test")
package(licenses = ["notice"])
diff --git a/pkg/tcpip/link/muxed/injectable.go b/pkg/tcpip/link/muxed/injectable.go
index a577a3d52..7c946101d 100644
--- a/pkg/tcpip/link/muxed/injectable.go
+++ b/pkg/tcpip/link/muxed/injectable.go
@@ -104,10 +104,16 @@ func (m *InjectableEndpoint) WriteRawPacket(dest tcpip.Address, packet []byte) *
return endpoint.WriteRawPacket(dest, packet)
}
+// Wait implements stack.LinkEndpoint.Wait.
+func (m *InjectableEndpoint) Wait() {
+ for _, ep := range m.routes {
+ ep.Wait()
+ }
+}
+
// NewInjectableEndpoint creates a new multi-endpoint injectable endpoint.
-func NewInjectableEndpoint(routes map[tcpip.Address]stack.InjectableLinkEndpoint) (tcpip.LinkEndpointID, *InjectableEndpoint) {
- e := &InjectableEndpoint{
+func NewInjectableEndpoint(routes map[tcpip.Address]stack.InjectableLinkEndpoint) *InjectableEndpoint {
+ return &InjectableEndpoint{
routes: routes,
}
- return stack.RegisterLinkEndpoint(e), e
}
diff --git a/pkg/tcpip/link/muxed/injectable_test.go b/pkg/tcpip/link/muxed/injectable_test.go
index 174b9330f..3086fec00 100644
--- a/pkg/tcpip/link/muxed/injectable_test.go
+++ b/pkg/tcpip/link/muxed/injectable_test.go
@@ -87,8 +87,8 @@ func makeTestInjectableEndpoint(t *testing.T) (*InjectableEndpoint, *os.File, tc
if err != nil {
t.Fatal("Failed to create socket pair:", err)
}
- _, underlyingEndpoint := fdbased.NewInjectable(pair[1], 6500, stack.CapabilityNone)
+ underlyingEndpoint := fdbased.NewInjectable(pair[1], 6500, stack.CapabilityNone)
routes := map[tcpip.Address]stack.InjectableLinkEndpoint{dstIP: underlyingEndpoint}
- _, endpoint := NewInjectableEndpoint(routes)
+ endpoint := NewInjectableEndpoint(routes)
return endpoint, os.NewFile(uintptr(pair[0]), "test route end"), dstIP
}
diff --git a/pkg/tcpip/link/rawfile/BUILD b/pkg/tcpip/link/rawfile/BUILD
index 6e3a7a9d7..2e8bc772a 100644
--- a/pkg/tcpip/link/rawfile/BUILD
+++ b/pkg/tcpip/link/rawfile/BUILD
@@ -6,8 +6,9 @@ go_library(
name = "rawfile",
srcs = [
"blockingpoll_amd64.s",
- "blockingpoll_amd64_unsafe.go",
- "blockingpoll_unsafe.go",
+ "blockingpoll_arm64.s",
+ "blockingpoll_noyield_unsafe.go",
+ "blockingpoll_yield_unsafe.go",
"errors.go",
"rawfile_unsafe.go",
],
diff --git a/pkg/tcpip/link/rawfile/blockingpoll_arm64.s b/pkg/tcpip/link/rawfile/blockingpoll_arm64.s
new file mode 100644
index 000000000..b62888b93
--- /dev/null
+++ b/pkg/tcpip/link/rawfile/blockingpoll_arm64.s
@@ -0,0 +1,42 @@
+// Copyright 2018 The gVisor Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+#include "textflag.h"
+
+// BlockingPoll makes the ppoll() syscall while calling the version of
+// entersyscall that relinquishes the P so that other Gs can run. This is meant
+// to be called in cases when the syscall is expected to block.
+//
+// func BlockingPoll(fds *PollEvent, nfds int, timeout *syscall.Timespec) (n int, err syscall.Errno)
+TEXT ·BlockingPoll(SB),NOSPLIT,$0-40
+ BL ·callEntersyscallblock(SB)
+ MOVD fds+0(FP), R0
+ MOVD nfds+8(FP), R1
+ MOVD timeout+16(FP), R2
+ MOVD $0x0, R3 // sigmask parameter which isn't used here
+ MOVD $0x49, R8 // SYS_PPOLL
+ SVC
+ CMP $0xfffffffffffff001, R0
+ BLS ok
+ MOVD $-1, R1
+ MOVD R1, n+24(FP)
+ NEG R0, R0
+ MOVD R0, err+32(FP)
+ BL ·callExitsyscall(SB)
+ RET
+ok:
+ MOVD R0, n+24(FP)
+ MOVD $0, err+32(FP)
+ BL ·callExitsyscall(SB)
+ RET
diff --git a/pkg/tcpip/link/rawfile/blockingpoll_unsafe.go b/pkg/tcpip/link/rawfile/blockingpoll_noyield_unsafe.go
index 84dc0e918..621ab8d29 100644
--- a/pkg/tcpip/link/rawfile/blockingpoll_unsafe.go
+++ b/pkg/tcpip/link/rawfile/blockingpoll_noyield_unsafe.go
@@ -12,7 +12,7 @@
// See the License for the specific language governing permissions and
// limitations under the License.
-// +build linux,!amd64
+// +build linux,!amd64,!arm64
package rawfile
@@ -22,7 +22,7 @@ import (
)
// BlockingPoll is just a stub function that forwards to the ppoll() system call
-// on non-amd64 platforms.
+// on non-amd64 and non-arm64 platforms.
func BlockingPoll(fds *PollEvent, nfds int, timeout *syscall.Timespec) (int, syscall.Errno) {
n, _, e := syscall.Syscall6(syscall.SYS_PPOLL, uintptr(unsafe.Pointer(fds)),
uintptr(nfds), uintptr(unsafe.Pointer(timeout)), 0, 0, 0)
diff --git a/pkg/tcpip/link/rawfile/blockingpoll_amd64_unsafe.go b/pkg/tcpip/link/rawfile/blockingpoll_yield_unsafe.go
index 47039a446..dda3b10a6 100644
--- a/pkg/tcpip/link/rawfile/blockingpoll_amd64_unsafe.go
+++ b/pkg/tcpip/link/rawfile/blockingpoll_yield_unsafe.go
@@ -12,7 +12,7 @@
// See the License for the specific language governing permissions and
// limitations under the License.
-// +build linux,amd64
+// +build linux,amd64 linux,arm64
// +build go1.12
// +build !go1.14
@@ -25,6 +25,12 @@ import (
_ "unsafe" // for go:linkname
)
+// BlockingPoll on amd64/arm64 makes the ppoll() syscall while calling the
+// version of entersyscall that relinquishes the P so that other Gs can
+// run. This is meant to be called in cases when the syscall is expected to
+// block. On non amd64/arm64 platforms it just forwards to the ppoll() system
+// call.
+//
//go:noescape
func BlockingPoll(fds *PollEvent, nfds int, timeout *syscall.Timespec) (int, syscall.Errno)
diff --git a/pkg/tcpip/link/sharedmem/BUILD b/pkg/tcpip/link/sharedmem/BUILD
index f2998aa98..0a5ea3dc4 100644
--- a/pkg/tcpip/link/sharedmem/BUILD
+++ b/pkg/tcpip/link/sharedmem/BUILD
@@ -1,4 +1,5 @@
-load("//tools/go_stateify:defs.bzl", "go_library", "go_test")
+load("//tools/go_stateify:defs.bzl", "go_library")
+load("@io_bazel_rules_go//go:def.bzl", "go_test")
package(licenses = ["notice"])
diff --git a/pkg/tcpip/link/sharedmem/pipe/BUILD b/pkg/tcpip/link/sharedmem/pipe/BUILD
index 94725cb11..330ed5e94 100644
--- a/pkg/tcpip/link/sharedmem/pipe/BUILD
+++ b/pkg/tcpip/link/sharedmem/pipe/BUILD
@@ -1,4 +1,5 @@
-load("//tools/go_stateify:defs.bzl", "go_library", "go_test")
+load("//tools/go_stateify:defs.bzl", "go_library")
+load("@io_bazel_rules_go//go:def.bzl", "go_test")
package(licenses = ["notice"])
diff --git a/pkg/tcpip/link/sharedmem/queue/BUILD b/pkg/tcpip/link/sharedmem/queue/BUILD
index 160a8f864..de1ce043d 100644
--- a/pkg/tcpip/link/sharedmem/queue/BUILD
+++ b/pkg/tcpip/link/sharedmem/queue/BUILD
@@ -1,4 +1,5 @@
-load("//tools/go_stateify:defs.bzl", "go_library", "go_test")
+load("//tools/go_stateify:defs.bzl", "go_library")
+load("@io_bazel_rules_go//go:def.bzl", "go_test")
package(licenses = ["notice"])
diff --git a/pkg/tcpip/link/sharedmem/sharedmem.go b/pkg/tcpip/link/sharedmem/sharedmem.go
index 834ea5c40..9e71d4edf 100644
--- a/pkg/tcpip/link/sharedmem/sharedmem.go
+++ b/pkg/tcpip/link/sharedmem/sharedmem.go
@@ -94,7 +94,7 @@ type endpoint struct {
// New creates a new shared-memory-based endpoint. Buffers will be broken up
// into buffers of "bufferSize" bytes.
-func New(mtu, bufferSize uint32, addr tcpip.LinkAddress, tx, rx QueueConfig) (tcpip.LinkEndpointID, error) {
+func New(mtu, bufferSize uint32, addr tcpip.LinkAddress, tx, rx QueueConfig) (stack.LinkEndpoint, error) {
e := &endpoint{
mtu: mtu,
bufferSize: bufferSize,
@@ -102,15 +102,15 @@ func New(mtu, bufferSize uint32, addr tcpip.LinkAddress, tx, rx QueueConfig) (tc
}
if err := e.tx.init(bufferSize, &tx); err != nil {
- return 0, err
+ return nil, err
}
if err := e.rx.init(bufferSize, &rx); err != nil {
e.tx.cleanup()
- return 0, err
+ return nil, err
}
- return stack.RegisterLinkEndpoint(e), nil
+ return e, nil
}
// Close frees all resources associated with the endpoint.
@@ -132,7 +132,8 @@ func (e *endpoint) Close() {
}
}
-// Wait waits until all workers have stopped after a Close() call.
+// Wait implements stack.LinkEndpoint.Wait. It waits until all workers have
+// stopped after a Close() call.
func (e *endpoint) Wait() {
e.completed.Wait()
}
diff --git a/pkg/tcpip/link/sharedmem/sharedmem_test.go b/pkg/tcpip/link/sharedmem/sharedmem_test.go
index 98036f367..0e9ba0846 100644
--- a/pkg/tcpip/link/sharedmem/sharedmem_test.go
+++ b/pkg/tcpip/link/sharedmem/sharedmem_test.go
@@ -119,12 +119,12 @@ func newTestContext(t *testing.T, mtu, bufferSize uint32, addr tcpip.LinkAddress
initQueue(t, &c.txq, &c.txCfg)
initQueue(t, &c.rxq, &c.rxCfg)
- id, err := New(mtu, bufferSize, addr, c.txCfg, c.rxCfg)
+ ep, err := New(mtu, bufferSize, addr, c.txCfg, c.rxCfg)
if err != nil {
t.Fatalf("New failed: %v", err)
}
- c.ep = stack.FindLinkEndpoint(id).(*endpoint)
+ c.ep = ep.(*endpoint)
c.ep.Attach(c)
return c
diff --git a/pkg/tcpip/link/sniffer/sniffer.go b/pkg/tcpip/link/sniffer/sniffer.go
index 36c8c46fc..e401dce44 100644
--- a/pkg/tcpip/link/sniffer/sniffer.go
+++ b/pkg/tcpip/link/sniffer/sniffer.go
@@ -58,10 +58,10 @@ type endpoint struct {
// New creates a new sniffer link-layer endpoint. It wraps around another
// endpoint and logs packets and they traverse the endpoint.
-func New(lower tcpip.LinkEndpointID) tcpip.LinkEndpointID {
- return stack.RegisterLinkEndpoint(&endpoint{
- lower: stack.FindLinkEndpoint(lower),
- })
+func New(lower stack.LinkEndpoint) stack.LinkEndpoint {
+ return &endpoint{
+ lower: lower,
+ }
}
func zoneOffset() (int32, error) {
@@ -102,15 +102,15 @@ func writePCAPHeader(w io.Writer, maxLen uint32) error {
// snapLen is the maximum amount of a packet to be saved. Packets with a length
// less than or equal too snapLen will be saved in their entirety. Longer
// packets will be truncated to snapLen.
-func NewWithFile(lower tcpip.LinkEndpointID, file *os.File, snapLen uint32) (tcpip.LinkEndpointID, error) {
+func NewWithFile(lower stack.LinkEndpoint, file *os.File, snapLen uint32) (stack.LinkEndpoint, error) {
if err := writePCAPHeader(file, snapLen); err != nil {
- return 0, err
+ return nil, err
}
- return stack.RegisterLinkEndpoint(&endpoint{
- lower: stack.FindLinkEndpoint(lower),
+ return &endpoint{
+ lower: lower,
file: file,
maxPCAPLen: snapLen,
- }), nil
+ }, nil
}
// DeliverNetworkPacket implements the stack.NetworkDispatcher interface. It is
@@ -240,6 +240,9 @@ func (e *endpoint) WritePacket(r *stack.Route, gso *stack.GSO, hdr buffer.Prepen
return e.lower.WritePacket(r, gso, hdr, payload, protocol)
}
+// Wait implements stack.LinkEndpoint.Wait.
+func (*endpoint) Wait() {}
+
func logPacket(prefix string, protocol tcpip.NetworkProtocolNumber, b buffer.View, gso *stack.GSO) {
// Figure out the network layer info.
var transProto uint8
diff --git a/pkg/tcpip/link/waitable/BUILD b/pkg/tcpip/link/waitable/BUILD
index 2597d4b3e..0746dc8ec 100644
--- a/pkg/tcpip/link/waitable/BUILD
+++ b/pkg/tcpip/link/waitable/BUILD
@@ -1,4 +1,5 @@
-load("//tools/go_stateify:defs.bzl", "go_library", "go_test")
+load("//tools/go_stateify:defs.bzl", "go_library")
+load("@io_bazel_rules_go//go:def.bzl", "go_test")
package(licenses = ["notice"])
diff --git a/pkg/tcpip/link/waitable/waitable.go b/pkg/tcpip/link/waitable/waitable.go
index 3b6ac2ff7..5a1791cb5 100644
--- a/pkg/tcpip/link/waitable/waitable.go
+++ b/pkg/tcpip/link/waitable/waitable.go
@@ -40,11 +40,10 @@ type Endpoint struct {
// New creates a new waitable link-layer endpoint. It wraps around another
// endpoint and allows the caller to block new write/dispatch calls and wait for
// the inflight ones to finish before returning.
-func New(lower tcpip.LinkEndpointID) (tcpip.LinkEndpointID, *Endpoint) {
- e := &Endpoint{
- lower: stack.FindLinkEndpoint(lower),
+func New(lower stack.LinkEndpoint) *Endpoint {
+ return &Endpoint{
+ lower: lower,
}
- return stack.RegisterLinkEndpoint(e), e
}
// DeliverNetworkPacket implements stack.NetworkDispatcher.DeliverNetworkPacket.
@@ -121,3 +120,6 @@ func (e *Endpoint) WaitWrite() {
func (e *Endpoint) WaitDispatch() {
e.dispatchGate.Close()
}
+
+// Wait implements stack.LinkEndpoint.Wait.
+func (e *Endpoint) Wait() {}
diff --git a/pkg/tcpip/link/waitable/waitable_test.go b/pkg/tcpip/link/waitable/waitable_test.go
index 56e18ecb0..ae23c96b7 100644
--- a/pkg/tcpip/link/waitable/waitable_test.go
+++ b/pkg/tcpip/link/waitable/waitable_test.go
@@ -70,9 +70,12 @@ func (e *countedEndpoint) WritePacket(r *stack.Route, _ *stack.GSO, hdr buffer.P
return nil
}
+// Wait implements stack.LinkEndpoint.Wait.
+func (*countedEndpoint) Wait() {}
+
func TestWaitWrite(t *testing.T) {
ep := &countedEndpoint{}
- _, wep := New(stack.RegisterLinkEndpoint(ep))
+ wep := New(ep)
// Write and check that it goes through.
wep.WritePacket(nil, nil /* gso */, buffer.Prependable{}, buffer.VectorisedView{}, 0)
@@ -97,7 +100,7 @@ func TestWaitWrite(t *testing.T) {
func TestWaitDispatch(t *testing.T) {
ep := &countedEndpoint{}
- _, wep := New(stack.RegisterLinkEndpoint(ep))
+ wep := New(ep)
// Check that attach happens.
wep.Attach(ep)
@@ -139,7 +142,7 @@ func TestOtherMethods(t *testing.T) {
hdrLen: hdrLen,
linkAddr: linkAddr,
}
- _, wep := New(stack.RegisterLinkEndpoint(ep))
+ wep := New(ep)
if v := wep.MTU(); v != mtu {
t.Fatalf("Unexpected mtu: got=%v, want=%v", v, mtu)
diff --git a/pkg/tcpip/network/BUILD b/pkg/tcpip/network/BUILD
index f36f49453..9d16ff8c9 100644
--- a/pkg/tcpip/network/BUILD
+++ b/pkg/tcpip/network/BUILD
@@ -1,4 +1,4 @@
-load("//tools/go_stateify:defs.bzl", "go_test")
+load("@io_bazel_rules_go//go:def.bzl", "go_test")
package(licenses = ["notice"])
diff --git a/pkg/tcpip/network/arp/BUILD b/pkg/tcpip/network/arp/BUILD
index d95d44f56..df0d3a8c0 100644
--- a/pkg/tcpip/network/arp/BUILD
+++ b/pkg/tcpip/network/arp/BUILD
@@ -1,4 +1,5 @@
-load("//tools/go_stateify:defs.bzl", "go_library", "go_test")
+load("//tools/go_stateify:defs.bzl", "go_library")
+load("@io_bazel_rules_go//go:def.bzl", "go_test")
package(licenses = ["notice"])
diff --git a/pkg/tcpip/network/arp/arp.go b/pkg/tcpip/network/arp/arp.go
index ea7296e6a..26cf1c528 100644
--- a/pkg/tcpip/network/arp/arp.go
+++ b/pkg/tcpip/network/arp/arp.go
@@ -16,9 +16,9 @@
// IPv4 addresses into link-local MAC addresses, and advertises IPv4
// addresses of its stack with the local network.
//
-// To use it in the networking stack, pass arp.ProtocolName as one of the
-// network protocols when calling stack.New. Then add an "arp" address to
-// every NIC on the stack that should respond to ARP requests. That is:
+// To use it in the networking stack, pass arp.NewProtocol() as one of the
+// network protocols when calling stack.New. Then add an "arp" address to every
+// NIC on the stack that should respond to ARP requests. That is:
//
// if err := s.AddAddress(1, arp.ProtocolNumber, "arp"); err != nil {
// // handle err
@@ -33,9 +33,6 @@ import (
)
const (
- // ProtocolName is the string representation of the ARP protocol name.
- ProtocolName = "arp"
-
// ProtocolNumber is the ARP protocol number.
ProtocolNumber = header.ARPProtocolNumber
@@ -112,11 +109,7 @@ func (e *endpoint) HandlePacket(r *stack.Route, vv buffer.VectorisedView) {
copy(pkt.HardwareAddressTarget(), h.HardwareAddressSender())
copy(pkt.ProtocolAddressTarget(), h.ProtocolAddressSender())
e.linkEP.WritePacket(r, nil /* gso */, hdr, buffer.VectorisedView{}, ProtocolNumber)
- fallthrough // also fill the cache from requests
case header.ARPReply:
- addr := tcpip.Address(h.ProtocolAddressSender())
- linkAddr := tcpip.LinkAddress(h.HardwareAddressSender())
- e.linkAddrCache.AddLinkAddress(e.nicid, addr, linkAddr)
}
}
@@ -204,8 +197,7 @@ func (p *protocol) Option(option interface{}) *tcpip.Error {
var broadcastMAC = tcpip.LinkAddress([]byte{0xff, 0xff, 0xff, 0xff, 0xff, 0xff})
-func init() {
- stack.RegisterNetworkProtocolFactory(ProtocolName, func() stack.NetworkProtocol {
- return &protocol{}
- })
+// NewProtocol returns an ARP network protocol.
+func NewProtocol() stack.NetworkProtocol {
+ return &protocol{}
}
diff --git a/pkg/tcpip/network/arp/arp_test.go b/pkg/tcpip/network/arp/arp_test.go
index 4c4b54469..88b57ec03 100644
--- a/pkg/tcpip/network/arp/arp_test.go
+++ b/pkg/tcpip/network/arp/arp_test.go
@@ -44,14 +44,19 @@ type testContext struct {
}
func newTestContext(t *testing.T) *testContext {
- s := stack.New([]string{ipv4.ProtocolName, arp.ProtocolName}, []string{icmp.ProtocolName4}, stack.Options{})
+ s := stack.New(stack.Options{
+ NetworkProtocols: []stack.NetworkProtocol{ipv4.NewProtocol(), arp.NewProtocol()},
+ TransportProtocols: []stack.TransportProtocol{icmp.NewProtocol4()},
+ })
const defaultMTU = 65536
- id, linkEP := channel.New(256, defaultMTU, stackLinkAddr)
+ ep := channel.New(256, defaultMTU, stackLinkAddr)
+ wep := stack.LinkEndpoint(ep)
+
if testing.Verbose() {
- id = sniffer.New(id)
+ wep = sniffer.New(ep)
}
- if err := s.CreateNIC(1, id); err != nil {
+ if err := s.CreateNIC(1, wep); err != nil {
t.Fatalf("CreateNIC failed: %v", err)
}
@@ -73,7 +78,7 @@ func newTestContext(t *testing.T) *testContext {
return &testContext{
t: t,
s: s,
- linkEP: linkEP,
+ linkEP: ep,
}
}
diff --git a/pkg/tcpip/network/fragmentation/BUILD b/pkg/tcpip/network/fragmentation/BUILD
index 118bfc763..c5c7aad86 100644
--- a/pkg/tcpip/network/fragmentation/BUILD
+++ b/pkg/tcpip/network/fragmentation/BUILD
@@ -1,7 +1,9 @@
+load("@io_bazel_rules_go//go:def.bzl", "go_test")
+
package(licenses = ["notice"])
load("//tools/go_generics:defs.bzl", "go_template_instance")
-load("//tools/go_stateify:defs.bzl", "go_library", "go_test")
+load("//tools/go_stateify:defs.bzl", "go_library")
go_template_instance(
name = "reassembler_list",
diff --git a/pkg/tcpip/network/ip_test.go b/pkg/tcpip/network/ip_test.go
index 6bbfcd97f..a9741622e 100644
--- a/pkg/tcpip/network/ip_test.go
+++ b/pkg/tcpip/network/ip_test.go
@@ -144,6 +144,9 @@ func (*testObject) LinkAddress() tcpip.LinkAddress {
return ""
}
+// Wait implements stack.LinkEndpoint.Wait.
+func (*testObject) Wait() {}
+
// WritePacket is called by network endpoints after producing a packet and
// writing it to the link endpoint. This is used by the test object to verify
// that the produced packet is as expected.
@@ -169,7 +172,10 @@ func (t *testObject) WritePacket(_ *stack.Route, _ *stack.GSO, hdr buffer.Prepen
}
func buildIPv4Route(local, remote tcpip.Address) (stack.Route, *tcpip.Error) {
- s := stack.New([]string{ipv4.ProtocolName}, []string{udp.ProtocolName, tcp.ProtocolName}, stack.Options{})
+ s := stack.New(stack.Options{
+ NetworkProtocols: []stack.NetworkProtocol{ipv4.NewProtocol()},
+ TransportProtocols: []stack.TransportProtocol{udp.NewProtocol(), tcp.NewProtocol()},
+ })
s.CreateNIC(1, loopback.New())
s.AddAddress(1, ipv4.ProtocolNumber, local)
s.SetRouteTable([]tcpip.Route{{
@@ -182,7 +188,10 @@ func buildIPv4Route(local, remote tcpip.Address) (stack.Route, *tcpip.Error) {
}
func buildIPv6Route(local, remote tcpip.Address) (stack.Route, *tcpip.Error) {
- s := stack.New([]string{ipv6.ProtocolName}, []string{udp.ProtocolName, tcp.ProtocolName}, stack.Options{})
+ s := stack.New(stack.Options{
+ NetworkProtocols: []stack.NetworkProtocol{ipv6.NewProtocol()},
+ TransportProtocols: []stack.TransportProtocol{udp.NewProtocol(), tcp.NewProtocol()},
+ })
s.CreateNIC(1, loopback.New())
s.AddAddress(1, ipv6.ProtocolNumber, local)
s.SetRouteTable([]tcpip.Route{{
@@ -319,7 +328,8 @@ func TestIPv4ReceiveControl(t *testing.T) {
icmp := header.ICMPv4(view[header.IPv4MinimumSize:])
icmp.SetType(header.ICMPv4DstUnreachable)
icmp.SetCode(c.code)
- copy(view[header.IPv4MinimumSize+header.ICMPv4PayloadOffset:], []byte{0xde, 0xad, 0xbe, 0xef})
+ icmp.SetIdent(0xdead)
+ icmp.SetSequence(0xbeef)
// Create the inner IPv4 header.
ip = header.IPv4(view[header.IPv4MinimumSize+header.ICMPv4MinimumSize:])
@@ -539,7 +549,7 @@ func TestIPv6ReceiveControl(t *testing.T) {
defer ep.Close()
- dataOffset := header.IPv6MinimumSize*2 + header.ICMPv6MinimumSize + 4
+ dataOffset := header.IPv6MinimumSize*2 + header.ICMPv6MinimumSize
if c.fragmentOffset != nil {
dataOffset += header.IPv6FragmentHeaderSize
}
@@ -559,10 +569,11 @@ func TestIPv6ReceiveControl(t *testing.T) {
icmp := header.ICMPv6(view[header.IPv6MinimumSize:])
icmp.SetType(c.typ)
icmp.SetCode(c.code)
- copy(view[header.IPv6MinimumSize+header.ICMPv6MinimumSize:], []byte{0xde, 0xad, 0xbe, 0xef})
+ icmp.SetIdent(0xdead)
+ icmp.SetSequence(0xbeef)
// Create the inner IPv6 header.
- ip = header.IPv6(view[header.IPv6MinimumSize+header.ICMPv6MinimumSize+4:])
+ ip = header.IPv6(view[header.IPv6MinimumSize+header.ICMPv6PayloadOffset:])
ip.Encode(&header.IPv6Fields{
PayloadLength: 100,
NextHeader: 10,
@@ -574,7 +585,7 @@ func TestIPv6ReceiveControl(t *testing.T) {
// Build the fragmentation header if needed.
if c.fragmentOffset != nil {
ip.SetNextHeader(header.IPv6FragmentHeader)
- frag := header.IPv6Fragment(view[2*header.IPv6MinimumSize+header.ICMPv6MinimumSize+4:])
+ frag := header.IPv6Fragment(view[2*header.IPv6MinimumSize+header.ICMPv6MinimumSize:])
frag.Encode(&header.IPv6FragmentFields{
NextHeader: 10,
FragmentOffset: *c.fragmentOffset,
diff --git a/pkg/tcpip/network/ipv4/BUILD b/pkg/tcpip/network/ipv4/BUILD
index be84fa63d..58e537aad 100644
--- a/pkg/tcpip/network/ipv4/BUILD
+++ b/pkg/tcpip/network/ipv4/BUILD
@@ -1,4 +1,5 @@
-load("//tools/go_stateify:defs.bzl", "go_library", "go_test")
+load("//tools/go_stateify:defs.bzl", "go_library")
+load("@io_bazel_rules_go//go:def.bzl", "go_test")
package(licenses = ["notice"])
diff --git a/pkg/tcpip/network/ipv4/icmp.go b/pkg/tcpip/network/ipv4/icmp.go
index 497164cbb..a25756443 100644
--- a/pkg/tcpip/network/ipv4/icmp.go
+++ b/pkg/tcpip/network/ipv4/icmp.go
@@ -15,8 +15,6 @@
package ipv4
import (
- "encoding/binary"
-
"gvisor.dev/gvisor/pkg/tcpip/buffer"
"gvisor.dev/gvisor/pkg/tcpip/header"
"gvisor.dev/gvisor/pkg/tcpip/stack"
@@ -117,7 +115,7 @@ func (e *endpoint) handleICMP(r *stack.Route, netHeader buffer.View, vv buffer.V
e.handleControl(stack.ControlPortUnreachable, 0, vv)
case header.ICMPv4FragmentationNeeded:
- mtu := uint32(binary.BigEndian.Uint16(v[header.ICMPv4PayloadOffset+2:]))
+ mtu := uint32(h.MTU())
e.handleControl(stack.ControlPacketTooBig, calculateMTU(mtu), vv)
}
diff --git a/pkg/tcpip/network/ipv4/ipv4.go b/pkg/tcpip/network/ipv4/ipv4.go
index b7a06f525..b7b07a6c1 100644
--- a/pkg/tcpip/network/ipv4/ipv4.go
+++ b/pkg/tcpip/network/ipv4/ipv4.go
@@ -14,9 +14,9 @@
// Package ipv4 contains the implementation of the ipv4 network protocol. To use
// it in the networking stack, this package must be added to the project, and
-// activated on the stack by passing ipv4.ProtocolName (or "ipv4") as one of the
-// network protocols when calling stack.New(). Then endpoints can be created
-// by passing ipv4.ProtocolNumber as the network protocol number when calling
+// activated on the stack by passing ipv4.NewProtocol() as one of the network
+// protocols when calling stack.New(). Then endpoints can be created by passing
+// ipv4.ProtocolNumber as the network protocol number when calling
// Stack.NewEndpoint().
package ipv4
@@ -32,9 +32,6 @@ import (
)
const (
- // ProtocolName is the string representation of the ipv4 protocol name.
- ProtocolName = "ipv4"
-
// ProtocolNumber is the ipv4 protocol number.
ProtocolNumber = header.IPv4ProtocolNumber
@@ -53,6 +50,7 @@ type endpoint struct {
linkEP stack.LinkEndpoint
dispatcher stack.TransportDispatcher
fragmentation *fragmentation.Fragmentation
+ protocol *protocol
}
// NewEndpoint creates a new ipv4 endpoint.
@@ -64,6 +62,7 @@ func (p *protocol) NewEndpoint(nicid tcpip.NICID, addrWithPrefix tcpip.AddressWi
linkEP: linkEP,
dispatcher: dispatcher,
fragmentation: fragmentation.NewFragmentation(fragmentation.HighFragThreshold, fragmentation.LowFragThreshold, fragmentation.DefaultReassembleTimeout),
+ protocol: p,
}
return e, nil
@@ -204,7 +203,7 @@ func (e *endpoint) WritePacket(r *stack.Route, gso *stack.GSO, hdr buffer.Prepen
if length > header.IPv4MaximumHeaderSize+8 {
// Packets of 68 bytes or less are required by RFC 791 to not be
// fragmented, so we only assign ids to larger packets.
- id = atomic.AddUint32(&ids[hashRoute(r, protocol)%buckets], 1)
+ id = atomic.AddUint32(&e.protocol.ids[hashRoute(r, protocol, e.protocol.hashIV)%buckets], 1)
}
ip.Encode(&header.IPv4Fields{
IHL: header.IPv4MinimumSize,
@@ -267,7 +266,7 @@ func (e *endpoint) WriteHeaderIncludedPacket(r *stack.Route, payload buffer.Vect
if payload.Size() > header.IPv4MaximumHeaderSize+8 {
// Packets of 68 bytes or less are required by RFC 791 to not be
// fragmented, so we only assign ids to larger packets.
- id = atomic.AddUint32(&ids[hashRoute(r, 0 /* protocol */)%buckets], 1)
+ id = atomic.AddUint32(&e.protocol.ids[hashRoute(r, 0 /* protocol */, e.protocol.hashIV)%buckets], 1)
}
ip.SetID(uint16(id))
}
@@ -325,14 +324,9 @@ func (e *endpoint) HandlePacket(r *stack.Route, vv buffer.VectorisedView) {
// Close cleans up resources associated with the endpoint.
func (e *endpoint) Close() {}
-type protocol struct{}
-
-// NewProtocol creates a new protocol ipv4 protocol descriptor. This is exported
-// only for tests that short-circuit the stack. Regular use of the protocol is
-// done via the stack, which gets a protocol descriptor from the init() function
-// below.
-func NewProtocol() stack.NetworkProtocol {
- return &protocol{}
+type protocol struct {
+ ids []uint32
+ hashIV uint32
}
// Number returns the ipv4 protocol number.
@@ -378,7 +372,7 @@ func calculateMTU(mtu uint32) uint32 {
// hashRoute calculates a hash value for the given route. It uses the source &
// destination address, the transport protocol number, and a random initial
// value (generated once on initialization) to generate the hash.
-func hashRoute(r *stack.Route, protocol tcpip.TransportProtocolNumber) uint32 {
+func hashRoute(r *stack.Route, protocol tcpip.TransportProtocolNumber, hashIV uint32) uint32 {
t := r.LocalAddress
a := uint32(t[0]) | uint32(t[1])<<8 | uint32(t[2])<<16 | uint32(t[3])<<24
t = r.RemoteAddress
@@ -386,22 +380,16 @@ func hashRoute(r *stack.Route, protocol tcpip.TransportProtocolNumber) uint32 {
return hash.Hash3Words(a, b, uint32(protocol), hashIV)
}
-var (
- ids []uint32
- hashIV uint32
-)
-
-func init() {
- ids = make([]uint32, buckets)
+// NewProtocol returns an IPv4 network protocol.
+func NewProtocol() stack.NetworkProtocol {
+ ids := make([]uint32, buckets)
// Randomly initialize hashIV and the ids.
r := hash.RandN32(1 + buckets)
for i := range ids {
ids[i] = r[i]
}
- hashIV = r[buckets]
+ hashIV := r[buckets]
- stack.RegisterNetworkProtocolFactory(ProtocolName, func() stack.NetworkProtocol {
- return &protocol{}
- })
+ return &protocol{ids: ids, hashIV: hashIV}
}
diff --git a/pkg/tcpip/network/ipv4/ipv4_test.go b/pkg/tcpip/network/ipv4/ipv4_test.go
index 1b5a55bea..b6641ccc3 100644
--- a/pkg/tcpip/network/ipv4/ipv4_test.go
+++ b/pkg/tcpip/network/ipv4/ipv4_test.go
@@ -33,14 +33,17 @@ import (
)
func TestExcludeBroadcast(t *testing.T) {
- s := stack.New([]string{ipv4.ProtocolName}, []string{udp.ProtocolName}, stack.Options{})
+ s := stack.New(stack.Options{
+ NetworkProtocols: []stack.NetworkProtocol{ipv4.NewProtocol()},
+ TransportProtocols: []stack.TransportProtocol{udp.NewProtocol()},
+ })
const defaultMTU = 65536
- id, _ := channel.New(256, defaultMTU, "")
+ ep := stack.LinkEndpoint(channel.New(256, defaultMTU, ""))
if testing.Verbose() {
- id = sniffer.New(id)
+ ep = sniffer.New(ep)
}
- if err := s.CreateNIC(1, id); err != nil {
+ if err := s.CreateNIC(1, ep); err != nil {
t.Fatalf("CreateNIC failed: %v", err)
}
@@ -184,15 +187,12 @@ type errorChannel struct {
// newErrorChannel creates a new errorChannel endpoint. Each call to WritePacket
// will return successive errors from packetCollectorErrors until the list is
// empty and then return nil each time.
-func newErrorChannel(size int, mtu uint32, linkAddr tcpip.LinkAddress, packetCollectorErrors []*tcpip.Error) (tcpip.LinkEndpointID, *errorChannel) {
- _, e := channel.New(size, mtu, linkAddr)
- ec := errorChannel{
- Endpoint: e,
+func newErrorChannel(size int, mtu uint32, linkAddr tcpip.LinkAddress, packetCollectorErrors []*tcpip.Error) *errorChannel {
+ return &errorChannel{
+ Endpoint: channel.New(size, mtu, linkAddr),
Ch: make(chan packetInfo, size),
packetCollectorErrors: packetCollectorErrors,
}
-
- return stack.RegisterLinkEndpoint(e), &ec
}
// packetInfo holds all the information about an outbound packet.
@@ -241,10 +241,11 @@ type context struct {
func buildContext(t *testing.T, packetCollectorErrors []*tcpip.Error, mtu uint32) context {
// Make the packet and write it.
- s := stack.New([]string{ipv4.ProtocolName}, []string{}, stack.Options{})
- _, linkEP := newErrorChannel(100 /* Enough for all tests. */, mtu, "", packetCollectorErrors)
- linkEPId := stack.RegisterLinkEndpoint(linkEP)
- s.CreateNIC(1, linkEPId)
+ s := stack.New(stack.Options{
+ NetworkProtocols: []stack.NetworkProtocol{ipv4.NewProtocol()},
+ })
+ ep := newErrorChannel(100 /* Enough for all tests. */, mtu, "", packetCollectorErrors)
+ s.CreateNIC(1, ep)
const (
src = "\x10\x00\x00\x01"
dst = "\x10\x00\x00\x02"
@@ -266,7 +267,7 @@ func buildContext(t *testing.T, packetCollectorErrors []*tcpip.Error, mtu uint32
}
return context{
Route: r,
- linkEP: linkEP,
+ linkEP: ep,
}
}
diff --git a/pkg/tcpip/network/ipv6/BUILD b/pkg/tcpip/network/ipv6/BUILD
index fae7f4507..f06622a8b 100644
--- a/pkg/tcpip/network/ipv6/BUILD
+++ b/pkg/tcpip/network/ipv6/BUILD
@@ -1,4 +1,5 @@
-load("//tools/go_stateify:defs.bzl", "go_library", "go_test")
+load("//tools/go_stateify:defs.bzl", "go_library")
+load("@io_bazel_rules_go//go:def.bzl", "go_test")
package(licenses = ["notice"])
@@ -23,7 +24,11 @@ go_library(
go_test(
name = "ipv6_test",
size = "small",
- srcs = ["icmp_test.go"],
+ srcs = [
+ "icmp_test.go",
+ "ipv6_test.go",
+ "ndp_test.go",
+ ],
embed = [":ipv6"],
deps = [
"//pkg/tcpip",
@@ -33,6 +38,7 @@ go_test(
"//pkg/tcpip/link/sniffer",
"//pkg/tcpip/stack",
"//pkg/tcpip/transport/icmp",
+ "//pkg/tcpip/transport/udp",
"//pkg/waiter",
],
)
diff --git a/pkg/tcpip/network/ipv6/icmp.go b/pkg/tcpip/network/ipv6/icmp.go
index 5e6a59e91..b4d0295bf 100644
--- a/pkg/tcpip/network/ipv6/icmp.go
+++ b/pkg/tcpip/network/ipv6/icmp.go
@@ -15,14 +15,21 @@
package ipv6
import (
- "encoding/binary"
-
"gvisor.dev/gvisor/pkg/tcpip"
"gvisor.dev/gvisor/pkg/tcpip/buffer"
"gvisor.dev/gvisor/pkg/tcpip/header"
"gvisor.dev/gvisor/pkg/tcpip/stack"
)
+const (
+ // ndpHopLimit is the expected IP hop limit value of 255 for received
+ // NDP packets, as per RFC 4861 sections 4.1 - 4.5, 6.1.1, 6.1.2, 7.1.1,
+ // 7.1.2 and 8.1. If the hop limit value is not 255, nodes MUST silently
+ // drop the NDP packet. All outgoing NDP packets must use this value for
+ // its IP hop limit field.
+ ndpHopLimit = 255
+)
+
// handleControl handles the case when an ICMP packet contains the headers of
// the original packet that caused the ICMP one to be sent. This information is
// used to find out which transport endpoint must be notified about the ICMP
@@ -73,6 +80,21 @@ func (e *endpoint) handleICMP(r *stack.Route, netHeader buffer.View, vv buffer.V
}
h := header.ICMPv6(v)
+ // As per RFC 4861 sections 4.1 - 4.5, 6.1.1, 6.1.2, 7.1.1, 7.1.2 and
+ // 8.1, nodes MUST silently drop NDP packets where the Hop Limit field
+ // in the IPv6 header is not set to 255.
+ switch h.Type() {
+ case header.ICMPv6NeighborSolicit,
+ header.ICMPv6NeighborAdvert,
+ header.ICMPv6RouterSolicit,
+ header.ICMPv6RouterAdvert,
+ header.ICMPv6RedirectMsg:
+ if header.IPv6(netHeader).HopLimit() != ndpHopLimit {
+ received.Invalid.Increment()
+ return
+ }
+ }
+
// TODO(b/112892170): Meaningfully handle all ICMP types.
switch h.Type() {
case header.ICMPv6PacketTooBig:
@@ -82,7 +104,7 @@ func (e *endpoint) handleICMP(r *stack.Route, netHeader buffer.View, vv buffer.V
return
}
vv.TrimFront(header.ICMPv6PacketTooBigMinimumSize)
- mtu := binary.BigEndian.Uint32(v[header.ICMPv6MinimumSize:])
+ mtu := h.MTU()
e.handleControl(stack.ControlPacketTooBig, calculateMTU(mtu), vv)
case header.ICMPv6DstUnreachable:
@@ -100,13 +122,11 @@ func (e *endpoint) handleICMP(r *stack.Route, netHeader buffer.View, vv buffer.V
case header.ICMPv6NeighborSolicit:
received.NeighborSolicit.Increment()
- e.linkAddrCache.AddLinkAddress(e.nicid, r.RemoteAddress, r.RemoteLinkAddress)
-
if len(v) < header.ICMPv6NeighborSolicitMinimumSize {
received.Invalid.Increment()
return
}
- targetAddr := tcpip.Address(v[8:][:16])
+ targetAddr := tcpip.Address(v[8:][:header.IPv6AddressSize])
if e.linkAddrCache.CheckLocalAddress(e.nicid, ProtocolNumber, targetAddr) == 0 {
// We don't have a useful answer; the best we can do is ignore the request.
return
@@ -132,7 +152,7 @@ func (e *endpoint) handleICMP(r *stack.Route, netHeader buffer.View, vv buffer.V
r := r.Clone()
defer r.Release()
r.LocalAddress = targetAddr
- pkt.SetChecksum(icmpChecksum(pkt, r.LocalAddress, r.RemoteAddress, buffer.VectorisedView{}))
+ pkt.SetChecksum(header.ICMPv6Checksum(pkt, r.LocalAddress, r.RemoteAddress, buffer.VectorisedView{}))
if err := r.WritePacket(nil /* gso */, hdr, buffer.VectorisedView{}, header.ICMPv6ProtocolNumber, r.DefaultTTL()); err != nil {
sent.Dropped.Increment()
@@ -146,7 +166,7 @@ func (e *endpoint) handleICMP(r *stack.Route, netHeader buffer.View, vv buffer.V
received.Invalid.Increment()
return
}
- targetAddr := tcpip.Address(v[8:][:16])
+ targetAddr := tcpip.Address(v[8:][:header.IPv6AddressSize])
e.linkAddrCache.AddLinkAddress(e.nicid, targetAddr, r.RemoteLinkAddress)
if targetAddr != r.RemoteAddress {
e.linkAddrCache.AddLinkAddress(e.nicid, r.RemoteAddress, r.RemoteLinkAddress)
@@ -164,7 +184,7 @@ func (e *endpoint) handleICMP(r *stack.Route, netHeader buffer.View, vv buffer.V
pkt := header.ICMPv6(hdr.Prepend(header.ICMPv6EchoMinimumSize))
copy(pkt, h)
pkt.SetType(header.ICMPv6EchoReply)
- pkt.SetChecksum(icmpChecksum(pkt, r.LocalAddress, r.RemoteAddress, vv))
+ pkt.SetChecksum(header.ICMPv6Checksum(pkt, r.LocalAddress, r.RemoteAddress, vv))
if err := r.WritePacket(nil /* gso */, hdr, vv, header.ICMPv6ProtocolNumber, r.DefaultTTL()); err != nil {
sent.Dropped.Increment()
return
@@ -235,7 +255,7 @@ func (*protocol) LinkAddressRequest(addr, localAddr tcpip.Address, linkEP stack.
pkt[icmpV6OptOffset] = ndpOptSrcLinkAddr
pkt[icmpV6LengthOffset] = 1
copy(pkt[icmpV6LengthOffset+1:], linkEP.LinkAddress())
- pkt.SetChecksum(icmpChecksum(pkt, r.LocalAddress, r.RemoteAddress, buffer.VectorisedView{}))
+ pkt.SetChecksum(header.ICMPv6Checksum(pkt, r.LocalAddress, r.RemoteAddress, buffer.VectorisedView{}))
length := uint16(hdr.UsedLength())
ip := header.IPv6(hdr.Prepend(header.IPv6MinimumSize))
@@ -274,24 +294,3 @@ func (*protocol) ResolveStaticAddress(addr tcpip.Address) (tcpip.LinkAddress, bo
}
return "", false
}
-
-func icmpChecksum(h header.ICMPv6, src, dst tcpip.Address, vv buffer.VectorisedView) uint16 {
- // Calculate the IPv6 pseudo-header upper-layer checksum.
- xsum := header.Checksum([]byte(src), 0)
- xsum = header.Checksum([]byte(dst), xsum)
- var upperLayerLength [4]byte
- binary.BigEndian.PutUint32(upperLayerLength[:], uint32(len(h)+vv.Size()))
- xsum = header.Checksum(upperLayerLength[:], xsum)
- xsum = header.Checksum([]byte{0, 0, 0, uint8(header.ICMPv6ProtocolNumber)}, xsum)
- for _, v := range vv.Views() {
- xsum = header.Checksum(v, xsum)
- }
-
- // h[2:4] is the checksum itself, set it aside to avoid checksumming the checksum.
- h2, h3 := h[2], h[3]
- h[2], h[3] = 0, 0
- xsum = ^header.Checksum(h, xsum)
- h[2], h[3] = h2, h3
-
- return xsum
-}
diff --git a/pkg/tcpip/network/ipv6/icmp_test.go b/pkg/tcpip/network/ipv6/icmp_test.go
index d0dc72506..01f5a17ec 100644
--- a/pkg/tcpip/network/ipv6/icmp_test.go
+++ b/pkg/tcpip/network/ipv6/icmp_test.go
@@ -81,10 +81,12 @@ func (*stubLinkAddressCache) AddLinkAddress(tcpip.NICID, tcpip.Address, tcpip.Li
}
func TestICMPCounts(t *testing.T) {
- s := stack.New([]string{ProtocolName}, []string{icmp.ProtocolName6}, stack.Options{})
+ s := stack.New(stack.Options{
+ NetworkProtocols: []stack.NetworkProtocol{NewProtocol()},
+ TransportProtocols: []stack.TransportProtocol{icmp.NewProtocol6()},
+ })
{
- id := stack.RegisterLinkEndpoint(&stubLinkEndpoint{})
- if err := s.CreateNIC(1, id); err != nil {
+ if err := s.CreateNIC(1, &stubLinkEndpoint{}); err != nil {
t.Fatalf("CreateNIC(_) = %s", err)
}
if err := s.AddAddress(1, ProtocolNumber, lladdr0); err != nil {
@@ -153,7 +155,7 @@ func TestICMPCounts(t *testing.T) {
hdr := buffer.NewPrependable(header.IPv6MinimumSize + typ.size)
pkt := header.ICMPv6(hdr.Prepend(typ.size))
pkt.SetType(typ.typ)
- pkt.SetChecksum(icmpChecksum(pkt, r.LocalAddress, r.RemoteAddress, buffer.VectorisedView{}))
+ pkt.SetChecksum(header.ICMPv6Checksum(pkt, r.LocalAddress, r.RemoteAddress, buffer.VectorisedView{}))
handleIPv6Payload(hdr)
}
@@ -206,41 +208,38 @@ func (e endpointWithResolutionCapability) Capabilities() stack.LinkEndpointCapab
func newTestContext(t *testing.T) *testContext {
c := &testContext{
- s0: stack.New([]string{ProtocolName}, []string{icmp.ProtocolName6}, stack.Options{}),
- s1: stack.New([]string{ProtocolName}, []string{icmp.ProtocolName6}, stack.Options{}),
+ s0: stack.New(stack.Options{
+ NetworkProtocols: []stack.NetworkProtocol{NewProtocol()},
+ TransportProtocols: []stack.TransportProtocol{icmp.NewProtocol6()},
+ }),
+ s1: stack.New(stack.Options{
+ NetworkProtocols: []stack.NetworkProtocol{NewProtocol()},
+ TransportProtocols: []stack.TransportProtocol{icmp.NewProtocol6()},
+ }),
}
const defaultMTU = 65536
- _, linkEP0 := channel.New(256, defaultMTU, linkAddr0)
- c.linkEP0 = linkEP0
- wrappedEP0 := endpointWithResolutionCapability{LinkEndpoint: linkEP0}
- id0 := stack.RegisterLinkEndpoint(wrappedEP0)
+ c.linkEP0 = channel.New(256, defaultMTU, linkAddr0)
+
+ wrappedEP0 := stack.LinkEndpoint(endpointWithResolutionCapability{LinkEndpoint: c.linkEP0})
if testing.Verbose() {
- id0 = sniffer.New(id0)
+ wrappedEP0 = sniffer.New(wrappedEP0)
}
- if err := c.s0.CreateNIC(1, id0); err != nil {
+ if err := c.s0.CreateNIC(1, wrappedEP0); err != nil {
t.Fatalf("CreateNIC s0: %v", err)
}
if err := c.s0.AddAddress(1, ProtocolNumber, lladdr0); err != nil {
t.Fatalf("AddAddress lladdr0: %v", err)
}
- if err := c.s0.AddAddress(1, ProtocolNumber, header.SolicitedNodeAddr(lladdr0)); err != nil {
- t.Fatalf("AddAddress sn lladdr0: %v", err)
- }
- _, linkEP1 := channel.New(256, defaultMTU, linkAddr1)
- c.linkEP1 = linkEP1
- wrappedEP1 := endpointWithResolutionCapability{LinkEndpoint: linkEP1}
- id1 := stack.RegisterLinkEndpoint(wrappedEP1)
- if err := c.s1.CreateNIC(1, id1); err != nil {
+ c.linkEP1 = channel.New(256, defaultMTU, linkAddr1)
+ wrappedEP1 := stack.LinkEndpoint(endpointWithResolutionCapability{LinkEndpoint: c.linkEP1})
+ if err := c.s1.CreateNIC(1, wrappedEP1); err != nil {
t.Fatalf("CreateNIC failed: %v", err)
}
if err := c.s1.AddAddress(1, ProtocolNumber, lladdr1); err != nil {
t.Fatalf("AddAddress lladdr1: %v", err)
}
- if err := c.s1.AddAddress(1, ProtocolNumber, header.SolicitedNodeAddr(lladdr1)); err != nil {
- t.Fatalf("AddAddress sn lladdr1: %v", err)
- }
subnet0, err := tcpip.NewSubnet(lladdr1, tcpip.AddressMask(strings.Repeat("\xff", len(lladdr1))))
if err != nil {
@@ -321,7 +320,7 @@ func TestLinkResolution(t *testing.T) {
hdr := buffer.NewPrependable(int(r.MaxHeaderLength()) + header.IPv6MinimumSize + header.ICMPv6EchoMinimumSize)
pkt := header.ICMPv6(hdr.Prepend(header.ICMPv6EchoMinimumSize))
pkt.SetType(header.ICMPv6EchoRequest)
- pkt.SetChecksum(icmpChecksum(pkt, r.LocalAddress, r.RemoteAddress, buffer.VectorisedView{}))
+ pkt.SetChecksum(header.ICMPv6Checksum(pkt, r.LocalAddress, r.RemoteAddress, buffer.VectorisedView{}))
payload := tcpip.SlicePayload(hdr.View())
// We can't send our payload directly over the route because that
diff --git a/pkg/tcpip/network/ipv6/ipv6.go b/pkg/tcpip/network/ipv6/ipv6.go
index 331a8bdaa..7de6a4546 100644
--- a/pkg/tcpip/network/ipv6/ipv6.go
+++ b/pkg/tcpip/network/ipv6/ipv6.go
@@ -14,9 +14,9 @@
// Package ipv6 contains the implementation of the ipv6 network protocol. To use
// it in the networking stack, this package must be added to the project, and
-// activated on the stack by passing ipv6.ProtocolName (or "ipv6") as one of the
-// network protocols when calling stack.New(). Then endpoints can be created
-// by passing ipv6.ProtocolNumber as the network protocol number when calling
+// activated on the stack by passing ipv6.NewProtocol() as one of the network
+// protocols when calling stack.New(). Then endpoints can be created by passing
+// ipv6.ProtocolNumber as the network protocol number when calling
// Stack.NewEndpoint().
package ipv6
@@ -28,9 +28,6 @@ import (
)
const (
- // ProtocolName is the string representation of the ipv6 protocol name.
- ProtocolName = "ipv6"
-
// ProtocolNumber is the ipv6 protocol number.
ProtocolNumber = header.IPv6ProtocolNumber
@@ -160,14 +157,6 @@ func (*endpoint) Close() {}
type protocol struct{}
-// NewProtocol creates a new protocol ipv6 protocol descriptor. This is exported
-// only for tests that short-circuit the stack. Regular use of the protocol is
-// done via the stack, which gets a protocol descriptor from the init() function
-// below.
-func NewProtocol() stack.NetworkProtocol {
- return &protocol{}
-}
-
// Number returns the ipv6 protocol number.
func (p *protocol) Number() tcpip.NetworkProtocolNumber {
return ProtocolNumber
@@ -221,8 +210,7 @@ func calculateMTU(mtu uint32) uint32 {
return maxPayloadSize
}
-func init() {
- stack.RegisterNetworkProtocolFactory(ProtocolName, func() stack.NetworkProtocol {
- return &protocol{}
- })
+// NewProtocol returns an IPv6 network protocol.
+func NewProtocol() stack.NetworkProtocol {
+ return &protocol{}
}
diff --git a/pkg/tcpip/network/ipv6/ipv6_test.go b/pkg/tcpip/network/ipv6/ipv6_test.go
new file mode 100644
index 000000000..78c674c2c
--- /dev/null
+++ b/pkg/tcpip/network/ipv6/ipv6_test.go
@@ -0,0 +1,266 @@
+// Copyright 2019 The gVisor Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+package ipv6
+
+import (
+ "testing"
+
+ "gvisor.dev/gvisor/pkg/tcpip"
+ "gvisor.dev/gvisor/pkg/tcpip/buffer"
+ "gvisor.dev/gvisor/pkg/tcpip/header"
+ "gvisor.dev/gvisor/pkg/tcpip/link/channel"
+ "gvisor.dev/gvisor/pkg/tcpip/stack"
+ "gvisor.dev/gvisor/pkg/tcpip/transport/icmp"
+ "gvisor.dev/gvisor/pkg/tcpip/transport/udp"
+ "gvisor.dev/gvisor/pkg/waiter"
+)
+
+const (
+ addr1 = "\x0a\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x01"
+ addr2 = "\x0a\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x02"
+ // The least significant 3 bytes are the same as addr2 so both addr2 and
+ // addr3 will have the same solicited-node address.
+ addr3 = "\x0a\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x01\x00\x00\x02"
+)
+
+// testReceiveICMP tests receiving an ICMP packet from src to dst. want is the
+// expected Neighbor Advertisement received count after receiving the packet.
+func testReceiveICMP(t *testing.T, s *stack.Stack, e *channel.Endpoint, src, dst tcpip.Address, want uint64) {
+ t.Helper()
+
+ // Receive ICMP packet.
+ hdr := buffer.NewPrependable(header.IPv6MinimumSize + header.ICMPv6NeighborAdvertSize)
+ pkt := header.ICMPv6(hdr.Prepend(header.ICMPv6NeighborAdvertSize))
+ pkt.SetType(header.ICMPv6NeighborAdvert)
+ pkt.SetChecksum(header.ICMPv6Checksum(pkt, src, dst, buffer.VectorisedView{}))
+ payloadLength := hdr.UsedLength()
+ ip := header.IPv6(hdr.Prepend(header.IPv6MinimumSize))
+ ip.Encode(&header.IPv6Fields{
+ PayloadLength: uint16(payloadLength),
+ NextHeader: uint8(header.ICMPv6ProtocolNumber),
+ HopLimit: 255,
+ SrcAddr: src,
+ DstAddr: dst,
+ })
+
+ e.Inject(ProtocolNumber, hdr.View().ToVectorisedView())
+
+ stats := s.Stats().ICMP.V6PacketsReceived
+
+ if got := stats.NeighborAdvert.Value(); got != want {
+ t.Fatalf("got NeighborAdvert = %d, want = %d", got, want)
+ }
+}
+
+// testReceiveICMP tests receiving a UDP packet from src to dst. want is the
+// expected UDP received count after receiving the packet.
+func testReceiveUDP(t *testing.T, s *stack.Stack, e *channel.Endpoint, src, dst tcpip.Address, want uint64) {
+ t.Helper()
+
+ wq := waiter.Queue{}
+ we, ch := waiter.NewChannelEntry(nil)
+ wq.EventRegister(&we, waiter.EventIn)
+ defer wq.EventUnregister(&we)
+ defer close(ch)
+
+ ep, err := s.NewEndpoint(udp.ProtocolNumber, ProtocolNumber, &wq)
+ if err != nil {
+ t.Fatalf("NewEndpoint failed: %v", err)
+ }
+ defer ep.Close()
+
+ if err := ep.Bind(tcpip.FullAddress{Addr: dst, Port: 80}); err != nil {
+ t.Fatalf("ep.Bind(...) failed: %v", err)
+ }
+
+ // Receive UDP Packet.
+ hdr := buffer.NewPrependable(header.IPv6MinimumSize + header.UDPMinimumSize)
+ u := header.UDP(hdr.Prepend(header.UDPMinimumSize))
+ u.Encode(&header.UDPFields{
+ SrcPort: 5555,
+ DstPort: 80,
+ Length: header.UDPMinimumSize,
+ })
+
+ // UDP pseudo-header checksum.
+ sum := header.PseudoHeaderChecksum(udp.ProtocolNumber, src, dst, header.UDPMinimumSize)
+
+ // UDP checksum
+ sum = header.Checksum(header.UDP([]byte{}), sum)
+ u.SetChecksum(^u.CalculateChecksum(sum))
+
+ payloadLength := hdr.UsedLength()
+ ip := header.IPv6(hdr.Prepend(header.IPv6MinimumSize))
+ ip.Encode(&header.IPv6Fields{
+ PayloadLength: uint16(payloadLength),
+ NextHeader: uint8(udp.ProtocolNumber),
+ HopLimit: 255,
+ SrcAddr: src,
+ DstAddr: dst,
+ })
+
+ e.Inject(ProtocolNumber, hdr.View().ToVectorisedView())
+
+ stat := s.Stats().UDP.PacketsReceived
+
+ if got := stat.Value(); got != want {
+ t.Fatalf("got UDPPacketsReceived = %d, want = %d", got, want)
+ }
+}
+
+// TestReceiveOnAllNodesMulticastAddr tests that IPv6 endpoints receive ICMP and
+// UDP packets destined to the IPv6 link-local all-nodes multicast address.
+func TestReceiveOnAllNodesMulticastAddr(t *testing.T) {
+ tests := []struct {
+ name string
+ protocolFactory stack.TransportProtocol
+ rxf func(t *testing.T, s *stack.Stack, e *channel.Endpoint, src, dst tcpip.Address, want uint64)
+ }{
+ {"ICMP", icmp.NewProtocol6(), testReceiveICMP},
+ {"UDP", udp.NewProtocol(), testReceiveUDP},
+ }
+
+ for _, test := range tests {
+ t.Run(test.name, func(t *testing.T) {
+ s := stack.New(stack.Options{
+ NetworkProtocols: []stack.NetworkProtocol{NewProtocol()},
+ TransportProtocols: []stack.TransportProtocol{test.protocolFactory},
+ })
+ e := channel.New(10, 1280, linkAddr1)
+ if err := s.CreateNIC(1, e); err != nil {
+ t.Fatalf("CreateNIC(_) = %s", err)
+ }
+
+ // Should receive a packet destined to the all-nodes
+ // multicast address.
+ test.rxf(t, s, e, addr1, header.IPv6AllNodesMulticastAddress, 1)
+ })
+ }
+}
+
+// TestReceiveOnSolicitedNodeAddr tests that IPv6 endpoints receive ICMP and UDP
+// packets destined to the IPv6 solicited-node address of an assigned IPv6
+// address.
+func TestReceiveOnSolicitedNodeAddr(t *testing.T) {
+ tests := []struct {
+ name string
+ protocolFactory stack.TransportProtocol
+ rxf func(t *testing.T, s *stack.Stack, e *channel.Endpoint, src, dst tcpip.Address, want uint64)
+ }{
+ {"ICMP", icmp.NewProtocol6(), testReceiveICMP},
+ {"UDP", udp.NewProtocol(), testReceiveUDP},
+ }
+
+ snmc := header.SolicitedNodeAddr(addr2)
+
+ for _, test := range tests {
+ t.Run(test.name, func(t *testing.T) {
+ s := stack.New(stack.Options{
+ NetworkProtocols: []stack.NetworkProtocol{NewProtocol()},
+ TransportProtocols: []stack.TransportProtocol{test.protocolFactory},
+ })
+ e := channel.New(10, 1280, linkAddr1)
+ if err := s.CreateNIC(1, e); err != nil {
+ t.Fatalf("CreateNIC(_) = %s", err)
+ }
+
+ // Should not receive a packet destined to the solicited
+ // node address of addr2/addr3 yet as we haven't added
+ // those addresses.
+ test.rxf(t, s, e, addr1, snmc, 0)
+
+ if err := s.AddAddress(1, ProtocolNumber, addr2); err != nil {
+ t.Fatalf("AddAddress(_, %d, %s) = %s", ProtocolNumber, addr2, err)
+ }
+
+ // Should receive a packet destined to the solicited
+ // node address of addr2/addr3 now that we have added
+ // added addr2.
+ test.rxf(t, s, e, addr1, snmc, 1)
+
+ if err := s.AddAddress(1, ProtocolNumber, addr3); err != nil {
+ t.Fatalf("AddAddress(_, %d, %s) = %s", ProtocolNumber, addr3, err)
+ }
+
+ // Should still receive a packet destined to the
+ // solicited node address of addr2/addr3 now that we
+ // have added addr3.
+ test.rxf(t, s, e, addr1, snmc, 2)
+
+ if err := s.RemoveAddress(1, addr2); err != nil {
+ t.Fatalf("RemoveAddress(_, %s) = %s", addr2, err)
+ }
+
+ // Should still receive a packet destined to the
+ // solicited node address of addr2/addr3 now that we
+ // have removed addr2.
+ test.rxf(t, s, e, addr1, snmc, 3)
+
+ if err := s.RemoveAddress(1, addr3); err != nil {
+ t.Fatalf("RemoveAddress(_, %s) = %s", addr3, err)
+ }
+
+ // Should not receive a packet destined to the solicited
+ // node address of addr2/addr3 yet as both of them got
+ // removed.
+ test.rxf(t, s, e, addr1, snmc, 3)
+ })
+ }
+}
+
+// TestAddIpv6Address tests adding IPv6 addresses.
+func TestAddIpv6Address(t *testing.T) {
+ tests := []struct {
+ name string
+ addr tcpip.Address
+ }{
+ // This test is in response to b/140943433.
+ {
+ "Nil",
+ tcpip.Address([]byte(nil)),
+ },
+ {
+ "ValidUnicast",
+ addr1,
+ },
+ {
+ "ValidLinkLocalUnicast",
+ lladdr0,
+ },
+ }
+
+ for _, test := range tests {
+ t.Run(test.name, func(t *testing.T) {
+ s := stack.New(stack.Options{
+ NetworkProtocols: []stack.NetworkProtocol{NewProtocol()},
+ })
+ if err := s.CreateNIC(1, &stubLinkEndpoint{}); err != nil {
+ t.Fatalf("CreateNIC(_) = %s", err)
+ }
+
+ if err := s.AddAddress(1, ProtocolNumber, test.addr); err != nil {
+ t.Fatalf("AddAddress(_, %d, nil) = %s", ProtocolNumber, err)
+ }
+
+ addr, err := s.GetMainNICAddress(1, header.IPv6ProtocolNumber)
+ if err != nil {
+ t.Fatalf("stack.GetMainNICAddress(_, _) err = %s", err)
+ }
+ if addr.Address != test.addr {
+ t.Fatalf("got stack.GetMainNICAddress(_, _) = %s, want = %s", addr.Address, test.addr)
+ }
+ })
+ }
+}
diff --git a/pkg/tcpip/network/ipv6/ndp_test.go b/pkg/tcpip/network/ipv6/ndp_test.go
new file mode 100644
index 000000000..e30791fe3
--- /dev/null
+++ b/pkg/tcpip/network/ipv6/ndp_test.go
@@ -0,0 +1,181 @@
+// Copyright 2019 The gVisor Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+package ipv6
+
+import (
+ "strings"
+ "testing"
+
+ "gvisor.dev/gvisor/pkg/tcpip"
+ "gvisor.dev/gvisor/pkg/tcpip/buffer"
+ "gvisor.dev/gvisor/pkg/tcpip/header"
+ "gvisor.dev/gvisor/pkg/tcpip/stack"
+ "gvisor.dev/gvisor/pkg/tcpip/transport/icmp"
+)
+
+// setupStackAndEndpoint creates a stack with a single NIC with a link-local
+// address llladdr and an IPv6 endpoint to a remote with link-local address
+// rlladdr
+func setupStackAndEndpoint(t *testing.T, llladdr, rlladdr tcpip.Address) (*stack.Stack, stack.NetworkEndpoint) {
+ t.Helper()
+
+ s := stack.New(stack.Options{
+ NetworkProtocols: []stack.NetworkProtocol{NewProtocol()},
+ TransportProtocols: []stack.TransportProtocol{icmp.NewProtocol6()},
+ })
+
+ if err := s.CreateNIC(1, &stubLinkEndpoint{}); err != nil {
+ t.Fatalf("CreateNIC(_) = %s", err)
+ }
+ if err := s.AddAddress(1, ProtocolNumber, llladdr); err != nil {
+ t.Fatalf("AddAddress(_, %d, %s) = %s", ProtocolNumber, llladdr, err)
+ }
+
+ {
+ subnet, err := tcpip.NewSubnet(rlladdr, tcpip.AddressMask(strings.Repeat("\xff", len(rlladdr))))
+ if err != nil {
+ t.Fatal(err)
+ }
+ s.SetRouteTable(
+ []tcpip.Route{{
+ Destination: subnet,
+ NIC: 1,
+ }},
+ )
+ }
+
+ netProto := s.NetworkProtocolInstance(ProtocolNumber)
+ if netProto == nil {
+ t.Fatalf("cannot find protocol instance for network protocol %d", ProtocolNumber)
+ }
+
+ ep, err := netProto.NewEndpoint(0, tcpip.AddressWithPrefix{rlladdr, netProto.DefaultPrefixLen()}, &stubLinkAddressCache{}, &stubDispatcher{}, nil)
+ if err != nil {
+ t.Fatalf("NewEndpoint(_) = _, %s, want = _, nil", err)
+ }
+
+ return s, ep
+}
+
+// TestHopLimitValidation is a test that makes sure that NDP packets are only
+// received if their IP header's hop limit is set to 255.
+func TestHopLimitValidation(t *testing.T) {
+ setup := func(t *testing.T) (*stack.Stack, stack.NetworkEndpoint, stack.Route) {
+ t.Helper()
+
+ // Create a stack with the assigned link-local address lladdr0
+ // and an endpoint to lladdr1.
+ s, ep := setupStackAndEndpoint(t, lladdr0, lladdr1)
+
+ r, err := s.FindRoute(1, lladdr0, lladdr1, ProtocolNumber, false /* multicastLoop */)
+ if err != nil {
+ t.Fatalf("FindRoute(_) = _, %s, want = _, nil", err)
+ }
+
+ return s, ep, r
+ }
+
+ handleIPv6Payload := func(hdr buffer.Prependable, hopLimit uint8, ep stack.NetworkEndpoint, r *stack.Route) {
+ payloadLength := hdr.UsedLength()
+ ip := header.IPv6(hdr.Prepend(header.IPv6MinimumSize))
+ ip.Encode(&header.IPv6Fields{
+ PayloadLength: uint16(payloadLength),
+ NextHeader: uint8(header.ICMPv6ProtocolNumber),
+ HopLimit: hopLimit,
+ SrcAddr: r.LocalAddress,
+ DstAddr: r.RemoteAddress,
+ })
+ ep.HandlePacket(r, hdr.View().ToVectorisedView())
+ }
+
+ types := []struct {
+ name string
+ typ header.ICMPv6Type
+ size int
+ statCounter func(tcpip.ICMPv6ReceivedPacketStats) *tcpip.StatCounter
+ }{
+ {"RouterSolicit", header.ICMPv6RouterSolicit, header.ICMPv6MinimumSize, func(stats tcpip.ICMPv6ReceivedPacketStats) *tcpip.StatCounter {
+ return stats.RouterSolicit
+ }},
+ {"RouterAdvert", header.ICMPv6RouterAdvert, header.ICMPv6MinimumSize, func(stats tcpip.ICMPv6ReceivedPacketStats) *tcpip.StatCounter {
+ return stats.RouterAdvert
+ }},
+ {"NeighborSolicit", header.ICMPv6NeighborSolicit, header.ICMPv6NeighborSolicitMinimumSize, func(stats tcpip.ICMPv6ReceivedPacketStats) *tcpip.StatCounter {
+ return stats.NeighborSolicit
+ }},
+ {"NeighborAdvert", header.ICMPv6NeighborAdvert, header.ICMPv6NeighborAdvertSize, func(stats tcpip.ICMPv6ReceivedPacketStats) *tcpip.StatCounter {
+ return stats.NeighborAdvert
+ }},
+ {"RedirectMsg", header.ICMPv6RedirectMsg, header.ICMPv6MinimumSize, func(stats tcpip.ICMPv6ReceivedPacketStats) *tcpip.StatCounter {
+ return stats.RedirectMsg
+ }},
+ }
+
+ for _, typ := range types {
+ t.Run(typ.name, func(t *testing.T) {
+ s, ep, r := setup(t)
+ defer r.Release()
+
+ stats := s.Stats().ICMP.V6PacketsReceived
+ invalid := stats.Invalid
+ typStat := typ.statCounter(stats)
+
+ hdr := buffer.NewPrependable(header.IPv6MinimumSize + typ.size)
+ pkt := header.ICMPv6(hdr.Prepend(typ.size))
+ pkt.SetType(typ.typ)
+ pkt.SetChecksum(header.ICMPv6Checksum(pkt, r.LocalAddress, r.RemoteAddress, buffer.VectorisedView{}))
+
+ // Invalid count should initially be 0.
+ if got := invalid.Value(); got != 0 {
+ t.Fatalf("got invalid = %d, want = 0", got)
+ }
+
+ // Should not have received any ICMPv6 packets with
+ // type = typ.typ.
+ if got := typStat.Value(); got != 0 {
+ t.Fatalf("got %s = %d, want = 0", typ.name, got)
+ }
+
+ // Receive the NDP packet with an invalid hop limit
+ // value.
+ handleIPv6Payload(hdr, ndpHopLimit-1, ep, &r)
+
+ // Invalid count should have increased.
+ if got := invalid.Value(); got != 1 {
+ t.Fatalf("got invalid = %d, want = 1", got)
+ }
+
+ // Rx count of NDP packet of type typ.typ should not
+ // have increased.
+ if got := typStat.Value(); got != 0 {
+ t.Fatalf("got %s = %d, want = 0", typ.name, got)
+ }
+
+ // Receive the NDP packet with a valid hop limit value.
+ handleIPv6Payload(hdr, ndpHopLimit, ep, &r)
+
+ // Rx count of NDP packet of type typ.typ should have
+ // increased.
+ if got := typStat.Value(); got != 1 {
+ t.Fatalf("got %s = %d, want = 1", typ.name, got)
+ }
+
+ // Invalid count should not have increased again.
+ if got := invalid.Value(); got != 1 {
+ t.Fatalf("got invalid = %d, want = 1", got)
+ }
+ })
+ }
+}
diff --git a/pkg/tcpip/ports/BUILD b/pkg/tcpip/ports/BUILD
index 989058413..11efb4e44 100644
--- a/pkg/tcpip/ports/BUILD
+++ b/pkg/tcpip/ports/BUILD
@@ -1,4 +1,5 @@
-load("//tools/go_stateify:defs.bzl", "go_library", "go_test")
+load("//tools/go_stateify:defs.bzl", "go_library")
+load("@io_bazel_rules_go//go:def.bzl", "go_test")
package(licenses = ["notice"])
diff --git a/pkg/tcpip/ports/ports.go b/pkg/tcpip/ports/ports.go
index 315780c0c..30cea8996 100644
--- a/pkg/tcpip/ports/ports.go
+++ b/pkg/tcpip/ports/ports.go
@@ -19,6 +19,7 @@ import (
"math"
"math/rand"
"sync"
+ "sync/atomic"
"gvisor.dev/gvisor/pkg/tcpip"
)
@@ -27,6 +28,10 @@ const (
// FirstEphemeral is the first ephemeral port.
FirstEphemeral = 16000
+ // numEphemeralPorts it the mnumber of available ephemeral ports to
+ // Netstack.
+ numEphemeralPorts = math.MaxUint16 - FirstEphemeral + 1
+
anyIPAddress tcpip.Address = ""
)
@@ -40,6 +45,13 @@ type portDescriptor struct {
type PortManager struct {
mu sync.RWMutex
allocatedPorts map[portDescriptor]bindAddresses
+
+ // hint is used to pick ports ephemeral ports in a stable order for
+ // a given port offset.
+ //
+ // hint must be accessed using the portHint/incPortHint helpers.
+ // TODO(gvisor.dev/issue/940): S/R this field.
+ hint uint32
}
type portNode struct {
@@ -47,43 +59,76 @@ type portNode struct {
refs int
}
-// bindAddresses is a set of IP addresses.
-type bindAddresses map[tcpip.Address]portNode
+// deviceNode is never empty. When it has no elements, it is removed from the
+// map that references it.
+type deviceNode map[tcpip.NICID]portNode
-// isAvailable checks whether an IP address is available to bind to.
-func (b bindAddresses) isAvailable(addr tcpip.Address, reuse bool) bool {
- if addr == anyIPAddress {
- if len(b) == 0 {
- return true
- }
+// isAvailable checks whether binding is possible by device. If not binding to a
+// device, check against all portNodes. If binding to a specific device, check
+// against the unspecified device and the provided device.
+func (d deviceNode) isAvailable(reuse bool, bindToDevice tcpip.NICID) bool {
+ if bindToDevice == 0 {
+ // Trying to binding all devices.
if !reuse {
+ // Can't bind because the (addr,port) is already bound.
return false
}
- for _, n := range b {
- if !n.reuse {
+ for _, p := range d {
+ if !p.reuse {
+ // Can't bind because the (addr,port) was previously bound without reuse.
return false
}
}
return true
}
- // If all addresses for this portDescriptor are already bound, no
- // address is available.
- if n, ok := b[anyIPAddress]; ok {
- if !reuse {
+ if p, ok := d[0]; ok {
+ if !reuse || !p.reuse {
return false
}
- if !n.reuse {
+ }
+
+ if p, ok := d[bindToDevice]; ok {
+ if !reuse || !p.reuse {
return false
}
}
- if n, ok := b[addr]; ok {
- if !reuse {
+ return true
+}
+
+// bindAddresses is a set of IP addresses.
+type bindAddresses map[tcpip.Address]deviceNode
+
+// isAvailable checks whether an IP address is available to bind to. If the
+// address is the "any" address, check all other addresses. Otherwise, just
+// check against the "any" address and the provided address.
+func (b bindAddresses) isAvailable(addr tcpip.Address, reuse bool, bindToDevice tcpip.NICID) bool {
+ if addr == anyIPAddress {
+ // If binding to the "any" address then check that there are no conflicts
+ // with all addresses.
+ for _, d := range b {
+ if !d.isAvailable(reuse, bindToDevice) {
+ return false
+ }
+ }
+ return true
+ }
+
+ // Check that there is no conflict with the "any" address.
+ if d, ok := b[anyIPAddress]; ok {
+ if !d.isAvailable(reuse, bindToDevice) {
return false
}
- return n.reuse
}
+
+ // Check that this is no conflict with the provided address.
+ if d, ok := b[addr]; ok {
+ if !d.isAvailable(reuse, bindToDevice) {
+ return false
+ }
+ }
+
return true
}
@@ -97,11 +142,40 @@ func NewPortManager() *PortManager {
// is suitable for its needs, and stopping when a port is found or an error
// occurs.
func (s *PortManager) PickEphemeralPort(testPort func(p uint16) (bool, *tcpip.Error)) (port uint16, err *tcpip.Error) {
- count := uint16(math.MaxUint16 - FirstEphemeral + 1)
- offset := uint16(rand.Int31n(int32(count)))
+ offset := uint32(rand.Int31n(numEphemeralPorts))
+ return s.pickEphemeralPort(offset, numEphemeralPorts, testPort)
+}
+
+// portHint atomically reads and returns the s.hint value.
+func (s *PortManager) portHint() uint32 {
+ return atomic.LoadUint32(&s.hint)
+}
+
+// incPortHint atomically increments s.hint by 1.
+func (s *PortManager) incPortHint() {
+ atomic.AddUint32(&s.hint, 1)
+}
- for i := uint16(0); i < count; i++ {
- port = FirstEphemeral + (offset+i)%count
+// PickEphemeralPortStable starts at the specified offset + s.portHint and
+// iterates over all ephemeral ports, allowing the caller to decide whether a
+// given port is suitable for its needs and stopping when a port is found or an
+// error occurs.
+func (s *PortManager) PickEphemeralPortStable(offset uint32, testPort func(p uint16) (bool, *tcpip.Error)) (port uint16, err *tcpip.Error) {
+ p, err := s.pickEphemeralPort(s.portHint()+offset, numEphemeralPorts, testPort)
+ if err == nil {
+ s.incPortHint()
+ }
+ return p, err
+
+}
+
+// pickEphemeralPort starts at the offset specified from the FirstEphemeral port
+// and iterates over the number of ports specified by count and allows the
+// caller to decide whether a given port is suitable for its needs, and stopping
+// when a port is found or an error occurs.
+func (s *PortManager) pickEphemeralPort(offset, count uint32, testPort func(p uint16) (bool, *tcpip.Error)) (port uint16, err *tcpip.Error) {
+ for i := uint32(0); i < count; i++ {
+ port = uint16(FirstEphemeral + (offset+i)%count)
ok, err := testPort(port)
if err != nil {
return 0, err
@@ -116,17 +190,17 @@ func (s *PortManager) PickEphemeralPort(testPort func(p uint16) (bool, *tcpip.Er
}
// IsPortAvailable tests if the given port is available on all given protocols.
-func (s *PortManager) IsPortAvailable(networks []tcpip.NetworkProtocolNumber, transport tcpip.TransportProtocolNumber, addr tcpip.Address, port uint16, reuse bool) bool {
+func (s *PortManager) IsPortAvailable(networks []tcpip.NetworkProtocolNumber, transport tcpip.TransportProtocolNumber, addr tcpip.Address, port uint16, reuse bool, bindToDevice tcpip.NICID) bool {
s.mu.Lock()
defer s.mu.Unlock()
- return s.isPortAvailableLocked(networks, transport, addr, port, reuse)
+ return s.isPortAvailableLocked(networks, transport, addr, port, reuse, bindToDevice)
}
-func (s *PortManager) isPortAvailableLocked(networks []tcpip.NetworkProtocolNumber, transport tcpip.TransportProtocolNumber, addr tcpip.Address, port uint16, reuse bool) bool {
+func (s *PortManager) isPortAvailableLocked(networks []tcpip.NetworkProtocolNumber, transport tcpip.TransportProtocolNumber, addr tcpip.Address, port uint16, reuse bool, bindToDevice tcpip.NICID) bool {
for _, network := range networks {
desc := portDescriptor{network, transport, port}
if addrs, ok := s.allocatedPorts[desc]; ok {
- if !addrs.isAvailable(addr, reuse) {
+ if !addrs.isAvailable(addr, reuse, bindToDevice) {
return false
}
}
@@ -138,14 +212,14 @@ func (s *PortManager) isPortAvailableLocked(networks []tcpip.NetworkProtocolNumb
// reserved by another endpoint. If port is zero, ReservePort will search for
// an unreserved ephemeral port and reserve it, returning its value in the
// "port" return value.
-func (s *PortManager) ReservePort(networks []tcpip.NetworkProtocolNumber, transport tcpip.TransportProtocolNumber, addr tcpip.Address, port uint16, reuse bool) (reservedPort uint16, err *tcpip.Error) {
+func (s *PortManager) ReservePort(networks []tcpip.NetworkProtocolNumber, transport tcpip.TransportProtocolNumber, addr tcpip.Address, port uint16, reuse bool, bindToDevice tcpip.NICID) (reservedPort uint16, err *tcpip.Error) {
s.mu.Lock()
defer s.mu.Unlock()
// If a port is specified, just try to reserve it for all network
// protocols.
if port != 0 {
- if !s.reserveSpecificPort(networks, transport, addr, port, reuse) {
+ if !s.reserveSpecificPort(networks, transport, addr, port, reuse, bindToDevice) {
return 0, tcpip.ErrPortInUse
}
return port, nil
@@ -153,13 +227,13 @@ func (s *PortManager) ReservePort(networks []tcpip.NetworkProtocolNumber, transp
// A port wasn't specified, so try to find one.
return s.PickEphemeralPort(func(p uint16) (bool, *tcpip.Error) {
- return s.reserveSpecificPort(networks, transport, addr, p, reuse), nil
+ return s.reserveSpecificPort(networks, transport, addr, p, reuse, bindToDevice), nil
})
}
// reserveSpecificPort tries to reserve the given port on all given protocols.
-func (s *PortManager) reserveSpecificPort(networks []tcpip.NetworkProtocolNumber, transport tcpip.TransportProtocolNumber, addr tcpip.Address, port uint16, reuse bool) bool {
- if !s.isPortAvailableLocked(networks, transport, addr, port, reuse) {
+func (s *PortManager) reserveSpecificPort(networks []tcpip.NetworkProtocolNumber, transport tcpip.TransportProtocolNumber, addr tcpip.Address, port uint16, reuse bool, bindToDevice tcpip.NICID) bool {
+ if !s.isPortAvailableLocked(networks, transport, addr, port, reuse, bindToDevice) {
return false
}
@@ -171,11 +245,16 @@ func (s *PortManager) reserveSpecificPort(networks []tcpip.NetworkProtocolNumber
m = make(bindAddresses)
s.allocatedPorts[desc] = m
}
- if n, ok := m[addr]; ok {
+ d, ok := m[addr]
+ if !ok {
+ d = make(deviceNode)
+ m[addr] = d
+ }
+ if n, ok := d[bindToDevice]; ok {
n.refs++
- m[addr] = n
+ d[bindToDevice] = n
} else {
- m[addr] = portNode{reuse: reuse, refs: 1}
+ d[bindToDevice] = portNode{reuse: reuse, refs: 1}
}
}
@@ -184,22 +263,28 @@ func (s *PortManager) reserveSpecificPort(networks []tcpip.NetworkProtocolNumber
// ReleasePort releases the reservation on a port/IP combination so that it can
// be reserved by other endpoints.
-func (s *PortManager) ReleasePort(networks []tcpip.NetworkProtocolNumber, transport tcpip.TransportProtocolNumber, addr tcpip.Address, port uint16) {
+func (s *PortManager) ReleasePort(networks []tcpip.NetworkProtocolNumber, transport tcpip.TransportProtocolNumber, addr tcpip.Address, port uint16, bindToDevice tcpip.NICID) {
s.mu.Lock()
defer s.mu.Unlock()
for _, network := range networks {
desc := portDescriptor{network, transport, port}
if m, ok := s.allocatedPorts[desc]; ok {
- n, ok := m[addr]
+ d, ok := m[addr]
+ if !ok {
+ continue
+ }
+ n, ok := d[bindToDevice]
if !ok {
continue
}
n.refs--
+ d[bindToDevice] = n
if n.refs == 0 {
+ delete(d, bindToDevice)
+ }
+ if len(d) == 0 {
delete(m, addr)
- } else {
- m[addr] = n
}
if len(m) == 0 {
delete(s.allocatedPorts, desc)
diff --git a/pkg/tcpip/ports/ports_test.go b/pkg/tcpip/ports/ports_test.go
index 689401661..19f4833fc 100644
--- a/pkg/tcpip/ports/ports_test.go
+++ b/pkg/tcpip/ports/ports_test.go
@@ -15,6 +15,7 @@
package ports
import (
+ "math/rand"
"testing"
"gvisor.dev/gvisor/pkg/tcpip"
@@ -34,6 +35,7 @@ type portReserveTestAction struct {
want *tcpip.Error
reuse bool
release bool
+ device tcpip.NICID
}
func TestPortReservation(t *testing.T) {
@@ -100,6 +102,112 @@ func TestPortReservation(t *testing.T) {
{port: 24, ip: anyIPAddress, release: true},
{port: 24, ip: anyIPAddress, reuse: false, want: nil},
},
+ }, {
+ tname: "bind twice with device fails",
+ actions: []portReserveTestAction{
+ {port: 24, ip: fakeIPAddress, device: 3, want: nil},
+ {port: 24, ip: fakeIPAddress, device: 3, want: tcpip.ErrPortInUse},
+ },
+ }, {
+ tname: "bind to device",
+ actions: []portReserveTestAction{
+ {port: 24, ip: fakeIPAddress, device: 1, want: nil},
+ {port: 24, ip: fakeIPAddress, device: 2, want: nil},
+ },
+ }, {
+ tname: "bind to device and then without device",
+ actions: []portReserveTestAction{
+ {port: 24, ip: fakeIPAddress, device: 123, want: nil},
+ {port: 24, ip: fakeIPAddress, device: 0, want: tcpip.ErrPortInUse},
+ },
+ }, {
+ tname: "bind without device",
+ actions: []portReserveTestAction{
+ {port: 24, ip: fakeIPAddress, want: nil},
+ {port: 24, ip: fakeIPAddress, device: 123, want: tcpip.ErrPortInUse},
+ {port: 24, ip: fakeIPAddress, device: 123, reuse: true, want: tcpip.ErrPortInUse},
+ {port: 24, ip: fakeIPAddress, want: tcpip.ErrPortInUse},
+ {port: 24, ip: fakeIPAddress, reuse: true, want: tcpip.ErrPortInUse},
+ },
+ }, {
+ tname: "bind with device",
+ actions: []portReserveTestAction{
+ {port: 24, ip: fakeIPAddress, device: 123, want: nil},
+ {port: 24, ip: fakeIPAddress, device: 123, want: tcpip.ErrPortInUse},
+ {port: 24, ip: fakeIPAddress, device: 123, reuse: true, want: tcpip.ErrPortInUse},
+ {port: 24, ip: fakeIPAddress, device: 0, want: tcpip.ErrPortInUse},
+ {port: 24, ip: fakeIPAddress, device: 0, reuse: true, want: tcpip.ErrPortInUse},
+ {port: 24, ip: fakeIPAddress, device: 456, reuse: true, want: nil},
+ {port: 24, ip: fakeIPAddress, device: 789, want: nil},
+ {port: 24, ip: fakeIPAddress, want: tcpip.ErrPortInUse},
+ {port: 24, ip: fakeIPAddress, reuse: true, want: tcpip.ErrPortInUse},
+ },
+ }, {
+ tname: "bind with reuse",
+ actions: []portReserveTestAction{
+ {port: 24, ip: fakeIPAddress, reuse: true, want: nil},
+ {port: 24, ip: fakeIPAddress, device: 123, want: tcpip.ErrPortInUse},
+ {port: 24, ip: fakeIPAddress, device: 123, reuse: true, want: nil},
+ {port: 24, ip: fakeIPAddress, device: 0, want: tcpip.ErrPortInUse},
+ {port: 24, ip: fakeIPAddress, device: 0, reuse: true, want: nil},
+ },
+ }, {
+ tname: "binding with reuse and device",
+ actions: []portReserveTestAction{
+ {port: 24, ip: fakeIPAddress, device: 123, reuse: true, want: nil},
+ {port: 24, ip: fakeIPAddress, device: 123, want: tcpip.ErrPortInUse},
+ {port: 24, ip: fakeIPAddress, device: 123, reuse: true, want: nil},
+ {port: 24, ip: fakeIPAddress, device: 0, want: tcpip.ErrPortInUse},
+ {port: 24, ip: fakeIPAddress, device: 456, reuse: true, want: nil},
+ {port: 24, ip: fakeIPAddress, device: 0, reuse: true, want: nil},
+ {port: 24, ip: fakeIPAddress, device: 789, reuse: true, want: nil},
+ {port: 24, ip: fakeIPAddress, device: 999, want: tcpip.ErrPortInUse},
+ },
+ }, {
+ tname: "mixing reuse and not reuse by binding to device",
+ actions: []portReserveTestAction{
+ {port: 24, ip: fakeIPAddress, device: 123, reuse: true, want: nil},
+ {port: 24, ip: fakeIPAddress, device: 456, want: nil},
+ {port: 24, ip: fakeIPAddress, device: 789, reuse: true, want: nil},
+ {port: 24, ip: fakeIPAddress, device: 999, want: nil},
+ },
+ }, {
+ tname: "can't bind to 0 after mixing reuse and not reuse",
+ actions: []portReserveTestAction{
+ {port: 24, ip: fakeIPAddress, device: 123, reuse: true, want: nil},
+ {port: 24, ip: fakeIPAddress, device: 456, want: nil},
+ {port: 24, ip: fakeIPAddress, device: 0, reuse: true, want: tcpip.ErrPortInUse},
+ },
+ }, {
+ tname: "bind and release",
+ actions: []portReserveTestAction{
+ {port: 24, ip: fakeIPAddress, device: 123, reuse: true, want: nil},
+ {port: 24, ip: fakeIPAddress, device: 0, reuse: true, want: nil},
+ {port: 24, ip: fakeIPAddress, device: 345, reuse: false, want: tcpip.ErrPortInUse},
+ {port: 24, ip: fakeIPAddress, device: 789, reuse: true, want: nil},
+
+ // Release the bind to device 0 and try again.
+ {port: 24, ip: fakeIPAddress, device: 0, reuse: true, want: nil, release: true},
+ {port: 24, ip: fakeIPAddress, device: 345, reuse: false, want: nil},
+ },
+ }, {
+ tname: "bind twice with reuse once",
+ actions: []portReserveTestAction{
+ {port: 24, ip: fakeIPAddress, device: 123, reuse: false, want: nil},
+ {port: 24, ip: fakeIPAddress, device: 0, reuse: true, want: tcpip.ErrPortInUse},
+ },
+ }, {
+ tname: "release an unreserved device",
+ actions: []portReserveTestAction{
+ {port: 24, ip: fakeIPAddress, device: 123, reuse: false, want: nil},
+ {port: 24, ip: fakeIPAddress, device: 456, reuse: false, want: nil},
+ // The below don't exist.
+ {port: 24, ip: fakeIPAddress, device: 345, reuse: false, want: nil, release: true},
+ {port: 9999, ip: fakeIPAddress, device: 123, reuse: false, want: nil, release: true},
+ // Release all.
+ {port: 24, ip: fakeIPAddress, device: 123, reuse: false, want: nil, release: true},
+ {port: 24, ip: fakeIPAddress, device: 456, reuse: false, want: nil, release: true},
+ },
},
} {
t.Run(test.tname, func(t *testing.T) {
@@ -108,12 +216,12 @@ func TestPortReservation(t *testing.T) {
for _, test := range test.actions {
if test.release {
- pm.ReleasePort(net, fakeTransNumber, test.ip, test.port)
+ pm.ReleasePort(net, fakeTransNumber, test.ip, test.port, test.device)
continue
}
- gotPort, err := pm.ReservePort(net, fakeTransNumber, test.ip, test.port, test.reuse)
+ gotPort, err := pm.ReservePort(net, fakeTransNumber, test.ip, test.port, test.reuse, test.device)
if err != test.want {
- t.Fatalf("ReservePort(.., .., %s, %d, %t) = %v, want %v", test.ip, test.port, test.release, err, test.want)
+ t.Fatalf("ReservePort(.., .., %s, %d, %t, %d) = %v, want %v", test.ip, test.port, test.reuse, test.device, err, test.want)
}
if test.port == 0 && (gotPort == 0 || gotPort < FirstEphemeral) {
t.Fatalf("ReservePort(.., .., .., 0) = %d, want port number >= %d to be picked", gotPort, FirstEphemeral)
@@ -125,7 +233,6 @@ func TestPortReservation(t *testing.T) {
}
func TestPickEphemeralPort(t *testing.T) {
- pm := NewPortManager()
customErr := &tcpip.Error{}
for _, test := range []struct {
name string
@@ -169,9 +276,63 @@ func TestPickEphemeralPort(t *testing.T) {
},
} {
t.Run(test.name, func(t *testing.T) {
+ pm := NewPortManager()
if port, err := pm.PickEphemeralPort(test.f); port != test.wantPort || err != test.wantErr {
t.Errorf("PickEphemeralPort(..) = (port %d, err %v); want (port %d, err %v)", port, err, test.wantPort, test.wantErr)
}
})
}
}
+
+func TestPickEphemeralPortStable(t *testing.T) {
+ customErr := &tcpip.Error{}
+ for _, test := range []struct {
+ name string
+ f func(port uint16) (bool, *tcpip.Error)
+ wantErr *tcpip.Error
+ wantPort uint16
+ }{
+ {
+ name: "no-port-available",
+ f: func(port uint16) (bool, *tcpip.Error) {
+ return false, nil
+ },
+ wantErr: tcpip.ErrNoPortAvailable,
+ },
+ {
+ name: "port-tester-error",
+ f: func(port uint16) (bool, *tcpip.Error) {
+ return false, customErr
+ },
+ wantErr: customErr,
+ },
+ {
+ name: "only-port-16042-available",
+ f: func(port uint16) (bool, *tcpip.Error) {
+ if port == FirstEphemeral+42 {
+ return true, nil
+ }
+ return false, nil
+ },
+ wantPort: FirstEphemeral + 42,
+ },
+ {
+ name: "only-port-under-16000-available",
+ f: func(port uint16) (bool, *tcpip.Error) {
+ if port < FirstEphemeral {
+ return true, nil
+ }
+ return false, nil
+ },
+ wantErr: tcpip.ErrNoPortAvailable,
+ },
+ } {
+ t.Run(test.name, func(t *testing.T) {
+ pm := NewPortManager()
+ portOffset := uint32(rand.Int31n(int32(numEphemeralPorts)))
+ if port, err := pm.PickEphemeralPortStable(portOffset, test.f); port != test.wantPort || err != test.wantErr {
+ t.Errorf("PickEphemeralPort(..) = (port %d, err %v); want (port %d, err %v)", port, err, test.wantPort, test.wantErr)
+ }
+ })
+ }
+}
diff --git a/pkg/tcpip/sample/tun_tcp_connect/main.go b/pkg/tcpip/sample/tun_tcp_connect/main.go
index e2021cd15..2239c1e66 100644
--- a/pkg/tcpip/sample/tun_tcp_connect/main.go
+++ b/pkg/tcpip/sample/tun_tcp_connect/main.go
@@ -126,7 +126,10 @@ func main() {
// Create the stack with ipv4 and tcp protocols, then add a tun-based
// NIC and ipv4 address.
- s := stack.New([]string{ipv4.ProtocolName}, []string{tcp.ProtocolName}, stack.Options{})
+ s := stack.New(stack.Options{
+ NetworkProtocols: []stack.NetworkProtocol{ipv4.NewProtocol()},
+ TransportProtocols: []stack.TransportProtocol{tcp.NewProtocol()},
+ })
mtu, err := rawfile.GetMTU(tunName)
if err != nil {
@@ -138,11 +141,11 @@ func main() {
log.Fatal(err)
}
- linkID, err := fdbased.New(&fdbased.Options{FDs: []int{fd}, MTU: mtu})
+ linkEP, err := fdbased.New(&fdbased.Options{FDs: []int{fd}, MTU: mtu})
if err != nil {
log.Fatal(err)
}
- if err := s.CreateNIC(1, sniffer.New(linkID)); err != nil {
+ if err := s.CreateNIC(1, sniffer.New(linkEP)); err != nil {
log.Fatal(err)
}
diff --git a/pkg/tcpip/sample/tun_tcp_echo/main.go b/pkg/tcpip/sample/tun_tcp_echo/main.go
index 1716be285..bca73cbb1 100644
--- a/pkg/tcpip/sample/tun_tcp_echo/main.go
+++ b/pkg/tcpip/sample/tun_tcp_echo/main.go
@@ -111,7 +111,10 @@ func main() {
// Create the stack with ip and tcp protocols, then add a tun-based
// NIC and address.
- s := stack.New([]string{ipv4.ProtocolName, ipv6.ProtocolName, arp.ProtocolName}, []string{tcp.ProtocolName}, stack.Options{})
+ s := stack.New(stack.Options{
+ NetworkProtocols: []stack.NetworkProtocol{ipv4.NewProtocol(), ipv6.NewProtocol(), arp.NewProtocol()},
+ TransportProtocols: []stack.TransportProtocol{tcp.NewProtocol()},
+ })
mtu, err := rawfile.GetMTU(tunName)
if err != nil {
@@ -128,7 +131,7 @@ func main() {
log.Fatal(err)
}
- linkID, err := fdbased.New(&fdbased.Options{
+ linkEP, err := fdbased.New(&fdbased.Options{
FDs: []int{fd},
MTU: mtu,
EthernetHeader: *tap,
@@ -137,7 +140,7 @@ func main() {
if err != nil {
log.Fatal(err)
}
- if err := s.CreateNIC(1, linkID); err != nil {
+ if err := s.CreateNIC(1, linkEP); err != nil {
log.Fatal(err)
}
diff --git a/pkg/tcpip/stack/BUILD b/pkg/tcpip/stack/BUILD
index 9986b4be3..baf88bfab 100644
--- a/pkg/tcpip/stack/BUILD
+++ b/pkg/tcpip/stack/BUILD
@@ -1,11 +1,28 @@
+load("@io_bazel_rules_go//go:def.bzl", "go_test")
+
package(licenses = ["notice"])
-load("//tools/go_stateify:defs.bzl", "go_library", "go_test")
+load("//tools/go_generics:defs.bzl", "go_template_instance")
+load("//tools/go_stateify:defs.bzl", "go_library")
+
+go_template_instance(
+ name = "linkaddrentry_list",
+ out = "linkaddrentry_list.go",
+ package = "stack",
+ prefix = "linkAddrEntry",
+ template = "//pkg/ilist:generic_list",
+ types = {
+ "Element": "*linkAddrEntry",
+ "Linker": "*linkAddrEntry",
+ },
+)
go_library(
name = "stack",
srcs = [
+ "icmp_rate_limit.go",
"linkaddrcache.go",
+ "linkaddrentry_list.go",
"nic.go",
"registration.go",
"route.go",
@@ -19,6 +36,7 @@ go_library(
],
deps = [
"//pkg/ilist",
+ "//pkg/rand",
"//pkg/sleep",
"//pkg/tcpip",
"//pkg/tcpip/buffer",
@@ -28,6 +46,7 @@ go_library(
"//pkg/tcpip/ports",
"//pkg/tcpip/seqnum",
"//pkg/waiter",
+ "@org_golang_x_time//rate:go_default_library",
],
)
@@ -36,6 +55,7 @@ go_test(
size = "small",
srcs = [
"stack_test.go",
+ "transport_demuxer_test.go",
"transport_test.go",
],
deps = [
@@ -46,6 +66,9 @@ go_test(
"//pkg/tcpip/iptables",
"//pkg/tcpip/link/channel",
"//pkg/tcpip/link/loopback",
+ "//pkg/tcpip/network/ipv4",
+ "//pkg/tcpip/network/ipv6",
+ "//pkg/tcpip/transport/udp",
"//pkg/waiter",
],
)
@@ -60,3 +83,11 @@ go_test(
"//pkg/tcpip",
],
)
+
+filegroup(
+ name = "autogen",
+ srcs = [
+ "linkaddrentry_list.go",
+ ],
+ visibility = ["//:sandbox"],
+)
diff --git a/pkg/tcpip/stack/icmp_rate_limit.go b/pkg/tcpip/stack/icmp_rate_limit.go
new file mode 100644
index 000000000..3a20839da
--- /dev/null
+++ b/pkg/tcpip/stack/icmp_rate_limit.go
@@ -0,0 +1,41 @@
+// Copyright 2018 The gVisor Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+package stack
+
+import (
+ "golang.org/x/time/rate"
+)
+
+const (
+ // icmpLimit is the default maximum number of ICMP messages permitted by this
+ // rate limiter.
+ icmpLimit = 1000
+
+ // icmpBurst is the default number of ICMP messages that can be sent in a single
+ // burst.
+ icmpBurst = 50
+)
+
+// ICMPRateLimiter is a global rate limiter that controls the generation of
+// ICMP messages generated by the stack.
+type ICMPRateLimiter struct {
+ *rate.Limiter
+}
+
+// NewICMPRateLimiter returns a global rate limiter for controlling the rate
+// at which ICMP messages are generated by the stack.
+func NewICMPRateLimiter() *ICMPRateLimiter {
+ return &ICMPRateLimiter{Limiter: rate.NewLimiter(icmpLimit, icmpBurst)}
+}
diff --git a/pkg/tcpip/stack/linkaddrcache.go b/pkg/tcpip/stack/linkaddrcache.go
index 77bb0ccb9..267df60d1 100644
--- a/pkg/tcpip/stack/linkaddrcache.go
+++ b/pkg/tcpip/stack/linkaddrcache.go
@@ -42,10 +42,11 @@ type linkAddrCache struct {
// resolved before failing.
resolutionAttempts int
- mu sync.Mutex
- cache map[tcpip.FullAddress]*linkAddrEntry
- next int // array index of next available entry
- entries [linkAddrCacheSize]linkAddrEntry
+ cache struct {
+ sync.Mutex
+ table map[tcpip.FullAddress]*linkAddrEntry
+ lru linkAddrEntryList
+ }
}
// entryState controls the state of a single entry in the cache.
@@ -60,9 +61,6 @@ const (
// failed means that address resolution timed out and the address
// could not be resolved.
failed
- // expired means that the cache entry has expired and the address must be
- // resolved again.
- expired
)
// String implements Stringer.
@@ -74,8 +72,6 @@ func (s entryState) String() string {
return "ready"
case failed:
return "failed"
- case expired:
- return "expired"
default:
return fmt.Sprintf("unknown(%d)", s)
}
@@ -84,64 +80,46 @@ func (s entryState) String() string {
// A linkAddrEntry is an entry in the linkAddrCache.
// This struct is thread-compatible.
type linkAddrEntry struct {
+ linkAddrEntryEntry
+
addr tcpip.FullAddress
linkAddr tcpip.LinkAddress
expiration time.Time
s entryState
// wakers is a set of waiters for address resolution result. Anytime
- // state transitions out of 'incomplete' these waiters are notified.
+ // state transitions out of incomplete these waiters are notified.
wakers map[*sleep.Waker]struct{}
+ // done is used to allow callers to wait on address resolution. It is nil iff
+ // s is incomplete and resolution is not yet in progress.
done chan struct{}
}
-func (e *linkAddrEntry) state() entryState {
- if e.s != expired && time.Now().After(e.expiration) {
- // Force the transition to ensure waiters are notified.
- e.changeState(expired)
- }
- return e.s
-}
-
-func (e *linkAddrEntry) changeState(ns entryState) {
- if e.s == ns {
- return
- }
-
- // Validate state transition.
- switch e.s {
- case incomplete:
- // All transitions are valid.
- case ready, failed:
- if ns != expired {
- panic(fmt.Sprintf("invalid state transition from %s to %s", e.s, ns))
- }
- case expired:
- // Terminal state.
- panic(fmt.Sprintf("invalid state transition from %s to %s", e.s, ns))
- default:
- panic(fmt.Sprintf("invalid state: %s", e.s))
- }
-
+// changeState sets the entry's state to ns, notifying any waiters.
+//
+// The entry's expiration is bumped up to the greater of itself and the passed
+// expiration; the zero value indicates immediate expiration, and is set
+// unconditionally - this is an implementation detail that allows for entries
+// to be reused.
+func (e *linkAddrEntry) changeState(ns entryState, expiration time.Time) {
// Notify whoever is waiting on address resolution when transitioning
- // out of 'incomplete'.
- if e.s == incomplete {
+ // out of incomplete.
+ if e.s == incomplete && ns != incomplete {
for w := range e.wakers {
w.Assert()
}
e.wakers = nil
- if e.done != nil {
- close(e.done)
+ if ch := e.done; ch != nil {
+ close(ch)
}
+ e.done = nil
}
- e.s = ns
-}
-func (e *linkAddrEntry) maybeAddWaker(w *sleep.Waker) {
- if w != nil {
- e.wakers[w] = struct{}{}
+ if expiration.IsZero() || expiration.After(e.expiration) {
+ e.expiration = expiration
}
+ e.s = ns
}
func (e *linkAddrEntry) removeWaker(w *sleep.Waker) {
@@ -150,53 +128,54 @@ func (e *linkAddrEntry) removeWaker(w *sleep.Waker) {
// add adds a k -> v mapping to the cache.
func (c *linkAddrCache) add(k tcpip.FullAddress, v tcpip.LinkAddress) {
- c.mu.Lock()
- defer c.mu.Unlock()
-
- entry, ok := c.cache[k]
- if ok {
- s := entry.state()
- if s != expired && entry.linkAddr == v {
- // Disregard repeated calls.
- return
- }
- // Check if entry is waiting for address resolution.
- if s == incomplete {
- entry.linkAddr = v
- } else {
- // Otherwise create a new entry to replace it.
- entry = c.makeAndAddEntry(k, v)
- }
- } else {
- entry = c.makeAndAddEntry(k, v)
- }
+ // Calculate expiration time before acquiring the lock, since expiration is
+ // relative to the time when information was learned, rather than when it
+ // happened to be inserted into the cache.
+ expiration := time.Now().Add(c.ageLimit)
- entry.changeState(ready)
+ c.cache.Lock()
+ entry := c.getOrCreateEntryLocked(k)
+ entry.linkAddr = v
+
+ entry.changeState(ready, expiration)
+ c.cache.Unlock()
}
-// makeAndAddEntry is a helper function to create and add a new
-// entry to the cache map and evict older entry as needed.
-func (c *linkAddrCache) makeAndAddEntry(k tcpip.FullAddress, v tcpip.LinkAddress) *linkAddrEntry {
- // Take over the next entry.
- entry := &c.entries[c.next]
- if c.cache[entry.addr] == entry {
- delete(c.cache, entry.addr)
+// getOrCreateEntryLocked retrieves a cache entry associated with k. The
+// returned entry is always refreshed in the cache (it is reachable via the
+// map, and its place is bumped in LRU).
+//
+// If a matching entry exists in the cache, it is returned. If no matching
+// entry exists and the cache is full, an existing entry is evicted via LRU,
+// reset to state incomplete, and returned. If no matching entry exists and the
+// cache is not full, a new entry with state incomplete is allocated and
+// returned.
+func (c *linkAddrCache) getOrCreateEntryLocked(k tcpip.FullAddress) *linkAddrEntry {
+ if entry, ok := c.cache.table[k]; ok {
+ c.cache.lru.Remove(entry)
+ c.cache.lru.PushFront(entry)
+ return entry
}
+ var entry *linkAddrEntry
+ if len(c.cache.table) == linkAddrCacheSize {
+ entry = c.cache.lru.Back()
- // Mark the soon-to-be-replaced entry as expired, just in case there is
- // someone waiting for address resolution on it.
- entry.changeState(expired)
+ delete(c.cache.table, entry.addr)
+ c.cache.lru.Remove(entry)
- *entry = linkAddrEntry{
- addr: k,
- linkAddr: v,
- expiration: time.Now().Add(c.ageLimit),
- wakers: make(map[*sleep.Waker]struct{}),
- done: make(chan struct{}),
+ // Wake waiters and mark the soon-to-be-reused entry as expired. Note
+ // that the state passed doesn't matter when the zero time is passed.
+ entry.changeState(failed, time.Time{})
+ } else {
+ entry = new(linkAddrEntry)
}
- c.cache[k] = entry
- c.next = (c.next + 1) % len(c.entries)
+ *entry = linkAddrEntry{
+ addr: k,
+ s: incomplete,
+ }
+ c.cache.table[k] = entry
+ c.cache.lru.PushFront(entry)
return entry
}
@@ -208,43 +187,55 @@ func (c *linkAddrCache) get(k tcpip.FullAddress, linkRes LinkAddressResolver, lo
}
}
- c.mu.Lock()
- defer c.mu.Unlock()
- if entry, ok := c.cache[k]; ok {
- switch s := entry.state(); s {
- case expired:
- case ready:
- return entry.linkAddr, nil, nil
- case failed:
- return "", nil, tcpip.ErrNoLinkAddress
- case incomplete:
- // Address resolution is still in progress.
- entry.maybeAddWaker(waker)
- return "", entry.done, tcpip.ErrWouldBlock
- default:
- panic(fmt.Sprintf("invalid cache entry state: %s", s))
+ c.cache.Lock()
+ defer c.cache.Unlock()
+ entry := c.getOrCreateEntryLocked(k)
+ switch s := entry.s; s {
+ case ready, failed:
+ if !time.Now().After(entry.expiration) {
+ // Not expired.
+ switch s {
+ case ready:
+ return entry.linkAddr, nil, nil
+ case failed:
+ return entry.linkAddr, nil, tcpip.ErrNoLinkAddress
+ default:
+ panic(fmt.Sprintf("invalid cache entry state: %s", s))
+ }
}
- }
- if linkRes == nil {
- return "", nil, tcpip.ErrNoLinkAddress
- }
+ entry.changeState(incomplete, time.Time{})
+ fallthrough
+ case incomplete:
+ if waker != nil {
+ if entry.wakers == nil {
+ entry.wakers = make(map[*sleep.Waker]struct{})
+ }
+ entry.wakers[waker] = struct{}{}
+ }
- // Add 'incomplete' entry in the cache to mark that resolution is in progress.
- e := c.makeAndAddEntry(k, "")
- e.maybeAddWaker(waker)
+ if entry.done == nil {
+ // Address resolution needs to be initiated.
+ if linkRes == nil {
+ return entry.linkAddr, nil, tcpip.ErrNoLinkAddress
+ }
- go c.startAddressResolution(k, linkRes, localAddr, linkEP, e.done) // S/R-SAFE: link non-savable; wakers dropped synchronously.
+ entry.done = make(chan struct{})
+ go c.startAddressResolution(k, linkRes, localAddr, linkEP, entry.done) // S/R-SAFE: link non-savable; wakers dropped synchronously.
+ }
- return "", e.done, tcpip.ErrWouldBlock
+ return entry.linkAddr, entry.done, tcpip.ErrWouldBlock
+ default:
+ panic(fmt.Sprintf("invalid cache entry state: %s", s))
+ }
}
// removeWaker removes a waker previously added through get().
func (c *linkAddrCache) removeWaker(k tcpip.FullAddress, waker *sleep.Waker) {
- c.mu.Lock()
- defer c.mu.Unlock()
+ c.cache.Lock()
+ defer c.cache.Unlock()
- if entry, ok := c.cache[k]; ok {
+ if entry, ok := c.cache.table[k]; ok {
entry.removeWaker(waker)
}
}
@@ -256,8 +247,8 @@ func (c *linkAddrCache) startAddressResolution(k tcpip.FullAddress, linkRes Link
linkRes.LinkAddressRequest(k.Addr, localAddr, linkEP)
select {
- case <-time.After(c.resolutionTimeout):
- if stop := c.checkLinkRequest(k, i); stop {
+ case now := <-time.After(c.resolutionTimeout):
+ if stop := c.checkLinkRequest(now, k, i); stop {
return
}
case <-done:
@@ -269,38 +260,36 @@ func (c *linkAddrCache) startAddressResolution(k tcpip.FullAddress, linkRes Link
// checkLinkRequest checks whether previous attempt to resolve address has succeeded
// and mark the entry accordingly, e.g. ready, failed, etc. Return true if request
// can stop, false if another request should be sent.
-func (c *linkAddrCache) checkLinkRequest(k tcpip.FullAddress, attempt int) bool {
- c.mu.Lock()
- defer c.mu.Unlock()
-
- entry, ok := c.cache[k]
+func (c *linkAddrCache) checkLinkRequest(now time.Time, k tcpip.FullAddress, attempt int) bool {
+ c.cache.Lock()
+ defer c.cache.Unlock()
+ entry, ok := c.cache.table[k]
if !ok {
// Entry was evicted from the cache.
return true
}
-
- switch s := entry.state(); s {
- case ready, failed, expired:
+ switch s := entry.s; s {
+ case ready, failed:
// Entry was made ready by resolver or failed. Either way we're done.
- return true
case incomplete:
- if attempt+1 >= c.resolutionAttempts {
- // Max number of retries reached, mark entry as failed.
- entry.changeState(failed)
- return true
+ if attempt+1 < c.resolutionAttempts {
+ // No response yet, need to send another ARP request.
+ return false
}
- // No response yet, need to send another ARP request.
- return false
+ // Max number of retries reached, mark entry as failed.
+ entry.changeState(failed, now.Add(c.ageLimit))
default:
panic(fmt.Sprintf("invalid cache entry state: %s", s))
}
+ return true
}
func newLinkAddrCache(ageLimit, resolutionTimeout time.Duration, resolutionAttempts int) *linkAddrCache {
- return &linkAddrCache{
+ c := &linkAddrCache{
ageLimit: ageLimit,
resolutionTimeout: resolutionTimeout,
resolutionAttempts: resolutionAttempts,
- cache: make(map[tcpip.FullAddress]*linkAddrEntry, linkAddrCacheSize),
}
+ c.cache.table = make(map[tcpip.FullAddress]*linkAddrEntry, linkAddrCacheSize)
+ return c
}
diff --git a/pkg/tcpip/stack/linkaddrcache_test.go b/pkg/tcpip/stack/linkaddrcache_test.go
index 924f4d240..9946b8fe8 100644
--- a/pkg/tcpip/stack/linkaddrcache_test.go
+++ b/pkg/tcpip/stack/linkaddrcache_test.go
@@ -17,6 +17,7 @@ package stack
import (
"fmt"
"sync"
+ "sync/atomic"
"testing"
"time"
@@ -29,25 +30,34 @@ type testaddr struct {
linkAddr tcpip.LinkAddress
}
-var testaddrs []testaddr
+var testAddrs = func() []testaddr {
+ var addrs []testaddr
+ for i := 0; i < 4*linkAddrCacheSize; i++ {
+ addr := fmt.Sprintf("Addr%06d", i)
+ addrs = append(addrs, testaddr{
+ addr: tcpip.FullAddress{NIC: 1, Addr: tcpip.Address(addr)},
+ linkAddr: tcpip.LinkAddress("Link" + addr),
+ })
+ }
+ return addrs
+}()
type testLinkAddressResolver struct {
- cache *linkAddrCache
- delay time.Duration
+ cache *linkAddrCache
+ delay time.Duration
+ onLinkAddressRequest func()
}
func (r *testLinkAddressResolver) LinkAddressRequest(addr, _ tcpip.Address, _ LinkEndpoint) *tcpip.Error {
- go func() {
- if r.delay > 0 {
- time.Sleep(r.delay)
- }
- r.fakeRequest(addr)
- }()
+ time.AfterFunc(r.delay, func() { r.fakeRequest(addr) })
+ if f := r.onLinkAddressRequest; f != nil {
+ f()
+ }
return nil
}
func (r *testLinkAddressResolver) fakeRequest(addr tcpip.Address) {
- for _, ta := range testaddrs {
+ for _, ta := range testAddrs {
if ta.addr.Addr == addr {
r.cache.add(ta.addr, ta.linkAddr)
break
@@ -80,20 +90,10 @@ func getBlocking(c *linkAddrCache, addr tcpip.FullAddress, linkRes LinkAddressRe
}
}
-func init() {
- for i := 0; i < 4*linkAddrCacheSize; i++ {
- addr := fmt.Sprintf("Addr%06d", i)
- testaddrs = append(testaddrs, testaddr{
- addr: tcpip.FullAddress{NIC: 1, Addr: tcpip.Address(addr)},
- linkAddr: tcpip.LinkAddress("Link" + addr),
- })
- }
-}
-
func TestCacheOverflow(t *testing.T) {
c := newLinkAddrCache(1<<63-1, 1*time.Second, 3)
- for i := len(testaddrs) - 1; i >= 0; i-- {
- e := testaddrs[i]
+ for i := len(testAddrs) - 1; i >= 0; i-- {
+ e := testAddrs[i]
c.add(e.addr, e.linkAddr)
got, _, err := c.get(e.addr, nil, "", nil, nil)
if err != nil {
@@ -105,7 +105,7 @@ func TestCacheOverflow(t *testing.T) {
}
// Expect to find at least half of the most recent entries.
for i := 0; i < linkAddrCacheSize/2; i++ {
- e := testaddrs[i]
+ e := testAddrs[i]
got, _, err := c.get(e.addr, nil, "", nil, nil)
if err != nil {
t.Errorf("check %d, c.get(%q)=%q, got error: %v", i, string(e.addr.Addr), got, err)
@@ -115,8 +115,8 @@ func TestCacheOverflow(t *testing.T) {
}
}
// The earliest entries should no longer be in the cache.
- for i := len(testaddrs) - 1; i >= len(testaddrs)-linkAddrCacheSize; i-- {
- e := testaddrs[i]
+ for i := len(testAddrs) - 1; i >= len(testAddrs)-linkAddrCacheSize; i-- {
+ e := testAddrs[i]
if _, _, err := c.get(e.addr, nil, "", nil, nil); err != tcpip.ErrNoLinkAddress {
t.Errorf("check %d, c.get(%q), got error: %v, want: error ErrNoLinkAddress", i, string(e.addr.Addr), err)
}
@@ -130,7 +130,7 @@ func TestCacheConcurrent(t *testing.T) {
for r := 0; r < 16; r++ {
wg.Add(1)
go func() {
- for _, e := range testaddrs {
+ for _, e := range testAddrs {
c.add(e.addr, e.linkAddr)
c.get(e.addr, nil, "", nil, nil) // make work for gotsan
}
@@ -142,7 +142,7 @@ func TestCacheConcurrent(t *testing.T) {
// All goroutines add in the same order and add more values than
// can fit in the cache, so our eviction strategy requires that
// the last entry be present and the first be missing.
- e := testaddrs[len(testaddrs)-1]
+ e := testAddrs[len(testAddrs)-1]
got, _, err := c.get(e.addr, nil, "", nil, nil)
if err != nil {
t.Errorf("c.get(%q)=%q, got error: %v", string(e.addr.Addr), got, err)
@@ -151,7 +151,7 @@ func TestCacheConcurrent(t *testing.T) {
t.Errorf("c.get(%q)=%q, want %q", string(e.addr.Addr), got, e.linkAddr)
}
- e = testaddrs[0]
+ e = testAddrs[0]
if _, _, err := c.get(e.addr, nil, "", nil, nil); err != tcpip.ErrNoLinkAddress {
t.Errorf("c.get(%q), got error: %v, want: error ErrNoLinkAddress", string(e.addr.Addr), err)
}
@@ -159,7 +159,7 @@ func TestCacheConcurrent(t *testing.T) {
func TestCacheAgeLimit(t *testing.T) {
c := newLinkAddrCache(1*time.Millisecond, 1*time.Second, 3)
- e := testaddrs[0]
+ e := testAddrs[0]
c.add(e.addr, e.linkAddr)
time.Sleep(50 * time.Millisecond)
if _, _, err := c.get(e.addr, nil, "", nil, nil); err != tcpip.ErrNoLinkAddress {
@@ -169,7 +169,7 @@ func TestCacheAgeLimit(t *testing.T) {
func TestCacheReplace(t *testing.T) {
c := newLinkAddrCache(1<<63-1, 1*time.Second, 3)
- e := testaddrs[0]
+ e := testAddrs[0]
l2 := e.linkAddr + "2"
c.add(e.addr, e.linkAddr)
got, _, err := c.get(e.addr, nil, "", nil, nil)
@@ -193,7 +193,7 @@ func TestCacheReplace(t *testing.T) {
func TestCacheResolution(t *testing.T) {
c := newLinkAddrCache(1<<63-1, 250*time.Millisecond, 1)
linkRes := &testLinkAddressResolver{cache: c}
- for i, ta := range testaddrs {
+ for i, ta := range testAddrs {
got, err := getBlocking(c, ta.addr, linkRes)
if err != nil {
t.Errorf("check %d, c.get(%q)=%q, got error: %v", i, string(ta.addr.Addr), got, err)
@@ -205,7 +205,7 @@ func TestCacheResolution(t *testing.T) {
// Check that after resolved, address stays in the cache and never returns WouldBlock.
for i := 0; i < 10; i++ {
- e := testaddrs[len(testaddrs)-1]
+ e := testAddrs[len(testAddrs)-1]
got, _, err := c.get(e.addr, linkRes, "", nil, nil)
if err != nil {
t.Errorf("c.get(%q)=%q, got error: %v", string(e.addr.Addr), got, err)
@@ -220,8 +220,13 @@ func TestCacheResolutionFailed(t *testing.T) {
c := newLinkAddrCache(1<<63-1, 10*time.Millisecond, 5)
linkRes := &testLinkAddressResolver{cache: c}
+ var requestCount uint32
+ linkRes.onLinkAddressRequest = func() {
+ atomic.AddUint32(&requestCount, 1)
+ }
+
// First, sanity check that resolution is working...
- e := testaddrs[0]
+ e := testAddrs[0]
got, err := getBlocking(c, e.addr, linkRes)
if err != nil {
t.Errorf("c.get(%q)=%q, got error: %v", string(e.addr.Addr), got, err)
@@ -230,10 +235,16 @@ func TestCacheResolutionFailed(t *testing.T) {
t.Errorf("c.get(%q)=%q, want %q", string(e.addr.Addr), got, e.linkAddr)
}
+ before := atomic.LoadUint32(&requestCount)
+
e.addr.Addr += "2"
if _, err := getBlocking(c, e.addr, linkRes); err != tcpip.ErrNoLinkAddress {
t.Errorf("c.get(%q), got error: %v, want: error ErrNoLinkAddress", string(e.addr.Addr), err)
}
+
+ if got, want := int(atomic.LoadUint32(&requestCount)-before), c.resolutionAttempts; got != want {
+ t.Errorf("got link address request count = %d, want = %d", got, want)
+ }
}
func TestCacheResolutionTimeout(t *testing.T) {
@@ -242,7 +253,7 @@ func TestCacheResolutionTimeout(t *testing.T) {
c := newLinkAddrCache(expiration, 1*time.Millisecond, 3)
linkRes := &testLinkAddressResolver{cache: c, delay: resolverDelay}
- e := testaddrs[0]
+ e := testAddrs[0]
if _, err := getBlocking(c, e.addr, linkRes); err != tcpip.ErrNoLinkAddress {
t.Errorf("c.get(%q), got error: %v, want: error ErrNoLinkAddress", string(e.addr.Addr), err)
}
diff --git a/pkg/tcpip/stack/nic.go b/pkg/tcpip/stack/nic.go
index 4ef85bdfb..f6106f762 100644
--- a/pkg/tcpip/stack/nic.go
+++ b/pkg/tcpip/stack/nic.go
@@ -34,15 +34,13 @@ type NIC struct {
linkEP LinkEndpoint
loopback bool
- demux *transportDemuxer
-
- mu sync.RWMutex
- spoofing bool
- promiscuous bool
- primary map[tcpip.NetworkProtocolNumber]*ilist.List
- endpoints map[NetworkEndpointID]*referencedNetworkEndpoint
- subnets []tcpip.Subnet
- mcastJoins map[NetworkEndpointID]int32
+ mu sync.RWMutex
+ spoofing bool
+ promiscuous bool
+ primary map[tcpip.NetworkProtocolNumber]*ilist.List
+ endpoints map[NetworkEndpointID]*referencedNetworkEndpoint
+ addressRanges []tcpip.Subnet
+ mcastJoins map[NetworkEndpointID]int32
stats NICStats
}
@@ -85,7 +83,6 @@ func newNIC(stack *Stack, id tcpip.NICID, name string, ep LinkEndpoint, loopback
name: name,
linkEP: ep,
loopback: loopback,
- demux: newTransportDemuxer(stack),
primary: make(map[tcpip.NetworkProtocolNumber]*ilist.List),
endpoints: make(map[NetworkEndpointID]*referencedNetworkEndpoint),
mcastJoins: make(map[NetworkEndpointID]int32),
@@ -102,6 +99,25 @@ func newNIC(stack *Stack, id tcpip.NICID, name string, ep LinkEndpoint, loopback
}
}
+// enable enables the NIC. enable will attach the link to its LinkEndpoint and
+// join the IPv6 All-Nodes Multicast address (ff02::1).
+func (n *NIC) enable() *tcpip.Error {
+ n.attachLinkEndpoint()
+
+ // Join the IPv6 All-Nodes Multicast group if the stack is configured to
+ // use IPv6. This is required to ensure that this node properly receives
+ // and responds to the various NDP messages that are destined to the
+ // all-nodes multicast address. An example is the Neighbor Advertisement
+ // when we perform Duplicate Address Detection, or Router Advertisement
+ // when we do Router Discovery. See RFC 4862, section 5.4.2 and RFC 4861
+ // section 4.2 for more information.
+ if _, ok := n.stack.networkProtocols[header.IPv6ProtocolNumber]; ok {
+ return n.joinGroup(header.IPv6ProtocolNumber, header.IPv6AllNodesMulticastAddress)
+ }
+
+ return nil
+}
+
// attachLinkEndpoint attaches the NIC to the endpoint, which will enable it
// to start delivering packets.
func (n *NIC) attachLinkEndpoint() {
@@ -129,37 +145,6 @@ func (n *NIC) setSpoofing(enable bool) {
n.mu.Unlock()
}
-func (n *NIC) getMainNICAddress(protocol tcpip.NetworkProtocolNumber) (tcpip.AddressWithPrefix, *tcpip.Error) {
- n.mu.RLock()
- defer n.mu.RUnlock()
-
- var r *referencedNetworkEndpoint
-
- // Check for a primary endpoint.
- if list, ok := n.primary[protocol]; ok {
- for e := list.Front(); e != nil; e = e.Next() {
- ref := e.(*referencedNetworkEndpoint)
- if ref.holdsInsertRef && ref.tryIncRef() {
- r = ref
- break
- }
- }
-
- }
-
- if r == nil {
- return tcpip.AddressWithPrefix{}, tcpip.ErrNoLinkAddress
- }
-
- addressWithPrefix := tcpip.AddressWithPrefix{
- Address: r.ep.ID().LocalAddress,
- PrefixLen: r.ep.PrefixLen(),
- }
- r.decRef()
-
- return addressWithPrefix, nil
-}
-
// primaryEndpoint returns the primary endpoint of n for the given network
// protocol.
func (n *NIC) primaryEndpoint(protocol tcpip.NetworkProtocolNumber) *referencedNetworkEndpoint {
@@ -178,7 +163,7 @@ func (n *NIC) primaryEndpoint(protocol tcpip.NetworkProtocolNumber) *referencedN
case header.IPv4Broadcast, header.IPv4Any:
continue
}
- if r.tryIncRef() {
+ if r.isValidForOutgoing() && r.tryIncRef() {
return r
}
}
@@ -197,22 +182,44 @@ func (n *NIC) findEndpoint(protocol tcpip.NetworkProtocolNumber, address tcpip.A
// getRefEpOrCreateTemp returns the referenced network endpoint for the given
// protocol and address. If none exists a temporary one may be created if
-// requested.
-func (n *NIC) getRefOrCreateTemp(protocol tcpip.NetworkProtocolNumber, address tcpip.Address, peb PrimaryEndpointBehavior, allowTemp bool) *referencedNetworkEndpoint {
+// we are in promiscuous mode or spoofing.
+func (n *NIC) getRefOrCreateTemp(protocol tcpip.NetworkProtocolNumber, address tcpip.Address, peb PrimaryEndpointBehavior, spoofingOrPromiscuous bool) *referencedNetworkEndpoint {
id := NetworkEndpointID{address}
n.mu.RLock()
- if ref, ok := n.endpoints[id]; ok && ref.tryIncRef() {
- n.mu.RUnlock()
- return ref
+ if ref, ok := n.endpoints[id]; ok {
+ // An endpoint with this id exists, check if it can be used and return it.
+ switch ref.getKind() {
+ case permanentExpired:
+ if !spoofingOrPromiscuous {
+ n.mu.RUnlock()
+ return nil
+ }
+ fallthrough
+ case temporary, permanent:
+ if ref.tryIncRef() {
+ n.mu.RUnlock()
+ return ref
+ }
+ }
}
- // The address was not found, create a temporary one if requested by the
- // caller or if the address is found in the NIC's subnets.
- createTempEP := allowTemp
+ // A usable reference was not found, create a temporary one if requested by
+ // the caller or if the address is found in the NIC's subnets.
+ createTempEP := spoofingOrPromiscuous
if !createTempEP {
- for _, sn := range n.subnets {
+ for _, sn := range n.addressRanges {
+ // Skip the subnet address.
+ if address == sn.ID() {
+ continue
+ }
+ // For now just skip the broadcast address, until we support it.
+ // FIXME(b/137608825): Add support for sending/receiving directed
+ // (subnet) broadcast.
+ if address == sn.Broadcast() {
+ continue
+ }
if sn.Contains(address) {
createTempEP = true
break
@@ -230,34 +237,70 @@ func (n *NIC) getRefOrCreateTemp(protocol tcpip.NetworkProtocolNumber, address t
// endpoint, create a new "temporary" endpoint. It will only exist while
// there's a route through it.
n.mu.Lock()
- if ref, ok := n.endpoints[id]; ok && ref.tryIncRef() {
- n.mu.Unlock()
- return ref
+ if ref, ok := n.endpoints[id]; ok {
+ // No need to check the type as we are ok with expired endpoints at this
+ // point.
+ if ref.tryIncRef() {
+ n.mu.Unlock()
+ return ref
+ }
+ // tryIncRef failing means the endpoint is scheduled to be removed once the
+ // lock is released. Remove it here so we can create a new (temporary) one.
+ // The removal logic waiting for the lock handles this case.
+ n.removeEndpointLocked(ref)
}
+ // Add a new temporary endpoint.
netProto, ok := n.stack.networkProtocols[protocol]
if !ok {
n.mu.Unlock()
return nil
}
-
ref, _ := n.addAddressLocked(tcpip.ProtocolAddress{
Protocol: protocol,
AddressWithPrefix: tcpip.AddressWithPrefix{
Address: address,
PrefixLen: netProto.DefaultPrefixLen(),
},
- }, peb, true)
-
- if ref != nil {
- ref.holdsInsertRef = false
- }
+ }, peb, temporary)
n.mu.Unlock()
return ref
}
-func (n *NIC) addAddressLocked(protocolAddress tcpip.ProtocolAddress, peb PrimaryEndpointBehavior, replace bool) (*referencedNetworkEndpoint, *tcpip.Error) {
+func (n *NIC) addPermanentAddressLocked(protocolAddress tcpip.ProtocolAddress, peb PrimaryEndpointBehavior) (*referencedNetworkEndpoint, *tcpip.Error) {
+ id := NetworkEndpointID{protocolAddress.AddressWithPrefix.Address}
+ if ref, ok := n.endpoints[id]; ok {
+ switch ref.getKind() {
+ case permanent:
+ // The NIC already have a permanent endpoint with that address.
+ return nil, tcpip.ErrDuplicateAddress
+ case permanentExpired, temporary:
+ // Promote the endpoint to become permanent.
+ if ref.tryIncRef() {
+ ref.setKind(permanent)
+ return ref, nil
+ }
+ // tryIncRef failing means the endpoint is scheduled to be removed once
+ // the lock is released. Remove it here so we can create a new
+ // (permanent) one. The removal logic waiting for the lock handles this
+ // case.
+ n.removeEndpointLocked(ref)
+ }
+ }
+ return n.addAddressLocked(protocolAddress, peb, permanent)
+}
+
+func (n *NIC) addAddressLocked(protocolAddress tcpip.ProtocolAddress, peb PrimaryEndpointBehavior, kind networkEndpointKind) (*referencedNetworkEndpoint, *tcpip.Error) {
+ // TODO(b/141022673): Validate IP address before adding them.
+
+ // Sanity check.
+ id := NetworkEndpointID{protocolAddress.AddressWithPrefix.Address}
+ if _, ok := n.endpoints[id]; ok {
+ // Endpoint already exists.
+ return nil, tcpip.ErrDuplicateAddress
+ }
+
netProto, ok := n.stack.networkProtocols[protocolAddress.Protocol]
if !ok {
return nil, tcpip.ErrUnknownProtocol
@@ -268,22 +311,12 @@ func (n *NIC) addAddressLocked(protocolAddress tcpip.ProtocolAddress, peb Primar
if err != nil {
return nil, err
}
-
- id := *ep.ID()
- if ref, ok := n.endpoints[id]; ok {
- if !replace {
- return nil, tcpip.ErrDuplicateAddress
- }
-
- n.removeEndpointLocked(ref)
- }
-
ref := &referencedNetworkEndpoint{
- refs: 1,
- ep: ep,
- nic: n,
- protocol: protocolAddress.Protocol,
- holdsInsertRef: true,
+ refs: 1,
+ ep: ep,
+ nic: n,
+ protocol: protocolAddress.Protocol,
+ kind: kind,
}
// Set up cache if link address resolution exists for this protocol.
@@ -293,6 +326,15 @@ func (n *NIC) addAddressLocked(protocolAddress tcpip.ProtocolAddress, peb Primar
}
}
+ // If we are adding an IPv6 unicast address, join the solicited-node
+ // multicast address.
+ if protocolAddress.Protocol == header.IPv6ProtocolNumber && header.IsV6UnicastAddress(protocolAddress.AddressWithPrefix.Address) {
+ snmc := header.SolicitedNodeAddr(protocolAddress.AddressWithPrefix.Address)
+ if err := n.joinGroupLocked(protocolAddress.Protocol, snmc); err != nil {
+ return nil, err
+ }
+ }
+
n.endpoints[id] = ref
l, ok := n.primary[protocolAddress.Protocol]
@@ -316,18 +358,26 @@ func (n *NIC) addAddressLocked(protocolAddress tcpip.ProtocolAddress, peb Primar
func (n *NIC) AddAddress(protocolAddress tcpip.ProtocolAddress, peb PrimaryEndpointBehavior) *tcpip.Error {
// Add the endpoint.
n.mu.Lock()
- _, err := n.addAddressLocked(protocolAddress, peb, false)
+ _, err := n.addPermanentAddressLocked(protocolAddress, peb)
n.mu.Unlock()
return err
}
-// Addresses returns the addresses associated with this NIC.
-func (n *NIC) Addresses() []tcpip.ProtocolAddress {
+// AllAddresses returns all addresses (primary and non-primary) associated with
+// this NIC.
+func (n *NIC) AllAddresses() []tcpip.ProtocolAddress {
n.mu.RLock()
defer n.mu.RUnlock()
+
addrs := make([]tcpip.ProtocolAddress, 0, len(n.endpoints))
for nid, ref := range n.endpoints {
+ // Don't include expired or tempory endpoints to avoid confusion and
+ // prevent the caller from using those.
+ switch ref.getKind() {
+ case permanentExpired, temporary:
+ continue
+ }
addrs = append(addrs, tcpip.ProtocolAddress{
Protocol: ref.protocol,
AddressWithPrefix: tcpip.AddressWithPrefix{
@@ -339,45 +389,66 @@ func (n *NIC) Addresses() []tcpip.ProtocolAddress {
return addrs
}
-// AddSubnet adds a new subnet to n, so that it starts accepting packets
-// targeted at the given address and network protocol.
-func (n *NIC) AddSubnet(protocol tcpip.NetworkProtocolNumber, subnet tcpip.Subnet) {
+// PrimaryAddresses returns the primary addresses associated with this NIC.
+func (n *NIC) PrimaryAddresses() []tcpip.ProtocolAddress {
+ n.mu.RLock()
+ defer n.mu.RUnlock()
+
+ var addrs []tcpip.ProtocolAddress
+ for proto, list := range n.primary {
+ for e := list.Front(); e != nil; e = e.Next() {
+ ref := e.(*referencedNetworkEndpoint)
+ // Don't include expired or tempory endpoints to avoid confusion and
+ // prevent the caller from using those.
+ switch ref.getKind() {
+ case permanentExpired, temporary:
+ continue
+ }
+
+ addrs = append(addrs, tcpip.ProtocolAddress{
+ Protocol: proto,
+ AddressWithPrefix: tcpip.AddressWithPrefix{
+ Address: ref.ep.ID().LocalAddress,
+ PrefixLen: ref.ep.PrefixLen(),
+ },
+ })
+ }
+ }
+ return addrs
+}
+
+// AddAddressRange adds a range of addresses to n, so that it starts accepting
+// packets targeted at the given addresses and network protocol. The range is
+// given by a subnet address, and all addresses contained in the subnet are
+// used except for the subnet address itself and the subnet's broadcast
+// address.
+func (n *NIC) AddAddressRange(protocol tcpip.NetworkProtocolNumber, subnet tcpip.Subnet) {
n.mu.Lock()
- n.subnets = append(n.subnets, subnet)
+ n.addressRanges = append(n.addressRanges, subnet)
n.mu.Unlock()
}
-// RemoveSubnet removes the given subnet from n.
-func (n *NIC) RemoveSubnet(subnet tcpip.Subnet) {
+// RemoveAddressRange removes the given address range from n.
+func (n *NIC) RemoveAddressRange(subnet tcpip.Subnet) {
n.mu.Lock()
// Use the same underlying array.
- tmp := n.subnets[:0]
- for _, sub := range n.subnets {
+ tmp := n.addressRanges[:0]
+ for _, sub := range n.addressRanges {
if sub != subnet {
tmp = append(tmp, sub)
}
}
- n.subnets = tmp
+ n.addressRanges = tmp
n.mu.Unlock()
}
-// ContainsSubnet reports whether this NIC contains the given subnet.
-func (n *NIC) ContainsSubnet(subnet tcpip.Subnet) bool {
- for _, s := range n.Subnets() {
- if s == subnet {
- return true
- }
- }
- return false
-}
-
// Subnets returns the Subnets associated with this NIC.
-func (n *NIC) Subnets() []tcpip.Subnet {
+func (n *NIC) AddressRanges() []tcpip.Subnet {
n.mu.RLock()
defer n.mu.RUnlock()
- sns := make([]tcpip.Subnet, 0, len(n.subnets)+len(n.endpoints))
+ sns := make([]tcpip.Subnet, 0, len(n.addressRanges)+len(n.endpoints))
for nid := range n.endpoints {
sn, err := tcpip.NewSubnet(nid.LocalAddress, tcpip.AddressMask(strings.Repeat("\xff", len(nid.LocalAddress))))
if err != nil {
@@ -387,19 +458,22 @@ func (n *NIC) Subnets() []tcpip.Subnet {
}
sns = append(sns, sn)
}
- return append(sns, n.subnets...)
+ return append(sns, n.addressRanges...)
}
func (n *NIC) removeEndpointLocked(r *referencedNetworkEndpoint) {
id := *r.ep.ID()
- // Nothing to do if the reference has already been replaced with a
- // different one.
+ // Nothing to do if the reference has already been replaced with a different
+ // one. This happens in the case where 1) this endpoint's ref count hit zero
+ // and was waiting (on the lock) to be removed and 2) the same address was
+ // re-added in the meantime by removing this endpoint from the list and
+ // adding a new one.
if n.endpoints[id] != r {
return
}
- if r.holdsInsertRef {
+ if r.getKind() == permanent {
panic("Reference count dropped to zero before being removed")
}
@@ -418,15 +492,28 @@ func (n *NIC) removeEndpoint(r *referencedNetworkEndpoint) {
n.mu.Unlock()
}
-func (n *NIC) removeAddressLocked(addr tcpip.Address) *tcpip.Error {
- r := n.endpoints[NetworkEndpointID{addr}]
- if r == nil || !r.holdsInsertRef {
+func (n *NIC) removePermanentAddressLocked(addr tcpip.Address) *tcpip.Error {
+ r, ok := n.endpoints[NetworkEndpointID{addr}]
+ if !ok || r.getKind() != permanent {
return tcpip.ErrBadLocalAddress
}
- r.holdsInsertRef = false
+ r.setKind(permanentExpired)
+ if !r.decRefLocked() {
+ // The endpoint still has references to it.
+ return nil
+ }
- r.decRefLocked()
+ // At this point the endpoint is deleted.
+
+ // If we are removing an IPv6 unicast address, leave the solicited-node
+ // multicast address.
+ if r.protocol == header.IPv6ProtocolNumber && header.IsV6UnicastAddress(addr) {
+ snmc := header.SolicitedNodeAddr(addr)
+ if err := n.leaveGroupLocked(snmc); err != nil {
+ return err
+ }
+ }
return nil
}
@@ -435,7 +522,7 @@ func (n *NIC) removeAddressLocked(addr tcpip.Address) *tcpip.Error {
func (n *NIC) RemoveAddress(addr tcpip.Address) *tcpip.Error {
n.mu.Lock()
defer n.mu.Unlock()
- return n.removeAddressLocked(addr)
+ return n.removePermanentAddressLocked(addr)
}
// joinGroup adds a new endpoint for the given multicast address, if none
@@ -444,6 +531,13 @@ func (n *NIC) joinGroup(protocol tcpip.NetworkProtocolNumber, addr tcpip.Address
n.mu.Lock()
defer n.mu.Unlock()
+ return n.joinGroupLocked(protocol, addr)
+}
+
+// joinGroupLocked adds a new endpoint for the given multicast address, if none
+// exists yet. Otherwise it just increments its count. n MUST be locked before
+// joinGroupLocked is called.
+func (n *NIC) joinGroupLocked(protocol tcpip.NetworkProtocolNumber, addr tcpip.Address) *tcpip.Error {
id := NetworkEndpointID{addr}
joins := n.mcastJoins[id]
if joins == 0 {
@@ -451,13 +545,13 @@ func (n *NIC) joinGroup(protocol tcpip.NetworkProtocolNumber, addr tcpip.Address
if !ok {
return tcpip.ErrUnknownProtocol
}
- if _, err := n.addAddressLocked(tcpip.ProtocolAddress{
+ if _, err := n.addPermanentAddressLocked(tcpip.ProtocolAddress{
Protocol: protocol,
AddressWithPrefix: tcpip.AddressWithPrefix{
Address: addr,
PrefixLen: netProto.DefaultPrefixLen(),
},
- }, NeverPrimaryEndpoint, false); err != nil {
+ }, NeverPrimaryEndpoint); err != nil {
return err
}
}
@@ -471,6 +565,13 @@ func (n *NIC) leaveGroup(addr tcpip.Address) *tcpip.Error {
n.mu.Lock()
defer n.mu.Unlock()
+ return n.leaveGroupLocked(addr)
+}
+
+// leaveGroupLocked decrements the count for the given multicast address, and
+// when it reaches zero removes the endpoint for this address. n MUST be locked
+// before leaveGroupLocked is called.
+func (n *NIC) leaveGroupLocked(addr tcpip.Address) *tcpip.Error {
id := NetworkEndpointID{addr}
joins := n.mcastJoins[id]
switch joins {
@@ -479,7 +580,7 @@ func (n *NIC) leaveGroup(addr tcpip.Address) *tcpip.Error {
return tcpip.ErrBadLocalAddress
case 1:
// This is the last one, clean up.
- if err := n.removeAddressLocked(addr); err != nil {
+ if err := n.removePermanentAddressLocked(addr); err != nil {
return err
}
}
@@ -487,6 +588,13 @@ func (n *NIC) leaveGroup(addr tcpip.Address) *tcpip.Error {
return nil
}
+func handlePacket(protocol tcpip.NetworkProtocolNumber, dst, src tcpip.Address, localLinkAddr, remotelinkAddr tcpip.LinkAddress, ref *referencedNetworkEndpoint, vv buffer.VectorisedView) {
+ r := makeRoute(protocol, dst, src, localLinkAddr, ref, false /* handleLocal */, false /* multicastLoop */)
+ r.RemoteLinkAddress = remotelinkAddr
+ ref.ep.HandlePacket(&r, vv)
+ ref.decRef()
+}
+
// DeliverNetworkPacket finds the appropriate network protocol endpoint and
// hands the packet over for further processing. This function is called when
// the NIC receives a packet from the physical interface.
@@ -514,6 +622,8 @@ func (n *NIC) DeliverNetworkPacket(linkEP LinkEndpoint, remote, _ tcpip.LinkAddr
src, dst := netProto.ParseAddresses(vv.First())
+ n.stack.AddLinkAddress(n.id, src, remote)
+
// If the packet is destined to the IPv4 Broadcast address, then make a
// route to each IPv4 network endpoint and let each endpoint handle the
// packet.
@@ -521,11 +631,8 @@ func (n *NIC) DeliverNetworkPacket(linkEP LinkEndpoint, remote, _ tcpip.LinkAddr
// n.endpoints is mutex protected so acquire lock.
n.mu.RLock()
for _, ref := range n.endpoints {
- if ref.protocol == header.IPv4ProtocolNumber && ref.tryIncRef() {
- r := makeRoute(protocol, dst, src, linkEP.LinkAddress(), ref, false /* handleLocal */, false /* multicastLoop */)
- r.RemoteLinkAddress = remote
- ref.ep.HandlePacket(&r, vv)
- ref.decRef()
+ if ref.isValidForIncoming() && ref.protocol == header.IPv4ProtocolNumber && ref.tryIncRef() {
+ handlePacket(protocol, dst, src, linkEP.LinkAddress(), remote, ref, vv)
}
}
n.mu.RUnlock()
@@ -533,10 +640,7 @@ func (n *NIC) DeliverNetworkPacket(linkEP LinkEndpoint, remote, _ tcpip.LinkAddr
}
if ref := n.getRef(protocol, dst); ref != nil {
- r := makeRoute(protocol, dst, src, linkEP.LinkAddress(), ref, false /* handleLocal */, false /* multicastLoop */)
- r.RemoteLinkAddress = remote
- ref.ep.HandlePacket(&r, vv)
- ref.decRef()
+ handlePacket(protocol, dst, src, linkEP.LinkAddress(), remote, ref, vv)
return
}
@@ -559,8 +663,9 @@ func (n *NIC) DeliverNetworkPacket(linkEP LinkEndpoint, remote, _ tcpip.LinkAddr
n := r.ref.nic
n.mu.RLock()
ref, ok := n.endpoints[NetworkEndpointID{dst}]
+ ok = ok && ref.isValidForOutgoing() && ref.tryIncRef()
n.mu.RUnlock()
- if ok && ref.tryIncRef() {
+ if ok {
r.RemoteAddress = src
// TODO(b/123449044): Update the source NIC as well.
ref.ep.HandlePacket(&r, vv)
@@ -599,9 +704,7 @@ func (n *NIC) DeliverTransportPacket(r *Route, protocol tcpip.TransportProtocolN
// Raw socket packets are delivered based solely on the transport
// protocol number. We do not inspect the payload to ensure it's
// validly formed.
- if !n.demux.deliverRawPacket(r, protocol, netHeader, vv) {
- n.stack.demux.deliverRawPacket(r, protocol, netHeader, vv)
- }
+ n.stack.demux.deliverRawPacket(r, protocol, netHeader, vv)
if len(vv.First()) < transProto.MinimumPacketSize() {
n.stack.stats.MalformedRcvdPackets.Increment()
@@ -615,9 +718,6 @@ func (n *NIC) DeliverTransportPacket(r *Route, protocol tcpip.TransportProtocolN
}
id := TransportEndpointID{dstPort, r.LocalAddress, srcPort, r.RemoteAddress}
- if n.demux.deliverPacket(r, protocol, netHeader, vv, id) {
- return
- }
if n.stack.demux.deliverPacket(r, protocol, netHeader, vv, id) {
return
}
@@ -631,7 +731,7 @@ func (n *NIC) DeliverTransportPacket(r *Route, protocol tcpip.TransportProtocolN
// We could not find an appropriate destination for this packet, so
// deliver it to the global handler.
- if !transProto.HandleUnknownDestinationPacket(r, id, vv) {
+ if !transProto.HandleUnknownDestinationPacket(r, id, netHeader, vv) {
n.stack.stats.MalformedRcvdPackets.Increment()
}
}
@@ -659,10 +759,7 @@ func (n *NIC) DeliverTransportControlPacket(local, remote tcpip.Address, net tcp
}
id := TransportEndpointID{srcPort, local, dstPort, remote}
- if n.demux.deliverControlPacket(net, trans, typ, extra, vv, id) {
- return
- }
- if n.stack.demux.deliverControlPacket(net, trans, typ, extra, vv, id) {
+ if n.stack.demux.deliverControlPacket(n, net, trans, typ, extra, vv, id) {
return
}
}
@@ -672,9 +769,38 @@ func (n *NIC) ID() tcpip.NICID {
return n.id
}
+// Stack returns the instance of the Stack that owns this NIC.
+func (n *NIC) Stack() *Stack {
+ return n.stack
+}
+
+type networkEndpointKind int32
+
+const (
+ // A permanent endpoint is created by adding a permanent address (vs. a
+ // temporary one) to the NIC. Its reference count is biased by 1 to avoid
+ // removal when no route holds a reference to it. It is removed by explicitly
+ // removing the permanent address from the NIC.
+ permanent networkEndpointKind = iota
+
+ // An expired permanent endoint is a permanent endoint that had its address
+ // removed from the NIC, and it is waiting to be removed once no more routes
+ // hold a reference to it. This is achieved by decreasing its reference count
+ // by 1. If its address is re-added before the endpoint is removed, its type
+ // changes back to permanent and its reference count increases by 1 again.
+ permanentExpired
+
+ // A temporary endpoint is created for spoofing outgoing packets, or when in
+ // promiscuous mode and accepting incoming packets that don't match any
+ // permanent endpoint. Its reference count is not biased by 1 and the
+ // endpoint is removed immediately when no more route holds a reference to
+ // it. A temporary endpoint can be promoted to permanent if its address
+ // is added permanently.
+ temporary
+)
+
type referencedNetworkEndpoint struct {
ilist.Entry
- refs int32
ep NetworkEndpoint
nic *NIC
protocol tcpip.NetworkProtocolNumber
@@ -683,11 +809,34 @@ type referencedNetworkEndpoint struct {
// protocol. Set to nil otherwise.
linkCache LinkAddressCache
- // holdsInsertRef is protected by the NIC's mutex. It indicates whether
- // the reference count is biased by 1 due to the insertion of the
- // endpoint. It is reset to false when RemoveAddress is called on the
- // NIC.
- holdsInsertRef bool
+ // refs is counting references held for this endpoint. When refs hits zero it
+ // triggers the automatic removal of the endpoint from the NIC.
+ refs int32
+
+ // networkEndpointKind must only be accessed using {get,set}Kind().
+ kind networkEndpointKind
+}
+
+func (r *referencedNetworkEndpoint) getKind() networkEndpointKind {
+ return networkEndpointKind(atomic.LoadInt32((*int32)(&r.kind)))
+}
+
+func (r *referencedNetworkEndpoint) setKind(kind networkEndpointKind) {
+ atomic.StoreInt32((*int32)(&r.kind), int32(kind))
+}
+
+// isValidForOutgoing returns true if the endpoint can be used to send out a
+// packet. It requires the endpoint to not be marked expired (i.e., its address
+// has been removed), or the NIC to be in spoofing mode.
+func (r *referencedNetworkEndpoint) isValidForOutgoing() bool {
+ return r.getKind() != permanentExpired || r.nic.spoofing
+}
+
+// isValidForIncoming returns true if the endpoint can accept an incoming
+// packet. It requires the endpoint to not be marked expired (i.e., its address
+// has been removed), or the NIC to be in promiscuous mode.
+func (r *referencedNetworkEndpoint) isValidForIncoming() bool {
+ return r.getKind() != permanentExpired || r.nic.promiscuous
}
// decRef decrements the ref count and cleans up the endpoint once it reaches
@@ -699,11 +848,14 @@ func (r *referencedNetworkEndpoint) decRef() {
}
// decRefLocked is the same as decRef but assumes that the NIC.mu mutex is
-// locked.
-func (r *referencedNetworkEndpoint) decRefLocked() {
+// locked. Returns true if the endpoint was removed.
+func (r *referencedNetworkEndpoint) decRefLocked() bool {
if atomic.AddInt32(&r.refs, -1) == 0 {
r.nic.removeEndpointLocked(r)
+ return true
}
+
+ return false
}
// incRef increments the ref count. It must only be called when the caller is
@@ -728,3 +880,8 @@ func (r *referencedNetworkEndpoint) tryIncRef() bool {
}
}
}
+
+// stack returns the Stack instance that owns the underlying endpoint.
+func (r *referencedNetworkEndpoint) stack() *Stack {
+ return r.nic.stack
+}
diff --git a/pkg/tcpip/stack/registration.go b/pkg/tcpip/stack/registration.go
index 2037eef9f..80101d4bb 100644
--- a/pkg/tcpip/stack/registration.go
+++ b/pkg/tcpip/stack/registration.go
@@ -15,8 +15,6 @@
package stack
import (
- "sync"
-
"gvisor.dev/gvisor/pkg/sleep"
"gvisor.dev/gvisor/pkg/tcpip"
"gvisor.dev/gvisor/pkg/tcpip/buffer"
@@ -109,7 +107,7 @@ type TransportProtocol interface {
//
// The return value indicates whether the packet was well-formed (for
// stats purposes only).
- HandleUnknownDestinationPacket(r *Route, id TransportEndpointID, vv buffer.VectorisedView) bool
+ HandleUnknownDestinationPacket(r *Route, id TransportEndpointID, netHeader buffer.View, vv buffer.VectorisedView) bool
// SetOption allows enabling/disabling protocol specific features.
// SetOption returns an error if the option is not supported or the
@@ -297,6 +295,15 @@ type LinkEndpoint interface {
// IsAttached returns whether a NetworkDispatcher is attached to the
// endpoint.
IsAttached() bool
+
+ // Wait waits for any worker goroutines owned by the endpoint to stop.
+ //
+ // For now, requesting that an endpoint's worker goroutine(s) stop is
+ // implementation specific.
+ //
+ // Wait will not block if the endpoint hasn't started any goroutines
+ // yet, even if it might later.
+ Wait()
}
// InjectableLinkEndpoint is a LinkEndpoint where inbound packets are
@@ -359,14 +366,6 @@ type LinkAddressCache interface {
RemoveWaker(nicid tcpip.NICID, addr tcpip.Address, waker *sleep.Waker)
}
-// TransportProtocolFactory functions are used by the stack to instantiate
-// transport protocols.
-type TransportProtocolFactory func() TransportProtocol
-
-// NetworkProtocolFactory provides methods to be used by the stack to
-// instantiate network protocols.
-type NetworkProtocolFactory func() NetworkProtocol
-
// UnassociatedEndpointFactory produces endpoints for writing packets not
// associated with a particular transport protocol. Such endpoints can be used
// to write arbitrary packets that include the IP header.
@@ -374,60 +373,6 @@ type UnassociatedEndpointFactory interface {
NewUnassociatedRawEndpoint(stack *Stack, netProto tcpip.NetworkProtocolNumber, transProto tcpip.TransportProtocolNumber, waiterQueue *waiter.Queue) (tcpip.Endpoint, *tcpip.Error)
}
-var (
- transportProtocols = make(map[string]TransportProtocolFactory)
- networkProtocols = make(map[string]NetworkProtocolFactory)
-
- unassociatedFactory UnassociatedEndpointFactory
-
- linkEPMu sync.RWMutex
- nextLinkEndpointID tcpip.LinkEndpointID = 1
- linkEndpoints = make(map[tcpip.LinkEndpointID]LinkEndpoint)
-)
-
-// RegisterTransportProtocolFactory registers a new transport protocol factory
-// with the stack so that it becomes available to users of the stack. This
-// function is intended to be called by init() functions of the protocols.
-func RegisterTransportProtocolFactory(name string, p TransportProtocolFactory) {
- transportProtocols[name] = p
-}
-
-// RegisterNetworkProtocolFactory registers a new network protocol factory with
-// the stack so that it becomes available to users of the stack. This function
-// is intended to be called by init() functions of the protocols.
-func RegisterNetworkProtocolFactory(name string, p NetworkProtocolFactory) {
- networkProtocols[name] = p
-}
-
-// RegisterUnassociatedFactory registers a factory to produce endpoints not
-// associated with any particular transport protocol. This function is intended
-// to be called by init() functions of the protocols.
-func RegisterUnassociatedFactory(f UnassociatedEndpointFactory) {
- unassociatedFactory = f
-}
-
-// RegisterLinkEndpoint register a link-layer protocol endpoint and returns an
-// ID that can be used to refer to it.
-func RegisterLinkEndpoint(linkEP LinkEndpoint) tcpip.LinkEndpointID {
- linkEPMu.Lock()
- defer linkEPMu.Unlock()
-
- v := nextLinkEndpointID
- nextLinkEndpointID++
-
- linkEndpoints[v] = linkEP
-
- return v
-}
-
-// FindLinkEndpoint finds the link endpoint associated with the given ID.
-func FindLinkEndpoint(id tcpip.LinkEndpointID) LinkEndpoint {
- linkEPMu.RLock()
- defer linkEPMu.RUnlock()
-
- return linkEndpoints[id]
-}
-
// GSOType is the type of GSO segments.
//
// +stateify savable
diff --git a/pkg/tcpip/stack/route.go b/pkg/tcpip/stack/route.go
index 391ab4344..5c8b7977a 100644
--- a/pkg/tcpip/stack/route.go
+++ b/pkg/tcpip/stack/route.go
@@ -148,11 +148,15 @@ func (r *Route) RemoveWaker(waker *sleep.Waker) {
// IsResolutionRequired returns true if Resolve() must be called to resolve
// the link address before the this route can be written to.
func (r *Route) IsResolutionRequired() bool {
- return r.ref.linkCache != nil && r.RemoteLinkAddress == ""
+ return r.ref.isValidForOutgoing() && r.ref.linkCache != nil && r.RemoteLinkAddress == ""
}
// WritePacket writes the packet through the given route.
func (r *Route) WritePacket(gso *GSO, hdr buffer.Prependable, payload buffer.VectorisedView, protocol tcpip.TransportProtocolNumber, ttl uint8) *tcpip.Error {
+ if !r.ref.isValidForOutgoing() {
+ return tcpip.ErrInvalidEndpointState
+ }
+
err := r.ref.ep.WritePacket(r, gso, hdr, payload, protocol, ttl, r.loop)
if err != nil {
r.Stats().IP.OutgoingPacketErrors.Increment()
@@ -166,6 +170,10 @@ func (r *Route) WritePacket(gso *GSO, hdr buffer.Prependable, payload buffer.Vec
// WriteHeaderIncludedPacket writes a packet already containing a network
// header through the given route.
func (r *Route) WriteHeaderIncludedPacket(payload buffer.VectorisedView) *tcpip.Error {
+ if !r.ref.isValidForOutgoing() {
+ return tcpip.ErrInvalidEndpointState
+ }
+
if err := r.ref.ep.WriteHeaderIncludedPacket(r, payload, r.loop); err != nil {
r.Stats().IP.OutgoingPacketErrors.Increment()
return err
@@ -209,3 +217,8 @@ func (r *Route) MakeLoopedRoute() Route {
}
return l
}
+
+// Stack returns the instance of the Stack that owns this route.
+func (r *Route) Stack() *Stack {
+ return r.ref.stack()
+}
diff --git a/pkg/tcpip/stack/stack.go b/pkg/tcpip/stack/stack.go
index d69162ba1..90c2cf1be 100644
--- a/pkg/tcpip/stack/stack.go
+++ b/pkg/tcpip/stack/stack.go
@@ -17,17 +17,15 @@
//
// For consumers, the only function of interest is New(), everything else is
// provided by the tcpip/public package.
-//
-// For protocol implementers, RegisterTransportProtocolFactory() and
-// RegisterNetworkProtocolFactory() are used to register protocol factories with
-// the stack, which will then be used to instantiate protocol objects when
-// consumers interact with the stack.
package stack
import (
+ "encoding/binary"
"sync"
"time"
+ "golang.org/x/time/rate"
+ "gvisor.dev/gvisor/pkg/rand"
"gvisor.dev/gvisor/pkg/sleep"
"gvisor.dev/gvisor/pkg/tcpip"
"gvisor.dev/gvisor/pkg/tcpip/buffer"
@@ -350,6 +348,9 @@ type Stack struct {
networkProtocols map[tcpip.NetworkProtocolNumber]NetworkProtocol
linkAddrResolvers map[tcpip.NetworkProtocolNumber]LinkAddressResolver
+ // unassociatedFactory creates unassociated endpoints. If nil, raw
+ // endpoints are disabled. It is set during Stack creation and is
+ // immutable.
unassociatedFactory UnassociatedEndpointFactory
demux *transportDemuxer
@@ -358,10 +359,6 @@ type Stack struct {
linkAddrCache *linkAddrCache
- // raw indicates whether raw sockets may be created. It is set during
- // Stack creation and is immutable.
- raw bool
-
mu sync.RWMutex
nics map[tcpip.NICID]*NIC
forwarding bool
@@ -389,10 +386,26 @@ type Stack struct {
// resumableEndpoints is a list of endpoints that need to be resumed if the
// stack is being restored.
resumableEndpoints []ResumableEndpoint
+
+ // icmpRateLimiter is a global rate limiter for all ICMP messages generated
+ // by the stack.
+ icmpRateLimiter *ICMPRateLimiter
+
+ // portSeed is a one-time random value initialized at stack startup
+ // and is used to seed the TCP port picking on active connections
+ //
+ // TODO(gvisor.dev/issues/940): S/R this field.
+ portSeed uint32
}
// Options contains optional Stack configuration.
type Options struct {
+ // NetworkProtocols lists the network protocols to enable.
+ NetworkProtocols []NetworkProtocol
+
+ // TransportProtocols lists the transport protocols to enable.
+ TransportProtocols []TransportProtocol
+
// Clock is an optional clock source used for timestampping packets.
//
// If no Clock is specified, the clock source will be time.Now.
@@ -406,8 +419,9 @@ type Options struct {
// stack (false).
HandleLocal bool
- // Raw indicates whether raw sockets may be created.
- Raw bool
+ // UnassociatedFactory produces unassociated endpoints raw endpoints.
+ // Raw endpoints are enabled only if this is non-nil.
+ UnassociatedFactory UnassociatedEndpointFactory
}
// New allocates a new networking stack with only the requested networking and
@@ -417,7 +431,7 @@ type Options struct {
// SetNetworkProtocolOption/SetTransportProtocolOption methods provided by the
// stack. Please refer to individual protocol implementations as to what options
// are supported.
-func New(network []string, transport []string, opts Options) *Stack {
+func New(opts Options) *Stack {
clock := opts.Clock
if clock == nil {
clock = &tcpip.StdClock{}
@@ -433,16 +447,12 @@ func New(network []string, transport []string, opts Options) *Stack {
clock: clock,
stats: opts.Stats.FillIn(),
handleLocal: opts.HandleLocal,
- raw: opts.Raw,
+ icmpRateLimiter: NewICMPRateLimiter(),
+ portSeed: generateRandUint32(),
}
// Add specified network protocols.
- for _, name := range network {
- netProtoFactory, ok := networkProtocols[name]
- if !ok {
- continue
- }
- netProto := netProtoFactory()
+ for _, netProto := range opts.NetworkProtocols {
s.networkProtocols[netProto.Number()] = netProto
if r, ok := netProto.(LinkAddressResolver); ok {
s.linkAddrResolvers[r.LinkAddressProtocol()] = r
@@ -450,18 +460,14 @@ func New(network []string, transport []string, opts Options) *Stack {
}
// Add specified transport protocols.
- for _, name := range transport {
- transProtoFactory, ok := transportProtocols[name]
- if !ok {
- continue
- }
- transProto := transProtoFactory()
+ for _, transProto := range opts.TransportProtocols {
s.transportProtocols[transProto.Number()] = &transportProtocolState{
proto: transProto,
}
}
- s.unassociatedFactory = unassociatedFactory
+ // Add the factory for unassociated endpoints, if present.
+ s.unassociatedFactory = opts.UnassociatedFactory
// Create the global transport demuxer.
s.demux = newTransportDemuxer(s)
@@ -596,7 +602,7 @@ func (s *Stack) NewEndpoint(transport tcpip.TransportProtocolNumber, network tcp
// protocol. Raw endpoints receive all traffic for a given protocol regardless
// of address.
func (s *Stack) NewRawEndpoint(transport tcpip.TransportProtocolNumber, network tcpip.NetworkProtocolNumber, waiterQueue *waiter.Queue, associated bool) (tcpip.Endpoint, *tcpip.Error) {
- if !s.raw {
+ if s.unassociatedFactory == nil {
return nil, tcpip.ErrNotPermitted
}
@@ -614,12 +620,7 @@ func (s *Stack) NewRawEndpoint(transport tcpip.TransportProtocolNumber, network
// createNIC creates a NIC with the provided id and link-layer endpoint, and
// optionally enable it.
-func (s *Stack) createNIC(id tcpip.NICID, name string, linkEP tcpip.LinkEndpointID, enabled, loopback bool) *tcpip.Error {
- ep := FindLinkEndpoint(linkEP)
- if ep == nil {
- return tcpip.ErrBadLinkEndpoint
- }
-
+func (s *Stack) createNIC(id tcpip.NICID, name string, ep LinkEndpoint, enabled, loopback bool) *tcpip.Error {
s.mu.Lock()
defer s.mu.Unlock()
@@ -632,40 +633,40 @@ func (s *Stack) createNIC(id tcpip.NICID, name string, linkEP tcpip.LinkEndpoint
s.nics[id] = n
if enabled {
- n.attachLinkEndpoint()
+ return n.enable()
}
return nil
}
// CreateNIC creates a NIC with the provided id and link-layer endpoint.
-func (s *Stack) CreateNIC(id tcpip.NICID, linkEP tcpip.LinkEndpointID) *tcpip.Error {
- return s.createNIC(id, "", linkEP, true, false)
+func (s *Stack) CreateNIC(id tcpip.NICID, ep LinkEndpoint) *tcpip.Error {
+ return s.createNIC(id, "", ep, true, false)
}
// CreateNamedNIC creates a NIC with the provided id and link-layer endpoint,
// and a human-readable name.
-func (s *Stack) CreateNamedNIC(id tcpip.NICID, name string, linkEP tcpip.LinkEndpointID) *tcpip.Error {
- return s.createNIC(id, name, linkEP, true, false)
+func (s *Stack) CreateNamedNIC(id tcpip.NICID, name string, ep LinkEndpoint) *tcpip.Error {
+ return s.createNIC(id, name, ep, true, false)
}
// CreateNamedLoopbackNIC creates a NIC with the provided id and link-layer
// endpoint, and a human-readable name.
-func (s *Stack) CreateNamedLoopbackNIC(id tcpip.NICID, name string, linkEP tcpip.LinkEndpointID) *tcpip.Error {
- return s.createNIC(id, name, linkEP, true, true)
+func (s *Stack) CreateNamedLoopbackNIC(id tcpip.NICID, name string, ep LinkEndpoint) *tcpip.Error {
+ return s.createNIC(id, name, ep, true, true)
}
// CreateDisabledNIC creates a NIC with the provided id and link-layer endpoint,
// but leave it disable. Stack.EnableNIC must be called before the link-layer
// endpoint starts delivering packets to it.
-func (s *Stack) CreateDisabledNIC(id tcpip.NICID, linkEP tcpip.LinkEndpointID) *tcpip.Error {
- return s.createNIC(id, "", linkEP, false, false)
+func (s *Stack) CreateDisabledNIC(id tcpip.NICID, ep LinkEndpoint) *tcpip.Error {
+ return s.createNIC(id, "", ep, false, false)
}
// CreateDisabledNamedNIC is a combination of CreateNamedNIC and
// CreateDisabledNIC.
-func (s *Stack) CreateDisabledNamedNIC(id tcpip.NICID, name string, linkEP tcpip.LinkEndpointID) *tcpip.Error {
- return s.createNIC(id, name, linkEP, false, false)
+func (s *Stack) CreateDisabledNamedNIC(id tcpip.NICID, name string, ep LinkEndpoint) *tcpip.Error {
+ return s.createNIC(id, name, ep, false, false)
}
// EnableNIC enables the given NIC so that the link-layer endpoint can start
@@ -679,9 +680,7 @@ func (s *Stack) EnableNIC(id tcpip.NICID) *tcpip.Error {
return tcpip.ErrUnknownNICID
}
- nic.attachLinkEndpoint()
-
- return nil
+ return nic.enable()
}
// CheckNIC checks if a NIC is usable.
@@ -696,14 +695,14 @@ func (s *Stack) CheckNIC(id tcpip.NICID) bool {
}
// NICSubnets returns a map of NICIDs to their associated subnets.
-func (s *Stack) NICSubnets() map[tcpip.NICID][]tcpip.Subnet {
+func (s *Stack) NICAddressRanges() map[tcpip.NICID][]tcpip.Subnet {
s.mu.RLock()
defer s.mu.RUnlock()
nics := map[tcpip.NICID][]tcpip.Subnet{}
for id, nic := range s.nics {
- nics[id] = append(nics[id], nic.Subnets()...)
+ nics[id] = append(nics[id], nic.AddressRanges()...)
}
return nics
}
@@ -739,7 +738,7 @@ func (s *Stack) NICInfo() map[tcpip.NICID]NICInfo {
nics[id] = NICInfo{
Name: nic.name,
LinkAddress: nic.linkEP.LinkAddress(),
- ProtocolAddresses: nic.Addresses(),
+ ProtocolAddresses: nic.PrimaryAddresses(),
Flags: flags,
MTU: nic.linkEP.MTU(),
Stats: nic.stats,
@@ -804,71 +803,79 @@ func (s *Stack) AddProtocolAddressWithOptions(id tcpip.NICID, protocolAddress tc
return nic.AddAddress(protocolAddress, peb)
}
-// AddSubnet adds a subnet range to the specified NIC.
-func (s *Stack) AddSubnet(id tcpip.NICID, protocol tcpip.NetworkProtocolNumber, subnet tcpip.Subnet) *tcpip.Error {
+// AddAddressRange adds a range of addresses to the specified NIC. The range is
+// given by a subnet address, and all addresses contained in the subnet are
+// used except for the subnet address itself and the subnet's broadcast
+// address.
+func (s *Stack) AddAddressRange(id tcpip.NICID, protocol tcpip.NetworkProtocolNumber, subnet tcpip.Subnet) *tcpip.Error {
s.mu.RLock()
defer s.mu.RUnlock()
if nic, ok := s.nics[id]; ok {
- nic.AddSubnet(protocol, subnet)
+ nic.AddAddressRange(protocol, subnet)
return nil
}
return tcpip.ErrUnknownNICID
}
-// RemoveSubnet removes the subnet range from the specified NIC.
-func (s *Stack) RemoveSubnet(id tcpip.NICID, subnet tcpip.Subnet) *tcpip.Error {
+// RemoveAddressRange removes the range of addresses from the specified NIC.
+func (s *Stack) RemoveAddressRange(id tcpip.NICID, subnet tcpip.Subnet) *tcpip.Error {
s.mu.RLock()
defer s.mu.RUnlock()
if nic, ok := s.nics[id]; ok {
- nic.RemoveSubnet(subnet)
+ nic.RemoveAddressRange(subnet)
return nil
}
return tcpip.ErrUnknownNICID
}
-// ContainsSubnet reports whether the specified NIC contains the specified
-// subnet.
-func (s *Stack) ContainsSubnet(id tcpip.NICID, subnet tcpip.Subnet) (bool, *tcpip.Error) {
+// RemoveAddress removes an existing network-layer address from the specified
+// NIC.
+func (s *Stack) RemoveAddress(id tcpip.NICID, addr tcpip.Address) *tcpip.Error {
s.mu.RLock()
defer s.mu.RUnlock()
if nic, ok := s.nics[id]; ok {
- return nic.ContainsSubnet(subnet), nil
+ return nic.RemoveAddress(addr)
}
- return false, tcpip.ErrUnknownNICID
+ return tcpip.ErrUnknownNICID
}
-// RemoveAddress removes an existing network-layer address from the specified
-// NIC.
-func (s *Stack) RemoveAddress(id tcpip.NICID, addr tcpip.Address) *tcpip.Error {
+// AllAddresses returns a map of NICIDs to their protocol addresses (primary
+// and non-primary).
+func (s *Stack) AllAddresses() map[tcpip.NICID][]tcpip.ProtocolAddress {
s.mu.RLock()
defer s.mu.RUnlock()
- if nic, ok := s.nics[id]; ok {
- return nic.RemoveAddress(addr)
+ nics := make(map[tcpip.NICID][]tcpip.ProtocolAddress)
+ for id, nic := range s.nics {
+ nics[id] = nic.AllAddresses()
}
-
- return tcpip.ErrUnknownNICID
+ return nics
}
-// GetMainNICAddress returns the first primary address (and the subnet that
-// contains it) for the given NIC and protocol. Returns an arbitrary endpoint's
-// address if no primary addresses exist. Returns an error if the NIC doesn't
-// exist or has no endpoints.
+// GetMainNICAddress returns the first primary address and prefix for the given
+// NIC and protocol. Returns an error if the NIC doesn't exist and an empty
+// value if the NIC doesn't have a primary address for the given protocol.
func (s *Stack) GetMainNICAddress(id tcpip.NICID, protocol tcpip.NetworkProtocolNumber) (tcpip.AddressWithPrefix, *tcpip.Error) {
s.mu.RLock()
defer s.mu.RUnlock()
- if nic, ok := s.nics[id]; ok {
- return nic.getMainNICAddress(protocol)
+ nic, ok := s.nics[id]
+ if !ok {
+ return tcpip.AddressWithPrefix{}, tcpip.ErrUnknownNICID
}
- return tcpip.AddressWithPrefix{}, tcpip.ErrUnknownNICID
+ for _, a := range nic.PrimaryAddresses() {
+ if a.Protocol == protocol {
+ return a.AddressWithPrefix, nil
+ }
+ }
+ return tcpip.AddressWithPrefix{}, nil
}
func (s *Stack) getRefEP(nic *NIC, localAddr tcpip.Address, netProto tcpip.NetworkProtocolNumber) (ref *referencedNetworkEndpoint) {
@@ -1035,73 +1042,27 @@ func (s *Stack) RemoveWaker(nicid tcpip.NICID, addr tcpip.Address, waker *sleep.
// transport dispatcher. Received packets that match the provided id will be
// delivered to the given endpoint; specifying a nic is optional, but
// nic-specific IDs have precedence over global ones.
-func (s *Stack) RegisterTransportEndpoint(nicID tcpip.NICID, netProtos []tcpip.NetworkProtocolNumber, protocol tcpip.TransportProtocolNumber, id TransportEndpointID, ep TransportEndpoint, reusePort bool) *tcpip.Error {
- if nicID == 0 {
- return s.demux.registerEndpoint(netProtos, protocol, id, ep, reusePort)
- }
-
- s.mu.RLock()
- defer s.mu.RUnlock()
-
- nic := s.nics[nicID]
- if nic == nil {
- return tcpip.ErrUnknownNICID
- }
-
- return nic.demux.registerEndpoint(netProtos, protocol, id, ep, reusePort)
+func (s *Stack) RegisterTransportEndpoint(nicID tcpip.NICID, netProtos []tcpip.NetworkProtocolNumber, protocol tcpip.TransportProtocolNumber, id TransportEndpointID, ep TransportEndpoint, reusePort bool, bindToDevice tcpip.NICID) *tcpip.Error {
+ return s.demux.registerEndpoint(netProtos, protocol, id, ep, reusePort, bindToDevice)
}
// UnregisterTransportEndpoint removes the endpoint with the given id from the
// stack transport dispatcher.
-func (s *Stack) UnregisterTransportEndpoint(nicID tcpip.NICID, netProtos []tcpip.NetworkProtocolNumber, protocol tcpip.TransportProtocolNumber, id TransportEndpointID, ep TransportEndpoint) {
- if nicID == 0 {
- s.demux.unregisterEndpoint(netProtos, protocol, id, ep)
- return
- }
-
- s.mu.RLock()
- defer s.mu.RUnlock()
-
- nic := s.nics[nicID]
- if nic != nil {
- nic.demux.unregisterEndpoint(netProtos, protocol, id, ep)
- }
+func (s *Stack) UnregisterTransportEndpoint(nicID tcpip.NICID, netProtos []tcpip.NetworkProtocolNumber, protocol tcpip.TransportProtocolNumber, id TransportEndpointID, ep TransportEndpoint, bindToDevice tcpip.NICID) {
+ s.demux.unregisterEndpoint(netProtos, protocol, id, ep, bindToDevice)
}
// RegisterRawTransportEndpoint registers the given endpoint with the stack
// transport dispatcher. Received packets that match the provided transport
// protocol will be delivered to the given endpoint.
func (s *Stack) RegisterRawTransportEndpoint(nicID tcpip.NICID, netProto tcpip.NetworkProtocolNumber, transProto tcpip.TransportProtocolNumber, ep RawTransportEndpoint) *tcpip.Error {
- if nicID == 0 {
- return s.demux.registerRawEndpoint(netProto, transProto, ep)
- }
-
- s.mu.RLock()
- defer s.mu.RUnlock()
-
- nic := s.nics[nicID]
- if nic == nil {
- return tcpip.ErrUnknownNICID
- }
-
- return nic.demux.registerRawEndpoint(netProto, transProto, ep)
+ return s.demux.registerRawEndpoint(netProto, transProto, ep)
}
// UnregisterRawTransportEndpoint removes the endpoint for the transport
// protocol from the stack transport dispatcher.
func (s *Stack) UnregisterRawTransportEndpoint(nicID tcpip.NICID, netProto tcpip.NetworkProtocolNumber, transProto tcpip.TransportProtocolNumber, ep RawTransportEndpoint) {
- if nicID == 0 {
- s.demux.unregisterRawEndpoint(netProto, transProto, ep)
- return
- }
-
- s.mu.RLock()
- defer s.mu.RUnlock()
-
- nic := s.nics[nicID]
- if nic != nil {
- nic.demux.unregisterRawEndpoint(netProto, transProto, ep)
- }
+ s.demux.unregisterRawEndpoint(netProto, transProto, ep)
}
// RegisterRestoredEndpoint records e as an endpoint that has been restored on
@@ -1215,3 +1176,49 @@ func (s *Stack) IPTables() iptables.IPTables {
func (s *Stack) SetIPTables(ipt iptables.IPTables) {
s.tables = ipt
}
+
+// ICMPLimit returns the maximum number of ICMP messages that can be sent
+// in one second.
+func (s *Stack) ICMPLimit() rate.Limit {
+ return s.icmpRateLimiter.Limit()
+}
+
+// SetICMPLimit sets the maximum number of ICMP messages that be sent
+// in one second.
+func (s *Stack) SetICMPLimit(newLimit rate.Limit) {
+ s.icmpRateLimiter.SetLimit(newLimit)
+}
+
+// ICMPBurst returns the maximum number of ICMP messages that can be sent
+// in a single burst.
+func (s *Stack) ICMPBurst() int {
+ return s.icmpRateLimiter.Burst()
+}
+
+// SetICMPBurst sets the maximum number of ICMP messages that can be sent
+// in a single burst.
+func (s *Stack) SetICMPBurst(burst int) {
+ s.icmpRateLimiter.SetBurst(burst)
+}
+
+// AllowICMPMessage returns true if we the rate limiter allows at least one
+// ICMP message to be sent at this instant.
+func (s *Stack) AllowICMPMessage() bool {
+ return s.icmpRateLimiter.Allow()
+}
+
+// PortSeed returns a 32 bit value that can be used as a seed value for port
+// picking.
+//
+// NOTE: The seed is generated once during stack initialization only.
+func (s *Stack) PortSeed() uint32 {
+ return s.portSeed
+}
+
+func generateRandUint32() uint32 {
+ b := make([]byte, 4)
+ if _, err := rand.Read(b); err != nil {
+ panic(err)
+ }
+ return binary.LittleEndian.Uint32(b)
+}
diff --git a/pkg/tcpip/stack/stack_test.go b/pkg/tcpip/stack/stack_test.go
index 137c6183e..d2dede8a9 100644
--- a/pkg/tcpip/stack/stack_test.go
+++ b/pkg/tcpip/stack/stack_test.go
@@ -60,11 +60,11 @@ type fakeNetworkEndpoint struct {
prefixLen int
proto *fakeNetworkProtocol
dispatcher stack.TransportDispatcher
- linkEP stack.LinkEndpoint
+ ep stack.LinkEndpoint
}
func (f *fakeNetworkEndpoint) MTU() uint32 {
- return f.linkEP.MTU() - uint32(f.MaxHeaderLength())
+ return f.ep.MTU() - uint32(f.MaxHeaderLength())
}
func (f *fakeNetworkEndpoint) NICID() tcpip.NICID {
@@ -108,7 +108,7 @@ func (f *fakeNetworkEndpoint) HandlePacket(r *stack.Route, vv buffer.VectorisedV
}
func (f *fakeNetworkEndpoint) MaxHeaderLength() uint16 {
- return f.linkEP.MaxHeaderLength() + fakeNetHeaderLen
+ return f.ep.MaxHeaderLength() + fakeNetHeaderLen
}
func (f *fakeNetworkEndpoint) PseudoHeaderChecksum(protocol tcpip.TransportProtocolNumber, dstAddr tcpip.Address) uint16 {
@@ -116,7 +116,7 @@ func (f *fakeNetworkEndpoint) PseudoHeaderChecksum(protocol tcpip.TransportProto
}
func (f *fakeNetworkEndpoint) Capabilities() stack.LinkEndpointCapabilities {
- return f.linkEP.Capabilities()
+ return f.ep.Capabilities()
}
func (f *fakeNetworkEndpoint) WritePacket(r *stack.Route, gso *stack.GSO, hdr buffer.Prependable, payload buffer.VectorisedView, protocol tcpip.TransportProtocolNumber, _ uint8, loop stack.PacketLooping) *tcpip.Error {
@@ -141,7 +141,7 @@ func (f *fakeNetworkEndpoint) WritePacket(r *stack.Route, gso *stack.GSO, hdr bu
return nil
}
- return f.linkEP.WritePacket(r, gso, hdr, payload, fakeNetNumber)
+ return f.ep.WritePacket(r, gso, hdr, payload, fakeNetNumber)
}
func (*fakeNetworkEndpoint) WriteHeaderIncludedPacket(r *stack.Route, payload buffer.VectorisedView, loop stack.PacketLooping) *tcpip.Error {
@@ -181,18 +181,22 @@ func (f *fakeNetworkProtocol) DefaultPrefixLen() int {
return fakeDefaultPrefixLen
}
+func (f *fakeNetworkProtocol) PacketCount(intfAddr byte) int {
+ return f.packetCount[int(intfAddr)%len(f.packetCount)]
+}
+
func (*fakeNetworkProtocol) ParseAddresses(v buffer.View) (src, dst tcpip.Address) {
return tcpip.Address(v[1:2]), tcpip.Address(v[0:1])
}
-func (f *fakeNetworkProtocol) NewEndpoint(nicid tcpip.NICID, addrWithPrefix tcpip.AddressWithPrefix, linkAddrCache stack.LinkAddressCache, dispatcher stack.TransportDispatcher, linkEP stack.LinkEndpoint) (stack.NetworkEndpoint, *tcpip.Error) {
+func (f *fakeNetworkProtocol) NewEndpoint(nicid tcpip.NICID, addrWithPrefix tcpip.AddressWithPrefix, linkAddrCache stack.LinkAddressCache, dispatcher stack.TransportDispatcher, ep stack.LinkEndpoint) (stack.NetworkEndpoint, *tcpip.Error) {
return &fakeNetworkEndpoint{
nicid: nicid,
id: stack.NetworkEndpointID{LocalAddress: addrWithPrefix.Address},
prefixLen: addrWithPrefix.PrefixLen,
proto: f,
dispatcher: dispatcher,
- linkEP: linkEP,
+ ep: ep,
}, nil
}
@@ -218,12 +222,18 @@ func (f *fakeNetworkProtocol) Option(option interface{}) *tcpip.Error {
}
}
+func fakeNetFactory() stack.NetworkProtocol {
+ return &fakeNetworkProtocol{}
+}
+
func TestNetworkReceive(t *testing.T) {
// Create a stack with the fake network protocol, one nic, and two
// addresses attached to it: 1 & 2.
- id, linkEP := channel.New(10, defaultMTU, "")
- s := stack.New([]string{"fakeNet"}, nil, stack.Options{})
- if err := s.CreateNIC(1, id); err != nil {
+ ep := channel.New(10, defaultMTU, "")
+ s := stack.New(stack.Options{
+ NetworkProtocols: []stack.NetworkProtocol{fakeNetFactory()},
+ })
+ if err := s.CreateNIC(1, ep); err != nil {
t.Fatal("CreateNIC failed:", err)
}
@@ -241,7 +251,7 @@ func TestNetworkReceive(t *testing.T) {
// Make sure packet with wrong address is not delivered.
buf[0] = 3
- linkEP.Inject(fakeNetNumber, buf.ToVectorisedView())
+ ep.Inject(fakeNetNumber, buf.ToVectorisedView())
if fakeNet.packetCount[1] != 0 {
t.Errorf("packetCount[1] = %d, want %d", fakeNet.packetCount[1], 0)
}
@@ -251,7 +261,7 @@ func TestNetworkReceive(t *testing.T) {
// Make sure packet is delivered to first endpoint.
buf[0] = 1
- linkEP.Inject(fakeNetNumber, buf.ToVectorisedView())
+ ep.Inject(fakeNetNumber, buf.ToVectorisedView())
if fakeNet.packetCount[1] != 1 {
t.Errorf("packetCount[1] = %d, want %d", fakeNet.packetCount[1], 1)
}
@@ -261,7 +271,7 @@ func TestNetworkReceive(t *testing.T) {
// Make sure packet is delivered to second endpoint.
buf[0] = 2
- linkEP.Inject(fakeNetNumber, buf.ToVectorisedView())
+ ep.Inject(fakeNetNumber, buf.ToVectorisedView())
if fakeNet.packetCount[1] != 1 {
t.Errorf("packetCount[1] = %d, want %d", fakeNet.packetCount[1], 1)
}
@@ -270,7 +280,7 @@ func TestNetworkReceive(t *testing.T) {
}
// Make sure packet is not delivered if protocol number is wrong.
- linkEP.Inject(fakeNetNumber-1, buf.ToVectorisedView())
+ ep.Inject(fakeNetNumber-1, buf.ToVectorisedView())
if fakeNet.packetCount[1] != 1 {
t.Errorf("packetCount[1] = %d, want %d", fakeNet.packetCount[1], 1)
}
@@ -280,7 +290,7 @@ func TestNetworkReceive(t *testing.T) {
// Make sure packet that is too small is dropped.
buf.CapLength(2)
- linkEP.Inject(fakeNetNumber, buf.ToVectorisedView())
+ ep.Inject(fakeNetNumber, buf.ToVectorisedView())
if fakeNet.packetCount[1] != 1 {
t.Errorf("packetCount[1] = %d, want %d", fakeNet.packetCount[1], 1)
}
@@ -289,16 +299,75 @@ func TestNetworkReceive(t *testing.T) {
}
}
-func sendTo(t *testing.T, s *stack.Stack, addr tcpip.Address, payload buffer.View) {
+func sendTo(s *stack.Stack, addr tcpip.Address, payload buffer.View) *tcpip.Error {
r, err := s.FindRoute(0, "", addr, fakeNetNumber, false /* multicastLoop */)
if err != nil {
- t.Fatal("FindRoute failed:", err)
+ return err
}
defer r.Release()
+ return send(r, payload)
+}
+func send(r stack.Route, payload buffer.View) *tcpip.Error {
hdr := buffer.NewPrependable(int(r.MaxHeaderLength()))
- if err := r.WritePacket(nil /* gso */, hdr, payload.ToVectorisedView(), fakeTransNumber, 123); err != nil {
- t.Error("WritePacket failed:", err)
+ return r.WritePacket(nil /* gso */, hdr, payload.ToVectorisedView(), fakeTransNumber, 123)
+}
+
+func testSendTo(t *testing.T, s *stack.Stack, addr tcpip.Address, ep *channel.Endpoint, payload buffer.View) {
+ t.Helper()
+ ep.Drain()
+ if err := sendTo(s, addr, payload); err != nil {
+ t.Error("sendTo failed:", err)
+ }
+ if got, want := ep.Drain(), 1; got != want {
+ t.Errorf("sendTo packet count: got = %d, want %d", got, want)
+ }
+}
+
+func testSend(t *testing.T, r stack.Route, ep *channel.Endpoint, payload buffer.View) {
+ t.Helper()
+ ep.Drain()
+ if err := send(r, payload); err != nil {
+ t.Error("send failed:", err)
+ }
+ if got, want := ep.Drain(), 1; got != want {
+ t.Errorf("send packet count: got = %d, want %d", got, want)
+ }
+}
+
+func testFailingSend(t *testing.T, r stack.Route, ep *channel.Endpoint, payload buffer.View, wantErr *tcpip.Error) {
+ t.Helper()
+ if gotErr := send(r, payload); gotErr != wantErr {
+ t.Errorf("send failed: got = %s, want = %s ", gotErr, wantErr)
+ }
+}
+
+func testFailingSendTo(t *testing.T, s *stack.Stack, addr tcpip.Address, ep *channel.Endpoint, payload buffer.View, wantErr *tcpip.Error) {
+ t.Helper()
+ if gotErr := sendTo(s, addr, payload); gotErr != wantErr {
+ t.Errorf("sendto failed: got = %s, want = %s ", gotErr, wantErr)
+ }
+}
+
+func testRecv(t *testing.T, fakeNet *fakeNetworkProtocol, localAddrByte byte, ep *channel.Endpoint, buf buffer.View) {
+ t.Helper()
+ // testRecvInternal injects one packet, and we expect to receive it.
+ want := fakeNet.PacketCount(localAddrByte) + 1
+ testRecvInternal(t, fakeNet, localAddrByte, ep, buf, want)
+}
+
+func testFailingRecv(t *testing.T, fakeNet *fakeNetworkProtocol, localAddrByte byte, ep *channel.Endpoint, buf buffer.View) {
+ t.Helper()
+ // testRecvInternal injects one packet, and we do NOT expect to receive it.
+ want := fakeNet.PacketCount(localAddrByte)
+ testRecvInternal(t, fakeNet, localAddrByte, ep, buf, want)
+}
+
+func testRecvInternal(t *testing.T, fakeNet *fakeNetworkProtocol, localAddrByte byte, ep *channel.Endpoint, buf buffer.View, want int) {
+ t.Helper()
+ ep.Inject(fakeNetNumber, buf.ToVectorisedView())
+ if got := fakeNet.PacketCount(localAddrByte); got != want {
+ t.Errorf("receive packet count: got = %d, want %d", got, want)
}
}
@@ -306,9 +375,11 @@ func TestNetworkSend(t *testing.T) {
// Create a stack with the fake network protocol, one nic, and one
// address: 1. The route table sends all packets through the only
// existing nic.
- id, linkEP := channel.New(10, defaultMTU, "")
- s := stack.New([]string{"fakeNet"}, nil, stack.Options{})
- if err := s.CreateNIC(1, id); err != nil {
+ ep := channel.New(10, defaultMTU, "")
+ s := stack.New(stack.Options{
+ NetworkProtocols: []stack.NetworkProtocol{fakeNetFactory()},
+ })
+ if err := s.CreateNIC(1, ep); err != nil {
t.Fatal("NewNIC failed:", err)
}
@@ -325,20 +396,19 @@ func TestNetworkSend(t *testing.T) {
}
// Make sure that the link-layer endpoint received the outbound packet.
- sendTo(t, s, "\x03", nil)
- if c := linkEP.Drain(); c != 1 {
- t.Errorf("packetCount = %d, want %d", c, 1)
- }
+ testSendTo(t, s, "\x03", ep, nil)
}
func TestNetworkSendMultiRoute(t *testing.T) {
// Create a stack with the fake network protocol, two nics, and two
// addresses per nic, the first nic has odd address, the second one has
// even addresses.
- s := stack.New([]string{"fakeNet"}, nil, stack.Options{})
+ s := stack.New(stack.Options{
+ NetworkProtocols: []stack.NetworkProtocol{fakeNetFactory()},
+ })
- id1, linkEP1 := channel.New(10, defaultMTU, "")
- if err := s.CreateNIC(1, id1); err != nil {
+ ep1 := channel.New(10, defaultMTU, "")
+ if err := s.CreateNIC(1, ep1); err != nil {
t.Fatal("CreateNIC failed:", err)
}
@@ -350,8 +420,8 @@ func TestNetworkSendMultiRoute(t *testing.T) {
t.Fatal("AddAddress failed:", err)
}
- id2, linkEP2 := channel.New(10, defaultMTU, "")
- if err := s.CreateNIC(2, id2); err != nil {
+ ep2 := channel.New(10, defaultMTU, "")
+ if err := s.CreateNIC(2, ep2); err != nil {
t.Fatal("CreateNIC failed:", err)
}
@@ -382,18 +452,10 @@ func TestNetworkSendMultiRoute(t *testing.T) {
}
// Send a packet to an odd destination.
- sendTo(t, s, "\x05", nil)
-
- if c := linkEP1.Drain(); c != 1 {
- t.Errorf("packetCount = %d, want %d", c, 1)
- }
+ testSendTo(t, s, "\x05", ep1, nil)
// Send a packet to an even destination.
- sendTo(t, s, "\x06", nil)
-
- if c := linkEP2.Drain(); c != 1 {
- t.Errorf("packetCount = %d, want %d", c, 1)
- }
+ testSendTo(t, s, "\x06", ep2, nil)
}
func testRoute(t *testing.T, s *stack.Stack, nic tcpip.NICID, srcAddr, dstAddr, expectedSrcAddr tcpip.Address) {
@@ -424,10 +486,12 @@ func TestRoutes(t *testing.T) {
// Create a stack with the fake network protocol, two nics, and two
// addresses per nic, the first nic has odd address, the second one has
// even addresses.
- s := stack.New([]string{"fakeNet"}, nil, stack.Options{})
+ s := stack.New(stack.Options{
+ NetworkProtocols: []stack.NetworkProtocol{fakeNetFactory()},
+ })
- id1, _ := channel.New(10, defaultMTU, "")
- if err := s.CreateNIC(1, id1); err != nil {
+ ep1 := channel.New(10, defaultMTU, "")
+ if err := s.CreateNIC(1, ep1); err != nil {
t.Fatal("CreateNIC failed:", err)
}
@@ -439,8 +503,8 @@ func TestRoutes(t *testing.T) {
t.Fatal("AddAddress failed:", err)
}
- id2, _ := channel.New(10, defaultMTU, "")
- if err := s.CreateNIC(2, id2); err != nil {
+ ep2 := channel.New(10, defaultMTU, "")
+ if err := s.CreateNIC(2, ep2); err != nil {
t.Fatal("CreateNIC failed:", err)
}
@@ -498,58 +562,71 @@ func TestRoutes(t *testing.T) {
}
func TestAddressRemoval(t *testing.T) {
- s := stack.New([]string{"fakeNet"}, nil, stack.Options{})
+ const localAddrByte byte = 0x01
+ localAddr := tcpip.Address([]byte{localAddrByte})
+ remoteAddr := tcpip.Address("\x02")
+
+ s := stack.New(stack.Options{
+ NetworkProtocols: []stack.NetworkProtocol{fakeNetFactory()},
+ })
- id, linkEP := channel.New(10, defaultMTU, "")
- if err := s.CreateNIC(1, id); err != nil {
+ ep := channel.New(10, defaultMTU, "")
+ if err := s.CreateNIC(1, ep); err != nil {
t.Fatal("CreateNIC failed:", err)
}
- if err := s.AddAddress(1, fakeNetNumber, "\x01"); err != nil {
+ if err := s.AddAddress(1, fakeNetNumber, localAddr); err != nil {
t.Fatal("AddAddress failed:", err)
}
+ {
+ subnet, err := tcpip.NewSubnet("\x00", "\x00")
+ if err != nil {
+ t.Fatal(err)
+ }
+ s.SetRouteTable([]tcpip.Route{{Destination: subnet, Gateway: "\x00", NIC: 1}})
+ }
fakeNet := s.NetworkProtocolInstance(fakeNetNumber).(*fakeNetworkProtocol)
buf := buffer.NewView(30)
- // Write a packet, and check that it gets delivered.
- fakeNet.packetCount[1] = 0
- buf[0] = 1
- linkEP.Inject(fakeNetNumber, buf.ToVectorisedView())
- if fakeNet.packetCount[1] != 1 {
- t.Errorf("packetCount[1] = %d, want %d", fakeNet.packetCount[1], 1)
- }
+ // Send and receive packets, and verify they are received.
+ buf[0] = localAddrByte
+ testRecv(t, fakeNet, localAddrByte, ep, buf)
+ testSendTo(t, s, remoteAddr, ep, nil)
- // Remove the address, then check that packet doesn't get delivered
- // anymore.
- if err := s.RemoveAddress(1, "\x01"); err != nil {
+ // Remove the address, then check that send/receive doesn't work anymore.
+ if err := s.RemoveAddress(1, localAddr); err != nil {
t.Fatal("RemoveAddress failed:", err)
}
-
- linkEP.Inject(fakeNetNumber, buf.ToVectorisedView())
- if fakeNet.packetCount[1] != 1 {
- t.Errorf("packetCount[1] = %d, want %d", fakeNet.packetCount[1], 1)
- }
+ testFailingRecv(t, fakeNet, localAddrByte, ep, buf)
+ testFailingSendTo(t, s, remoteAddr, ep, nil, tcpip.ErrNoRoute)
// Check that removing the same address fails.
- if err := s.RemoveAddress(1, "\x01"); err != tcpip.ErrBadLocalAddress {
+ if err := s.RemoveAddress(1, localAddr); err != tcpip.ErrBadLocalAddress {
t.Fatalf("RemoveAddress returned unexpected error, got = %v, want = %s", err, tcpip.ErrBadLocalAddress)
}
}
-func TestDelayedRemovalDueToRoute(t *testing.T) {
- s := stack.New([]string{"fakeNet"}, nil, stack.Options{})
+func TestAddressRemovalWithRouteHeld(t *testing.T) {
+ const localAddrByte byte = 0x01
+ localAddr := tcpip.Address([]byte{localAddrByte})
+ remoteAddr := tcpip.Address("\x02")
- id, linkEP := channel.New(10, defaultMTU, "")
- if err := s.CreateNIC(1, id); err != nil {
- t.Fatal("CreateNIC failed:", err)
+ s := stack.New(stack.Options{
+ NetworkProtocols: []stack.NetworkProtocol{fakeNetFactory()},
+ })
+
+ ep := channel.New(10, defaultMTU, "")
+ if err := s.CreateNIC(1, ep); err != nil {
+ t.Fatalf("CreateNIC failed: %v", err)
}
+ fakeNet := s.NetworkProtocolInstance(fakeNetNumber).(*fakeNetworkProtocol)
+ buf := buffer.NewView(30)
- if err := s.AddAddress(1, fakeNetNumber, "\x01"); err != nil {
+ if err := s.AddAddress(1, fakeNetNumber, localAddr); err != nil {
t.Fatal("AddAddress failed:", err)
}
-
{
subnet, err := tcpip.NewSubnet("\x00", "\x00")
if err != nil {
@@ -558,58 +635,239 @@ func TestDelayedRemovalDueToRoute(t *testing.T) {
s.SetRouteTable([]tcpip.Route{{Destination: subnet, Gateway: "\x00", NIC: 1}})
}
- fakeNet := s.NetworkProtocolInstance(fakeNetNumber).(*fakeNetworkProtocol)
-
- buf := buffer.NewView(30)
-
- // Write a packet, and check that it gets delivered.
- fakeNet.packetCount[1] = 0
- buf[0] = 1
- linkEP.Inject(fakeNetNumber, buf.ToVectorisedView())
- if fakeNet.packetCount[1] != 1 {
- t.Errorf("packetCount[1] = %d, want %d", fakeNet.packetCount[1], 1)
- }
-
- // Get a route, check that packet is still deliverable.
- r, err := s.FindRoute(0, "", "\x02", fakeNetNumber, false /* multicastLoop */)
+ r, err := s.FindRoute(0, "", remoteAddr, fakeNetNumber, false /* multicastLoop */)
if err != nil {
t.Fatal("FindRoute failed:", err)
}
- linkEP.Inject(fakeNetNumber, buf.ToVectorisedView())
- if fakeNet.packetCount[1] != 2 {
- t.Errorf("packetCount[1] = %d, want %d", fakeNet.packetCount[1], 2)
- }
+ // Send and receive packets, and verify they are received.
+ buf[0] = localAddrByte
+ testRecv(t, fakeNet, localAddrByte, ep, buf)
+ testSend(t, r, ep, nil)
+ testSendTo(t, s, remoteAddr, ep, nil)
- // Remove the address, then check that packet is still deliverable
- // because the route is keeping the address alive.
- if err := s.RemoveAddress(1, "\x01"); err != nil {
+ // Remove the address, then check that send/receive doesn't work anymore.
+ if err := s.RemoveAddress(1, localAddr); err != nil {
t.Fatal("RemoveAddress failed:", err)
}
-
- linkEP.Inject(fakeNetNumber, buf.ToVectorisedView())
- if fakeNet.packetCount[1] != 3 {
- t.Errorf("packetCount[1] = %d, want %d", fakeNet.packetCount[1], 3)
- }
+ testFailingRecv(t, fakeNet, localAddrByte, ep, buf)
+ testFailingSend(t, r, ep, nil, tcpip.ErrInvalidEndpointState)
+ testFailingSendTo(t, s, remoteAddr, ep, nil, tcpip.ErrNoRoute)
// Check that removing the same address fails.
- if err := s.RemoveAddress(1, "\x01"); err != tcpip.ErrBadLocalAddress {
+ if err := s.RemoveAddress(1, localAddr); err != tcpip.ErrBadLocalAddress {
t.Fatalf("RemoveAddress returned unexpected error, got = %v, want = %s", err, tcpip.ErrBadLocalAddress)
}
+}
+
+func verifyAddress(t *testing.T, s *stack.Stack, nicid tcpip.NICID, addr tcpip.Address) {
+ t.Helper()
+ info, ok := s.NICInfo()[nicid]
+ if !ok {
+ t.Fatalf("NICInfo() failed to find nicid=%d", nicid)
+ }
+ if len(addr) == 0 {
+ // No address given, verify that there is no address assigned to the NIC.
+ for _, a := range info.ProtocolAddresses {
+ if a.Protocol == fakeNetNumber && a.AddressWithPrefix != (tcpip.AddressWithPrefix{}) {
+ t.Errorf("verify no-address: got = %s, want = %s", a.AddressWithPrefix, (tcpip.AddressWithPrefix{}))
+ }
+ }
+ return
+ }
+ // Address given, verify the address is assigned to the NIC and no other
+ // address is.
+ found := false
+ for _, a := range info.ProtocolAddresses {
+ if a.Protocol == fakeNetNumber {
+ if a.AddressWithPrefix.Address == addr {
+ found = true
+ } else {
+ t.Errorf("verify address: got = %s, want = %s", a.AddressWithPrefix.Address, addr)
+ }
+ }
+ }
+ if !found {
+ t.Errorf("verify address: couldn't find %s on the NIC", addr)
+ }
+}
+
+func TestEndpointExpiration(t *testing.T) {
+ const (
+ localAddrByte byte = 0x01
+ remoteAddr tcpip.Address = "\x03"
+ noAddr tcpip.Address = ""
+ nicid tcpip.NICID = 1
+ )
+ localAddr := tcpip.Address([]byte{localAddrByte})
+
+ for _, promiscuous := range []bool{true, false} {
+ for _, spoofing := range []bool{true, false} {
+ t.Run(fmt.Sprintf("promiscuous=%t spoofing=%t", promiscuous, spoofing), func(t *testing.T) {
+ s := stack.New(stack.Options{
+ NetworkProtocols: []stack.NetworkProtocol{fakeNetFactory()},
+ })
- // Release the route, then check that packet is not deliverable anymore.
- r.Release()
- linkEP.Inject(fakeNetNumber, buf.ToVectorisedView())
- if fakeNet.packetCount[1] != 3 {
- t.Errorf("packetCount[1] = %d, want %d", fakeNet.packetCount[1], 3)
+ ep := channel.New(10, defaultMTU, "")
+ if err := s.CreateNIC(nicid, ep); err != nil {
+ t.Fatal("CreateNIC failed:", err)
+ }
+
+ {
+ subnet, err := tcpip.NewSubnet("\x00", "\x00")
+ if err != nil {
+ t.Fatal(err)
+ }
+ s.SetRouteTable([]tcpip.Route{{Destination: subnet, Gateway: "\x00", NIC: 1}})
+ }
+
+ fakeNet := s.NetworkProtocolInstance(fakeNetNumber).(*fakeNetworkProtocol)
+ buf := buffer.NewView(30)
+ buf[0] = localAddrByte
+
+ if promiscuous {
+ if err := s.SetPromiscuousMode(nicid, true); err != nil {
+ t.Fatal("SetPromiscuousMode failed:", err)
+ }
+ }
+
+ if spoofing {
+ if err := s.SetSpoofing(nicid, true); err != nil {
+ t.Fatal("SetSpoofing failed:", err)
+ }
+ }
+
+ // 1. No Address yet, send should only work for spoofing, receive for
+ // promiscuous mode.
+ //-----------------------
+ verifyAddress(t, s, nicid, noAddr)
+ if promiscuous {
+ testRecv(t, fakeNet, localAddrByte, ep, buf)
+ } else {
+ testFailingRecv(t, fakeNet, localAddrByte, ep, buf)
+ }
+ if spoofing {
+ // FIXME(b/139841518):Spoofing doesn't work if there is no primary address.
+ // testSendTo(t, s, remoteAddr, ep, nil)
+ } else {
+ testFailingSendTo(t, s, remoteAddr, ep, nil, tcpip.ErrNoRoute)
+ }
+
+ // 2. Add Address, everything should work.
+ //-----------------------
+ if err := s.AddAddress(nicid, fakeNetNumber, localAddr); err != nil {
+ t.Fatal("AddAddress failed:", err)
+ }
+ verifyAddress(t, s, nicid, localAddr)
+ testRecv(t, fakeNet, localAddrByte, ep, buf)
+ testSendTo(t, s, remoteAddr, ep, nil)
+
+ // 3. Remove the address, send should only work for spoofing, receive
+ // for promiscuous mode.
+ //-----------------------
+ if err := s.RemoveAddress(nicid, localAddr); err != nil {
+ t.Fatal("RemoveAddress failed:", err)
+ }
+ verifyAddress(t, s, nicid, noAddr)
+ if promiscuous {
+ testRecv(t, fakeNet, localAddrByte, ep, buf)
+ } else {
+ testFailingRecv(t, fakeNet, localAddrByte, ep, buf)
+ }
+ if spoofing {
+ // FIXME(b/139841518):Spoofing doesn't work if there is no primary address.
+ // testSendTo(t, s, remoteAddr, ep, nil)
+ } else {
+ testFailingSendTo(t, s, remoteAddr, ep, nil, tcpip.ErrNoRoute)
+ }
+
+ // 4. Add Address back, everything should work again.
+ //-----------------------
+ if err := s.AddAddress(nicid, fakeNetNumber, localAddr); err != nil {
+ t.Fatal("AddAddress failed:", err)
+ }
+ verifyAddress(t, s, nicid, localAddr)
+ testRecv(t, fakeNet, localAddrByte, ep, buf)
+ testSendTo(t, s, remoteAddr, ep, nil)
+
+ // 5. Take a reference to the endpoint by getting a route. Verify that
+ // we can still send/receive, including sending using the route.
+ //-----------------------
+ r, err := s.FindRoute(0, "", remoteAddr, fakeNetNumber, false /* multicastLoop */)
+ if err != nil {
+ t.Fatal("FindRoute failed:", err)
+ }
+ testRecv(t, fakeNet, localAddrByte, ep, buf)
+ testSendTo(t, s, remoteAddr, ep, nil)
+ testSend(t, r, ep, nil)
+
+ // 6. Remove the address. Send should only work for spoofing, receive
+ // for promiscuous mode.
+ //-----------------------
+ if err := s.RemoveAddress(nicid, localAddr); err != nil {
+ t.Fatal("RemoveAddress failed:", err)
+ }
+ verifyAddress(t, s, nicid, noAddr)
+ if promiscuous {
+ testRecv(t, fakeNet, localAddrByte, ep, buf)
+ } else {
+ testFailingRecv(t, fakeNet, localAddrByte, ep, buf)
+ }
+ if spoofing {
+ testSend(t, r, ep, nil)
+ testSendTo(t, s, remoteAddr, ep, nil)
+ } else {
+ testFailingSend(t, r, ep, nil, tcpip.ErrInvalidEndpointState)
+ testFailingSendTo(t, s, remoteAddr, ep, nil, tcpip.ErrNoRoute)
+ }
+
+ // 7. Add Address back, everything should work again.
+ //-----------------------
+ if err := s.AddAddress(nicid, fakeNetNumber, localAddr); err != nil {
+ t.Fatal("AddAddress failed:", err)
+ }
+ verifyAddress(t, s, nicid, localAddr)
+ testRecv(t, fakeNet, localAddrByte, ep, buf)
+ testSendTo(t, s, remoteAddr, ep, nil)
+ testSend(t, r, ep, nil)
+
+ // 8. Remove the route, sendTo/recv should still work.
+ //-----------------------
+ r.Release()
+ verifyAddress(t, s, nicid, localAddr)
+ testRecv(t, fakeNet, localAddrByte, ep, buf)
+ testSendTo(t, s, remoteAddr, ep, nil)
+
+ // 9. Remove the address. Send should only work for spoofing, receive
+ // for promiscuous mode.
+ //-----------------------
+ if err := s.RemoveAddress(nicid, localAddr); err != nil {
+ t.Fatal("RemoveAddress failed:", err)
+ }
+ verifyAddress(t, s, nicid, noAddr)
+ if promiscuous {
+ testRecv(t, fakeNet, localAddrByte, ep, buf)
+ } else {
+ testFailingRecv(t, fakeNet, localAddrByte, ep, buf)
+ }
+ if spoofing {
+ // FIXME(b/139841518):Spoofing doesn't work if there is no primary address.
+ // testSendTo(t, s, remoteAddr, ep, nil)
+ } else {
+ testFailingSendTo(t, s, remoteAddr, ep, nil, tcpip.ErrNoRoute)
+ }
+ })
+ }
}
}
func TestPromiscuousMode(t *testing.T) {
- s := stack.New([]string{"fakeNet"}, nil, stack.Options{})
+ s := stack.New(stack.Options{
+ NetworkProtocols: []stack.NetworkProtocol{fakeNetFactory()},
+ })
- id, linkEP := channel.New(10, defaultMTU, "")
- if err := s.CreateNIC(1, id); err != nil {
+ ep := channel.New(10, defaultMTU, "")
+ if err := s.CreateNIC(1, ep); err != nil {
t.Fatal("CreateNIC failed:", err)
}
@@ -627,22 +885,15 @@ func TestPromiscuousMode(t *testing.T) {
// Write a packet, and check that it doesn't get delivered as we don't
// have a matching endpoint.
- fakeNet.packetCount[1] = 0
- buf[0] = 1
- linkEP.Inject(fakeNetNumber, buf.ToVectorisedView())
- if fakeNet.packetCount[1] != 0 {
- t.Errorf("packetCount[1] = %d, want %d", fakeNet.packetCount[1], 0)
- }
+ const localAddrByte byte = 0x01
+ buf[0] = localAddrByte
+ testFailingRecv(t, fakeNet, localAddrByte, ep, buf)
// Set promiscuous mode, then check that packet is delivered.
if err := s.SetPromiscuousMode(1, true); err != nil {
t.Fatal("SetPromiscuousMode failed:", err)
}
-
- linkEP.Inject(fakeNetNumber, buf.ToVectorisedView())
- if fakeNet.packetCount[1] != 1 {
- t.Errorf("packetCount[1] = %d, want %d", fakeNet.packetCount[1], 1)
- }
+ testRecv(t, fakeNet, localAddrByte, ep, buf)
// Check that we can't get a route as there is no local address.
_, err := s.FindRoute(0, "", "\x02", fakeNetNumber, false /* multicastLoop */)
@@ -655,25 +906,24 @@ func TestPromiscuousMode(t *testing.T) {
if err := s.SetPromiscuousMode(1, false); err != nil {
t.Fatal("SetPromiscuousMode failed:", err)
}
-
- linkEP.Inject(fakeNetNumber, buf.ToVectorisedView())
- if fakeNet.packetCount[1] != 1 {
- t.Errorf("packetCount[1] = %d, want %d", fakeNet.packetCount[1], 1)
- }
+ testFailingRecv(t, fakeNet, localAddrByte, ep, buf)
}
-func TestAddressSpoofing(t *testing.T) {
- srcAddr := tcpip.Address("\x01")
- dstAddr := tcpip.Address("\x02")
+func TestSpoofingWithAddress(t *testing.T) {
+ localAddr := tcpip.Address("\x01")
+ nonExistentLocalAddr := tcpip.Address("\x02")
+ dstAddr := tcpip.Address("\x03")
- s := stack.New([]string{"fakeNet"}, nil, stack.Options{})
+ s := stack.New(stack.Options{
+ NetworkProtocols: []stack.NetworkProtocol{fakeNetFactory()},
+ })
- id, _ := channel.New(10, defaultMTU, "")
- if err := s.CreateNIC(1, id); err != nil {
+ ep := channel.New(10, defaultMTU, "")
+ if err := s.CreateNIC(1, ep); err != nil {
t.Fatal("CreateNIC failed:", err)
}
- if err := s.AddAddress(1, fakeNetNumber, dstAddr); err != nil {
+ if err := s.AddAddress(1, fakeNetNumber, localAddr); err != nil {
t.Fatal("AddAddress failed:", err)
}
@@ -687,7 +937,7 @@ func TestAddressSpoofing(t *testing.T) {
// With address spoofing disabled, FindRoute does not permit an address
// that was not added to the NIC to be used as the source.
- r, err := s.FindRoute(0, srcAddr, dstAddr, fakeNetNumber, false /* multicastLoop */)
+ r, err := s.FindRoute(0, nonExistentLocalAddr, dstAddr, fakeNetNumber, false /* multicastLoop */)
if err == nil {
t.Errorf("FindRoute succeeded with route %+v when it should have failed", r)
}
@@ -697,23 +947,92 @@ func TestAddressSpoofing(t *testing.T) {
if err := s.SetSpoofing(1, true); err != nil {
t.Fatal("SetSpoofing failed:", err)
}
- r, err = s.FindRoute(0, srcAddr, dstAddr, fakeNetNumber, false /* multicastLoop */)
+ r, err = s.FindRoute(0, nonExistentLocalAddr, dstAddr, fakeNetNumber, false /* multicastLoop */)
+ if err != nil {
+ t.Fatal("FindRoute failed:", err)
+ }
+ if r.LocalAddress != nonExistentLocalAddr {
+ t.Errorf("Route has wrong local address: got %s, want %s", r.LocalAddress, nonExistentLocalAddr)
+ }
+ if r.RemoteAddress != dstAddr {
+ t.Errorf("Route has wrong remote address: got %s, want %s", r.RemoteAddress, dstAddr)
+ }
+ // Sending a packet works.
+ testSendTo(t, s, dstAddr, ep, nil)
+ testSend(t, r, ep, nil)
+
+ // FindRoute should also work with a local address that exists on the NIC.
+ r, err = s.FindRoute(0, localAddr, dstAddr, fakeNetNumber, false /* multicastLoop */)
if err != nil {
t.Fatal("FindRoute failed:", err)
}
- if r.LocalAddress != srcAddr {
- t.Errorf("Route has wrong local address: got %v, wanted %v", r.LocalAddress, srcAddr)
+ if r.LocalAddress != localAddr {
+ t.Errorf("Route has wrong local address: got %s, want %s", r.LocalAddress, nonExistentLocalAddr)
}
if r.RemoteAddress != dstAddr {
- t.Errorf("Route has wrong remote address: got %v, wanted %v", r.RemoteAddress, dstAddr)
+ t.Errorf("Route has wrong remote address: got %s, want %s", r.RemoteAddress, dstAddr)
}
+ // Sending a packet using the route works.
+ testSend(t, r, ep, nil)
+}
+
+func TestSpoofingNoAddress(t *testing.T) {
+ nonExistentLocalAddr := tcpip.Address("\x01")
+ dstAddr := tcpip.Address("\x02")
+
+ s := stack.New(stack.Options{
+ NetworkProtocols: []stack.NetworkProtocol{fakeNetFactory()},
+ })
+
+ ep := channel.New(10, defaultMTU, "")
+ if err := s.CreateNIC(1, ep); err != nil {
+ t.Fatal("CreateNIC failed:", err)
+ }
+
+ {
+ subnet, err := tcpip.NewSubnet("\x00", "\x00")
+ if err != nil {
+ t.Fatal(err)
+ }
+ s.SetRouteTable([]tcpip.Route{{Destination: subnet, Gateway: "\x00", NIC: 1}})
+ }
+
+ // With address spoofing disabled, FindRoute does not permit an address
+ // that was not added to the NIC to be used as the source.
+ r, err := s.FindRoute(0, nonExistentLocalAddr, dstAddr, fakeNetNumber, false /* multicastLoop */)
+ if err == nil {
+ t.Errorf("FindRoute succeeded with route %+v when it should have failed", r)
+ }
+ // Sending a packet fails.
+ testFailingSendTo(t, s, dstAddr, ep, nil, tcpip.ErrNoRoute)
+
+ // With address spoofing enabled, FindRoute permits any address to be used
+ // as the source.
+ if err := s.SetSpoofing(1, true); err != nil {
+ t.Fatal("SetSpoofing failed:", err)
+ }
+ r, err = s.FindRoute(0, nonExistentLocalAddr, dstAddr, fakeNetNumber, false /* multicastLoop */)
+ if err != nil {
+ t.Fatal("FindRoute failed:", err)
+ }
+ if r.LocalAddress != nonExistentLocalAddr {
+ t.Errorf("Route has wrong local address: got %s, want %s", r.LocalAddress, nonExistentLocalAddr)
+ }
+ if r.RemoteAddress != dstAddr {
+ t.Errorf("Route has wrong remote address: got %s, want %s", r.RemoteAddress, dstAddr)
+ }
+ // Sending a packet works.
+ // FIXME(b/139841518):Spoofing doesn't work if there is no primary address.
+ // testSendTo(t, s, remoteAddr, ep, nil)
}
func TestBroadcastNeedsNoRoute(t *testing.T) {
- s := stack.New([]string{"fakeNet"}, nil, stack.Options{})
+ s := stack.New(stack.Options{
+ NetworkProtocols: []stack.NetworkProtocol{fakeNetFactory()},
+ })
- id, _ := channel.New(10, defaultMTU, "")
- if err := s.CreateNIC(1, id); err != nil {
+ ep := channel.New(10, defaultMTU, "")
+ if err := s.CreateNIC(1, ep); err != nil {
t.Fatal("CreateNIC failed:", err)
}
s.SetRouteTable([]tcpip.Route{})
@@ -781,10 +1100,12 @@ func TestMulticastOrIPv6LinkLocalNeedsNoRoute(t *testing.T) {
{"IPv6 Unicast Not Link-Local 7", true, "\xf0\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00"},
} {
t.Run(tc.name, func(t *testing.T) {
- s := stack.New([]string{"fakeNet"}, nil, stack.Options{})
+ s := stack.New(stack.Options{
+ NetworkProtocols: []stack.NetworkProtocol{fakeNetFactory()},
+ })
- id, _ := channel.New(10, defaultMTU, "")
- if err := s.CreateNIC(1, id); err != nil {
+ ep := channel.New(10, defaultMTU, "")
+ if err := s.CreateNIC(1, ep); err != nil {
t.Fatal("CreateNIC failed:", err)
}
@@ -835,12 +1156,14 @@ func TestMulticastOrIPv6LinkLocalNeedsNoRoute(t *testing.T) {
}
}
-// Set the subnet, then check that packet is delivered.
-func TestSubnetAcceptsMatchingPacket(t *testing.T) {
- s := stack.New([]string{"fakeNet"}, nil, stack.Options{})
+// Add a range of addresses, then check that a packet is delivered.
+func TestAddressRangeAcceptsMatchingPacket(t *testing.T) {
+ s := stack.New(stack.Options{
+ NetworkProtocols: []stack.NetworkProtocol{fakeNetFactory()},
+ })
- id, linkEP := channel.New(10, defaultMTU, "")
- if err := s.CreateNIC(1, id); err != nil {
+ ep := channel.New(10, defaultMTU, "")
+ if err := s.CreateNIC(1, ep); err != nil {
t.Fatal("CreateNIC failed:", err)
}
@@ -856,29 +1179,59 @@ func TestSubnetAcceptsMatchingPacket(t *testing.T) {
buf := buffer.NewView(30)
- buf[0] = 1
- fakeNet.packetCount[1] = 0
+ const localAddrByte byte = 0x01
+ buf[0] = localAddrByte
subnet, err := tcpip.NewSubnet(tcpip.Address("\x00"), tcpip.AddressMask("\xF0"))
if err != nil {
t.Fatal("NewSubnet failed:", err)
}
- if err := s.AddSubnet(1, fakeNetNumber, subnet); err != nil {
- t.Fatal("AddSubnet failed:", err)
+ if err := s.AddAddressRange(1, fakeNetNumber, subnet); err != nil {
+ t.Fatal("AddAddressRange failed:", err)
}
- linkEP.Inject(fakeNetNumber, buf.ToVectorisedView())
- if fakeNet.packetCount[1] != 1 {
- t.Errorf("packetCount[1] = %d, want %d", fakeNet.packetCount[1], 1)
+ testRecv(t, fakeNet, localAddrByte, ep, buf)
+}
+
+func testNicForAddressRange(t *testing.T, nicID tcpip.NICID, s *stack.Stack, subnet tcpip.Subnet, rangeExists bool) {
+ t.Helper()
+
+ // Loop over all addresses and check them.
+ numOfAddresses := 1 << uint(8-subnet.Prefix())
+ if numOfAddresses < 1 || numOfAddresses > 255 {
+ t.Fatalf("got numOfAddresses = %d, want = [1 .. 255] (subnet=%s)", numOfAddresses, subnet)
+ }
+
+ addrBytes := []byte(subnet.ID())
+ for i := 0; i < numOfAddresses; i++ {
+ addr := tcpip.Address(addrBytes)
+ wantNicID := nicID
+ // The subnet and broadcast addresses are skipped.
+ if !rangeExists || addr == subnet.ID() || addr == subnet.Broadcast() {
+ wantNicID = 0
+ }
+ if gotNicID := s.CheckLocalAddress(0, fakeNetNumber, addr); gotNicID != wantNicID {
+ t.Errorf("got CheckLocalAddress(0, %d, %s) = %d, want = %d", fakeNetNumber, addr, gotNicID, wantNicID)
+ }
+ addrBytes[0]++
+ }
+
+ // Trying the next address should always fail since it is outside the range.
+ if gotNicID := s.CheckLocalAddress(0, fakeNetNumber, tcpip.Address(addrBytes)); gotNicID != 0 {
+ t.Errorf("got CheckLocalAddress(0, %d, %s) = %d, want = %d", fakeNetNumber, tcpip.Address(addrBytes), gotNicID, 0)
}
}
-// Set the subnet, then check that CheckLocalAddress returns the correct NIC.
+// Set a range of addresses, then remove it again, and check at each step that
+// CheckLocalAddress returns the correct NIC for each address or zero if not
+// existent.
func TestCheckLocalAddressForSubnet(t *testing.T) {
const nicID tcpip.NICID = 1
- s := stack.New([]string{"fakeNet"}, nil, stack.Options{})
+ s := stack.New(stack.Options{
+ NetworkProtocols: []stack.NetworkProtocol{fakeNetFactory()},
+ })
- id, _ := channel.New(10, defaultMTU, "")
- if err := s.CreateNIC(nicID, id); err != nil {
+ ep := channel.New(10, defaultMTU, "")
+ if err := s.CreateNIC(nicID, ep); err != nil {
t.Fatal("CreateNIC failed:", err)
}
@@ -891,39 +1244,34 @@ func TestCheckLocalAddressForSubnet(t *testing.T) {
}
subnet, err := tcpip.NewSubnet(tcpip.Address("\xa0"), tcpip.AddressMask("\xf0"))
-
if err != nil {
t.Fatal("NewSubnet failed:", err)
}
- if err := s.AddSubnet(nicID, fakeNetNumber, subnet); err != nil {
- t.Fatal("AddSubnet failed:", err)
- }
- // Loop over all subnet addresses and check them.
- numOfAddresses := 1 << uint(8-subnet.Prefix())
- if numOfAddresses < 1 || numOfAddresses > 255 {
- t.Fatalf("got numOfAddresses = %d, want = [1 .. 255] (subnet=%s)", numOfAddresses, subnet)
- }
- addr := []byte(subnet.ID())
- for i := 0; i < numOfAddresses; i++ {
- if gotNicID := s.CheckLocalAddress(0, fakeNetNumber, tcpip.Address(addr)); gotNicID != nicID {
- t.Errorf("got CheckLocalAddress(0, %d, %s) = %d, want = %d", fakeNetNumber, tcpip.Address(addr), gotNicID, nicID)
- }
- addr[0]++
+ testNicForAddressRange(t, nicID, s, subnet, false /* rangeExists */)
+
+ if err := s.AddAddressRange(nicID, fakeNetNumber, subnet); err != nil {
+ t.Fatal("AddAddressRange failed:", err)
}
- // Trying the next address should fail since it is outside the subnet range.
- if gotNicID := s.CheckLocalAddress(0, fakeNetNumber, tcpip.Address(addr)); gotNicID != 0 {
- t.Errorf("got CheckLocalAddress(0, %d, %s) = %d, want = %d", fakeNetNumber, tcpip.Address(addr), gotNicID, 0)
+ testNicForAddressRange(t, nicID, s, subnet, true /* rangeExists */)
+
+ if err := s.RemoveAddressRange(nicID, subnet); err != nil {
+ t.Fatal("RemoveAddressRange failed:", err)
}
+
+ testNicForAddressRange(t, nicID, s, subnet, false /* rangeExists */)
}
-// Set destination outside the subnet, then check it doesn't get delivered.
-func TestSubnetRejectsNonmatchingPacket(t *testing.T) {
- s := stack.New([]string{"fakeNet"}, nil, stack.Options{})
+// Set a range of addresses, then send a packet to a destination outside the
+// range and then check it doesn't get delivered.
+func TestAddressRangeRejectsNonmatchingPacket(t *testing.T) {
+ s := stack.New(stack.Options{
+ NetworkProtocols: []stack.NetworkProtocol{fakeNetFactory()},
+ })
- id, linkEP := channel.New(10, defaultMTU, "")
- if err := s.CreateNIC(1, id); err != nil {
+ ep := channel.New(10, defaultMTU, "")
+ if err := s.CreateNIC(1, ep); err != nil {
t.Fatal("CreateNIC failed:", err)
}
@@ -939,23 +1287,23 @@ func TestSubnetRejectsNonmatchingPacket(t *testing.T) {
buf := buffer.NewView(30)
- buf[0] = 1
- fakeNet.packetCount[1] = 0
+ const localAddrByte byte = 0x01
+ buf[0] = localAddrByte
subnet, err := tcpip.NewSubnet(tcpip.Address("\x10"), tcpip.AddressMask("\xF0"))
if err != nil {
t.Fatal("NewSubnet failed:", err)
}
- if err := s.AddSubnet(1, fakeNetNumber, subnet); err != nil {
- t.Fatal("AddSubnet failed:", err)
- }
- linkEP.Inject(fakeNetNumber, buf.ToVectorisedView())
- if fakeNet.packetCount[1] != 0 {
- t.Errorf("packetCount[1] = %d, want %d", fakeNet.packetCount[1], 0)
+ if err := s.AddAddressRange(1, fakeNetNumber, subnet); err != nil {
+ t.Fatal("AddAddressRange failed:", err)
}
+ testFailingRecv(t, fakeNet, localAddrByte, ep, buf)
}
func TestNetworkOptions(t *testing.T) {
- s := stack.New([]string{"fakeNet"}, []string{}, stack.Options{})
+ s := stack.New(stack.Options{
+ NetworkProtocols: []stack.NetworkProtocol{fakeNetFactory()},
+ TransportProtocols: []stack.TransportProtocol{},
+ })
// Try an unsupported network protocol.
if err := s.SetNetworkProtocolOption(tcpip.NetworkProtocolNumber(99999), fakeNetGoodOption(false)); err != tcpip.ErrUnknownProtocol {
@@ -994,44 +1342,53 @@ func TestNetworkOptions(t *testing.T) {
}
}
-func TestSubnetAddRemove(t *testing.T) {
- s := stack.New([]string{"fakeNet"}, nil, stack.Options{})
- id, _ := channel.New(10, defaultMTU, "")
- if err := s.CreateNIC(1, id); err != nil {
+func stackContainsAddressRange(s *stack.Stack, id tcpip.NICID, addrRange tcpip.Subnet) bool {
+ ranges, ok := s.NICAddressRanges()[id]
+ if !ok {
+ return false
+ }
+ for _, r := range ranges {
+ if r == addrRange {
+ return true
+ }
+ }
+ return false
+}
+
+func TestAddresRangeAddRemove(t *testing.T) {
+ s := stack.New(stack.Options{
+ NetworkProtocols: []stack.NetworkProtocol{fakeNetFactory()},
+ })
+ ep := channel.New(10, defaultMTU, "")
+ if err := s.CreateNIC(1, ep); err != nil {
t.Fatal("CreateNIC failed:", err)
}
addr := tcpip.Address("\x01\x01\x01\x01")
mask := tcpip.AddressMask(strings.Repeat("\xff", len(addr)))
- subnet, err := tcpip.NewSubnet(addr, mask)
+ addrRange, err := tcpip.NewSubnet(addr, mask)
if err != nil {
t.Fatal("NewSubnet failed:", err)
}
- if contained, err := s.ContainsSubnet(1, subnet); err != nil {
- t.Fatal("ContainsSubnet failed:", err)
- } else if contained {
- t.Fatal("got s.ContainsSubnet(...) = true, want = false")
+ if got, want := stackContainsAddressRange(s, 1, addrRange), false; got != want {
+ t.Fatalf("got stackContainsAddressRange(...) = %t, want = %t", got, want)
}
- if err := s.AddSubnet(1, fakeNetNumber, subnet); err != nil {
- t.Fatal("AddSubnet failed:", err)
+ if err := s.AddAddressRange(1, fakeNetNumber, addrRange); err != nil {
+ t.Fatal("AddAddressRange failed:", err)
}
- if contained, err := s.ContainsSubnet(1, subnet); err != nil {
- t.Fatal("ContainsSubnet failed:", err)
- } else if !contained {
- t.Fatal("got s.ContainsSubnet(...) = false, want = true")
+ if got, want := stackContainsAddressRange(s, 1, addrRange), true; got != want {
+ t.Fatalf("got stackContainsAddressRange(...) = %t, want = %t", got, want)
}
- if err := s.RemoveSubnet(1, subnet); err != nil {
- t.Fatal("RemoveSubnet failed:", err)
+ if err := s.RemoveAddressRange(1, addrRange); err != nil {
+ t.Fatal("RemoveAddressRange failed:", err)
}
- if contained, err := s.ContainsSubnet(1, subnet); err != nil {
- t.Fatal("ContainsSubnet failed:", err)
- } else if contained {
- t.Fatal("got s.ContainsSubnet(...) = true, want = false")
+ if got, want := stackContainsAddressRange(s, 1, addrRange), false; got != want {
+ t.Fatalf("got stackContainsAddressRange(...) = %t, want = %t", got, want)
}
}
@@ -1042,9 +1399,11 @@ func TestGetMainNICAddressAddPrimaryNonPrimary(t *testing.T) {
t.Run(fmt.Sprintf("canBe=%d", canBe), func(t *testing.T) {
for never := 0; never < 3; never++ {
t.Run(fmt.Sprintf("never=%d", never), func(t *testing.T) {
- s := stack.New([]string{"fakeNet"}, nil, stack.Options{})
- id, _ := channel.New(10, defaultMTU, "")
- if err := s.CreateNIC(1, id); err != nil {
+ s := stack.New(stack.Options{
+ NetworkProtocols: []stack.NetworkProtocol{fakeNetFactory()},
+ })
+ ep := channel.New(10, defaultMTU, "")
+ if err := s.CreateNIC(1, ep); err != nil {
t.Fatal("CreateNIC failed:", err)
}
// Insert <canBe> primary and <never> never-primary addresses.
@@ -1082,20 +1441,20 @@ func TestGetMainNICAddressAddPrimaryNonPrimary(t *testing.T) {
// Check that GetMainNICAddress returns an address if at least
// one primary address was added. In that case make sure the
// address/prefixLen matches what we added.
+ gotAddr, err := s.GetMainNICAddress(1, fakeNetNumber)
+ if err != nil {
+ t.Fatal("GetMainNICAddress failed:", err)
+ }
if len(primaryAddrAdded) == 0 {
- // No primary addresses present, expect an error.
- if _, err := s.GetMainNICAddress(1, fakeNetNumber); err != tcpip.ErrNoLinkAddress {
- t.Fatalf("got s.GetMainNICAddress(...) = %v, wanted = %s", err, tcpip.ErrNoLinkAddress)
+ // No primary addresses present.
+ if wantAddr := (tcpip.AddressWithPrefix{}); gotAddr != wantAddr {
+ t.Fatalf("GetMainNICAddress: got addr = %s, want = %s", gotAddr, wantAddr)
}
} else {
- // At least one primary address was added, expect a valid
- // address and prefixLen.
- gotAddressWithPefix, err := s.GetMainNICAddress(1, fakeNetNumber)
- if err != nil {
- t.Fatal("GetMainNICAddress failed:", err)
- }
- if _, ok := primaryAddrAdded[gotAddressWithPefix]; !ok {
- t.Fatalf("GetMainNICAddress: got addressWithPrefix = %v, wanted any in {%v}", gotAddressWithPefix, primaryAddrAdded)
+ // At least one primary address was added, verify the returned
+ // address is in the list of primary addresses we added.
+ if _, ok := primaryAddrAdded[gotAddr]; !ok {
+ t.Fatalf("GetMainNICAddress: got = %s, want any in {%v}", gotAddr, primaryAddrAdded)
}
}
})
@@ -1107,9 +1466,11 @@ func TestGetMainNICAddressAddPrimaryNonPrimary(t *testing.T) {
}
func TestGetMainNICAddressAddRemove(t *testing.T) {
- s := stack.New([]string{"fakeNet"}, nil, stack.Options{})
- id, _ := channel.New(10, defaultMTU, "")
- if err := s.CreateNIC(1, id); err != nil {
+ s := stack.New(stack.Options{
+ NetworkProtocols: []stack.NetworkProtocol{fakeNetFactory()},
+ })
+ ep := channel.New(10, defaultMTU, "")
+ if err := s.CreateNIC(1, ep); err != nil {
t.Fatal("CreateNIC failed:", err)
}
@@ -1134,19 +1495,25 @@ func TestGetMainNICAddressAddRemove(t *testing.T) {
}
// Check that we get the right initial address and prefix length.
- if gotAddressWithPrefix, err := s.GetMainNICAddress(1, fakeNetNumber); err != nil {
+ gotAddr, err := s.GetMainNICAddress(1, fakeNetNumber)
+ if err != nil {
t.Fatal("GetMainNICAddress failed:", err)
- } else if gotAddressWithPrefix != protocolAddress.AddressWithPrefix {
- t.Fatalf("got GetMainNICAddress = %+v, want = %+v", gotAddressWithPrefix, protocolAddress.AddressWithPrefix)
+ }
+ if wantAddr := protocolAddress.AddressWithPrefix; gotAddr != wantAddr {
+ t.Fatalf("got s.GetMainNICAddress(...) = %s, want = %s", gotAddr, wantAddr)
}
if err := s.RemoveAddress(1, protocolAddress.AddressWithPrefix.Address); err != nil {
t.Fatal("RemoveAddress failed:", err)
}
- // Check that we get an error after removal.
- if _, err := s.GetMainNICAddress(1, fakeNetNumber); err != tcpip.ErrNoLinkAddress {
- t.Fatalf("got s.GetMainNICAddress(...) = %v, want = %s", err, tcpip.ErrNoLinkAddress)
+ // Check that we get no address after removal.
+ gotAddr, err = s.GetMainNICAddress(1, fakeNetNumber)
+ if err != nil {
+ t.Fatal("GetMainNICAddress failed:", err)
+ }
+ if wantAddr := (tcpip.AddressWithPrefix{}); gotAddr != wantAddr {
+ t.Fatalf("got GetMainNICAddress(...) = %s, want = %s", gotAddr, wantAddr)
}
})
}
@@ -1161,8 +1528,10 @@ func (g *addressGenerator) next(addrLen int) tcpip.Address {
}
func verifyAddresses(t *testing.T, expectedAddresses, gotAddresses []tcpip.ProtocolAddress) {
+ t.Helper()
+
if len(gotAddresses) != len(expectedAddresses) {
- t.Fatalf("got len(addresses) = %d, wanted = %d", len(gotAddresses), len(expectedAddresses))
+ t.Fatalf("got len(addresses) = %d, want = %d", len(gotAddresses), len(expectedAddresses))
}
sort.Slice(gotAddresses, func(i, j int) bool {
@@ -1182,9 +1551,11 @@ func verifyAddresses(t *testing.T, expectedAddresses, gotAddresses []tcpip.Proto
func TestAddAddress(t *testing.T) {
const nicid = 1
- s := stack.New([]string{"fakeNet"}, nil, stack.Options{})
- id, _ := channel.New(10, defaultMTU, "")
- if err := s.CreateNIC(nicid, id); err != nil {
+ s := stack.New(stack.Options{
+ NetworkProtocols: []stack.NetworkProtocol{fakeNetFactory()},
+ })
+ ep := channel.New(10, defaultMTU, "")
+ if err := s.CreateNIC(nicid, ep); err != nil {
t.Fatal("CreateNIC failed:", err)
}
@@ -1201,15 +1572,17 @@ func TestAddAddress(t *testing.T) {
})
}
- gotAddresses := s.NICInfo()[nicid].ProtocolAddresses
+ gotAddresses := s.AllAddresses()[nicid]
verifyAddresses(t, expectedAddresses, gotAddresses)
}
func TestAddProtocolAddress(t *testing.T) {
const nicid = 1
- s := stack.New([]string{"fakeNet"}, nil, stack.Options{})
- id, _ := channel.New(10, defaultMTU, "")
- if err := s.CreateNIC(nicid, id); err != nil {
+ s := stack.New(stack.Options{
+ NetworkProtocols: []stack.NetworkProtocol{fakeNetFactory()},
+ })
+ ep := channel.New(10, defaultMTU, "")
+ if err := s.CreateNIC(nicid, ep); err != nil {
t.Fatal("CreateNIC failed:", err)
}
@@ -1233,15 +1606,17 @@ func TestAddProtocolAddress(t *testing.T) {
}
}
- gotAddresses := s.NICInfo()[nicid].ProtocolAddresses
+ gotAddresses := s.AllAddresses()[nicid]
verifyAddresses(t, expectedAddresses, gotAddresses)
}
func TestAddAddressWithOptions(t *testing.T) {
const nicid = 1
- s := stack.New([]string{"fakeNet"}, nil, stack.Options{})
- id, _ := channel.New(10, defaultMTU, "")
- if err := s.CreateNIC(nicid, id); err != nil {
+ s := stack.New(stack.Options{
+ NetworkProtocols: []stack.NetworkProtocol{fakeNetFactory()},
+ })
+ ep := channel.New(10, defaultMTU, "")
+ if err := s.CreateNIC(nicid, ep); err != nil {
t.Fatal("CreateNIC failed:", err)
}
@@ -1262,15 +1637,17 @@ func TestAddAddressWithOptions(t *testing.T) {
}
}
- gotAddresses := s.NICInfo()[nicid].ProtocolAddresses
+ gotAddresses := s.AllAddresses()[nicid]
verifyAddresses(t, expectedAddresses, gotAddresses)
}
func TestAddProtocolAddressWithOptions(t *testing.T) {
const nicid = 1
- s := stack.New([]string{"fakeNet"}, nil, stack.Options{})
- id, _ := channel.New(10, defaultMTU, "")
- if err := s.CreateNIC(nicid, id); err != nil {
+ s := stack.New(stack.Options{
+ NetworkProtocols: []stack.NetworkProtocol{fakeNetFactory()},
+ })
+ ep := channel.New(10, defaultMTU, "")
+ if err := s.CreateNIC(nicid, ep); err != nil {
t.Fatal("CreateNIC failed:", err)
}
@@ -1297,15 +1674,17 @@ func TestAddProtocolAddressWithOptions(t *testing.T) {
}
}
- gotAddresses := s.NICInfo()[nicid].ProtocolAddresses
+ gotAddresses := s.AllAddresses()[nicid]
verifyAddresses(t, expectedAddresses, gotAddresses)
}
func TestNICStats(t *testing.T) {
- s := stack.New([]string{"fakeNet"}, nil, stack.Options{})
- id1, linkEP1 := channel.New(10, defaultMTU, "")
- if err := s.CreateNIC(1, id1); err != nil {
- t.Fatal("CreateNIC failed:", err)
+ s := stack.New(stack.Options{
+ NetworkProtocols: []stack.NetworkProtocol{fakeNetFactory()},
+ })
+ ep1 := channel.New(10, defaultMTU, "")
+ if err := s.CreateNIC(1, ep1); err != nil {
+ t.Fatal("CreateNIC failed: ", err)
}
if err := s.AddAddress(1, fakeNetNumber, "\x01"); err != nil {
t.Fatal("AddAddress failed:", err)
@@ -1321,7 +1700,7 @@ func TestNICStats(t *testing.T) {
// Send a packet to address 1.
buf := buffer.NewView(30)
- linkEP1.Inject(fakeNetNumber, buf.ToVectorisedView())
+ ep1.Inject(fakeNetNumber, buf.ToVectorisedView())
if got, want := s.NICInfo()[1].Stats.Rx.Packets.Value(), uint64(1); got != want {
t.Errorf("got Rx.Packets.Value() = %d, want = %d", got, want)
}
@@ -1332,10 +1711,12 @@ func TestNICStats(t *testing.T) {
payload := buffer.NewView(10)
// Write a packet out via the address for NIC 1
- sendTo(t, s, "\x01", payload)
- want := uint64(linkEP1.Drain())
+ if err := sendTo(s, "\x01", payload); err != nil {
+ t.Fatal("sendTo failed: ", err)
+ }
+ want := uint64(ep1.Drain())
if got := s.NICInfo()[1].Stats.Tx.Packets.Value(); got != want {
- t.Errorf("got Tx.Packets.Value() = %d, linkEP1.Drain() = %d", got, want)
+ t.Errorf("got Tx.Packets.Value() = %d, ep1.Drain() = %d", got, want)
}
if got, want := s.NICInfo()[1].Stats.Tx.Bytes.Value(), uint64(len(payload)); got != want {
@@ -1346,19 +1727,21 @@ func TestNICStats(t *testing.T) {
func TestNICForwarding(t *testing.T) {
// Create a stack with the fake network protocol, two NICs, each with
// an address.
- s := stack.New([]string{"fakeNet"}, nil, stack.Options{})
+ s := stack.New(stack.Options{
+ NetworkProtocols: []stack.NetworkProtocol{fakeNetFactory()},
+ })
s.SetForwarding(true)
- id1, linkEP1 := channel.New(10, defaultMTU, "")
- if err := s.CreateNIC(1, id1); err != nil {
+ ep1 := channel.New(10, defaultMTU, "")
+ if err := s.CreateNIC(1, ep1); err != nil {
t.Fatal("CreateNIC #1 failed:", err)
}
if err := s.AddAddress(1, fakeNetNumber, "\x01"); err != nil {
t.Fatal("AddAddress #1 failed:", err)
}
- id2, linkEP2 := channel.New(10, defaultMTU, "")
- if err := s.CreateNIC(2, id2); err != nil {
+ ep2 := channel.New(10, defaultMTU, "")
+ if err := s.CreateNIC(2, ep2); err != nil {
t.Fatal("CreateNIC #2 failed:", err)
}
if err := s.AddAddress(2, fakeNetNumber, "\x02"); err != nil {
@@ -1377,10 +1760,10 @@ func TestNICForwarding(t *testing.T) {
// Send a packet to address 3.
buf := buffer.NewView(30)
buf[0] = 3
- linkEP1.Inject(fakeNetNumber, buf.ToVectorisedView())
+ ep1.Inject(fakeNetNumber, buf.ToVectorisedView())
select {
- case <-linkEP2.C:
+ case <-ep2.C:
default:
t.Fatal("Packet not forwarded")
}
@@ -1394,9 +1777,3 @@ func TestNICForwarding(t *testing.T) {
t.Errorf("got Tx.Bytes.Value() = %d, want = %d", got, want)
}
}
-
-func init() {
- stack.RegisterNetworkProtocolFactory("fakeNet", func() stack.NetworkProtocol {
- return &fakeNetworkProtocol{}
- })
-}
diff --git a/pkg/tcpip/stack/transport_demuxer.go b/pkg/tcpip/stack/transport_demuxer.go
index cf8a6d129..8c768c299 100644
--- a/pkg/tcpip/stack/transport_demuxer.go
+++ b/pkg/tcpip/stack/transport_demuxer.go
@@ -35,25 +35,109 @@ type protocolIDs struct {
type transportEndpoints struct {
// mu protects all fields of the transportEndpoints.
mu sync.RWMutex
- endpoints map[TransportEndpointID]TransportEndpoint
+ endpoints map[TransportEndpointID]*endpointsByNic
// rawEndpoints contains endpoints for raw sockets, which receive all
// traffic of a given protocol regardless of port.
rawEndpoints []RawTransportEndpoint
}
+type endpointsByNic struct {
+ mu sync.RWMutex
+ endpoints map[tcpip.NICID]*multiPortEndpoint
+ // seed is a random secret for a jenkins hash.
+ seed uint32
+}
+
+// HandlePacket is called by the stack when new packets arrive to this transport
+// endpoint.
+func (epsByNic *endpointsByNic) handlePacket(r *Route, id TransportEndpointID, vv buffer.VectorisedView) {
+ epsByNic.mu.RLock()
+
+ mpep, ok := epsByNic.endpoints[r.ref.nic.ID()]
+ if !ok {
+ if mpep, ok = epsByNic.endpoints[0]; !ok {
+ epsByNic.mu.RUnlock() // Don't use defer for performance reasons.
+ return
+ }
+ }
+
+ // If this is a broadcast or multicast datagram, deliver the datagram to all
+ // endpoints bound to the right device.
+ if id.LocalAddress == header.IPv4Broadcast || header.IsV4MulticastAddress(id.LocalAddress) || header.IsV6MulticastAddress(id.LocalAddress) {
+ mpep.handlePacketAll(r, id, vv)
+ epsByNic.mu.RUnlock() // Don't use defer for performance reasons.
+ return
+ }
+
+ // multiPortEndpoints are guaranteed to have at least one element.
+ selectEndpoint(id, mpep, epsByNic.seed).HandlePacket(r, id, vv)
+ epsByNic.mu.RUnlock() // Don't use defer for performance reasons.
+}
+
+// HandleControlPacket implements stack.TransportEndpoint.HandleControlPacket.
+func (epsByNic *endpointsByNic) handleControlPacket(n *NIC, id TransportEndpointID, typ ControlType, extra uint32, vv buffer.VectorisedView) {
+ epsByNic.mu.RLock()
+ defer epsByNic.mu.RUnlock()
+
+ mpep, ok := epsByNic.endpoints[n.ID()]
+ if !ok {
+ mpep, ok = epsByNic.endpoints[0]
+ }
+ if !ok {
+ return
+ }
+
+ // TODO(eyalsoha): Why don't we look at id to see if this packet needs to
+ // broadcast like we are doing with handlePacket above?
+
+ // multiPortEndpoints are guaranteed to have at least one element.
+ selectEndpoint(id, mpep, epsByNic.seed).HandleControlPacket(id, typ, extra, vv)
+}
+
+// registerEndpoint returns true if it succeeds. It fails and returns
+// false if ep already has an element with the same key.
+func (epsByNic *endpointsByNic) registerEndpoint(t TransportEndpoint, reusePort bool, bindToDevice tcpip.NICID) *tcpip.Error {
+ epsByNic.mu.Lock()
+ defer epsByNic.mu.Unlock()
+
+ if multiPortEp, ok := epsByNic.endpoints[bindToDevice]; ok {
+ // There was already a bind.
+ return multiPortEp.singleRegisterEndpoint(t, reusePort)
+ }
+
+ // This is a new binding.
+ multiPortEp := &multiPortEndpoint{}
+ multiPortEp.endpointsMap = make(map[TransportEndpoint]int)
+ multiPortEp.reuse = reusePort
+ epsByNic.endpoints[bindToDevice] = multiPortEp
+ return multiPortEp.singleRegisterEndpoint(t, reusePort)
+}
+
+// unregisterEndpoint returns true if endpointsByNic has to be unregistered.
+func (epsByNic *endpointsByNic) unregisterEndpoint(bindToDevice tcpip.NICID, t TransportEndpoint) bool {
+ epsByNic.mu.Lock()
+ defer epsByNic.mu.Unlock()
+ multiPortEp, ok := epsByNic.endpoints[bindToDevice]
+ if !ok {
+ return false
+ }
+ if multiPortEp.unregisterEndpoint(t) {
+ delete(epsByNic.endpoints, bindToDevice)
+ }
+ return len(epsByNic.endpoints) == 0
+}
+
// unregisterEndpoint unregisters the endpoint with the given id such that it
// won't receive any more packets.
-func (eps *transportEndpoints) unregisterEndpoint(id TransportEndpointID, ep TransportEndpoint) {
+func (eps *transportEndpoints) unregisterEndpoint(id TransportEndpointID, ep TransportEndpoint, bindToDevice tcpip.NICID) {
eps.mu.Lock()
defer eps.mu.Unlock()
- e, ok := eps.endpoints[id]
+ epsByNic, ok := eps.endpoints[id]
if !ok {
return
}
- if multiPortEp, ok := e.(*multiPortEndpoint); ok {
- if !multiPortEp.unregisterEndpoint(ep) {
- return
- }
+ if !epsByNic.unregisterEndpoint(bindToDevice, ep) {
+ return
}
delete(eps.endpoints, id)
}
@@ -75,7 +159,7 @@ func newTransportDemuxer(stack *Stack) *transportDemuxer {
for netProto := range stack.networkProtocols {
for proto := range stack.transportProtocols {
d.protocol[protocolIDs{netProto, proto}] = &transportEndpoints{
- endpoints: make(map[TransportEndpointID]TransportEndpoint),
+ endpoints: make(map[TransportEndpointID]*endpointsByNic),
}
}
}
@@ -85,10 +169,10 @@ func newTransportDemuxer(stack *Stack) *transportDemuxer {
// registerEndpoint registers the given endpoint with the dispatcher such that
// packets that match the endpoint ID are delivered to it.
-func (d *transportDemuxer) registerEndpoint(netProtos []tcpip.NetworkProtocolNumber, protocol tcpip.TransportProtocolNumber, id TransportEndpointID, ep TransportEndpoint, reusePort bool) *tcpip.Error {
+func (d *transportDemuxer) registerEndpoint(netProtos []tcpip.NetworkProtocolNumber, protocol tcpip.TransportProtocolNumber, id TransportEndpointID, ep TransportEndpoint, reusePort bool, bindToDevice tcpip.NICID) *tcpip.Error {
for i, n := range netProtos {
- if err := d.singleRegisterEndpoint(n, protocol, id, ep, reusePort); err != nil {
- d.unregisterEndpoint(netProtos[:i], protocol, id, ep)
+ if err := d.singleRegisterEndpoint(n, protocol, id, ep, reusePort, bindToDevice); err != nil {
+ d.unregisterEndpoint(netProtos[:i], protocol, id, ep, bindToDevice)
return err
}
}
@@ -97,13 +181,14 @@ func (d *transportDemuxer) registerEndpoint(netProtos []tcpip.NetworkProtocolNum
}
// multiPortEndpoint is a container for TransportEndpoints which are bound to
-// the same pair of address and port.
+// the same pair of address and port. endpointsArr always has at least one
+// element.
type multiPortEndpoint struct {
mu sync.RWMutex
endpointsArr []TransportEndpoint
endpointsMap map[TransportEndpoint]int
- // seed is a random secret for a jenkins hash.
- seed uint32
+ // reuse indicates if more than one endpoint is allowed.
+ reuse bool
}
// reciprocalScale scales a value into range [0, n).
@@ -117,9 +202,10 @@ func reciprocalScale(val, n uint32) uint32 {
// selectEndpoint calculates a hash of destination and source addresses and
// ports then uses it to select a socket. In this case, all packets from one
// address will be sent to same endpoint.
-func (ep *multiPortEndpoint) selectEndpoint(id TransportEndpointID) TransportEndpoint {
- ep.mu.RLock()
- defer ep.mu.RUnlock()
+func selectEndpoint(id TransportEndpointID, mpep *multiPortEndpoint, seed uint32) TransportEndpoint {
+ if len(mpep.endpointsArr) == 1 {
+ return mpep.endpointsArr[0]
+ }
payload := []byte{
byte(id.LocalPort),
@@ -128,51 +214,50 @@ func (ep *multiPortEndpoint) selectEndpoint(id TransportEndpointID) TransportEnd
byte(id.RemotePort >> 8),
}
- h := jenkins.Sum32(ep.seed)
+ h := jenkins.Sum32(seed)
h.Write(payload)
h.Write([]byte(id.LocalAddress))
h.Write([]byte(id.RemoteAddress))
hash := h.Sum32()
- idx := reciprocalScale(hash, uint32(len(ep.endpointsArr)))
- return ep.endpointsArr[idx]
+ idx := reciprocalScale(hash, uint32(len(mpep.endpointsArr)))
+ return mpep.endpointsArr[idx]
}
-// HandlePacket is called by the stack when new packets arrive to this transport
-// endpoint.
-func (ep *multiPortEndpoint) HandlePacket(r *Route, id TransportEndpointID, vv buffer.VectorisedView) {
- // If this is a broadcast or multicast datagram, deliver the datagram to all
- // endpoints managed by ep.
- if id.LocalAddress == header.IPv4Broadcast || header.IsV4MulticastAddress(id.LocalAddress) || header.IsV6MulticastAddress(id.LocalAddress) {
- for i, endpoint := range ep.endpointsArr {
- // HandlePacket modifies vv, so each endpoint needs its own copy.
- if i == len(ep.endpointsArr)-1 {
- endpoint.HandlePacket(r, id, vv)
- break
- }
- vvCopy := buffer.NewView(vv.Size())
- copy(vvCopy, vv.ToView())
- endpoint.HandlePacket(r, id, vvCopy.ToVectorisedView())
+func (ep *multiPortEndpoint) handlePacketAll(r *Route, id TransportEndpointID, vv buffer.VectorisedView) {
+ ep.mu.RLock()
+ for i, endpoint := range ep.endpointsArr {
+ // HandlePacket modifies vv, so each endpoint needs its own copy except for
+ // the final one.
+ if i == len(ep.endpointsArr)-1 {
+ endpoint.HandlePacket(r, id, vv)
+ break
}
- } else {
- ep.selectEndpoint(id).HandlePacket(r, id, vv)
+ vvCopy := buffer.NewView(vv.Size())
+ copy(vvCopy, vv.ToView())
+ endpoint.HandlePacket(r, id, vvCopy.ToVectorisedView())
}
+ ep.mu.RUnlock() // Don't use defer for performance reasons.
}
-// HandleControlPacket implements stack.TransportEndpoint.HandleControlPacket.
-func (ep *multiPortEndpoint) HandleControlPacket(id TransportEndpointID, typ ControlType, extra uint32, vv buffer.VectorisedView) {
- ep.selectEndpoint(id).HandleControlPacket(id, typ, extra, vv)
-}
-
-func (ep *multiPortEndpoint) singleRegisterEndpoint(t TransportEndpoint) {
+// singleRegisterEndpoint tries to add an endpoint to the multiPortEndpoint
+// list. The list might be empty already.
+func (ep *multiPortEndpoint) singleRegisterEndpoint(t TransportEndpoint, reusePort bool) *tcpip.Error {
ep.mu.Lock()
defer ep.mu.Unlock()
- // A new endpoint is added into endpointsArr and its index there is
- // saved in endpointsMap. This will allows to remove endpoint from
- // the array fast.
+ if len(ep.endpointsArr) > 0 {
+ // If it was previously bound, we need to check if we can bind again.
+ if !ep.reuse || !reusePort {
+ return tcpip.ErrPortInUse
+ }
+ }
+
+ // A new endpoint is added into endpointsArr and its index there is saved in
+ // endpointsMap. This will allow us to remove endpoint from the array fast.
ep.endpointsMap[t] = len(ep.endpointsArr)
ep.endpointsArr = append(ep.endpointsArr, t)
+ return nil
}
// unregisterEndpoint returns true if multiPortEndpoint has to be unregistered.
@@ -197,53 +282,41 @@ func (ep *multiPortEndpoint) unregisterEndpoint(t TransportEndpoint) bool {
return true
}
-func (d *transportDemuxer) singleRegisterEndpoint(netProto tcpip.NetworkProtocolNumber, protocol tcpip.TransportProtocolNumber, id TransportEndpointID, ep TransportEndpoint, reusePort bool) *tcpip.Error {
+func (d *transportDemuxer) singleRegisterEndpoint(netProto tcpip.NetworkProtocolNumber, protocol tcpip.TransportProtocolNumber, id TransportEndpointID, ep TransportEndpoint, reusePort bool, bindToDevice tcpip.NICID) *tcpip.Error {
if id.RemotePort != 0 {
+ // TODO(eyalsoha): Why?
reusePort = false
}
eps, ok := d.protocol[protocolIDs{netProto, protocol}]
if !ok {
- return nil
+ return tcpip.ErrUnknownProtocol
}
eps.mu.Lock()
defer eps.mu.Unlock()
- var multiPortEp *multiPortEndpoint
- if _, ok := eps.endpoints[id]; ok {
- if !reusePort {
- return tcpip.ErrPortInUse
- }
- multiPortEp, ok = eps.endpoints[id].(*multiPortEndpoint)
- if !ok {
- return tcpip.ErrPortInUse
- }
+ if epsByNic, ok := eps.endpoints[id]; ok {
+ // There was already a binding.
+ return epsByNic.registerEndpoint(ep, reusePort, bindToDevice)
}
- if reusePort {
- if multiPortEp == nil {
- multiPortEp = &multiPortEndpoint{}
- multiPortEp.endpointsMap = make(map[TransportEndpoint]int)
- multiPortEp.seed = rand.Uint32()
- eps.endpoints[id] = multiPortEp
- }
-
- multiPortEp.singleRegisterEndpoint(ep)
-
- return nil
+ // This is a new binding.
+ epsByNic := &endpointsByNic{
+ endpoints: make(map[tcpip.NICID]*multiPortEndpoint),
+ seed: rand.Uint32(),
}
- eps.endpoints[id] = ep
+ eps.endpoints[id] = epsByNic
- return nil
+ return epsByNic.registerEndpoint(ep, reusePort, bindToDevice)
}
// unregisterEndpoint unregisters the endpoint with the given id such that it
// won't receive any more packets.
-func (d *transportDemuxer) unregisterEndpoint(netProtos []tcpip.NetworkProtocolNumber, protocol tcpip.TransportProtocolNumber, id TransportEndpointID, ep TransportEndpoint) {
+func (d *transportDemuxer) unregisterEndpoint(netProtos []tcpip.NetworkProtocolNumber, protocol tcpip.TransportProtocolNumber, id TransportEndpointID, ep TransportEndpoint, bindToDevice tcpip.NICID) {
for _, n := range netProtos {
if eps, ok := d.protocol[protocolIDs{n, protocol}]; ok {
- eps.unregisterEndpoint(id, ep)
+ eps.unregisterEndpoint(id, ep, bindToDevice)
}
}
}
@@ -273,7 +346,7 @@ func (d *transportDemuxer) deliverPacket(r *Route, protocol tcpip.TransportProto
// If the packet is a broadcast, then find all matching transport endpoints.
// Otherwise, try to find a single matching transport endpoint.
- destEps := make([]TransportEndpoint, 0, 1)
+ destEps := make([]*endpointsByNic, 0, 1)
eps.mu.RLock()
if protocol == header.UDPProtocolNumber && id.LocalAddress == header.IPv4Broadcast {
@@ -299,7 +372,7 @@ func (d *transportDemuxer) deliverPacket(r *Route, protocol tcpip.TransportProto
// Deliver the packet.
for _, ep := range destEps {
- ep.HandlePacket(r, id, vv)
+ ep.handlePacket(r, id, vv)
}
return true
@@ -331,7 +404,7 @@ func (d *transportDemuxer) deliverRawPacket(r *Route, protocol tcpip.TransportPr
// deliverControlPacket attempts to deliver the given control packet. Returns
// true if it found an endpoint, false otherwise.
-func (d *transportDemuxer) deliverControlPacket(net tcpip.NetworkProtocolNumber, trans tcpip.TransportProtocolNumber, typ ControlType, extra uint32, vv buffer.VectorisedView, id TransportEndpointID) bool {
+func (d *transportDemuxer) deliverControlPacket(n *NIC, net tcpip.NetworkProtocolNumber, trans tcpip.TransportProtocolNumber, typ ControlType, extra uint32, vv buffer.VectorisedView, id TransportEndpointID) bool {
eps, ok := d.protocol[protocolIDs{net, trans}]
if !ok {
return false
@@ -348,12 +421,12 @@ func (d *transportDemuxer) deliverControlPacket(net tcpip.NetworkProtocolNumber,
}
// Deliver the packet.
- ep.HandleControlPacket(id, typ, extra, vv)
+ ep.handleControlPacket(n, id, typ, extra, vv)
return true
}
-func (d *transportDemuxer) findEndpointLocked(eps *transportEndpoints, vv buffer.VectorisedView, id TransportEndpointID) TransportEndpoint {
+func (d *transportDemuxer) findEndpointLocked(eps *transportEndpoints, vv buffer.VectorisedView, id TransportEndpointID) *endpointsByNic {
// Try to find a match with the id as provided.
if ep, ok := eps.endpoints[id]; ok {
return ep
diff --git a/pkg/tcpip/stack/transport_demuxer_test.go b/pkg/tcpip/stack/transport_demuxer_test.go
new file mode 100644
index 000000000..210233dc0
--- /dev/null
+++ b/pkg/tcpip/stack/transport_demuxer_test.go
@@ -0,0 +1,352 @@
+// Copyright 2018 The gVisor Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+package stack_test
+
+import (
+ "math"
+ "math/rand"
+ "testing"
+
+ "gvisor.dev/gvisor/pkg/tcpip"
+ "gvisor.dev/gvisor/pkg/tcpip/buffer"
+ "gvisor.dev/gvisor/pkg/tcpip/header"
+ "gvisor.dev/gvisor/pkg/tcpip/link/channel"
+ "gvisor.dev/gvisor/pkg/tcpip/network/ipv4"
+ "gvisor.dev/gvisor/pkg/tcpip/network/ipv6"
+ "gvisor.dev/gvisor/pkg/tcpip/stack"
+ "gvisor.dev/gvisor/pkg/tcpip/transport/udp"
+ "gvisor.dev/gvisor/pkg/waiter"
+)
+
+const (
+ stackV6Addr = "\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x01"
+ testV6Addr = "\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x02"
+
+ stackAddr = "\x0a\x00\x00\x01"
+ stackPort = 1234
+ testPort = 4096
+)
+
+type testContext struct {
+ t *testing.T
+ linkEPs map[string]*channel.Endpoint
+ s *stack.Stack
+
+ ep tcpip.Endpoint
+ wq waiter.Queue
+}
+
+func (c *testContext) cleanup() {
+ if c.ep != nil {
+ c.ep.Close()
+ }
+}
+
+func (c *testContext) createV6Endpoint(v6only bool) {
+ var err *tcpip.Error
+ c.ep, err = c.s.NewEndpoint(udp.ProtocolNumber, ipv6.ProtocolNumber, &c.wq)
+ if err != nil {
+ c.t.Fatalf("NewEndpoint failed: %v", err)
+ }
+
+ var v tcpip.V6OnlyOption
+ if v6only {
+ v = 1
+ }
+ if err := c.ep.SetSockOpt(v); err != nil {
+ c.t.Fatalf("SetSockOpt failed: %v", err)
+ }
+}
+
+// newDualTestContextMultiNic creates the testing context and also linkEpNames
+// named NICs.
+func newDualTestContextMultiNic(t *testing.T, mtu uint32, linkEpNames []string) *testContext {
+ s := stack.New(stack.Options{
+ NetworkProtocols: []stack.NetworkProtocol{ipv4.NewProtocol(), ipv6.NewProtocol()},
+ TransportProtocols: []stack.TransportProtocol{udp.NewProtocol()}})
+ linkEPs := make(map[string]*channel.Endpoint)
+ for i, linkEpName := range linkEpNames {
+ channelEP := channel.New(256, mtu, "")
+ nicid := tcpip.NICID(i + 1)
+ if err := s.CreateNamedNIC(nicid, linkEpName, channelEP); err != nil {
+ t.Fatalf("CreateNIC failed: %v", err)
+ }
+ linkEPs[linkEpName] = channelEP
+
+ if err := s.AddAddress(nicid, ipv4.ProtocolNumber, stackAddr); err != nil {
+ t.Fatalf("AddAddress IPv4 failed: %v", err)
+ }
+
+ if err := s.AddAddress(nicid, ipv6.ProtocolNumber, stackV6Addr); err != nil {
+ t.Fatalf("AddAddress IPv6 failed: %v", err)
+ }
+ }
+
+ s.SetRouteTable([]tcpip.Route{
+ {
+ Destination: header.IPv4EmptySubnet,
+ NIC: 1,
+ },
+ {
+ Destination: header.IPv6EmptySubnet,
+ NIC: 1,
+ },
+ })
+
+ return &testContext{
+ t: t,
+ s: s,
+ linkEPs: linkEPs,
+ }
+}
+
+type headers struct {
+ srcPort uint16
+ dstPort uint16
+}
+
+func newPayload() []byte {
+ b := make([]byte, 30+rand.Intn(100))
+ for i := range b {
+ b[i] = byte(rand.Intn(256))
+ }
+ return b
+}
+
+func (c *testContext) sendV6Packet(payload []byte, h *headers, linkEpName string) {
+ // Allocate a buffer for data and headers.
+ buf := buffer.NewView(header.UDPMinimumSize + header.IPv6MinimumSize + len(payload))
+ copy(buf[len(buf)-len(payload):], payload)
+
+ // Initialize the IP header.
+ ip := header.IPv6(buf)
+ ip.Encode(&header.IPv6Fields{
+ PayloadLength: uint16(header.UDPMinimumSize + len(payload)),
+ NextHeader: uint8(udp.ProtocolNumber),
+ HopLimit: 65,
+ SrcAddr: testV6Addr,
+ DstAddr: stackV6Addr,
+ })
+
+ // Initialize the UDP header.
+ u := header.UDP(buf[header.IPv6MinimumSize:])
+ u.Encode(&header.UDPFields{
+ SrcPort: h.srcPort,
+ DstPort: h.dstPort,
+ Length: uint16(header.UDPMinimumSize + len(payload)),
+ })
+
+ // Calculate the UDP pseudo-header checksum.
+ xsum := header.PseudoHeaderChecksum(udp.ProtocolNumber, testV6Addr, stackV6Addr, uint16(len(u)))
+
+ // Calculate the UDP checksum and set it.
+ xsum = header.Checksum(payload, xsum)
+ u.SetChecksum(^u.CalculateChecksum(xsum))
+
+ // Inject packet.
+ c.linkEPs[linkEpName].Inject(ipv6.ProtocolNumber, buf.ToVectorisedView())
+}
+
+func TestTransportDemuxerRegister(t *testing.T) {
+ for _, test := range []struct {
+ name string
+ proto tcpip.NetworkProtocolNumber
+ want *tcpip.Error
+ }{
+ {"failure", ipv6.ProtocolNumber, tcpip.ErrUnknownProtocol},
+ {"success", ipv4.ProtocolNumber, nil},
+ } {
+ t.Run(test.name, func(t *testing.T) {
+ s := stack.New(stack.Options{
+ NetworkProtocols: []stack.NetworkProtocol{ipv4.NewProtocol()},
+ TransportProtocols: []stack.TransportProtocol{udp.NewProtocol()}})
+ if got, want := s.RegisterTransportEndpoint(0, []tcpip.NetworkProtocolNumber{test.proto}, udp.ProtocolNumber, stack.TransportEndpointID{}, nil, false, 0), test.want; got != want {
+ t.Fatalf("s.RegisterTransportEndpoint(...) = %v, want %v", got, want)
+ }
+ })
+ }
+}
+
+// TestReuseBindToDevice injects varied packets on input devices and checks that
+// the distribution of packets received matches expectations.
+func TestDistribution(t *testing.T) {
+ type endpointSockopts struct {
+ reuse int
+ bindToDevice string
+ }
+ for _, test := range []struct {
+ name string
+ // endpoints will received the inject packets.
+ endpoints []endpointSockopts
+ // wantedDistribution is the wanted ratio of packets received on each
+ // endpoint for each NIC on which packets are injected.
+ wantedDistributions map[string][]float64
+ }{
+ {
+ "BindPortReuse",
+ // 5 endpoints that all have reuse set.
+ []endpointSockopts{
+ endpointSockopts{1, ""},
+ endpointSockopts{1, ""},
+ endpointSockopts{1, ""},
+ endpointSockopts{1, ""},
+ endpointSockopts{1, ""},
+ },
+ map[string][]float64{
+ // Injected packets on dev0 get distributed evenly.
+ "dev0": []float64{0.2, 0.2, 0.2, 0.2, 0.2},
+ },
+ },
+ {
+ "BindToDevice",
+ // 3 endpoints with various bindings.
+ []endpointSockopts{
+ endpointSockopts{0, "dev0"},
+ endpointSockopts{0, "dev1"},
+ endpointSockopts{0, "dev2"},
+ },
+ map[string][]float64{
+ // Injected packets on dev0 go only to the endpoint bound to dev0.
+ "dev0": []float64{1, 0, 0},
+ // Injected packets on dev1 go only to the endpoint bound to dev1.
+ "dev1": []float64{0, 1, 0},
+ // Injected packets on dev2 go only to the endpoint bound to dev2.
+ "dev2": []float64{0, 0, 1},
+ },
+ },
+ {
+ "ReuseAndBindToDevice",
+ // 6 endpoints with various bindings.
+ []endpointSockopts{
+ endpointSockopts{1, "dev0"},
+ endpointSockopts{1, "dev0"},
+ endpointSockopts{1, "dev1"},
+ endpointSockopts{1, "dev1"},
+ endpointSockopts{1, "dev1"},
+ endpointSockopts{1, ""},
+ },
+ map[string][]float64{
+ // Injected packets on dev0 get distributed among endpoints bound to
+ // dev0.
+ "dev0": []float64{0.5, 0.5, 0, 0, 0, 0},
+ // Injected packets on dev1 get distributed among endpoints bound to
+ // dev1 or unbound.
+ "dev1": []float64{0, 0, 1. / 3, 1. / 3, 1. / 3, 0},
+ // Injected packets on dev999 go only to the unbound.
+ "dev999": []float64{0, 0, 0, 0, 0, 1},
+ },
+ },
+ } {
+ t.Run(test.name, func(t *testing.T) {
+ for device, wantedDistribution := range test.wantedDistributions {
+ t.Run(device, func(t *testing.T) {
+ var devices []string
+ for d := range test.wantedDistributions {
+ devices = append(devices, d)
+ }
+ c := newDualTestContextMultiNic(t, defaultMTU, devices)
+ defer c.cleanup()
+
+ c.createV6Endpoint(false)
+
+ eps := make(map[tcpip.Endpoint]int)
+
+ pollChannel := make(chan tcpip.Endpoint)
+ for i, endpoint := range test.endpoints {
+ // Try to receive the data.
+ wq := waiter.Queue{}
+ we, ch := waiter.NewChannelEntry(nil)
+ wq.EventRegister(&we, waiter.EventIn)
+ defer wq.EventUnregister(&we)
+ defer close(ch)
+
+ var err *tcpip.Error
+ ep, err := c.s.NewEndpoint(udp.ProtocolNumber, ipv6.ProtocolNumber, &wq)
+ if err != nil {
+ c.t.Fatalf("NewEndpoint failed: %v", err)
+ }
+ eps[ep] = i
+
+ go func(ep tcpip.Endpoint) {
+ for range ch {
+ pollChannel <- ep
+ }
+ }(ep)
+
+ defer ep.Close()
+ reusePortOption := tcpip.ReusePortOption(endpoint.reuse)
+ if err := ep.SetSockOpt(reusePortOption); err != nil {
+ c.t.Fatalf("SetSockOpt(%#v) on endpoint %d failed: %v", reusePortOption, i, err)
+ }
+ bindToDeviceOption := tcpip.BindToDeviceOption(endpoint.bindToDevice)
+ if err := ep.SetSockOpt(bindToDeviceOption); err != nil {
+ c.t.Fatalf("SetSockOpt(%#v) on endpoint %d failed: %v", bindToDeviceOption, i, err)
+ }
+ if err := ep.Bind(tcpip.FullAddress{Addr: stackV6Addr, Port: stackPort}); err != nil {
+ t.Fatalf("ep.Bind(...) on endpoint %d failed: %v", i, err)
+ }
+ }
+
+ npackets := 100000
+ nports := 10000
+ if got, want := len(test.endpoints), len(wantedDistribution); got != want {
+ t.Fatalf("got len(test.endpoints) = %d, want %d", got, want)
+ }
+ ports := make(map[uint16]tcpip.Endpoint)
+ stats := make(map[tcpip.Endpoint]int)
+ for i := 0; i < npackets; i++ {
+ // Send a packet.
+ port := uint16(i % nports)
+ payload := newPayload()
+ c.sendV6Packet(payload,
+ &headers{
+ srcPort: testPort + port,
+ dstPort: stackPort},
+ device)
+
+ var addr tcpip.FullAddress
+ ep := <-pollChannel
+ _, _, err := ep.Read(&addr)
+ if err != nil {
+ c.t.Fatalf("Read on endpoint %d failed: %v", eps[ep], err)
+ }
+ stats[ep]++
+ if i < nports {
+ ports[uint16(i)] = ep
+ } else {
+ // Check that all packets from one client are handled by the same
+ // socket.
+ if want, got := ports[port], ep; want != got {
+ t.Fatalf("Packet sent on port %d expected on endpoint %d but received on endpoint %d", port, eps[want], eps[got])
+ }
+ }
+ }
+
+ // Check that a packet distribution is as expected.
+ for ep, i := range eps {
+ wantedRatio := wantedDistribution[i]
+ wantedRecv := wantedRatio * float64(npackets)
+ actualRecv := stats[ep]
+ actualRatio := float64(stats[ep]) / float64(npackets)
+ // The deviation is less than 10%.
+ if math.Abs(actualRatio-wantedRatio) > 0.05 {
+ t.Errorf("wanted about %.0f%% (%.0f of %d) packets to arrive on endpoint %d, got %.0f%% (%d of %d)", wantedRatio*100, wantedRecv, npackets, i, actualRatio*100, actualRecv, npackets)
+ }
+ }
+ })
+ }
+ })
+ }
+}
diff --git a/pkg/tcpip/stack/transport_test.go b/pkg/tcpip/stack/transport_test.go
index 5335897f5..842a16277 100644
--- a/pkg/tcpip/stack/transport_test.go
+++ b/pkg/tcpip/stack/transport_test.go
@@ -65,13 +65,13 @@ func (*fakeTransportEndpoint) Read(*tcpip.FullAddress) (buffer.View, tcpip.Contr
return buffer.View{}, tcpip.ControlMessages{}, nil
}
-func (f *fakeTransportEndpoint) Write(p tcpip.Payload, opts tcpip.WriteOptions) (int64, <-chan struct{}, *tcpip.Error) {
+func (f *fakeTransportEndpoint) Write(p tcpip.Payloader, opts tcpip.WriteOptions) (int64, <-chan struct{}, *tcpip.Error) {
if len(f.route.RemoteAddress) == 0 {
return 0, nil, tcpip.ErrNoRoute
}
hdr := buffer.NewPrependable(int(f.route.MaxHeaderLength()))
- v, err := p.Get(p.Size())
+ v, err := p.FullPayload()
if err != nil {
return 0, nil, err
}
@@ -91,6 +91,11 @@ func (*fakeTransportEndpoint) SetSockOpt(interface{}) *tcpip.Error {
return tcpip.ErrInvalidEndpointState
}
+// SetSockOptInt sets a socket option. Currently not supported.
+func (*fakeTransportEndpoint) SetSockOptInt(tcpip.SockOpt, int) *tcpip.Error {
+ return tcpip.ErrInvalidEndpointState
+}
+
// GetSockOptInt implements tcpip.Endpoint.GetSockOptInt.
func (*fakeTransportEndpoint) GetSockOptInt(opt tcpip.SockOpt) (int, *tcpip.Error) {
return -1, tcpip.ErrUnknownProtocolOption
@@ -122,7 +127,7 @@ func (f *fakeTransportEndpoint) Connect(addr tcpip.FullAddress) *tcpip.Error {
// Try to register so that we can start receiving packets.
f.id.RemoteAddress = addr.Addr
- err = f.stack.RegisterTransportEndpoint(0, []tcpip.NetworkProtocolNumber{fakeNetNumber}, fakeTransNumber, f.id, f, false)
+ err = f.stack.RegisterTransportEndpoint(0, []tcpip.NetworkProtocolNumber{fakeNetNumber}, fakeTransNumber, f.id, f, false /* reuse */, 0 /* bindToDevice */)
if err != nil {
return err
}
@@ -163,7 +168,8 @@ func (f *fakeTransportEndpoint) Bind(a tcpip.FullAddress) *tcpip.Error {
fakeTransNumber,
stack.TransportEndpointID{LocalAddress: a.Addr},
f,
- false,
+ false, /* reuse */
+ 0, /* bindtoDevice */
); err != nil {
return err
}
@@ -251,7 +257,7 @@ func (*fakeTransportProtocol) ParsePorts(buffer.View) (src, dst uint16, err *tcp
return 0, 0, nil
}
-func (*fakeTransportProtocol) HandleUnknownDestinationPacket(*stack.Route, stack.TransportEndpointID, buffer.VectorisedView) bool {
+func (*fakeTransportProtocol) HandleUnknownDestinationPacket(*stack.Route, stack.TransportEndpointID, buffer.View, buffer.VectorisedView) bool {
return true
}
@@ -277,10 +283,17 @@ func (f *fakeTransportProtocol) Option(option interface{}) *tcpip.Error {
}
}
+func fakeTransFactory() stack.TransportProtocol {
+ return &fakeTransportProtocol{}
+}
+
func TestTransportReceive(t *testing.T) {
- id, linkEP := channel.New(10, defaultMTU, "")
- s := stack.New([]string{"fakeNet"}, []string{"fakeTrans"}, stack.Options{})
- if err := s.CreateNIC(1, id); err != nil {
+ linkEP := channel.New(10, defaultMTU, "")
+ s := stack.New(stack.Options{
+ NetworkProtocols: []stack.NetworkProtocol{fakeNetFactory()},
+ TransportProtocols: []stack.TransportProtocol{fakeTransFactory()},
+ })
+ if err := s.CreateNIC(1, linkEP); err != nil {
t.Fatalf("CreateNIC failed: %v", err)
}
@@ -340,9 +353,12 @@ func TestTransportReceive(t *testing.T) {
}
func TestTransportControlReceive(t *testing.T) {
- id, linkEP := channel.New(10, defaultMTU, "")
- s := stack.New([]string{"fakeNet"}, []string{"fakeTrans"}, stack.Options{})
- if err := s.CreateNIC(1, id); err != nil {
+ linkEP := channel.New(10, defaultMTU, "")
+ s := stack.New(stack.Options{
+ NetworkProtocols: []stack.NetworkProtocol{fakeNetFactory()},
+ TransportProtocols: []stack.TransportProtocol{fakeTransFactory()},
+ })
+ if err := s.CreateNIC(1, linkEP); err != nil {
t.Fatalf("CreateNIC failed: %v", err)
}
@@ -408,9 +424,12 @@ func TestTransportControlReceive(t *testing.T) {
}
func TestTransportSend(t *testing.T) {
- id, _ := channel.New(10, defaultMTU, "")
- s := stack.New([]string{"fakeNet"}, []string{"fakeTrans"}, stack.Options{})
- if err := s.CreateNIC(1, id); err != nil {
+ linkEP := channel.New(10, defaultMTU, "")
+ s := stack.New(stack.Options{
+ NetworkProtocols: []stack.NetworkProtocol{fakeNetFactory()},
+ TransportProtocols: []stack.TransportProtocol{fakeTransFactory()},
+ })
+ if err := s.CreateNIC(1, linkEP); err != nil {
t.Fatalf("CreateNIC failed: %v", err)
}
@@ -452,7 +471,10 @@ func TestTransportSend(t *testing.T) {
}
func TestTransportOptions(t *testing.T) {
- s := stack.New([]string{"fakeNet"}, []string{"fakeTrans"}, stack.Options{})
+ s := stack.New(stack.Options{
+ NetworkProtocols: []stack.NetworkProtocol{fakeNetFactory()},
+ TransportProtocols: []stack.TransportProtocol{fakeTransFactory()},
+ })
// Try an unsupported transport protocol.
if err := s.SetTransportProtocolOption(tcpip.TransportProtocolNumber(99999), fakeTransportGoodOption(false)); err != tcpip.ErrUnknownProtocol {
@@ -493,20 +515,23 @@ func TestTransportOptions(t *testing.T) {
}
func TestTransportForwarding(t *testing.T) {
- s := stack.New([]string{"fakeNet"}, []string{"fakeTrans"}, stack.Options{})
+ s := stack.New(stack.Options{
+ NetworkProtocols: []stack.NetworkProtocol{fakeNetFactory()},
+ TransportProtocols: []stack.TransportProtocol{fakeTransFactory()},
+ })
s.SetForwarding(true)
// TODO(b/123449044): Change this to a channel NIC.
- id1 := loopback.New()
- if err := s.CreateNIC(1, id1); err != nil {
+ ep1 := loopback.New()
+ if err := s.CreateNIC(1, ep1); err != nil {
t.Fatalf("CreateNIC #1 failed: %v", err)
}
if err := s.AddAddress(1, fakeNetNumber, "\x01"); err != nil {
t.Fatalf("AddAddress #1 failed: %v", err)
}
- id2, linkEP2 := channel.New(10, defaultMTU, "")
- if err := s.CreateNIC(2, id2); err != nil {
+ ep2 := channel.New(10, defaultMTU, "")
+ if err := s.CreateNIC(2, ep2); err != nil {
t.Fatalf("CreateNIC #2 failed: %v", err)
}
if err := s.AddAddress(2, fakeNetNumber, "\x02"); err != nil {
@@ -545,7 +570,7 @@ func TestTransportForwarding(t *testing.T) {
req[0] = 1
req[1] = 3
req[2] = byte(fakeTransNumber)
- linkEP2.Inject(fakeNetNumber, req.ToVectorisedView())
+ ep2.Inject(fakeNetNumber, req.ToVectorisedView())
aep, _, err := ep.Accept()
if err != nil || aep == nil {
@@ -559,7 +584,7 @@ func TestTransportForwarding(t *testing.T) {
var p channel.PacketInfo
select {
- case p = <-linkEP2.C:
+ case p = <-ep2.C:
default:
t.Fatal("Response packet not forwarded")
}
@@ -571,9 +596,3 @@ func TestTransportForwarding(t *testing.T) {
t.Errorf("Response packet has incorrect source addresss: got = %d, want = 3", src)
}
}
-
-func init() {
- stack.RegisterTransportProtocolFactory("fakeTrans", func() stack.TransportProtocol {
- return &fakeTransportProtocol{}
- })
-}
diff --git a/pkg/tcpip/tcpip.go b/pkg/tcpip/tcpip.go
index 043dd549b..faaa4a4e3 100644
--- a/pkg/tcpip/tcpip.go
+++ b/pkg/tcpip/tcpip.go
@@ -219,6 +219,15 @@ func (s *Subnet) Mask() AddressMask {
return s.mask
}
+// Broadcast returns the subnet's broadcast address.
+func (s *Subnet) Broadcast() Address {
+ addr := []byte(s.address)
+ for i := range addr {
+ addr[i] |= ^s.mask[i]
+ }
+ return Address(addr)
+}
+
// NICID is a number that uniquely identifies a NIC.
type NICID int32
@@ -252,31 +261,34 @@ type FullAddress struct {
Port uint16
}
-// Payload provides an interface around data that is being sent to an endpoint.
-// This allows the endpoint to request the amount of data it needs based on
-// internal buffers without exposing them. 'p.Get(p.Size())' reads all the data.
-type Payload interface {
- // Get returns a slice containing exactly 'min(size, p.Size())' bytes.
- Get(size int) ([]byte, *Error)
-
- // Size returns the payload size.
- Size() int
+// Payloader is an interface that provides data.
+//
+// This interface allows the endpoint to request the amount of data it needs
+// based on internal buffers without exposing them.
+type Payloader interface {
+ // FullPayload returns all available bytes.
+ FullPayload() ([]byte, *Error)
+
+ // Payload returns a slice containing at most size bytes.
+ Payload(size int) ([]byte, *Error)
}
-// SlicePayload implements Payload on top of slices for convenience.
+// SlicePayload implements Payloader for slices.
+//
+// This is typically used for tests.
type SlicePayload []byte
-// Get implements Payload.
-func (s SlicePayload) Get(size int) ([]byte, *Error) {
- if size > s.Size() {
- size = s.Size()
- }
- return s[:size], nil
+// FullPayload implements Payloader.FullPayload.
+func (s SlicePayload) FullPayload() ([]byte, *Error) {
+ return s, nil
}
-// Size implements Payload.
-func (s SlicePayload) Size() int {
- return len(s)
+// Payload implements Payloader.Payload.
+func (s SlicePayload) Payload(size int) ([]byte, *Error) {
+ if size > len(s) {
+ size = len(s)
+ }
+ return s[:size], nil
}
// A ControlMessages contains socket control messages for IP sockets.
@@ -329,7 +341,7 @@ type Endpoint interface {
// ErrNoLinkAddress and a notification channel is returned for the caller to
// block. Channel is closed once address resolution is complete (success or
// not). The channel is only non-nil in this case.
- Write(Payload, WriteOptions) (int64, <-chan struct{}, *Error)
+ Write(Payloader, WriteOptions) (int64, <-chan struct{}, *Error)
// Peek reads data without consuming it from the endpoint.
//
@@ -389,6 +401,10 @@ type Endpoint interface {
// SetSockOpt sets a socket option. opt should be one of the *Option types.
SetSockOpt(opt interface{}) *Error
+ // SetSockOptInt sets a socket option, for simple cases where a value
+ // has the int type.
+ SetSockOptInt(opt SockOpt, v int) *Error
+
// GetSockOpt gets a socket option. opt should be a pointer to one of the
// *Option types.
GetSockOpt(opt interface{}) *Error
@@ -423,16 +439,33 @@ type WriteOptions struct {
// EndOfRecord has the same semantics as Linux's MSG_EOR.
EndOfRecord bool
+
+ // Atomic means that all data fetched from Payloader must be written to the
+ // endpoint. If Atomic is false, then data fetched from the Payloader may be
+ // discarded if available endpoint buffer space is unsufficient.
+ Atomic bool
}
// SockOpt represents socket options which values have the int type.
type SockOpt int
const (
- // ReceiveQueueSizeOption is used in GetSockOpt to specify that the number of
- // unread bytes in the input buffer should be returned.
+ // ReceiveQueueSizeOption is used in GetSockOptInt to specify that the
+ // number of unread bytes in the input buffer should be returned.
ReceiveQueueSizeOption SockOpt = iota
+ // SendBufferSizeOption is used by SetSockOptInt/GetSockOptInt to
+ // specify the send buffer size option.
+ SendBufferSizeOption
+
+ // ReceiveBufferSizeOption is used by SetSockOptInt/GetSockOptInt to
+ // specify the receive buffer size option.
+ ReceiveBufferSizeOption
+
+ // SendQueueSizeOption is used in GetSockOptInt to specify that the
+ // number of unread bytes in the output buffer should be returned.
+ SendQueueSizeOption
+
// TODO(b/137664753): convert all int socket options to be handled via
// GetSockOptInt.
)
@@ -441,18 +474,6 @@ const (
// the endpoint should be cleared and returned.
type ErrorOption struct{}
-// SendBufferSizeOption is used by SetSockOpt/GetSockOpt to specify the send
-// buffer size option.
-type SendBufferSizeOption int
-
-// ReceiveBufferSizeOption is used by SetSockOpt/GetSockOpt to specify the
-// receive buffer size option.
-type ReceiveBufferSizeOption int
-
-// SendQueueSizeOption is used in GetSockOpt to specify that the number of
-// unread bytes in the output buffer should be returned.
-type SendQueueSizeOption int
-
// V6OnlyOption is used by SetSockOpt/GetSockOpt to specify whether an IPv6
// socket is to be restricted to sending and receiving IPv6 packets only.
type V6OnlyOption int
@@ -474,6 +495,10 @@ type ReuseAddressOption int
// to be bound to an identical socket address.
type ReusePortOption int
+// BindToDeviceOption is used by SetSockOpt/GetSockOpt to specify that sockets
+// should bind only on a specific NIC.
+type BindToDeviceOption string
+
// QuickAckOption is stubbed out in SetSockOpt/GetSockOpt.
type QuickAckOption int
@@ -581,7 +606,7 @@ type Route struct {
}
// String implements the fmt.Stringer interface.
-func (r *Route) String() string {
+func (r Route) String() string {
var out strings.Builder
fmt.Fprintf(&out, "%s", r.Destination)
if len(r.Gateway) > 0 {
@@ -591,9 +616,6 @@ func (r *Route) String() string {
return out.String()
}
-// LinkEndpointID represents a data link layer endpoint.
-type LinkEndpointID uint64
-
// TransportProtocolNumber is the number of a transport protocol.
type TransportProtocolNumber uint32
@@ -720,6 +742,10 @@ type ICMPv4SentPacketStats struct {
// Dropped is the total number of ICMPv4 packets dropped due to link
// layer errors.
Dropped *StatCounter
+
+ // RateLimited is the total number of ICMPv6 packets dropped due to
+ // rate limit being exceeded.
+ RateLimited *StatCounter
}
// ICMPv4ReceivedPacketStats collects inbound ICMPv4-specific stats.
@@ -738,6 +764,10 @@ type ICMPv6SentPacketStats struct {
// Dropped is the total number of ICMPv6 packets dropped due to link
// layer errors.
Dropped *StatCounter
+
+ // RateLimited is the total number of ICMPv6 packets dropped due to
+ // rate limit being exceeded.
+ RateLimited *StatCounter
}
// ICMPv6ReceivedPacketStats collects inbound ICMPv6-specific stats.
diff --git a/pkg/tcpip/transport/icmp/endpoint.go b/pkg/tcpip/transport/icmp/endpoint.go
index 451d3880e..a3a910d41 100644
--- a/pkg/tcpip/transport/icmp/endpoint.go
+++ b/pkg/tcpip/transport/icmp/endpoint.go
@@ -15,7 +15,6 @@
package icmp
import (
- "encoding/binary"
"sync"
"gvisor.dev/gvisor/pkg/tcpip"
@@ -105,7 +104,7 @@ func (e *endpoint) Close() {
e.shutdownFlags = tcpip.ShutdownRead | tcpip.ShutdownWrite
switch e.state {
case stateBound, stateConnected:
- e.stack.UnregisterTransportEndpoint(e.regNICID, []tcpip.NetworkProtocolNumber{e.netProto}, e.transProto, e.id, e)
+ e.stack.UnregisterTransportEndpoint(e.regNICID, []tcpip.NetworkProtocolNumber{e.netProto}, e.transProto, e.id, e, 0 /* bindToDevice */)
}
// Close the receive list and drain it.
@@ -205,7 +204,7 @@ func (e *endpoint) prepareForWrite(to *tcpip.FullAddress) (retry bool, err *tcpi
// Write writes data to the endpoint's peer. This method does not block
// if the data cannot be written.
-func (e *endpoint) Write(p tcpip.Payload, opts tcpip.WriteOptions) (int64, <-chan struct{}, *tcpip.Error) {
+func (e *endpoint) Write(p tcpip.Payloader, opts tcpip.WriteOptions) (int64, <-chan struct{}, *tcpip.Error) {
// MSG_MORE is unimplemented. (This also means that MSG_EOR is a no-op.)
if opts.More {
return 0, nil, tcpip.ErrInvalidOptionValue
@@ -290,7 +289,7 @@ func (e *endpoint) Write(p tcpip.Payload, opts tcpip.WriteOptions) (int64, <-cha
}
}
- v, err := p.Get(p.Size())
+ v, err := p.FullPayload()
if err != nil {
return 0, nil, err
}
@@ -320,6 +319,11 @@ func (e *endpoint) SetSockOpt(opt interface{}) *tcpip.Error {
return nil
}
+// SetSockOptInt sets a socket option. Currently not supported.
+func (e *endpoint) SetSockOptInt(opt tcpip.SockOpt, v int) *tcpip.Error {
+ return nil
+}
+
// GetSockOptInt implements tcpip.Endpoint.GetSockOptInt.
func (e *endpoint) GetSockOptInt(opt tcpip.SockOpt) (int, *tcpip.Error) {
switch opt {
@@ -332,6 +336,18 @@ func (e *endpoint) GetSockOptInt(opt tcpip.SockOpt) (int, *tcpip.Error) {
}
e.rcvMu.Unlock()
return v, nil
+ case tcpip.SendBufferSizeOption:
+ e.mu.Lock()
+ v := e.sndBufSize
+ e.mu.Unlock()
+ return v, nil
+
+ case tcpip.ReceiveBufferSizeOption:
+ e.rcvMu.Lock()
+ v := e.rcvBufSizeMax
+ e.rcvMu.Unlock()
+ return v, nil
+
}
return -1, tcpip.ErrUnknownProtocolOption
}
@@ -342,18 +358,6 @@ func (e *endpoint) GetSockOpt(opt interface{}) *tcpip.Error {
case tcpip.ErrorOption:
return nil
- case *tcpip.SendBufferSizeOption:
- e.mu.Lock()
- *o = tcpip.SendBufferSizeOption(e.sndBufSize)
- e.mu.Unlock()
- return nil
-
- case *tcpip.ReceiveBufferSizeOption:
- e.rcvMu.Lock()
- *o = tcpip.ReceiveBufferSizeOption(e.rcvBufSizeMax)
- e.rcvMu.Unlock()
- return nil
-
case *tcpip.KeepaliveEnabledOption:
*o = 0
return nil
@@ -368,14 +372,13 @@ func send4(r *stack.Route, ident uint16, data buffer.View) *tcpip.Error {
return tcpip.ErrInvalidEndpointState
}
- // Set the ident to the user-specified port. Sequence number should
- // already be set by the user.
- binary.BigEndian.PutUint16(data[header.ICMPv4PayloadOffset:], ident)
-
hdr := buffer.NewPrependable(header.ICMPv4MinimumSize + int(r.MaxHeaderLength()))
icmpv4 := header.ICMPv4(hdr.Prepend(header.ICMPv4MinimumSize))
copy(icmpv4, data)
+ // Set the ident to the user-specified port. Sequence number should
+ // already be set by the user.
+ icmpv4.SetIdent(ident)
data = data[header.ICMPv4MinimumSize:]
// Linux performs these basic checks.
@@ -394,14 +397,13 @@ func send6(r *stack.Route, ident uint16, data buffer.View) *tcpip.Error {
return tcpip.ErrInvalidEndpointState
}
- // Set the ident. Sequence number is provided by the user.
- binary.BigEndian.PutUint16(data[header.ICMPv6MinimumSize:], ident)
-
- hdr := buffer.NewPrependable(header.ICMPv6EchoMinimumSize + int(r.MaxHeaderLength()))
+ hdr := buffer.NewPrependable(header.ICMPv6MinimumSize + int(r.MaxHeaderLength()))
- icmpv6 := header.ICMPv6(hdr.Prepend(header.ICMPv6EchoMinimumSize))
+ icmpv6 := header.ICMPv6(hdr.Prepend(header.ICMPv6MinimumSize))
copy(icmpv6, data)
- data = data[header.ICMPv6EchoMinimumSize:]
+ // Set the ident. Sequence number is provided by the user.
+ icmpv6.SetIdent(ident)
+ data = data[header.ICMPv6MinimumSize:]
if icmpv6.Type() != header.ICMPv6EchoRequest || icmpv6.Code() != 0 {
return tcpip.ErrInvalidEndpointState
@@ -541,14 +543,14 @@ func (e *endpoint) registerWithStack(nicid tcpip.NICID, netProtos []tcpip.Networ
if id.LocalPort != 0 {
// The endpoint already has a local port, just attempt to
// register it.
- err := e.stack.RegisterTransportEndpoint(nicid, netProtos, e.transProto, id, e, false)
+ err := e.stack.RegisterTransportEndpoint(nicid, netProtos, e.transProto, id, e, false /* reuse */, 0 /* bindToDevice */)
return id, err
}
// We need to find a port for the endpoint.
_, err := e.stack.PickEphemeralPort(func(p uint16) (bool, *tcpip.Error) {
id.LocalPort = p
- err := e.stack.RegisterTransportEndpoint(nicid, netProtos, e.transProto, id, e, false)
+ err := e.stack.RegisterTransportEndpoint(nicid, netProtos, e.transProto, id, e, false /* reuse */, 0 /* bindtodevice */)
switch err {
case nil:
return true, nil
diff --git a/pkg/tcpip/transport/icmp/protocol.go b/pkg/tcpip/transport/icmp/protocol.go
index 7fdba5d56..bfb16f7c3 100644
--- a/pkg/tcpip/transport/icmp/protocol.go
+++ b/pkg/tcpip/transport/icmp/protocol.go
@@ -14,16 +14,14 @@
// Package icmp contains the implementation of the ICMP and IPv6-ICMP transport
// protocols for use in ping. To use it in the networking stack, this package
-// must be added to the project, and
-// activated on the stack by passing icmp.ProtocolName (or "icmp") and/or
-// icmp.ProtocolName6 (or "icmp6") as one of the transport protocols when
-// calling stack.New(). Then endpoints can be created by passing
+// must be added to the project, and activated on the stack by passing
+// icmp.NewProtocol4() and/or icmp.NewProtocol6() as one of the transport
+// protocols when calling stack.New(). Then endpoints can be created by passing
// icmp.ProtocolNumber or icmp.ProtocolNumber6 as the transport protocol number
// when calling Stack.NewEndpoint().
package icmp
import (
- "encoding/binary"
"fmt"
"gvisor.dev/gvisor/pkg/tcpip"
@@ -35,15 +33,9 @@ import (
)
const (
- // ProtocolName4 is the string representation of the icmp protocol name.
- ProtocolName4 = "icmp4"
-
// ProtocolNumber4 is the ICMP protocol number.
ProtocolNumber4 = header.ICMPv4ProtocolNumber
- // ProtocolName6 is the string representation of the icmp protocol name.
- ProtocolName6 = "icmp6"
-
// ProtocolNumber6 is the IPv6-ICMP protocol number.
ProtocolNumber6 = header.ICMPv6ProtocolNumber
)
@@ -92,7 +84,7 @@ func (p *protocol) MinimumPacketSize() int {
case ProtocolNumber4:
return header.ICMPv4MinimumSize
case ProtocolNumber6:
- return header.ICMPv6EchoMinimumSize
+ return header.ICMPv6MinimumSize
}
panic(fmt.Sprint("unknown protocol number: ", p.number))
}
@@ -101,16 +93,18 @@ func (p *protocol) MinimumPacketSize() int {
func (p *protocol) ParsePorts(v buffer.View) (src, dst uint16, err *tcpip.Error) {
switch p.number {
case ProtocolNumber4:
- return 0, binary.BigEndian.Uint16(v[header.ICMPv4PayloadOffset:]), nil
+ hdr := header.ICMPv4(v)
+ return 0, hdr.Ident(), nil
case ProtocolNumber6:
- return 0, binary.BigEndian.Uint16(v[header.ICMPv6MinimumSize:]), nil
+ hdr := header.ICMPv6(v)
+ return 0, hdr.Ident(), nil
}
panic(fmt.Sprint("unknown protocol number: ", p.number))
}
// HandleUnknownDestinationPacket handles packets targeted at this protocol but
// that don't match any existing endpoint.
-func (p *protocol) HandleUnknownDestinationPacket(*stack.Route, stack.TransportEndpointID, buffer.VectorisedView) bool {
+func (p *protocol) HandleUnknownDestinationPacket(*stack.Route, stack.TransportEndpointID, buffer.View, buffer.VectorisedView) bool {
return true
}
@@ -124,12 +118,12 @@ func (p *protocol) Option(option interface{}) *tcpip.Error {
return tcpip.ErrUnknownProtocolOption
}
-func init() {
- stack.RegisterTransportProtocolFactory(ProtocolName4, func() stack.TransportProtocol {
- return &protocol{ProtocolNumber4}
- })
+// NewProtocol4 returns an ICMPv4 transport protocol.
+func NewProtocol4() stack.TransportProtocol {
+ return &protocol{ProtocolNumber4}
+}
- stack.RegisterTransportProtocolFactory(ProtocolName6, func() stack.TransportProtocol {
- return &protocol{ProtocolNumber6}
- })
+// NewProtocol6 returns an ICMPv6 transport protocol.
+func NewProtocol6() stack.TransportProtocol {
+ return &protocol{ProtocolNumber6}
}
diff --git a/pkg/tcpip/transport/raw/endpoint.go b/pkg/tcpip/transport/raw/endpoint.go
index 13e17e2a6..a02731a5d 100644
--- a/pkg/tcpip/transport/raw/endpoint.go
+++ b/pkg/tcpip/transport/raw/endpoint.go
@@ -207,7 +207,7 @@ func (ep *endpoint) Read(addr *tcpip.FullAddress) (buffer.View, tcpip.ControlMes
}
// Write implements tcpip.Endpoint.Write.
-func (ep *endpoint) Write(payload tcpip.Payload, opts tcpip.WriteOptions) (int64, <-chan struct{}, *tcpip.Error) {
+func (ep *endpoint) Write(p tcpip.Payloader, opts tcpip.WriteOptions) (int64, <-chan struct{}, *tcpip.Error) {
// MSG_MORE is unimplemented. This also means that MSG_EOR is a no-op.
if opts.More {
return 0, nil, tcpip.ErrInvalidOptionValue
@@ -220,9 +220,8 @@ func (ep *endpoint) Write(payload tcpip.Payload, opts tcpip.WriteOptions) (int64
return 0, nil, tcpip.ErrInvalidEndpointState
}
- payloadBytes, err := payload.Get(payload.Size())
+ payloadBytes, err := p.FullPayload()
if err != nil {
- ep.mu.RUnlock()
return 0, nil, err
}
@@ -230,7 +229,7 @@ func (ep *endpoint) Write(payload tcpip.Payload, opts tcpip.WriteOptions) (int64
// destination address, route using that address.
if !ep.associated {
ip := header.IPv4(payloadBytes)
- if !ip.IsValid(payload.Size()) {
+ if !ip.IsValid(len(payloadBytes)) {
ep.mu.RUnlock()
return 0, nil, tcpip.ErrInvalidOptionValue
}
@@ -493,6 +492,11 @@ func (ep *endpoint) SetSockOpt(opt interface{}) *tcpip.Error {
return tcpip.ErrUnknownProtocolOption
}
+// SetSockOptInt implements tcpip.Endpoint.SetSockOptInt.
+func (ep *endpoint) SetSockOptInt(opt tcpip.SockOpt, v int) *tcpip.Error {
+ return tcpip.ErrUnknownProtocolOption
+}
+
// GetSockOptInt implements tcpip.Endpoint.GetSockOptInt.
func (ep *endpoint) GetSockOptInt(opt tcpip.SockOpt) (int, *tcpip.Error) {
switch opt {
@@ -505,6 +509,19 @@ func (ep *endpoint) GetSockOptInt(opt tcpip.SockOpt) (int, *tcpip.Error) {
}
ep.rcvMu.Unlock()
return v, nil
+
+ case tcpip.SendBufferSizeOption:
+ ep.mu.Lock()
+ v := ep.sndBufSize
+ ep.mu.Unlock()
+ return v, nil
+
+ case tcpip.ReceiveBufferSizeOption:
+ ep.rcvMu.Lock()
+ v := ep.rcvBufSizeMax
+ ep.rcvMu.Unlock()
+ return v, nil
+
}
return -1, tcpip.ErrUnknownProtocolOption
@@ -516,18 +533,6 @@ func (ep *endpoint) GetSockOpt(opt interface{}) *tcpip.Error {
case tcpip.ErrorOption:
return nil
- case *tcpip.SendBufferSizeOption:
- ep.mu.Lock()
- *o = tcpip.SendBufferSizeOption(ep.sndBufSize)
- ep.mu.Unlock()
- return nil
-
- case *tcpip.ReceiveBufferSizeOption:
- ep.rcvMu.Lock()
- *o = tcpip.ReceiveBufferSizeOption(ep.rcvBufSizeMax)
- ep.rcvMu.Unlock()
- return nil
-
case *tcpip.KeepaliveEnabledOption:
*o = 0
return nil
diff --git a/pkg/tcpip/transport/raw/protocol.go b/pkg/tcpip/transport/raw/protocol.go
index 783c21e6b..a2512d666 100644
--- a/pkg/tcpip/transport/raw/protocol.go
+++ b/pkg/tcpip/transport/raw/protocol.go
@@ -20,13 +20,10 @@ import (
"gvisor.dev/gvisor/pkg/waiter"
)
-type factory struct{}
+// EndpointFactory implements stack.UnassociatedEndpointFactory.
+type EndpointFactory struct{}
// NewUnassociatedRawEndpoint implements stack.UnassociatedEndpointFactory.
-func (factory) NewUnassociatedRawEndpoint(stack *stack.Stack, netProto tcpip.NetworkProtocolNumber, transProto tcpip.TransportProtocolNumber, waiterQueue *waiter.Queue) (tcpip.Endpoint, *tcpip.Error) {
+func (EndpointFactory) NewUnassociatedRawEndpoint(stack *stack.Stack, netProto tcpip.NetworkProtocolNumber, transProto tcpip.TransportProtocolNumber, waiterQueue *waiter.Queue) (tcpip.Endpoint, *tcpip.Error) {
return newEndpoint(stack, netProto, transProto, waiterQueue, false /* associated */)
}
-
-func init() {
- stack.RegisterUnassociatedFactory(factory{})
-}
diff --git a/pkg/tcpip/transport/tcp/BUILD b/pkg/tcpip/transport/tcp/BUILD
index 1ee1a53f8..a42e1f4a2 100644
--- a/pkg/tcpip/transport/tcp/BUILD
+++ b/pkg/tcpip/transport/tcp/BUILD
@@ -1,7 +1,9 @@
+load("@io_bazel_rules_go//go:def.bzl", "go_test")
+
package(licenses = ["notice"])
load("//tools/go_generics:defs.bzl", "go_template_instance")
-load("//tools/go_stateify:defs.bzl", "go_library", "go_test")
+load("//tools/go_stateify:defs.bzl", "go_library")
go_template_instance(
name = "tcp_segment_list",
@@ -47,6 +49,7 @@ go_library(
"//pkg/sleep",
"//pkg/tcpip",
"//pkg/tcpip/buffer",
+ "//pkg/tcpip/hash/jenkins",
"//pkg/tcpip/header",
"//pkg/tcpip/iptables",
"//pkg/tcpip/seqnum",
diff --git a/pkg/tcpip/transport/tcp/accept.go b/pkg/tcpip/transport/tcp/accept.go
index e9c5099ea..3ae4a5426 100644
--- a/pkg/tcpip/transport/tcp/accept.go
+++ b/pkg/tcpip/transport/tcp/accept.go
@@ -143,6 +143,15 @@ func decSynRcvdCount() {
synRcvdCount.Unlock()
}
+// synCookiesInUse() returns true if the synRcvdCount is greater than
+// SynRcvdCountThreshold.
+func synCookiesInUse() bool {
+ synRcvdCount.Lock()
+ v := synRcvdCount.value
+ synRcvdCount.Unlock()
+ return v >= SynRcvdCountThreshold
+}
+
// newListenContext creates a new listen context.
func newListenContext(stk *stack.Stack, listenEP *endpoint, rcvWnd seqnum.Size, v6only bool, netProto tcpip.NetworkProtocolNumber) *listenContext {
l := &listenContext{
@@ -233,7 +242,7 @@ func (l *listenContext) createConnectingEndpoint(s *segment, iss seqnum.Value, i
n.initGSO()
// Register new endpoint so that packets are routed to it.
- if err := n.stack.RegisterTransportEndpoint(n.boundNICID, n.effectiveNetProtos, ProtocolNumber, n.id, n, n.reusePort); err != nil {
+ if err := n.stack.RegisterTransportEndpoint(n.boundNICID, n.effectiveNetProtos, ProtocolNumber, n.id, n, n.reusePort, n.bindToDevice); err != nil {
n.Close()
return nil, err
}
@@ -446,6 +455,27 @@ func (e *endpoint) handleListenSegment(ctx *listenContext, s *segment) {
return
}
+ if !synCookiesInUse() {
+ // Send a reset as this is an ACK for which there is no
+ // half open connections and we are not using cookies
+ // yet.
+ //
+ // The only time we should reach here when a connection
+ // was opened and closed really quickly and a delayed
+ // ACK was received from the sender.
+ replyWithReset(s)
+ return
+ }
+
+ // Since SYN cookies are in use this is potentially an ACK to a
+ // SYN-ACK we sent but don't have a half open connection state
+ // as cookies are being used to protect against a potential SYN
+ // flood. In such cases validate the cookie and if valid create
+ // a fully connected endpoint and deliver to the accept queue.
+ //
+ // If not, silently drop the ACK to avoid leaking information
+ // when under a potential syn flood attack.
+ //
// Validate the cookie.
data, ok := ctx.isCookieValid(s.id, s.ackNumber-1, s.sequenceNumber-1)
if !ok || int(data) >= len(mssTable) {
diff --git a/pkg/tcpip/transport/tcp/connect.go b/pkg/tcpip/transport/tcp/connect.go
index 00d2ae524..21038a65a 100644
--- a/pkg/tcpip/transport/tcp/connect.go
+++ b/pkg/tcpip/transport/tcp/connect.go
@@ -720,13 +720,18 @@ func (e *endpoint) handleClose() *tcpip.Error {
return nil
}
-// resetConnectionLocked sends a RST segment and puts the endpoint in an error
-// state with the given error code. This method must only be called from the
-// protocol goroutine.
+// resetConnectionLocked puts the endpoint in an error state with the given
+// error code and sends a RST if and only if the error is not ErrConnectionReset
+// indicating that the connection is being reset due to receiving a RST. This
+// method must only be called from the protocol goroutine.
func (e *endpoint) resetConnectionLocked(err *tcpip.Error) {
- e.sendRaw(buffer.VectorisedView{}, header.TCPFlagAck|header.TCPFlagRst, e.snd.sndUna, e.rcv.rcvNxt, 0)
+ // Only send a reset if the connection is being aborted for a reason
+ // other than receiving a reset.
e.state = StateError
e.hardError = err
+ if err != tcpip.ErrConnectionReset {
+ e.sendRaw(buffer.VectorisedView{}, header.TCPFlagAck|header.TCPFlagRst, e.snd.sndUna, e.rcv.rcvNxt, 0)
+ }
}
// completeWorkerLocked is called by the worker goroutine when it's about to
@@ -806,7 +811,7 @@ func (e *endpoint) keepaliveTimerExpired() *tcpip.Error {
if e.keepalive.unacked >= e.keepalive.count {
e.keepalive.Unlock()
- return tcpip.ErrConnectionReset
+ return tcpip.ErrTimeout
}
// RFC1122 4.2.3.6: TCP keepalive is a dataless ACK with
@@ -1068,6 +1073,10 @@ func (e *endpoint) protocolMainLoop(handshake bool) *tcpip.Error {
e.workMu.Lock()
if err := funcs[v].f(); err != nil {
e.mu.Lock()
+ // Ensure we release all endpoint registration and route
+ // references as the connection is now in an error
+ // state.
+ e.workerCleanup = true
e.resetConnectionLocked(err)
// Lock released below.
epilogue()
diff --git a/pkg/tcpip/transport/tcp/endpoint.go b/pkg/tcpip/transport/tcp/endpoint.go
index ac927569a..f9d5e0085 100644
--- a/pkg/tcpip/transport/tcp/endpoint.go
+++ b/pkg/tcpip/transport/tcp/endpoint.go
@@ -15,6 +15,7 @@
package tcp
import (
+ "encoding/binary"
"fmt"
"math"
"strings"
@@ -26,6 +27,7 @@ import (
"gvisor.dev/gvisor/pkg/sleep"
"gvisor.dev/gvisor/pkg/tcpip"
"gvisor.dev/gvisor/pkg/tcpip/buffer"
+ "gvisor.dev/gvisor/pkg/tcpip/hash/jenkins"
"gvisor.dev/gvisor/pkg/tcpip/header"
"gvisor.dev/gvisor/pkg/tcpip/iptables"
"gvisor.dev/gvisor/pkg/tcpip/seqnum"
@@ -280,6 +282,9 @@ type endpoint struct {
// reusePort is set to true if SO_REUSEPORT is enabled.
reusePort bool
+ // bindToDevice is set to the NIC on which to bind or disabled if 0.
+ bindToDevice tcpip.NICID
+
// delay enables Nagle's algorithm.
//
// delay is a boolean (0 is false) and must be accessed atomically.
@@ -564,11 +569,11 @@ func (e *endpoint) Close() {
// in Listen() when trying to register.
if e.state == StateListen && e.isPortReserved {
if e.isRegistered {
- e.stack.UnregisterTransportEndpoint(e.boundNICID, e.effectiveNetProtos, ProtocolNumber, e.id, e)
+ e.stack.UnregisterTransportEndpoint(e.boundNICID, e.effectiveNetProtos, ProtocolNumber, e.id, e, e.bindToDevice)
e.isRegistered = false
}
- e.stack.ReleasePort(e.effectiveNetProtos, ProtocolNumber, e.id.LocalAddress, e.id.LocalPort)
+ e.stack.ReleasePort(e.effectiveNetProtos, ProtocolNumber, e.id.LocalAddress, e.id.LocalPort, e.bindToDevice)
e.isPortReserved = false
}
@@ -625,12 +630,12 @@ func (e *endpoint) cleanupLocked() {
e.workerCleanup = false
if e.isRegistered {
- e.stack.UnregisterTransportEndpoint(e.boundNICID, e.effectiveNetProtos, ProtocolNumber, e.id, e)
+ e.stack.UnregisterTransportEndpoint(e.boundNICID, e.effectiveNetProtos, ProtocolNumber, e.id, e, e.bindToDevice)
e.isRegistered = false
}
if e.isPortReserved {
- e.stack.ReleasePort(e.effectiveNetProtos, ProtocolNumber, e.id.LocalAddress, e.id.LocalPort)
+ e.stack.ReleasePort(e.effectiveNetProtos, ProtocolNumber, e.id.LocalAddress, e.id.LocalPort, e.bindToDevice)
e.isPortReserved = false
}
@@ -806,7 +811,7 @@ func (e *endpoint) isEndpointWritableLocked() (int, *tcpip.Error) {
}
// Write writes data to the endpoint's peer.
-func (e *endpoint) Write(p tcpip.Payload, opts tcpip.WriteOptions) (int64, <-chan struct{}, *tcpip.Error) {
+func (e *endpoint) Write(p tcpip.Payloader, opts tcpip.WriteOptions) (int64, <-chan struct{}, *tcpip.Error) {
// Linux completely ignores any address passed to sendto(2) for TCP sockets
// (without the MSG_FASTOPEN flag). Corking is unimplemented, so opts.More
// and opts.EndOfRecord are also ignored.
@@ -821,47 +826,52 @@ func (e *endpoint) Write(p tcpip.Payload, opts tcpip.WriteOptions) (int64, <-cha
return 0, nil, err
}
- e.sndBufMu.Unlock()
- e.mu.RUnlock()
-
- // Nothing to do if the buffer is empty.
- if p.Size() == 0 {
- return 0, nil, nil
+ // We can release locks while copying data.
+ //
+ // This is not possible if atomic is set, because we can't allow the
+ // available buffer space to be consumed by some other caller while we
+ // are copying data in.
+ if !opts.Atomic {
+ e.sndBufMu.Unlock()
+ e.mu.RUnlock()
}
- // Copy in memory without holding sndBufMu so that worker goroutine can
- // make progress independent of this operation.
- v, perr := p.Get(avail)
- if perr != nil {
+ // Fetch data.
+ v, perr := p.Payload(avail)
+ if perr != nil || len(v) == 0 {
+ if opts.Atomic { // See above.
+ e.sndBufMu.Unlock()
+ e.mu.RUnlock()
+ }
+ // Note that perr may be nil if len(v) == 0.
return 0, nil, perr
}
- e.mu.RLock()
- e.sndBufMu.Lock()
+ if !opts.Atomic { // See above.
+ e.mu.RLock()
+ e.sndBufMu.Lock()
- // Because we released the lock before copying, check state again
- // to make sure the endpoint is still in a valid state for a
- // write.
- avail, err = e.isEndpointWritableLocked()
- if err != nil {
- e.sndBufMu.Unlock()
- e.mu.RUnlock()
- return 0, nil, err
- }
+ // Because we released the lock before copying, check state again
+ // to make sure the endpoint is still in a valid state for a write.
+ avail, err = e.isEndpointWritableLocked()
+ if err != nil {
+ e.sndBufMu.Unlock()
+ e.mu.RUnlock()
+ return 0, nil, err
+ }
- // Discard any excess data copied in due to avail being reduced due to a
- // simultaneous write call to the socket.
- if avail < len(v) {
- v = v[:avail]
+ // Discard any excess data copied in due to avail being reduced due
+ // to a simultaneous write call to the socket.
+ if avail < len(v) {
+ v = v[:avail]
+ }
}
// Add data to the send queue.
- l := len(v)
s := newSegmentFromView(&e.route, e.id, v)
- e.sndBufUsed += l
- e.sndBufInQueue += seqnum.Size(l)
+ e.sndBufUsed += len(v)
+ e.sndBufInQueue += seqnum.Size(len(v))
e.sndQueue.PushBack(s)
-
e.sndBufMu.Unlock()
// Release the endpoint lock to prevent deadlocks due to lock
// order inversion when acquiring workMu.
@@ -875,7 +885,8 @@ func (e *endpoint) Write(p tcpip.Payload, opts tcpip.WriteOptions) (int64, <-cha
// Let the protocol goroutine do the work.
e.sndWaker.Assert()
}
- return int64(l), nil, nil
+
+ return int64(len(v)), nil, nil
}
// Peek reads data without consuming it from the endpoint.
@@ -946,62 +957,9 @@ func (e *endpoint) zeroReceiveWindow(scale uint8) bool {
return ((e.rcvBufSize - e.rcvBufUsed) >> scale) == 0
}
-// SetSockOpt sets a socket option.
-func (e *endpoint) SetSockOpt(opt interface{}) *tcpip.Error {
- switch v := opt.(type) {
- case tcpip.DelayOption:
- if v == 0 {
- atomic.StoreUint32(&e.delay, 0)
-
- // Handle delayed data.
- e.sndWaker.Assert()
- } else {
- atomic.StoreUint32(&e.delay, 1)
- }
- return nil
-
- case tcpip.CorkOption:
- if v == 0 {
- atomic.StoreUint32(&e.cork, 0)
-
- // Handle the corked data.
- e.sndWaker.Assert()
- } else {
- atomic.StoreUint32(&e.cork, 1)
- }
- return nil
-
- case tcpip.ReuseAddressOption:
- e.mu.Lock()
- e.reuseAddr = v != 0
- e.mu.Unlock()
- return nil
-
- case tcpip.ReusePortOption:
- e.mu.Lock()
- e.reusePort = v != 0
- e.mu.Unlock()
- return nil
-
- case tcpip.QuickAckOption:
- if v == 0 {
- atomic.StoreUint32(&e.slowAck, 1)
- } else {
- atomic.StoreUint32(&e.slowAck, 0)
- }
- return nil
-
- case tcpip.MaxSegOption:
- userMSS := v
- if userMSS < header.TCPMinimumMSS || userMSS > header.TCPMaximumMSS {
- return tcpip.ErrInvalidOptionValue
- }
- e.mu.Lock()
- e.userMSS = int(userMSS)
- e.mu.Unlock()
- e.notifyProtocolGoroutine(notifyMSSChanged)
- return nil
-
+// SetSockOptInt sets a socket option.
+func (e *endpoint) SetSockOptInt(opt tcpip.SockOpt, v int) *tcpip.Error {
+ switch opt {
case tcpip.ReceiveBufferSizeOption:
// Make sure the receive buffer size is within the min and max
// allowed.
@@ -1065,6 +1023,82 @@ func (e *endpoint) SetSockOpt(opt interface{}) *tcpip.Error {
e.sndBufMu.Unlock()
return nil
+ default:
+ return nil
+ }
+}
+
+// SetSockOpt sets a socket option.
+func (e *endpoint) SetSockOpt(opt interface{}) *tcpip.Error {
+ switch v := opt.(type) {
+ case tcpip.DelayOption:
+ if v == 0 {
+ atomic.StoreUint32(&e.delay, 0)
+
+ // Handle delayed data.
+ e.sndWaker.Assert()
+ } else {
+ atomic.StoreUint32(&e.delay, 1)
+ }
+ return nil
+
+ case tcpip.CorkOption:
+ if v == 0 {
+ atomic.StoreUint32(&e.cork, 0)
+
+ // Handle the corked data.
+ e.sndWaker.Assert()
+ } else {
+ atomic.StoreUint32(&e.cork, 1)
+ }
+ return nil
+
+ case tcpip.ReuseAddressOption:
+ e.mu.Lock()
+ e.reuseAddr = v != 0
+ e.mu.Unlock()
+ return nil
+
+ case tcpip.ReusePortOption:
+ e.mu.Lock()
+ e.reusePort = v != 0
+ e.mu.Unlock()
+ return nil
+
+ case tcpip.BindToDeviceOption:
+ e.mu.Lock()
+ defer e.mu.Unlock()
+ if v == "" {
+ e.bindToDevice = 0
+ return nil
+ }
+ for nicid, nic := range e.stack.NICInfo() {
+ if nic.Name == string(v) {
+ e.bindToDevice = nicid
+ return nil
+ }
+ }
+ return tcpip.ErrUnknownDevice
+
+ case tcpip.QuickAckOption:
+ if v == 0 {
+ atomic.StoreUint32(&e.slowAck, 1)
+ } else {
+ atomic.StoreUint32(&e.slowAck, 0)
+ }
+ return nil
+
+ case tcpip.MaxSegOption:
+ userMSS := v
+ if userMSS < header.TCPMinimumMSS || userMSS > header.TCPMaximumMSS {
+ return tcpip.ErrInvalidOptionValue
+ }
+ e.mu.Lock()
+ e.userMSS = int(userMSS)
+ e.mu.Unlock()
+ e.notifyProtocolGoroutine(notifyMSSChanged)
+ return nil
+
case tcpip.V6OnlyOption:
// We only recognize this option on v6 endpoints.
if e.netProto != header.IPv6ProtocolNumber {
@@ -1176,6 +1210,18 @@ func (e *endpoint) GetSockOptInt(opt tcpip.SockOpt) (int, *tcpip.Error) {
switch opt {
case tcpip.ReceiveQueueSizeOption:
return e.readyReceiveSize()
+ case tcpip.SendBufferSizeOption:
+ e.sndBufMu.Lock()
+ v := e.sndBufSize
+ e.sndBufMu.Unlock()
+ return v, nil
+
+ case tcpip.ReceiveBufferSizeOption:
+ e.rcvListMu.Lock()
+ v := e.rcvBufSize
+ e.rcvListMu.Unlock()
+ return v, nil
+
}
return -1, tcpip.ErrUnknownProtocolOption
}
@@ -1198,18 +1244,6 @@ func (e *endpoint) GetSockOpt(opt interface{}) *tcpip.Error {
*o = header.TCPDefaultMSS
return nil
- case *tcpip.SendBufferSizeOption:
- e.sndBufMu.Lock()
- *o = tcpip.SendBufferSizeOption(e.sndBufSize)
- e.sndBufMu.Unlock()
- return nil
-
- case *tcpip.ReceiveBufferSizeOption:
- e.rcvListMu.Lock()
- *o = tcpip.ReceiveBufferSizeOption(e.rcvBufSize)
- e.rcvListMu.Unlock()
- return nil
-
case *tcpip.DelayOption:
*o = 0
if v := atomic.LoadUint32(&e.delay); v != 0 {
@@ -1246,6 +1280,16 @@ func (e *endpoint) GetSockOpt(opt interface{}) *tcpip.Error {
}
return nil
+ case *tcpip.BindToDeviceOption:
+ e.mu.RLock()
+ defer e.mu.RUnlock()
+ if nic, ok := e.stack.NICInfo()[e.bindToDevice]; ok {
+ *o = tcpip.BindToDeviceOption(nic.Name)
+ return nil
+ }
+ *o = ""
+ return nil
+
case *tcpip.QuickAckOption:
*o = 1
if v := atomic.LoadUint32(&e.slowAck); v != 0 {
@@ -1452,7 +1496,7 @@ func (e *endpoint) connect(addr tcpip.FullAddress, handshake bool, run bool) (er
if e.id.LocalPort != 0 {
// The endpoint is bound to a port, attempt to register it.
- err := e.stack.RegisterTransportEndpoint(nicid, netProtos, ProtocolNumber, e.id, e, e.reusePort)
+ err := e.stack.RegisterTransportEndpoint(nicid, netProtos, ProtocolNumber, e.id, e, e.reusePort, e.bindToDevice)
if err != nil {
return err
}
@@ -1462,17 +1506,32 @@ func (e *endpoint) connect(addr tcpip.FullAddress, handshake bool, run bool) (er
// address/port for both local and remote (otherwise this
// endpoint would be trying to connect to itself).
sameAddr := e.id.LocalAddress == e.id.RemoteAddress
- if _, err := e.stack.PickEphemeralPort(func(p uint16) (bool, *tcpip.Error) {
+
+ // Calculate a port offset based on the destination IP/port and
+ // src IP to ensure that for a given tuple (srcIP, destIP,
+ // destPort) the offset used as a starting point is the same to
+ // ensure that we can cycle through the port space effectively.
+ h := jenkins.Sum32(e.stack.PortSeed())
+ h.Write([]byte(e.id.LocalAddress))
+ h.Write([]byte(e.id.RemoteAddress))
+ portBuf := make([]byte, 2)
+ binary.LittleEndian.PutUint16(portBuf, e.id.RemotePort)
+ h.Write(portBuf)
+ portOffset := h.Sum32()
+
+ if _, err := e.stack.PickEphemeralPortStable(portOffset, func(p uint16) (bool, *tcpip.Error) {
if sameAddr && p == e.id.RemotePort {
return false, nil
}
- if !e.stack.IsPortAvailable(netProtos, ProtocolNumber, e.id.LocalAddress, p, false) {
+ // reusePort is false below because connect cannot reuse a port even if
+ // reusePort was set.
+ if !e.stack.IsPortAvailable(netProtos, ProtocolNumber, e.id.LocalAddress, p, false /* reusePort */, e.bindToDevice) {
return false, nil
}
id := e.id
id.LocalPort = p
- switch e.stack.RegisterTransportEndpoint(nicid, netProtos, ProtocolNumber, id, e, e.reusePort) {
+ switch e.stack.RegisterTransportEndpoint(nicid, netProtos, ProtocolNumber, id, e, e.reusePort, e.bindToDevice) {
case nil:
e.id = id
return true, nil
@@ -1490,7 +1549,7 @@ func (e *endpoint) connect(addr tcpip.FullAddress, handshake bool, run bool) (er
// before Connect: in such a case we don't want to hold on to
// reservations anymore.
if e.isPortReserved {
- e.stack.ReleasePort(e.effectiveNetProtos, ProtocolNumber, origID.LocalAddress, origID.LocalPort)
+ e.stack.ReleasePort(e.effectiveNetProtos, ProtocolNumber, origID.LocalAddress, origID.LocalPort, e.bindToDevice)
e.isPortReserved = false
}
@@ -1634,7 +1693,7 @@ func (e *endpoint) Listen(backlog int) (err *tcpip.Error) {
}
// Register the endpoint.
- if err := e.stack.RegisterTransportEndpoint(e.boundNICID, e.effectiveNetProtos, ProtocolNumber, e.id, e, e.reusePort); err != nil {
+ if err := e.stack.RegisterTransportEndpoint(e.boundNICID, e.effectiveNetProtos, ProtocolNumber, e.id, e, e.reusePort, e.bindToDevice); err != nil {
return err
}
@@ -1715,7 +1774,7 @@ func (e *endpoint) Bind(addr tcpip.FullAddress) (err *tcpip.Error) {
}
}
- port, err := e.stack.ReservePort(netProtos, ProtocolNumber, addr.Addr, addr.Port, e.reusePort)
+ port, err := e.stack.ReservePort(netProtos, ProtocolNumber, addr.Addr, addr.Port, e.reusePort, e.bindToDevice)
if err != nil {
return err
}
@@ -1725,16 +1784,16 @@ func (e *endpoint) Bind(addr tcpip.FullAddress) (err *tcpip.Error) {
e.id.LocalPort = port
// Any failures beyond this point must remove the port registration.
- defer func() {
+ defer func(bindToDevice tcpip.NICID) {
if err != nil {
- e.stack.ReleasePort(netProtos, ProtocolNumber, addr.Addr, port)
+ e.stack.ReleasePort(netProtos, ProtocolNumber, addr.Addr, port, bindToDevice)
e.isPortReserved = false
e.effectiveNetProtos = nil
e.id.LocalPort = 0
e.id.LocalAddress = ""
e.boundNICID = 0
}
- }()
+ }(e.bindToDevice)
// If an address is specified, we must ensure that it's one of our
// local addresses.
diff --git a/pkg/tcpip/transport/tcp/protocol.go b/pkg/tcpip/transport/tcp/protocol.go
index ee04dcfcc..d5d8ab96a 100644
--- a/pkg/tcpip/transport/tcp/protocol.go
+++ b/pkg/tcpip/transport/tcp/protocol.go
@@ -14,7 +14,7 @@
// Package tcp contains the implementation of the TCP transport protocol. To use
// it in the networking stack, this package must be added to the project, and
-// activated on the stack by passing tcp.ProtocolName (or "tcp") as one of the
+// activated on the stack by passing tcp.NewProtocol() as one of the
// transport protocols when calling stack.New(). Then endpoints can be created
// by passing tcp.ProtocolNumber as the transport protocol number when calling
// Stack.NewEndpoint().
@@ -34,9 +34,6 @@ import (
)
const (
- // ProtocolName is the string representation of the tcp protocol name.
- ProtocolName = "tcp"
-
// ProtocolNumber is the tcp protocol number.
ProtocolNumber = header.TCPProtocolNumber
@@ -129,7 +126,7 @@ func (*protocol) ParsePorts(v buffer.View) (src, dst uint16, err *tcpip.Error) {
// a reset is sent in response to any incoming segment except another reset. In
// particular, SYNs addressed to a non-existent connection are rejected by this
// means."
-func (*protocol) HandleUnknownDestinationPacket(r *stack.Route, id stack.TransportEndpointID, vv buffer.VectorisedView) bool {
+func (*protocol) HandleUnknownDestinationPacket(r *stack.Route, id stack.TransportEndpointID, netHeader buffer.View, vv buffer.VectorisedView) bool {
s := newSegment(r, id, vv)
defer s.decRef()
@@ -254,13 +251,12 @@ func (p *protocol) Option(option interface{}) *tcpip.Error {
}
}
-func init() {
- stack.RegisterTransportProtocolFactory(ProtocolName, func() stack.TransportProtocol {
- return &protocol{
- sendBufferSize: SendBufferSizeOption{MinBufferSize, DefaultSendBufferSize, MaxBufferSize},
- recvBufferSize: ReceiveBufferSizeOption{MinBufferSize, DefaultReceiveBufferSize, MaxBufferSize},
- congestionControl: ccReno,
- availableCongestionControl: []string{ccReno, ccCubic},
- }
- })
+// NewProtocol returns a TCP transport protocol.
+func NewProtocol() stack.TransportProtocol {
+ return &protocol{
+ sendBufferSize: SendBufferSizeOption{MinBufferSize, DefaultSendBufferSize, MaxBufferSize},
+ recvBufferSize: ReceiveBufferSizeOption{MinBufferSize, DefaultReceiveBufferSize, MaxBufferSize},
+ congestionControl: ccReno,
+ availableCongestionControl: []string{ccReno, ccCubic},
+ }
}
diff --git a/pkg/tcpip/transport/tcp/snd.go b/pkg/tcpip/transport/tcp/snd.go
index 1f9b1e0ef..735edfe55 100644
--- a/pkg/tcpip/transport/tcp/snd.go
+++ b/pkg/tcpip/transport/tcp/snd.go
@@ -664,7 +664,14 @@ func (s *sender) maybeSendSegment(seg *segment, limit int, end seqnum.Value) (se
segEnd = seg.sequenceNumber.Add(1)
// Transition to FIN-WAIT1 state since we're initiating an active close.
s.ep.mu.Lock()
- s.ep.state = StateFinWait1
+ switch s.ep.state {
+ case StateCloseWait:
+ // We've already received a FIN and are now sending our own. The
+ // sender is now awaiting a final ACK for this FIN.
+ s.ep.state = StateLastAck
+ default:
+ s.ep.state = StateFinWait1
+ }
s.ep.mu.Unlock()
} else {
// We're sending a non-FIN segment.
diff --git a/pkg/tcpip/transport/tcp/tcp_noracedetector_test.go b/pkg/tcpip/transport/tcp/tcp_noracedetector_test.go
index 272bbcdbd..9fa97528b 100644
--- a/pkg/tcpip/transport/tcp/tcp_noracedetector_test.go
+++ b/pkg/tcpip/transport/tcp/tcp_noracedetector_test.go
@@ -38,7 +38,7 @@ func TestFastRecovery(t *testing.T) {
c := context.New(t, uint32(header.TCPMinimumSize+header.IPv4MinimumSize+maxPayload))
defer c.Cleanup()
- c.CreateConnected(789, 30000, nil)
+ c.CreateConnected(789, 30000, -1 /* epRcvBuf */)
const iterations = 7
data := buffer.NewView(2 * maxPayload * (tcp.InitialCwnd << (iterations + 1)))
@@ -190,7 +190,7 @@ func TestExponentialIncreaseDuringSlowStart(t *testing.T) {
c := context.New(t, uint32(header.TCPMinimumSize+header.IPv4MinimumSize+maxPayload))
defer c.Cleanup()
- c.CreateConnected(789, 30000, nil)
+ c.CreateConnected(789, 30000, -1 /* epRcvBuf */)
const iterations = 7
data := buffer.NewView(maxPayload * (tcp.InitialCwnd << (iterations + 1)))
@@ -232,7 +232,7 @@ func TestCongestionAvoidance(t *testing.T) {
c := context.New(t, uint32(header.TCPMinimumSize+header.IPv4MinimumSize+maxPayload))
defer c.Cleanup()
- c.CreateConnected(789, 30000, nil)
+ c.CreateConnected(789, 30000, -1 /* epRcvBuf */)
const iterations = 7
data := buffer.NewView(2 * maxPayload * (tcp.InitialCwnd << (iterations + 1)))
@@ -336,7 +336,7 @@ func TestCubicCongestionAvoidance(t *testing.T) {
enableCUBIC(t, c)
- c.CreateConnected(789, 30000, nil)
+ c.CreateConnected(789, 30000, -1 /* epRcvBuf */)
const iterations = 7
data := buffer.NewView(2 * maxPayload * (tcp.InitialCwnd << (iterations + 1)))
@@ -445,7 +445,7 @@ func TestRetransmit(t *testing.T) {
c := context.New(t, uint32(header.TCPMinimumSize+header.IPv4MinimumSize+maxPayload))
defer c.Cleanup()
- c.CreateConnected(789, 30000, nil)
+ c.CreateConnected(789, 30000, -1 /* epRcvBuf */)
const iterations = 7
data := buffer.NewView(maxPayload * (tcp.InitialCwnd << (iterations + 1)))
diff --git a/pkg/tcpip/transport/tcp/tcp_test.go b/pkg/tcpip/transport/tcp/tcp_test.go
index f79b8ec5f..089826a88 100644
--- a/pkg/tcpip/transport/tcp/tcp_test.go
+++ b/pkg/tcpip/transport/tcp/tcp_test.go
@@ -84,7 +84,7 @@ func TestConnectIncrementActiveConnection(t *testing.T) {
stats := c.Stack().Stats()
want := stats.TCP.ActiveConnectionOpenings.Value() + 1
- c.CreateConnected(789, 30000, nil)
+ c.CreateConnected(789, 30000, -1 /* epRcvBuf */)
if got := stats.TCP.ActiveConnectionOpenings.Value(); got != want {
t.Errorf("got stats.TCP.ActtiveConnectionOpenings.Value() = %v, want = %v", got, want)
}
@@ -97,7 +97,7 @@ func TestConnectDoesNotIncrementFailedConnectionAttempts(t *testing.T) {
stats := c.Stack().Stats()
want := stats.TCP.FailedConnectionAttempts.Value()
- c.CreateConnected(789, 30000, nil)
+ c.CreateConnected(789, 30000, -1 /* epRcvBuf */)
if got := stats.TCP.FailedConnectionAttempts.Value(); got != want {
t.Errorf("got stats.TCP.FailedConnectionOpenings.Value() = %v, want = %v", got, want)
}
@@ -131,7 +131,7 @@ func TestTCPSegmentsSentIncrement(t *testing.T) {
stats := c.Stack().Stats()
// SYN and ACK
want := stats.TCP.SegmentsSent.Value() + 2
- c.CreateConnected(789, 30000, nil)
+ c.CreateConnected(789, 30000, -1 /* epRcvBuf */)
if got := stats.TCP.SegmentsSent.Value(); got != want {
t.Errorf("got stats.TCP.SegmentsSent.Value() = %v, want = %v", got, want)
@@ -190,21 +190,122 @@ func TestTCPResetsSentIncrement(t *testing.T) {
}
}
+// TestTCPResetSentForACKWhenNotUsingSynCookies checks that the stack generates
+// a RST if an ACK is received on the listening socket for which there is no
+// active handshake in progress and we are not using SYN cookies.
+func TestTCPResetSentForACKWhenNotUsingSynCookies(t *testing.T) {
+ c := context.New(t, defaultMTU)
+ defer c.Cleanup()
+
+ wq := &waiter.Queue{}
+ ep, err := c.Stack().NewEndpoint(tcp.ProtocolNumber, ipv4.ProtocolNumber, wq)
+ if err != nil {
+ t.Fatalf("NewEndpoint failed: %v", err)
+ }
+ if err := ep.Bind(tcpip.FullAddress{Port: context.StackPort}); err != nil {
+ t.Fatalf("Bind failed: %v", err)
+ }
+
+ if err := ep.Listen(10); err != nil {
+ t.Fatalf("Listen failed: %v", err)
+ }
+
+ // Send a SYN request.
+ iss := seqnum.Value(789)
+ c.SendPacket(nil, &context.Headers{
+ SrcPort: context.TestPort,
+ DstPort: context.StackPort,
+ Flags: header.TCPFlagSyn,
+ SeqNum: iss,
+ })
+
+ // Receive the SYN-ACK reply.
+ b := c.GetPacket()
+ tcpHdr := header.TCP(header.IPv4(b).Payload())
+ c.IRS = seqnum.Value(tcpHdr.SequenceNumber())
+
+ ackHeaders := &context.Headers{
+ SrcPort: context.TestPort,
+ DstPort: context.StackPort,
+ Flags: header.TCPFlagAck,
+ SeqNum: iss + 1,
+ AckNum: c.IRS + 1,
+ }
+
+ // Send ACK.
+ c.SendPacket(nil, ackHeaders)
+
+ // Try to accept the connection.
+ we, ch := waiter.NewChannelEntry(nil)
+ wq.EventRegister(&we, waiter.EventIn)
+ defer wq.EventUnregister(&we)
+
+ c.EP, _, err = ep.Accept()
+ if err == tcpip.ErrWouldBlock {
+ // Wait for connection to be established.
+ select {
+ case <-ch:
+ c.EP, _, err = ep.Accept()
+ if err != nil {
+ t.Fatalf("Accept failed: %v", err)
+ }
+
+ case <-time.After(1 * time.Second):
+ t.Fatalf("Timed out waiting for accept")
+ }
+ }
+
+ c.EP.Close()
+ checker.IPv4(t, c.GetPacket(), checker.TCP(
+ checker.SrcPort(context.StackPort),
+ checker.DstPort(context.TestPort),
+ checker.SeqNum(uint32(c.IRS+1)),
+ checker.AckNum(uint32(iss)+1),
+ checker.TCPFlags(header.TCPFlagFin|header.TCPFlagAck)))
+
+ finHeaders := &context.Headers{
+ SrcPort: context.TestPort,
+ DstPort: context.StackPort,
+ Flags: header.TCPFlagAck | header.TCPFlagFin,
+ SeqNum: iss + 1,
+ AckNum: c.IRS + 2,
+ }
+
+ c.SendPacket(nil, finHeaders)
+
+ // Get the ACK to the FIN we just sent.
+ c.GetPacket()
+
+ // Now resend the same ACK, this ACK should generate a RST as there
+ // should be no endpoint in SYN-RCVD state and we are not using
+ // syn-cookies yet. The reason we send the same ACK is we need a valid
+ // cookie(IRS) generated by the netstack without which the ACK will be
+ // rejected.
+ c.SendPacket(nil, ackHeaders)
+
+ checker.IPv4(t, c.GetPacket(), checker.TCP(
+ checker.SrcPort(context.StackPort),
+ checker.DstPort(context.TestPort),
+ checker.SeqNum(uint32(c.IRS+1)),
+ checker.AckNum(uint32(iss)+1),
+ checker.TCPFlags(header.TCPFlagRst|header.TCPFlagAck)))
+}
+
func TestTCPResetsReceivedIncrement(t *testing.T) {
c := context.New(t, defaultMTU)
defer c.Cleanup()
stats := c.Stack().Stats()
want := stats.TCP.ResetsReceived.Value() + 1
- ackNum := seqnum.Value(789)
+ iss := seqnum.Value(789)
rcvWnd := seqnum.Size(30000)
- c.CreateConnected(ackNum, rcvWnd, nil)
+ c.CreateConnected(iss, rcvWnd, -1 /* epRcvBuf */)
c.SendPacket(nil, &context.Headers{
SrcPort: context.TestPort,
DstPort: c.Port,
- SeqNum: c.IRS.Add(2),
- AckNum: ackNum.Add(2),
+ SeqNum: iss.Add(1),
+ AckNum: c.IRS.Add(1),
RcvWnd: rcvWnd,
Flags: header.TCPFlagRst,
})
@@ -214,18 +315,43 @@ func TestTCPResetsReceivedIncrement(t *testing.T) {
}
}
+func TestTCPResetsDoNotGenerateResets(t *testing.T) {
+ c := context.New(t, defaultMTU)
+ defer c.Cleanup()
+
+ stats := c.Stack().Stats()
+ want := stats.TCP.ResetsReceived.Value() + 1
+ iss := seqnum.Value(789)
+ rcvWnd := seqnum.Size(30000)
+ c.CreateConnected(iss, rcvWnd, -1 /* epRcvBuf */)
+
+ c.SendPacket(nil, &context.Headers{
+ SrcPort: context.TestPort,
+ DstPort: c.Port,
+ SeqNum: iss.Add(1),
+ AckNum: c.IRS.Add(1),
+ RcvWnd: rcvWnd,
+ Flags: header.TCPFlagRst,
+ })
+
+ if got := stats.TCP.ResetsReceived.Value(); got != want {
+ t.Errorf("got stats.TCP.ResetsReceived.Value() = %v, want = %v", got, want)
+ }
+ c.CheckNoPacketTimeout("got an unexpected packet", 100*time.Millisecond)
+}
+
func TestActiveHandshake(t *testing.T) {
c := context.New(t, defaultMTU)
defer c.Cleanup()
- c.CreateConnected(789, 30000, nil)
+ c.CreateConnected(789, 30000, -1 /* epRcvBuf */)
}
func TestNonBlockingClose(t *testing.T) {
c := context.New(t, defaultMTU)
defer c.Cleanup()
- c.CreateConnected(789, 30000, nil)
+ c.CreateConnected(789, 30000, -1 /* epRcvBuf */)
ep := c.EP
c.EP = nil
@@ -241,7 +367,7 @@ func TestConnectResetAfterClose(t *testing.T) {
c := context.New(t, defaultMTU)
defer c.Cleanup()
- c.CreateConnected(789, 30000, nil)
+ c.CreateConnected(789, 30000, -1 /* epRcvBuf */)
ep := c.EP
c.EP = nil
@@ -291,7 +417,7 @@ func TestSimpleReceive(t *testing.T) {
c := context.New(t, defaultMTU)
defer c.Cleanup()
- c.CreateConnected(789, 30000, nil)
+ c.CreateConnected(789, 30000, -1 /* epRcvBuf */)
we, ch := waiter.NewChannelEntry(nil)
c.WQ.EventRegister(&we, waiter.EventIn)
@@ -339,11 +465,71 @@ func TestSimpleReceive(t *testing.T) {
)
}
+func TestConnectBindToDevice(t *testing.T) {
+ for _, test := range []struct {
+ name string
+ device string
+ want tcp.EndpointState
+ }{
+ {"RightDevice", "nic1", tcp.StateEstablished},
+ {"WrongDevice", "nic2", tcp.StateSynSent},
+ {"AnyDevice", "", tcp.StateEstablished},
+ } {
+ t.Run(test.name, func(t *testing.T) {
+ c := context.New(t, defaultMTU)
+ defer c.Cleanup()
+
+ c.Create(-1)
+ bindToDevice := tcpip.BindToDeviceOption(test.device)
+ c.EP.SetSockOpt(bindToDevice)
+ // Start connection attempt.
+ waitEntry, _ := waiter.NewChannelEntry(nil)
+ c.WQ.EventRegister(&waitEntry, waiter.EventOut)
+ defer c.WQ.EventUnregister(&waitEntry)
+
+ if err := c.EP.Connect(tcpip.FullAddress{Addr: context.TestAddr, Port: context.TestPort}); err != tcpip.ErrConnectStarted {
+ t.Fatalf("Unexpected return value from Connect: %v", err)
+ }
+
+ // Receive SYN packet.
+ b := c.GetPacket()
+ checker.IPv4(t, b,
+ checker.TCP(
+ checker.DstPort(context.TestPort),
+ checker.TCPFlags(header.TCPFlagSyn),
+ ),
+ )
+ if got, want := tcp.EndpointState(c.EP.State()), tcp.StateSynSent; got != want {
+ t.Fatalf("Unexpected endpoint state: want %v, got %v", want, got)
+ }
+ tcpHdr := header.TCP(header.IPv4(b).Payload())
+ c.IRS = seqnum.Value(tcpHdr.SequenceNumber())
+
+ iss := seqnum.Value(789)
+ rcvWnd := seqnum.Size(30000)
+ c.SendPacket(nil, &context.Headers{
+ SrcPort: tcpHdr.DestinationPort(),
+ DstPort: tcpHdr.SourcePort(),
+ Flags: header.TCPFlagSyn | header.TCPFlagAck,
+ SeqNum: iss,
+ AckNum: c.IRS.Add(1),
+ RcvWnd: rcvWnd,
+ TCPOpts: nil,
+ })
+
+ c.GetPacket()
+ if got, want := tcp.EndpointState(c.EP.State()), test.want; got != want {
+ t.Fatalf("Unexpected endpoint state: want %v, got %v", want, got)
+ }
+ })
+ }
+}
+
func TestOutOfOrderReceive(t *testing.T) {
c := context.New(t, defaultMTU)
defer c.Cleanup()
- c.CreateConnected(789, 30000, nil)
+ c.CreateConnected(789, 30000, -1 /* epRcvBuf */)
we, ch := waiter.NewChannelEntry(nil)
c.WQ.EventRegister(&we, waiter.EventIn)
@@ -431,8 +617,7 @@ func TestOutOfOrderFlood(t *testing.T) {
defer c.Cleanup()
// Create a new connection with initial window size of 10.
- opt := tcpip.ReceiveBufferSizeOption(10)
- c.CreateConnected(789, 30000, &opt)
+ c.CreateConnected(789, 30000, 10)
if _, _, err := c.EP.Read(nil); err != tcpip.ErrWouldBlock {
t.Fatalf("got c.EP.Read(nil) = %v, want = %v", err, tcpip.ErrWouldBlock)
@@ -505,7 +690,7 @@ func TestRstOnCloseWithUnreadData(t *testing.T) {
c := context.New(t, defaultMTU)
defer c.Cleanup()
- c.CreateConnected(789, 30000, nil)
+ c.CreateConnected(789, 30000, -1 /* epRcvBuf */)
we, ch := waiter.NewChannelEntry(nil)
c.WQ.EventRegister(&we, waiter.EventIn)
@@ -574,7 +759,7 @@ func TestRstOnCloseWithUnreadDataFinConvertRst(t *testing.T) {
c := context.New(t, defaultMTU)
defer c.Cleanup()
- c.CreateConnected(789, 30000, nil)
+ c.CreateConnected(789, 30000, -1 /* epRcvBuf */)
we, ch := waiter.NewChannelEntry(nil)
c.WQ.EventRegister(&we, waiter.EventIn)
@@ -659,7 +844,7 @@ func TestShutdownRead(t *testing.T) {
c := context.New(t, defaultMTU)
defer c.Cleanup()
- c.CreateConnected(789, 30000, nil)
+ c.CreateConnected(789, 30000, -1 /* epRcvBuf */)
if _, _, err := c.EP.Read(nil); err != tcpip.ErrWouldBlock {
t.Fatalf("got c.EP.Read(nil) = %v, want = %v", err, tcpip.ErrWouldBlock)
@@ -678,8 +863,7 @@ func TestFullWindowReceive(t *testing.T) {
c := context.New(t, defaultMTU)
defer c.Cleanup()
- opt := tcpip.ReceiveBufferSizeOption(10)
- c.CreateConnected(789, 30000, &opt)
+ c.CreateConnected(789, 30000, 10)
we, ch := waiter.NewChannelEntry(nil)
c.WQ.EventRegister(&we, waiter.EventIn)
@@ -746,11 +930,9 @@ func TestNoWindowShrinking(t *testing.T) {
defer c.Cleanup()
// Start off with a window size of 10, then shrink it to 5.
- opt := tcpip.ReceiveBufferSizeOption(10)
- c.CreateConnected(789, 30000, &opt)
+ c.CreateConnected(789, 30000, 10)
- opt = 5
- if err := c.EP.SetSockOpt(opt); err != nil {
+ if err := c.EP.SetSockOptInt(tcpip.ReceiveBufferSizeOption, 5); err != nil {
t.Fatalf("SetSockOpt failed: %v", err)
}
@@ -850,7 +1032,7 @@ func TestSimpleSend(t *testing.T) {
c := context.New(t, defaultMTU)
defer c.Cleanup()
- c.CreateConnected(789, 30000, nil)
+ c.CreateConnected(789, 30000, -1 /* epRcvBuf */)
data := []byte{1, 2, 3}
view := buffer.NewView(len(data))
@@ -891,7 +1073,7 @@ func TestZeroWindowSend(t *testing.T) {
c := context.New(t, defaultMTU)
defer c.Cleanup()
- c.CreateConnected(789, 0, nil)
+ c.CreateConnected(789, 0, -1 /* epRcvBuf */)
data := []byte{1, 2, 3}
view := buffer.NewView(len(data))
@@ -949,8 +1131,7 @@ func TestScaledWindowConnect(t *testing.T) {
defer c.Cleanup()
// Set the window size greater than the maximum non-scaled window.
- opt := tcpip.ReceiveBufferSizeOption(65535 * 3)
- c.CreateConnectedWithRawOptions(789, 30000, &opt, []byte{
+ c.CreateConnectedWithRawOptions(789, 30000, 65535*3, []byte{
header.TCPOptionWS, 3, 0, header.TCPOptionNOP,
})
@@ -984,8 +1165,7 @@ func TestNonScaledWindowConnect(t *testing.T) {
defer c.Cleanup()
// Set the window size greater than the maximum non-scaled window.
- opt := tcpip.ReceiveBufferSizeOption(65535 * 3)
- c.CreateConnected(789, 30000, &opt)
+ c.CreateConnected(789, 30000, 65535*3)
data := []byte{1, 2, 3}
view := buffer.NewView(len(data))
@@ -1025,7 +1205,7 @@ func TestScaledWindowAccept(t *testing.T) {
defer ep.Close()
// Set the window size greater than the maximum non-scaled window.
- if err := ep.SetSockOpt(tcpip.ReceiveBufferSizeOption(65535 * 3)); err != nil {
+ if err := ep.SetSockOptInt(tcpip.ReceiveBufferSizeOption, 65535*3); err != nil {
t.Fatalf("SetSockOpt failed failed: %v", err)
}
@@ -1098,7 +1278,7 @@ func TestNonScaledWindowAccept(t *testing.T) {
defer ep.Close()
// Set the window size greater than the maximum non-scaled window.
- if err := ep.SetSockOpt(tcpip.ReceiveBufferSizeOption(65535 * 3)); err != nil {
+ if err := ep.SetSockOptInt(tcpip.ReceiveBufferSizeOption, 65535*3); err != nil {
t.Fatalf("SetSockOpt failed failed: %v", err)
}
@@ -1167,8 +1347,7 @@ func TestZeroScaledWindowReceive(t *testing.T) {
// Set the window size such that a window scale of 4 will be used.
const wnd = 65535 * 10
const ws = uint32(4)
- opt := tcpip.ReceiveBufferSizeOption(wnd)
- c.CreateConnectedWithRawOptions(789, 30000, &opt, []byte{
+ c.CreateConnectedWithRawOptions(789, 30000, wnd, []byte{
header.TCPOptionWS, 3, 0, header.TCPOptionNOP,
})
@@ -1273,7 +1452,7 @@ func TestSegmentMerging(t *testing.T) {
c := context.New(t, defaultMTU)
defer c.Cleanup()
- c.CreateConnected(789, 30000, nil)
+ c.CreateConnected(789, 30000, -1 /* epRcvBuf */)
// Prevent the endpoint from processing packets.
test.stop(c.EP)
@@ -1323,7 +1502,7 @@ func TestDelay(t *testing.T) {
c := context.New(t, defaultMTU)
defer c.Cleanup()
- c.CreateConnected(789, 30000, nil)
+ c.CreateConnected(789, 30000, -1 /* epRcvBuf */)
c.EP.SetSockOpt(tcpip.DelayOption(1))
@@ -1371,7 +1550,7 @@ func TestUndelay(t *testing.T) {
c := context.New(t, defaultMTU)
defer c.Cleanup()
- c.CreateConnected(789, 30000, nil)
+ c.CreateConnected(789, 30000, -1 /* epRcvBuf */)
c.EP.SetSockOpt(tcpip.DelayOption(1))
@@ -1453,7 +1632,7 @@ func TestMSSNotDelayed(t *testing.T) {
c := context.New(t, defaultMTU)
defer c.Cleanup()
- c.CreateConnectedWithRawOptions(789, 30000, nil, []byte{
+ c.CreateConnectedWithRawOptions(789, 30000, -1 /* epRcvBuf */, []byte{
header.TCPOptionMSS, 4, byte(maxPayload / 256), byte(maxPayload % 256),
})
@@ -1569,7 +1748,7 @@ func TestSendGreaterThanMTU(t *testing.T) {
c := context.New(t, uint32(header.TCPMinimumSize+header.IPv4MinimumSize+maxPayload))
defer c.Cleanup()
- c.CreateConnected(789, 30000, nil)
+ c.CreateConnected(789, 30000, -1 /* epRcvBuf */)
testBrokenUpWrite(t, c, maxPayload)
}
@@ -1578,7 +1757,7 @@ func TestActiveSendMSSLessThanMTU(t *testing.T) {
c := context.New(t, 65535)
defer c.Cleanup()
- c.CreateConnectedWithRawOptions(789, 30000, nil, []byte{
+ c.CreateConnectedWithRawOptions(789, 30000, -1 /* epRcvBuf */, []byte{
header.TCPOptionMSS, 4, byte(maxPayload / 256), byte(maxPayload % 256),
})
testBrokenUpWrite(t, c, maxPayload)
@@ -1601,7 +1780,7 @@ func TestPassiveSendMSSLessThanMTU(t *testing.T) {
// Set the buffer size to a deterministic size so that we can check the
// window scaling option.
const rcvBufferSize = 0x20000
- if err := ep.SetSockOpt(tcpip.ReceiveBufferSizeOption(rcvBufferSize)); err != nil {
+ if err := ep.SetSockOptInt(tcpip.ReceiveBufferSizeOption, rcvBufferSize); err != nil {
t.Fatalf("SetSockOpt failed failed: %v", err)
}
@@ -1745,7 +1924,7 @@ func TestSynOptionsOnActiveConnect(t *testing.T) {
// window scaling option.
const rcvBufferSize = 0x20000
const wndScale = 2
- if err := c.EP.SetSockOpt(tcpip.ReceiveBufferSizeOption(rcvBufferSize)); err != nil {
+ if err := c.EP.SetSockOptInt(tcpip.ReceiveBufferSizeOption, rcvBufferSize); err != nil {
t.Fatalf("SetSockOpt failed failed: %v", err)
}
@@ -1847,7 +2026,7 @@ func TestReceiveOnResetConnection(t *testing.T) {
c := context.New(t, defaultMTU)
defer c.Cleanup()
- c.CreateConnected(789, 30000, nil)
+ c.CreateConnected(789, 30000, -1 /* epRcvBuf */)
// Send RST segment.
c.SendPacket(nil, &context.Headers{
@@ -1884,7 +2063,7 @@ func TestSendOnResetConnection(t *testing.T) {
c := context.New(t, defaultMTU)
defer c.Cleanup()
- c.CreateConnected(789, 30000, nil)
+ c.CreateConnected(789, 30000, -1 /* epRcvBuf */)
// Send RST segment.
c.SendPacket(nil, &context.Headers{
@@ -1909,7 +2088,7 @@ func TestFinImmediately(t *testing.T) {
c := context.New(t, defaultMTU)
defer c.Cleanup()
- c.CreateConnected(789, 30000, nil)
+ c.CreateConnected(789, 30000, -1 /* epRcvBuf */)
// Shutdown immediately, check that we get a FIN.
if err := c.EP.Shutdown(tcpip.ShutdownWrite); err != nil {
@@ -1952,7 +2131,7 @@ func TestFinRetransmit(t *testing.T) {
c := context.New(t, defaultMTU)
defer c.Cleanup()
- c.CreateConnected(789, 30000, nil)
+ c.CreateConnected(789, 30000, -1 /* epRcvBuf */)
// Shutdown immediately, check that we get a FIN.
if err := c.EP.Shutdown(tcpip.ShutdownWrite); err != nil {
@@ -2006,7 +2185,7 @@ func TestFinWithNoPendingData(t *testing.T) {
c := context.New(t, defaultMTU)
defer c.Cleanup()
- c.CreateConnected(789, 30000, nil)
+ c.CreateConnected(789, 30000, -1 /* epRcvBuf */)
// Write something out, and have it acknowledged.
view := buffer.NewView(10)
@@ -2077,7 +2256,7 @@ func TestFinWithPendingDataCwndFull(t *testing.T) {
c := context.New(t, defaultMTU)
defer c.Cleanup()
- c.CreateConnected(789, 30000, nil)
+ c.CreateConnected(789, 30000, -1 /* epRcvBuf */)
// Write enough segments to fill the congestion window before ACK'ing
// any of them.
@@ -2165,7 +2344,7 @@ func TestFinWithPendingData(t *testing.T) {
c := context.New(t, defaultMTU)
defer c.Cleanup()
- c.CreateConnected(789, 30000, nil)
+ c.CreateConnected(789, 30000, -1 /* epRcvBuf */)
// Write something out, and acknowledge it to get cwnd to 2.
view := buffer.NewView(10)
@@ -2251,7 +2430,7 @@ func TestFinWithPartialAck(t *testing.T) {
c := context.New(t, defaultMTU)
defer c.Cleanup()
- c.CreateConnected(789, 30000, nil)
+ c.CreateConnected(789, 30000, -1 /* epRcvBuf */)
// Write something out, and acknowledge it to get cwnd to 2. Also send
// FIN from the test side.
@@ -2383,7 +2562,7 @@ func scaledSendWindow(t *testing.T, scale uint8) {
defer c.Cleanup()
maxPayload := defaultMTU - header.IPv4MinimumSize - header.TCPMinimumSize
- c.CreateConnectedWithRawOptions(789, 0, nil, []byte{
+ c.CreateConnectedWithRawOptions(789, 0, -1 /* epRcvBuf */, []byte{
header.TCPOptionMSS, 4, byte(maxPayload / 256), byte(maxPayload % 256),
header.TCPOptionWS, 3, scale, header.TCPOptionNOP,
})
@@ -2433,7 +2612,7 @@ func TestScaledSendWindow(t *testing.T) {
func TestReceivedValidSegmentCountIncrement(t *testing.T) {
c := context.New(t, defaultMTU)
defer c.Cleanup()
- c.CreateConnected(789, 30000, nil)
+ c.CreateConnected(789, 30000, -1 /* epRcvBuf */)
stats := c.Stack().Stats()
want := stats.TCP.ValidSegmentsReceived.Value() + 1
@@ -2454,7 +2633,7 @@ func TestReceivedValidSegmentCountIncrement(t *testing.T) {
func TestReceivedInvalidSegmentCountIncrement(t *testing.T) {
c := context.New(t, defaultMTU)
defer c.Cleanup()
- c.CreateConnected(789, 30000, nil)
+ c.CreateConnected(789, 30000, -1 /* epRcvBuf */)
stats := c.Stack().Stats()
want := stats.TCP.InvalidSegmentsReceived.Value() + 1
vv := c.BuildSegment(nil, &context.Headers{
@@ -2478,7 +2657,7 @@ func TestReceivedInvalidSegmentCountIncrement(t *testing.T) {
func TestReceivedIncorrectChecksumIncrement(t *testing.T) {
c := context.New(t, defaultMTU)
defer c.Cleanup()
- c.CreateConnected(789, 30000, nil)
+ c.CreateConnected(789, 30000, -1 /* epRcvBuf */)
stats := c.Stack().Stats()
want := stats.TCP.ChecksumErrors.Value() + 1
vv := c.BuildSegment([]byte{0x1, 0x2, 0x3}, &context.Headers{
@@ -2509,7 +2688,7 @@ func TestReceivedSegmentQueuing(t *testing.T) {
c := context.New(t, defaultMTU)
defer c.Cleanup()
- c.CreateConnected(789, 30000, nil)
+ c.CreateConnected(789, 30000, -1 /* epRcvBuf */)
// Send 200 segments.
data := []byte{1, 2, 3}
@@ -2555,7 +2734,7 @@ func TestReadAfterClosedState(t *testing.T) {
c := context.New(t, defaultMTU)
defer c.Cleanup()
- c.CreateConnected(789, 30000, nil)
+ c.CreateConnected(789, 30000, -1 /* epRcvBuf */)
we, ch := waiter.NewChannelEntry(nil)
c.WQ.EventRegister(&we, waiter.EventIn)
@@ -2730,8 +2909,8 @@ func TestReusePort(t *testing.T) {
func checkRecvBufferSize(t *testing.T, ep tcpip.Endpoint, v int) {
t.Helper()
- var s tcpip.ReceiveBufferSizeOption
- if err := ep.GetSockOpt(&s); err != nil {
+ s, err := ep.GetSockOptInt(tcpip.ReceiveBufferSizeOption)
+ if err != nil {
t.Fatalf("GetSockOpt failed: %v", err)
}
@@ -2743,8 +2922,8 @@ func checkRecvBufferSize(t *testing.T, ep tcpip.Endpoint, v int) {
func checkSendBufferSize(t *testing.T, ep tcpip.Endpoint, v int) {
t.Helper()
- var s tcpip.SendBufferSizeOption
- if err := ep.GetSockOpt(&s); err != nil {
+ s, err := ep.GetSockOptInt(tcpip.SendBufferSizeOption)
+ if err != nil {
t.Fatalf("GetSockOpt failed: %v", err)
}
@@ -2754,7 +2933,10 @@ func checkSendBufferSize(t *testing.T, ep tcpip.Endpoint, v int) {
}
func TestDefaultBufferSizes(t *testing.T) {
- s := stack.New([]string{ipv4.ProtocolName}, []string{tcp.ProtocolName}, stack.Options{})
+ s := stack.New(stack.Options{
+ NetworkProtocols: []stack.NetworkProtocol{ipv4.NewProtocol()},
+ TransportProtocols: []stack.TransportProtocol{tcp.NewProtocol()},
+ })
// Check the default values.
ep, err := s.NewEndpoint(tcp.ProtocolNumber, ipv4.ProtocolNumber, &waiter.Queue{})
@@ -2800,7 +2982,10 @@ func TestDefaultBufferSizes(t *testing.T) {
}
func TestMinMaxBufferSizes(t *testing.T) {
- s := stack.New([]string{ipv4.ProtocolName}, []string{tcp.ProtocolName}, stack.Options{})
+ s := stack.New(stack.Options{
+ NetworkProtocols: []stack.NetworkProtocol{ipv4.NewProtocol()},
+ TransportProtocols: []stack.TransportProtocol{tcp.NewProtocol()},
+ })
// Check the default values.
ep, err := s.NewEndpoint(tcp.ProtocolNumber, ipv4.ProtocolNumber, &waiter.Queue{})
@@ -2819,37 +3004,96 @@ func TestMinMaxBufferSizes(t *testing.T) {
}
// Set values below the min.
- if err := ep.SetSockOpt(tcpip.ReceiveBufferSizeOption(199)); err != nil {
+ if err := ep.SetSockOptInt(tcpip.ReceiveBufferSizeOption, 199); err != nil {
t.Fatalf("GetSockOpt failed: %v", err)
}
checkRecvBufferSize(t, ep, 200)
- if err := ep.SetSockOpt(tcpip.SendBufferSizeOption(299)); err != nil {
+ if err := ep.SetSockOptInt(tcpip.SendBufferSizeOption, 299); err != nil {
t.Fatalf("GetSockOpt failed: %v", err)
}
checkSendBufferSize(t, ep, 300)
// Set values above the max.
- if err := ep.SetSockOpt(tcpip.ReceiveBufferSizeOption(1 + tcp.DefaultReceiveBufferSize*20)); err != nil {
+ if err := ep.SetSockOptInt(tcpip.ReceiveBufferSizeOption, 1+tcp.DefaultReceiveBufferSize*20); err != nil {
t.Fatalf("GetSockOpt failed: %v", err)
}
checkRecvBufferSize(t, ep, tcp.DefaultReceiveBufferSize*20)
- if err := ep.SetSockOpt(tcpip.SendBufferSizeOption(1 + tcp.DefaultSendBufferSize*30)); err != nil {
+ if err := ep.SetSockOptInt(tcpip.SendBufferSizeOption, 1+tcp.DefaultSendBufferSize*30); err != nil {
t.Fatalf("GetSockOpt failed: %v", err)
}
checkSendBufferSize(t, ep, tcp.DefaultSendBufferSize*30)
}
+func TestBindToDeviceOption(t *testing.T) {
+ s := stack.New(stack.Options{
+ NetworkProtocols: []stack.NetworkProtocol{ipv4.NewProtocol()},
+ TransportProtocols: []stack.TransportProtocol{tcp.NewProtocol()}})
+
+ ep, err := s.NewEndpoint(tcp.ProtocolNumber, ipv4.ProtocolNumber, &waiter.Queue{})
+ if err != nil {
+ t.Fatalf("NewEndpoint failed; %v", err)
+ }
+ defer ep.Close()
+
+ if err := s.CreateNamedNIC(321, "my_device", loopback.New()); err != nil {
+ t.Errorf("CreateNamedNIC failed: %v", err)
+ }
+
+ // Make an nameless NIC.
+ if err := s.CreateNIC(54321, loopback.New()); err != nil {
+ t.Errorf("CreateNIC failed: %v", err)
+ }
+
+ // strPtr is used instead of taking the address of string literals, which is
+ // a compiler error.
+ strPtr := func(s string) *string {
+ return &s
+ }
+
+ testActions := []struct {
+ name string
+ setBindToDevice *string
+ setBindToDeviceError *tcpip.Error
+ getBindToDevice tcpip.BindToDeviceOption
+ }{
+ {"GetDefaultValue", nil, nil, ""},
+ {"BindToNonExistent", strPtr("non_existent_device"), tcpip.ErrUnknownDevice, ""},
+ {"BindToExistent", strPtr("my_device"), nil, "my_device"},
+ {"UnbindToDevice", strPtr(""), nil, ""},
+ }
+ for _, testAction := range testActions {
+ t.Run(testAction.name, func(t *testing.T) {
+ if testAction.setBindToDevice != nil {
+ bindToDevice := tcpip.BindToDeviceOption(*testAction.setBindToDevice)
+ if got, want := ep.SetSockOpt(bindToDevice), testAction.setBindToDeviceError; got != want {
+ t.Errorf("SetSockOpt(%v) got %v, want %v", bindToDevice, got, want)
+ }
+ }
+ bindToDevice := tcpip.BindToDeviceOption("to be modified by GetSockOpt")
+ if ep.GetSockOpt(&bindToDevice) != nil {
+ t.Errorf("GetSockOpt got %v, want %v", ep.GetSockOpt(&bindToDevice), nil)
+ }
+ if got, want := bindToDevice, testAction.getBindToDevice; got != want {
+ t.Errorf("bindToDevice got %q, want %q", got, want)
+ }
+ })
+ }
+}
+
func makeStack() (*stack.Stack, *tcpip.Error) {
- s := stack.New([]string{
- ipv4.ProtocolName,
- ipv6.ProtocolName,
- }, []string{tcp.ProtocolName}, stack.Options{})
+ s := stack.New(stack.Options{
+ NetworkProtocols: []stack.NetworkProtocol{
+ ipv4.NewProtocol(),
+ ipv6.NewProtocol(),
+ },
+ TransportProtocols: []stack.TransportProtocol{tcp.NewProtocol()},
+ })
id := loopback.New()
if testing.Verbose() {
@@ -3105,7 +3349,7 @@ func TestPathMTUDiscovery(t *testing.T) {
// Create new connection with MSS of 1460.
const maxPayload = 1500 - header.TCPMinimumSize - header.IPv4MinimumSize
- c.CreateConnectedWithRawOptions(789, 30000, nil, []byte{
+ c.CreateConnectedWithRawOptions(789, 30000, -1 /* epRcvBuf */, []byte{
header.TCPOptionMSS, 4, byte(maxPayload / 256), byte(maxPayload % 256),
})
@@ -3182,7 +3426,7 @@ func TestTCPEndpointProbe(t *testing.T) {
invoked <- struct{}{}
})
- c.CreateConnected(789, 30000, nil)
+ c.CreateConnected(789, 30000, -1 /* epRcvBuf */)
data := []byte{1, 2, 3}
c.SendPacket(data, &context.Headers{
@@ -3356,7 +3600,7 @@ func TestKeepalive(t *testing.T) {
c := context.New(t, defaultMTU)
defer c.Cleanup()
- c.CreateConnected(789, 30000, nil)
+ c.CreateConnected(789, 30000, -1 /* epRcvBuf */)
c.EP.SetSockOpt(tcpip.KeepaliveIdleOption(10 * time.Millisecond))
c.EP.SetSockOpt(tcpip.KeepaliveIntervalOption(10 * time.Millisecond))
@@ -3459,8 +3703,8 @@ func TestKeepalive(t *testing.T) {
),
)
- if _, _, err := c.EP.Read(nil); err != tcpip.ErrConnectionReset {
- t.Fatalf("got c.EP.Read(nil) = %v, want = %v", err, tcpip.ErrConnectionReset)
+ if _, _, err := c.EP.Read(nil); err != tcpip.ErrTimeout {
+ t.Fatalf("got c.EP.Read(nil) = %v, want = %v", err, tcpip.ErrTimeout)
}
}
diff --git a/pkg/tcpip/transport/tcp/testing/context/context.go b/pkg/tcpip/transport/tcp/testing/context/context.go
index 272481aa0..ef823e4ae 100644
--- a/pkg/tcpip/transport/tcp/testing/context/context.go
+++ b/pkg/tcpip/transport/tcp/testing/context/context.go
@@ -137,7 +137,10 @@ type Context struct {
// New allocates and initializes a test context containing a new
// stack and a link-layer endpoint.
func New(t *testing.T, mtu uint32) *Context {
- s := stack.New([]string{ipv4.ProtocolName, ipv6.ProtocolName}, []string{tcp.ProtocolName}, stack.Options{})
+ s := stack.New(stack.Options{
+ NetworkProtocols: []stack.NetworkProtocol{ipv4.NewProtocol(), ipv6.NewProtocol()},
+ TransportProtocols: []stack.TransportProtocol{tcp.NewProtocol()},
+ })
// Allow minimum send/receive buffer sizes to be 1 during tests.
if err := s.SetTransportProtocolOption(tcp.ProtocolNumber, tcp.SendBufferSizeOption{1, tcp.DefaultSendBufferSize, 10 * tcp.DefaultSendBufferSize}); err != nil {
@@ -150,11 +153,19 @@ func New(t *testing.T, mtu uint32) *Context {
// Some of the congestion control tests send up to 640 packets, we so
// set the channel size to 1000.
- id, linkEP := channel.New(1000, mtu, "")
+ ep := channel.New(1000, mtu, "")
+ wep := stack.LinkEndpoint(ep)
+ if testing.Verbose() {
+ wep = sniffer.New(ep)
+ }
+ if err := s.CreateNamedNIC(1, "nic1", wep); err != nil {
+ t.Fatalf("CreateNIC failed: %v", err)
+ }
+ wep2 := stack.LinkEndpoint(channel.New(1000, mtu, ""))
if testing.Verbose() {
- id = sniffer.New(id)
+ wep2 = sniffer.New(channel.New(1000, mtu, ""))
}
- if err := s.CreateNIC(1, id); err != nil {
+ if err := s.CreateNamedNIC(2, "nic2", wep2); err != nil {
t.Fatalf("CreateNIC failed: %v", err)
}
@@ -180,7 +191,7 @@ func New(t *testing.T, mtu uint32) *Context {
return &Context{
t: t,
s: s,
- linkEP: linkEP,
+ linkEP: ep,
WindowScale: uint8(tcp.FindWndScale(tcp.DefaultReceiveBufferSize)),
}
}
@@ -267,7 +278,7 @@ func (c *Context) GetPacketNonBlocking() []byte {
// SendICMPPacket builds and sends an ICMPv4 packet via the link layer endpoint.
func (c *Context) SendICMPPacket(typ header.ICMPv4Type, code uint8, p1, p2 []byte, maxTotalSize int) {
// Allocate a buffer data and headers.
- buf := buffer.NewView(header.IPv4MinimumSize + header.ICMPv4PayloadOffset + len(p1) + len(p2))
+ buf := buffer.NewView(header.IPv4MinimumSize + header.ICMPv4PayloadOffset + len(p2))
if len(buf) > maxTotalSize {
buf = buf[:maxTotalSize]
}
@@ -286,9 +297,9 @@ func (c *Context) SendICMPPacket(typ header.ICMPv4Type, code uint8, p1, p2 []byt
icmp := header.ICMPv4(buf[header.IPv4MinimumSize:])
icmp.SetType(typ)
icmp.SetCode(code)
-
- copy(icmp[header.ICMPv4PayloadOffset:], p1)
- copy(icmp[header.ICMPv4PayloadOffset+len(p1):], p2)
+ const icmpv4VariableHeaderOffset = 4
+ copy(icmp[icmpv4VariableHeaderOffset:], p1)
+ copy(icmp[header.ICMPv4PayloadOffset:], p2)
// Inject packet.
c.linkEP.Inject(ipv4.ProtocolNumber, buf.ToVectorisedView())
@@ -511,7 +522,7 @@ func (c *Context) SendV6Packet(payload []byte, h *Headers) {
}
// CreateConnected creates a connected TCP endpoint.
-func (c *Context) CreateConnected(iss seqnum.Value, rcvWnd seqnum.Size, epRcvBuf *tcpip.ReceiveBufferSizeOption) {
+func (c *Context) CreateConnected(iss seqnum.Value, rcvWnd seqnum.Size, epRcvBuf int) {
c.CreateConnectedWithRawOptions(iss, rcvWnd, epRcvBuf, nil)
}
@@ -584,12 +595,8 @@ func (c *Context) Connect(iss seqnum.Value, rcvWnd seqnum.Size, options []byte)
c.Port = tcpHdr.SourcePort()
}
-// CreateConnectedWithRawOptions creates a connected TCP endpoint and sends
-// the specified option bytes as the Option field in the initial SYN packet.
-//
-// It also sets the receive buffer for the endpoint to the specified
-// value in epRcvBuf.
-func (c *Context) CreateConnectedWithRawOptions(iss seqnum.Value, rcvWnd seqnum.Size, epRcvBuf *tcpip.ReceiveBufferSizeOption, options []byte) {
+// Create creates a TCP endpoint.
+func (c *Context) Create(epRcvBuf int) {
// Create TCP endpoint.
var err *tcpip.Error
c.EP, err = c.s.NewEndpoint(tcp.ProtocolNumber, ipv4.ProtocolNumber, &c.WQ)
@@ -597,11 +604,20 @@ func (c *Context) CreateConnectedWithRawOptions(iss seqnum.Value, rcvWnd seqnum.
c.t.Fatalf("NewEndpoint failed: %v", err)
}
- if epRcvBuf != nil {
- if err := c.EP.SetSockOpt(*epRcvBuf); err != nil {
+ if epRcvBuf != -1 {
+ if err := c.EP.SetSockOptInt(tcpip.ReceiveBufferSizeOption, epRcvBuf); err != nil {
c.t.Fatalf("SetSockOpt failed failed: %v", err)
}
}
+}
+
+// CreateConnectedWithRawOptions creates a connected TCP endpoint and sends
+// the specified option bytes as the Option field in the initial SYN packet.
+//
+// It also sets the receive buffer for the endpoint to the specified
+// value in epRcvBuf.
+func (c *Context) CreateConnectedWithRawOptions(iss seqnum.Value, rcvWnd seqnum.Size, epRcvBuf int, options []byte) {
+ c.Create(epRcvBuf)
c.Connect(iss, rcvWnd, options)
}
diff --git a/pkg/tcpip/transport/tcpconntrack/BUILD b/pkg/tcpip/transport/tcpconntrack/BUILD
index 4bec48c0f..43fcc27f0 100644
--- a/pkg/tcpip/transport/tcpconntrack/BUILD
+++ b/pkg/tcpip/transport/tcpconntrack/BUILD
@@ -1,4 +1,5 @@
-load("//tools/go_stateify:defs.bzl", "go_library", "go_test")
+load("//tools/go_stateify:defs.bzl", "go_library")
+load("@io_bazel_rules_go//go:def.bzl", "go_test")
package(licenses = ["notice"])
diff --git a/pkg/tcpip/transport/udp/BUILD b/pkg/tcpip/transport/udp/BUILD
index ac2666f69..7a635ab8d 100644
--- a/pkg/tcpip/transport/udp/BUILD
+++ b/pkg/tcpip/transport/udp/BUILD
@@ -1,7 +1,9 @@
+load("@io_bazel_rules_go//go:def.bzl", "go_test")
+
package(licenses = ["notice"])
load("//tools/go_generics:defs.bzl", "go_template_instance")
-load("//tools/go_stateify:defs.bzl", "go_library", "go_test")
+load("//tools/go_stateify:defs.bzl", "go_library")
go_template_instance(
name = "udp_packet_list",
@@ -50,6 +52,7 @@ go_test(
"//pkg/tcpip/checker",
"//pkg/tcpip/header",
"//pkg/tcpip/link/channel",
+ "//pkg/tcpip/link/loopback",
"//pkg/tcpip/link/sniffer",
"//pkg/tcpip/network/ipv4",
"//pkg/tcpip/network/ipv6",
diff --git a/pkg/tcpip/transport/udp/endpoint.go b/pkg/tcpip/transport/udp/endpoint.go
index ac5905772..52f5af777 100644
--- a/pkg/tcpip/transport/udp/endpoint.go
+++ b/pkg/tcpip/transport/udp/endpoint.go
@@ -15,7 +15,6 @@
package udp
import (
- "math"
"sync"
"gvisor.dev/gvisor/pkg/tcpip"
@@ -37,13 +36,17 @@ type udpPacket struct {
views [8]buffer.View `state:"nosave"`
}
-type endpointState int
+// EndpointState represents the state of a UDP endpoint.
+type EndpointState uint32
+// Endpoint states. Note that are represented in a netstack-specific manner and
+// may not be meaningful externally. Specifically, they need to be translated to
+// Linux's representation for these states if presented to userspace.
const (
- stateInitial endpointState = iota
- stateBound
- stateConnected
- stateClosed
+ StateInitial EndpointState = iota
+ StateBound
+ StateConnected
+ StateClosed
)
// endpoint represents a UDP endpoint. This struct serves as the interface
@@ -74,7 +77,7 @@ type endpoint struct {
mu sync.RWMutex `state:"nosave"`
sndBufSize int
id stack.TransportEndpointID
- state endpointState
+ state EndpointState
bindNICID tcpip.NICID
regNICID tcpip.NICID
route stack.Route `state:"manual"`
@@ -85,6 +88,7 @@ type endpoint struct {
multicastNICID tcpip.NICID
multicastLoop bool
reusePort bool
+ bindToDevice tcpip.NICID
broadcast bool
// shutdownFlags represent the current shutdown state of the endpoint.
@@ -140,9 +144,9 @@ func (e *endpoint) Close() {
e.shutdownFlags = tcpip.ShutdownRead | tcpip.ShutdownWrite
switch e.state {
- case stateBound, stateConnected:
- e.stack.UnregisterTransportEndpoint(e.regNICID, e.effectiveNetProtos, ProtocolNumber, e.id, e)
- e.stack.ReleasePort(e.effectiveNetProtos, ProtocolNumber, e.id.LocalAddress, e.id.LocalPort)
+ case StateBound, StateConnected:
+ e.stack.UnregisterTransportEndpoint(e.regNICID, e.effectiveNetProtos, ProtocolNumber, e.id, e, e.bindToDevice)
+ e.stack.ReleasePort(e.effectiveNetProtos, ProtocolNumber, e.id.LocalAddress, e.id.LocalPort, e.bindToDevice)
}
for _, mem := range e.multicastMemberships {
@@ -163,7 +167,7 @@ func (e *endpoint) Close() {
e.route.Release()
// Update the state.
- e.state = stateClosed
+ e.state = StateClosed
e.mu.Unlock()
@@ -211,11 +215,11 @@ func (e *endpoint) Read(addr *tcpip.FullAddress) (buffer.View, tcpip.ControlMess
// Returns true for retry if preparation should be retried.
func (e *endpoint) prepareForWrite(to *tcpip.FullAddress) (retry bool, err *tcpip.Error) {
switch e.state {
- case stateInitial:
- case stateConnected:
+ case StateInitial:
+ case StateConnected:
return false, nil
- case stateBound:
+ case StateBound:
if to == nil {
return false, tcpip.ErrDestinationRequired
}
@@ -232,7 +236,7 @@ func (e *endpoint) prepareForWrite(to *tcpip.FullAddress) (retry bool, err *tcpi
// The state changed when we released the shared locked and re-acquired
// it in exclusive mode. Try again.
- if e.state != stateInitial {
+ if e.state != StateInitial {
return true, nil
}
@@ -273,17 +277,12 @@ func (e *endpoint) connectRoute(nicid tcpip.NICID, addr tcpip.FullAddress, netPr
// Write writes data to the endpoint's peer. This method does not block
// if the data cannot be written.
-func (e *endpoint) Write(p tcpip.Payload, opts tcpip.WriteOptions) (int64, <-chan struct{}, *tcpip.Error) {
+func (e *endpoint) Write(p tcpip.Payloader, opts tcpip.WriteOptions) (int64, <-chan struct{}, *tcpip.Error) {
// MSG_MORE is unimplemented. (This also means that MSG_EOR is a no-op.)
if opts.More {
return 0, nil, tcpip.ErrInvalidOptionValue
}
- if p.Size() > math.MaxUint16 {
- // Payload can't possibly fit in a packet.
- return 0, nil, tcpip.ErrMessageTooLong
- }
-
to := opts.To
e.mu.RLock()
@@ -322,7 +321,7 @@ func (e *endpoint) Write(p tcpip.Payload, opts tcpip.WriteOptions) (int64, <-cha
defer e.mu.Unlock()
// Recheck state after lock was re-acquired.
- if e.state != stateConnected {
+ if e.state != StateConnected {
return 0, nil, tcpip.ErrInvalidEndpointState
}
}
@@ -366,10 +365,14 @@ func (e *endpoint) Write(p tcpip.Payload, opts tcpip.WriteOptions) (int64, <-cha
}
}
- v, err := p.Get(p.Size())
+ v, err := p.FullPayload()
if err != nil {
return 0, nil, err
}
+ if len(v) > header.UDPMaximumPacketSize {
+ // Payload can't possibly fit in a packet.
+ return 0, nil, tcpip.ErrMessageTooLong
+ }
ttl := route.DefaultTTL()
if header.IsV4MulticastAddress(route.RemoteAddress) || header.IsV6MulticastAddress(route.RemoteAddress) {
@@ -387,7 +390,12 @@ func (e *endpoint) Peek([][]byte) (int64, tcpip.ControlMessages, *tcpip.Error) {
return 0, tcpip.ControlMessages{}, nil
}
-// SetSockOpt sets a socket option. Currently not supported.
+// SetSockOptInt implements tcpip.Endpoint.SetSockOptInt.
+func (e *endpoint) SetSockOptInt(opt tcpip.SockOpt, v int) *tcpip.Error {
+ return nil
+}
+
+// SetSockOpt implements tcpip.Endpoint.SetSockOpt.
func (e *endpoint) SetSockOpt(opt interface{}) *tcpip.Error {
switch v := opt.(type) {
case tcpip.V6OnlyOption:
@@ -400,7 +408,7 @@ func (e *endpoint) SetSockOpt(opt interface{}) *tcpip.Error {
defer e.mu.Unlock()
// We only allow this to be set when we're in the initial state.
- if e.state != stateInitial {
+ if e.state != StateInitial {
return tcpip.ErrInvalidEndpointState
}
@@ -544,6 +552,21 @@ func (e *endpoint) SetSockOpt(opt interface{}) *tcpip.Error {
e.reusePort = v != 0
e.mu.Unlock()
+ case tcpip.BindToDeviceOption:
+ e.mu.Lock()
+ defer e.mu.Unlock()
+ if v == "" {
+ e.bindToDevice = 0
+ return nil
+ }
+ for nicid, nic := range e.stack.NICInfo() {
+ if nic.Name == string(v) {
+ e.bindToDevice = nicid
+ return nil
+ }
+ }
+ return tcpip.ErrUnknownDevice
+
case tcpip.BroadcastOption:
e.mu.Lock()
e.broadcast = v != 0
@@ -566,7 +589,20 @@ func (e *endpoint) GetSockOptInt(opt tcpip.SockOpt) (int, *tcpip.Error) {
}
e.rcvMu.Unlock()
return v, nil
+
+ case tcpip.SendBufferSizeOption:
+ e.mu.Lock()
+ v := e.sndBufSize
+ e.mu.Unlock()
+ return v, nil
+
+ case tcpip.ReceiveBufferSizeOption:
+ e.rcvMu.Lock()
+ v := e.rcvBufSizeMax
+ e.rcvMu.Unlock()
+ return v, nil
}
+
return -1, tcpip.ErrUnknownProtocolOption
}
@@ -576,18 +612,6 @@ func (e *endpoint) GetSockOpt(opt interface{}) *tcpip.Error {
case tcpip.ErrorOption:
return nil
- case *tcpip.SendBufferSizeOption:
- e.mu.Lock()
- *o = tcpip.SendBufferSizeOption(e.sndBufSize)
- e.mu.Unlock()
- return nil
-
- case *tcpip.ReceiveBufferSizeOption:
- e.rcvMu.Lock()
- *o = tcpip.ReceiveBufferSizeOption(e.rcvBufSizeMax)
- e.rcvMu.Unlock()
- return nil
-
case *tcpip.V6OnlyOption:
// We only recognize this option on v6 endpoints.
if e.netProto != header.IPv6ProtocolNumber {
@@ -638,6 +662,16 @@ func (e *endpoint) GetSockOpt(opt interface{}) *tcpip.Error {
}
return nil
+ case *tcpip.BindToDeviceOption:
+ e.mu.RLock()
+ defer e.mu.RUnlock()
+ if nic, ok := e.stack.NICInfo()[e.bindToDevice]; ok {
+ *o = tcpip.BindToDeviceOption(nic.Name)
+ return nil
+ }
+ *o = tcpip.BindToDeviceOption("")
+ return nil
+
case *tcpip.KeepaliveEnabledOption:
*o = 0
return nil
@@ -726,7 +760,7 @@ func (e *endpoint) Disconnect() *tcpip.Error {
e.mu.Lock()
defer e.mu.Unlock()
- if e.state != stateConnected {
+ if e.state != StateConnected {
return nil
}
id := stack.TransportEndpointID{}
@@ -741,12 +775,16 @@ func (e *endpoint) Disconnect() *tcpip.Error {
if err != nil {
return err
}
- e.state = stateBound
+ e.state = StateBound
} else {
- e.state = stateInitial
+ if e.id.LocalPort != 0 {
+ // Release the ephemeral port.
+ e.stack.ReleasePort(e.effectiveNetProtos, ProtocolNumber, e.id.LocalAddress, e.id.LocalPort, e.bindToDevice)
+ }
+ e.state = StateInitial
}
- e.stack.UnregisterTransportEndpoint(e.regNICID, e.effectiveNetProtos, ProtocolNumber, e.id, e)
+ e.stack.UnregisterTransportEndpoint(e.regNICID, e.effectiveNetProtos, ProtocolNumber, e.id, e, e.bindToDevice)
e.id = id
e.route.Release()
e.route = stack.Route{}
@@ -772,8 +810,8 @@ func (e *endpoint) Connect(addr tcpip.FullAddress) *tcpip.Error {
nicid := addr.NIC
var localPort uint16
switch e.state {
- case stateInitial:
- case stateBound, stateConnected:
+ case StateInitial:
+ case StateBound, StateConnected:
localPort = e.id.LocalPort
if e.bindNICID == 0 {
break
@@ -801,7 +839,7 @@ func (e *endpoint) Connect(addr tcpip.FullAddress) *tcpip.Error {
RemoteAddress: r.RemoteAddress,
}
- if e.state == stateInitial {
+ if e.state == StateInitial {
id.LocalAddress = r.LocalAddress
}
@@ -823,7 +861,7 @@ func (e *endpoint) Connect(addr tcpip.FullAddress) *tcpip.Error {
// Remove the old registration.
if e.id.LocalPort != 0 {
- e.stack.UnregisterTransportEndpoint(e.regNICID, e.effectiveNetProtos, ProtocolNumber, e.id, e)
+ e.stack.UnregisterTransportEndpoint(e.regNICID, e.effectiveNetProtos, ProtocolNumber, e.id, e, e.bindToDevice)
}
e.id = id
@@ -832,7 +870,7 @@ func (e *endpoint) Connect(addr tcpip.FullAddress) *tcpip.Error {
e.regNICID = nicid
e.effectiveNetProtos = netProtos
- e.state = stateConnected
+ e.state = StateConnected
e.rcvMu.Lock()
e.rcvReady = true
@@ -854,7 +892,7 @@ func (e *endpoint) Shutdown(flags tcpip.ShutdownFlags) *tcpip.Error {
// A socket in the bound state can still receive multicast messages,
// so we need to notify waiters on shutdown.
- if e.state != stateBound && e.state != stateConnected {
+ if e.state != StateBound && e.state != StateConnected {
return tcpip.ErrNotConnected
}
@@ -886,16 +924,16 @@ func (*endpoint) Accept() (tcpip.Endpoint, *waiter.Queue, *tcpip.Error) {
func (e *endpoint) registerWithStack(nicid tcpip.NICID, netProtos []tcpip.NetworkProtocolNumber, id stack.TransportEndpointID) (stack.TransportEndpointID, *tcpip.Error) {
if e.id.LocalPort == 0 {
- port, err := e.stack.ReservePort(netProtos, ProtocolNumber, id.LocalAddress, id.LocalPort, e.reusePort)
+ port, err := e.stack.ReservePort(netProtos, ProtocolNumber, id.LocalAddress, id.LocalPort, e.reusePort, e.bindToDevice)
if err != nil {
return id, err
}
id.LocalPort = port
}
- err := e.stack.RegisterTransportEndpoint(nicid, netProtos, ProtocolNumber, id, e, e.reusePort)
+ err := e.stack.RegisterTransportEndpoint(nicid, netProtos, ProtocolNumber, id, e, e.reusePort, e.bindToDevice)
if err != nil {
- e.stack.ReleasePort(netProtos, ProtocolNumber, id.LocalAddress, id.LocalPort)
+ e.stack.ReleasePort(netProtos, ProtocolNumber, id.LocalAddress, id.LocalPort, e.bindToDevice)
}
return id, err
}
@@ -903,7 +941,7 @@ func (e *endpoint) registerWithStack(nicid tcpip.NICID, netProtos []tcpip.Networ
func (e *endpoint) bindLocked(addr tcpip.FullAddress) *tcpip.Error {
// Don't allow binding once endpoint is not in the initial state
// anymore.
- if e.state != stateInitial {
+ if e.state != StateInitial {
return tcpip.ErrInvalidEndpointState
}
@@ -946,7 +984,7 @@ func (e *endpoint) bindLocked(addr tcpip.FullAddress) *tcpip.Error {
e.effectiveNetProtos = netProtos
// Mark endpoint as bound.
- e.state = stateBound
+ e.state = StateBound
e.rcvMu.Lock()
e.rcvReady = true
@@ -989,7 +1027,7 @@ func (e *endpoint) GetRemoteAddress() (tcpip.FullAddress, *tcpip.Error) {
e.mu.RLock()
defer e.mu.RUnlock()
- if e.state != stateConnected {
+ if e.state != StateConnected {
return tcpip.FullAddress{}, tcpip.ErrNotConnected
}
@@ -1069,10 +1107,11 @@ func (e *endpoint) HandlePacket(r *stack.Route, id stack.TransportEndpointID, vv
func (e *endpoint) HandleControlPacket(id stack.TransportEndpointID, typ stack.ControlType, extra uint32, vv buffer.VectorisedView) {
}
-// State implements socket.Socket.State.
+// State implements tcpip.Endpoint.State.
func (e *endpoint) State() uint32 {
- // TODO(b/112063468): Translate internal state to values returned by Linux.
- return 0
+ e.mu.Lock()
+ defer e.mu.Unlock()
+ return uint32(e.state)
}
func isBroadcastOrMulticast(a tcpip.Address) bool {
diff --git a/pkg/tcpip/transport/udp/endpoint_state.go b/pkg/tcpip/transport/udp/endpoint_state.go
index 5cbb56120..be46e6d4e 100644
--- a/pkg/tcpip/transport/udp/endpoint_state.go
+++ b/pkg/tcpip/transport/udp/endpoint_state.go
@@ -77,7 +77,7 @@ func (e *endpoint) Resume(s *stack.Stack) {
}
}
- if e.state != stateBound && e.state != stateConnected {
+ if e.state != StateBound && e.state != StateConnected {
return
}
@@ -92,7 +92,7 @@ func (e *endpoint) Resume(s *stack.Stack) {
}
var err *tcpip.Error
- if e.state == stateConnected {
+ if e.state == StateConnected {
e.route, err = e.stack.FindRoute(e.regNICID, e.id.LocalAddress, e.id.RemoteAddress, netProto, e.multicastLoop)
if err != nil {
panic(err)
diff --git a/pkg/tcpip/transport/udp/forwarder.go b/pkg/tcpip/transport/udp/forwarder.go
index a874fc9fd..2d0bc5221 100644
--- a/pkg/tcpip/transport/udp/forwarder.go
+++ b/pkg/tcpip/transport/udp/forwarder.go
@@ -74,7 +74,7 @@ func (r *ForwarderRequest) ID() stack.TransportEndpointID {
// CreateEndpoint creates a connected UDP endpoint for the session request.
func (r *ForwarderRequest) CreateEndpoint(queue *waiter.Queue) (tcpip.Endpoint, *tcpip.Error) {
ep := newEndpoint(r.stack, r.route.NetProto, queue)
- if err := r.stack.RegisterTransportEndpoint(r.route.NICID(), []tcpip.NetworkProtocolNumber{r.route.NetProto}, ProtocolNumber, r.id, ep, ep.reusePort); err != nil {
+ if err := r.stack.RegisterTransportEndpoint(r.route.NICID(), []tcpip.NetworkProtocolNumber{r.route.NetProto}, ProtocolNumber, r.id, ep, ep.reusePort, ep.bindToDevice); err != nil {
ep.Close()
return nil, err
}
@@ -84,7 +84,7 @@ func (r *ForwarderRequest) CreateEndpoint(queue *waiter.Queue) (tcpip.Endpoint,
ep.dstPort = r.id.RemotePort
ep.regNICID = r.route.NICID()
- ep.state = stateConnected
+ ep.state = StateConnected
ep.rcvMu.Lock()
ep.rcvReady = true
diff --git a/pkg/tcpip/transport/udp/protocol.go b/pkg/tcpip/transport/udp/protocol.go
index f76e7fbe1..f5cc932dd 100644
--- a/pkg/tcpip/transport/udp/protocol.go
+++ b/pkg/tcpip/transport/udp/protocol.go
@@ -14,7 +14,7 @@
// Package udp contains the implementation of the UDP transport protocol. To use
// it in the networking stack, this package must be added to the project, and
-// activated on the stack by passing udp.ProtocolName (or "udp") as one of the
+// activated on the stack by passing udp.NewProtocol() as one of the
// transport protocols when calling stack.New(). Then endpoints can be created
// by passing udp.ProtocolNumber as the transport protocol number when calling
// Stack.NewEndpoint().
@@ -30,9 +30,6 @@ import (
)
const (
- // ProtocolName is the string representation of the udp protocol name.
- ProtocolName = "udp"
-
// ProtocolNumber is the udp protocol number.
ProtocolNumber = header.UDPProtocolNumber
)
@@ -69,7 +66,106 @@ func (*protocol) ParsePorts(v buffer.View) (src, dst uint16, err *tcpip.Error) {
// HandleUnknownDestinationPacket handles packets targeted at this protocol but
// that don't match any existing endpoint.
-func (p *protocol) HandleUnknownDestinationPacket(*stack.Route, stack.TransportEndpointID, buffer.VectorisedView) bool {
+func (p *protocol) HandleUnknownDestinationPacket(r *stack.Route, id stack.TransportEndpointID, netHeader buffer.View, vv buffer.VectorisedView) bool {
+ // Get the header then trim it from the view.
+ hdr := header.UDP(vv.First())
+ if int(hdr.Length()) > vv.Size() {
+ // Malformed packet.
+ r.Stack().Stats().UDP.MalformedPacketsReceived.Increment()
+ return true
+ }
+ // TODO(b/129426613): only send an ICMP message if UDP checksum is valid.
+
+ // Only send ICMP error if the address is not a multicast/broadcast
+ // v4/v6 address or the source is not the unspecified address.
+ //
+ // See: point e) in https://tools.ietf.org/html/rfc4443#section-2.4
+ if id.LocalAddress == header.IPv4Broadcast || header.IsV4MulticastAddress(id.LocalAddress) || header.IsV6MulticastAddress(id.LocalAddress) || id.RemoteAddress == header.IPv6Any || id.RemoteAddress == header.IPv4Any {
+ return true
+ }
+
+ // As per RFC: 1122 Section 3.2.2.1 A host SHOULD generate Destination
+ // Unreachable messages with code:
+ //
+ // 2 (Protocol Unreachable), when the designated transport protocol
+ // is not supported; or
+ //
+ // 3 (Port Unreachable), when the designated transport protocol
+ // (e.g., UDP) is unable to demultiplex the datagram but has no
+ // protocol mechanism to inform the sender.
+ switch len(id.LocalAddress) {
+ case header.IPv4AddressSize:
+ if !r.Stack().AllowICMPMessage() {
+ r.Stack().Stats().ICMP.V4PacketsSent.RateLimited.Increment()
+ return true
+ }
+ // As per RFC 1812 Section 4.3.2.3
+ //
+ // ICMP datagram SHOULD contain as much of the original
+ // datagram as possible without the length of the ICMP
+ // datagram exceeding 576 bytes
+ //
+ // NOTE: The above RFC referenced is different from the original
+ // recommendation in RFC 1122 where it mentioned that at least 8
+ // bytes of the payload must be included. Today linux and other
+ // systems implement the] RFC1812 definition and not the original
+ // RFC 1122 requirement.
+ mtu := int(r.MTU())
+ if mtu > header.IPv4MinimumProcessableDatagramSize {
+ mtu = header.IPv4MinimumProcessableDatagramSize
+ }
+ headerLen := int(r.MaxHeaderLength()) + header.ICMPv4MinimumSize
+ available := int(mtu) - headerLen
+ payloadLen := len(netHeader) + vv.Size()
+ if payloadLen > available {
+ payloadLen = available
+ }
+
+ payload := buffer.NewVectorisedView(len(netHeader), []buffer.View{netHeader})
+ payload.Append(vv)
+ payload.CapLength(payloadLen)
+
+ hdr := buffer.NewPrependable(headerLen)
+ pkt := header.ICMPv4(hdr.Prepend(header.ICMPv4MinimumSize))
+ pkt.SetType(header.ICMPv4DstUnreachable)
+ pkt.SetCode(header.ICMPv4PortUnreachable)
+ pkt.SetChecksum(header.ICMPv4Checksum(pkt, payload))
+ r.WritePacket(nil /* gso */, hdr, payload, header.ICMPv4ProtocolNumber, r.DefaultTTL())
+
+ case header.IPv6AddressSize:
+ if !r.Stack().AllowICMPMessage() {
+ r.Stack().Stats().ICMP.V6PacketsSent.RateLimited.Increment()
+ return true
+ }
+
+ // As per RFC 4443 section 2.4
+ //
+ // (c) Every ICMPv6 error message (type < 128) MUST include
+ // as much of the IPv6 offending (invoking) packet (the
+ // packet that caused the error) as possible without making
+ // the error message packet exceed the minimum IPv6 MTU
+ // [IPv6].
+ mtu := int(r.MTU())
+ if mtu > header.IPv6MinimumMTU {
+ mtu = header.IPv6MinimumMTU
+ }
+ headerLen := int(r.MaxHeaderLength()) + header.ICMPv6DstUnreachableMinimumSize
+ available := int(mtu) - headerLen
+ payloadLen := len(netHeader) + vv.Size()
+ if payloadLen > available {
+ payloadLen = available
+ }
+ payload := buffer.NewVectorisedView(len(netHeader), []buffer.View{netHeader})
+ payload.Append(vv)
+ payload.CapLength(payloadLen)
+
+ hdr := buffer.NewPrependable(headerLen)
+ pkt := header.ICMPv6(hdr.Prepend(header.ICMPv6DstUnreachableMinimumSize))
+ pkt.SetType(header.ICMPv6DstUnreachable)
+ pkt.SetCode(header.ICMPv6PortUnreachable)
+ pkt.SetChecksum(header.ICMPv6Checksum(pkt, r.LocalAddress, r.RemoteAddress, payload))
+ r.WritePacket(nil /* gso */, hdr, payload, header.ICMPv6ProtocolNumber, r.DefaultTTL())
+ }
return true
}
@@ -83,8 +179,7 @@ func (p *protocol) Option(option interface{}) *tcpip.Error {
return tcpip.ErrUnknownProtocolOption
}
-func init() {
- stack.RegisterTransportProtocolFactory(ProtocolName, func() stack.TransportProtocol {
- return &protocol{}
- })
+// NewProtocol returns a UDP transport protocol.
+func NewProtocol() stack.TransportProtocol {
+ return &protocol{}
}
diff --git a/pkg/tcpip/transport/udp/udp_test.go b/pkg/tcpip/transport/udp/udp_test.go
index 9da6edce2..5059ca22d 100644
--- a/pkg/tcpip/transport/udp/udp_test.go
+++ b/pkg/tcpip/transport/udp/udp_test.go
@@ -17,7 +17,6 @@ package udp_test
import (
"bytes"
"fmt"
- "math"
"math/rand"
"testing"
"time"
@@ -27,6 +26,7 @@ import (
"gvisor.dev/gvisor/pkg/tcpip/checker"
"gvisor.dev/gvisor/pkg/tcpip/header"
"gvisor.dev/gvisor/pkg/tcpip/link/channel"
+ "gvisor.dev/gvisor/pkg/tcpip/link/loopback"
"gvisor.dev/gvisor/pkg/tcpip/link/sniffer"
"gvisor.dev/gvisor/pkg/tcpip/network/ipv4"
"gvisor.dev/gvisor/pkg/tcpip/network/ipv6"
@@ -274,13 +274,17 @@ type testContext struct {
func newDualTestContext(t *testing.T, mtu uint32) *testContext {
t.Helper()
- s := stack.New([]string{ipv4.ProtocolName, ipv6.ProtocolName}, []string{udp.ProtocolName}, stack.Options{})
+ s := stack.New(stack.Options{
+ NetworkProtocols: []stack.NetworkProtocol{ipv4.NewProtocol(), ipv6.NewProtocol()},
+ TransportProtocols: []stack.TransportProtocol{udp.NewProtocol()},
+ })
+ ep := channel.New(256, mtu, "")
+ wep := stack.LinkEndpoint(ep)
- id, linkEP := channel.New(256, mtu, "")
if testing.Verbose() {
- id = sniffer.New(id)
+ wep = sniffer.New(ep)
}
- if err := s.CreateNIC(1, id); err != nil {
+ if err := s.CreateNIC(1, wep); err != nil {
t.Fatalf("CreateNIC failed: %v", err)
}
@@ -306,7 +310,7 @@ func newDualTestContext(t *testing.T, mtu uint32) *testContext {
return &testContext{
t: t,
s: s,
- linkEP: linkEP,
+ linkEP: ep,
}
}
@@ -461,94 +465,70 @@ func (c *testContext) injectV4Packet(payload []byte, h *header4Tuple) {
}
func newPayload() []byte {
- b := make([]byte, 30+rand.Intn(100))
+ return newMinPayload(30)
+}
+
+func newMinPayload(minSize int) []byte {
+ b := make([]byte, minSize+rand.Intn(100))
for i := range b {
b[i] = byte(rand.Intn(256))
}
return b
}
-func TestBindPortReuse(t *testing.T) {
- c := newDualTestContext(t, defaultMTU)
- defer c.cleanup()
-
- c.createEndpoint(ipv6.ProtocolNumber)
-
- var eps [5]tcpip.Endpoint
- reusePortOpt := tcpip.ReusePortOption(1)
-
- pollChannel := make(chan tcpip.Endpoint)
- for i := 0; i < len(eps); i++ {
- // Try to receive the data.
- wq := waiter.Queue{}
- we, ch := waiter.NewChannelEntry(nil)
- wq.EventRegister(&we, waiter.EventIn)
- defer wq.EventUnregister(&we)
- defer close(ch)
+func TestBindToDeviceOption(t *testing.T) {
+ s := stack.New(stack.Options{
+ NetworkProtocols: []stack.NetworkProtocol{ipv4.NewProtocol()},
+ TransportProtocols: []stack.TransportProtocol{udp.NewProtocol()}})
- var err *tcpip.Error
- eps[i], err = c.s.NewEndpoint(udp.ProtocolNumber, ipv6.ProtocolNumber, &wq)
- if err != nil {
- c.t.Fatalf("NewEndpoint failed: %v", err)
- }
-
- go func(ep tcpip.Endpoint) {
- for range ch {
- pollChannel <- ep
- }
- }(eps[i])
-
- defer eps[i].Close()
- if err := eps[i].SetSockOpt(reusePortOpt); err != nil {
- c.t.Fatalf("SetSockOpt failed failed: %v", err)
- }
- if err := eps[i].Bind(tcpip.FullAddress{Addr: stackV6Addr, Port: stackPort}); err != nil {
- t.Fatalf("ep.Bind(...) failed: %v", err)
- }
+ ep, err := s.NewEndpoint(udp.ProtocolNumber, ipv4.ProtocolNumber, &waiter.Queue{})
+ if err != nil {
+ t.Fatalf("NewEndpoint failed; %v", err)
}
+ defer ep.Close()
- npackets := 100000
- nports := 10000
- ports := make(map[uint16]tcpip.Endpoint)
- stats := make(map[tcpip.Endpoint]int)
- for i := 0; i < npackets; i++ {
- // Send a packet.
- port := uint16(i % nports)
- payload := newPayload()
- c.injectV6Packet(payload, &header4Tuple{
- srcAddr: tcpip.FullAddress{Addr: testV6Addr, Port: testPort + port},
- dstAddr: tcpip.FullAddress{Addr: stackV6Addr, Port: stackPort},
- })
+ if err := s.CreateNamedNIC(321, "my_device", loopback.New()); err != nil {
+ t.Errorf("CreateNamedNIC failed: %v", err)
+ }
- var addr tcpip.FullAddress
- ep := <-pollChannel
- _, _, err := ep.Read(&addr)
- if err != nil {
- c.t.Fatalf("Read failed: %v", err)
- }
- stats[ep]++
- if i < nports {
- ports[uint16(i)] = ep
- } else {
- // Check that all packets from one client are handled
- // by the same socket.
- if ports[port] != ep {
- t.Fatalf("Port mismatch")
- }
- }
+ // Make an nameless NIC.
+ if err := s.CreateNIC(54321, loopback.New()); err != nil {
+ t.Errorf("CreateNIC failed: %v", err)
}
- if len(stats) != len(eps) {
- t.Fatalf("Only %d(expected %d) sockets received packets", len(stats), len(eps))
+ // strPtr is used instead of taking the address of string literals, which is
+ // a compiler error.
+ strPtr := func(s string) *string {
+ return &s
}
- // Check that a packet distribution is fair between sockets.
- for _, c := range stats {
- n := float64(npackets) / float64(len(eps))
- // The deviation is less than 10%.
- if math.Abs(float64(c)-n) > n/10 {
- t.Fatal(c, n)
- }
+ testActions := []struct {
+ name string
+ setBindToDevice *string
+ setBindToDeviceError *tcpip.Error
+ getBindToDevice tcpip.BindToDeviceOption
+ }{
+ {"GetDefaultValue", nil, nil, ""},
+ {"BindToNonExistent", strPtr("non_existent_device"), tcpip.ErrUnknownDevice, ""},
+ {"BindToExistent", strPtr("my_device"), nil, "my_device"},
+ {"UnbindToDevice", strPtr(""), nil, ""},
+ }
+ for _, testAction := range testActions {
+ t.Run(testAction.name, func(t *testing.T) {
+ if testAction.setBindToDevice != nil {
+ bindToDevice := tcpip.BindToDeviceOption(*testAction.setBindToDevice)
+ if got, want := ep.SetSockOpt(bindToDevice), testAction.setBindToDeviceError; got != want {
+ t.Errorf("SetSockOpt(%v) got %v, want %v", bindToDevice, got, want)
+ }
+ }
+ bindToDevice := tcpip.BindToDeviceOption("to be modified by GetSockOpt")
+ if ep.GetSockOpt(&bindToDevice) != nil {
+ t.Errorf("GetSockOpt got %v, want %v", ep.GetSockOpt(&bindToDevice), nil)
+ }
+ if got, want := bindToDevice, testAction.getBindToDevice; got != want {
+ t.Errorf("bindToDevice got %q, want %q", got, want)
+ }
+ })
}
}
@@ -1238,3 +1218,153 @@ func TestMulticastInterfaceOption(t *testing.T) {
})
}
}
+
+// TestV4UnknownDestination verifies that we generate an ICMPv4 Destination
+// Unreachable message when a udp datagram is received on ports for which there
+// is no bound udp socket.
+func TestV4UnknownDestination(t *testing.T) {
+ c := newDualTestContext(t, defaultMTU)
+ defer c.cleanup()
+
+ testCases := []struct {
+ flow testFlow
+ icmpRequired bool
+ // largePayload if true, will result in a payload large enough
+ // so that the final generated IPv4 packet is larger than
+ // header.IPv4MinimumProcessableDatagramSize.
+ largePayload bool
+ }{
+ {unicastV4, true, false},
+ {unicastV4, true, true},
+ {multicastV4, false, false},
+ {multicastV4, false, true},
+ {broadcast, false, false},
+ {broadcast, false, true},
+ }
+ for _, tc := range testCases {
+ t.Run(fmt.Sprintf("flow:%s icmpRequired:%t largePayload:%t", tc.flow, tc.icmpRequired, tc.largePayload), func(t *testing.T) {
+ payload := newPayload()
+ if tc.largePayload {
+ payload = newMinPayload(576)
+ }
+ c.injectPacket(tc.flow, payload)
+ if !tc.icmpRequired {
+ select {
+ case p := <-c.linkEP.C:
+ t.Fatalf("unexpected packet received: %+v", p)
+ case <-time.After(1 * time.Second):
+ return
+ }
+ }
+
+ select {
+ case p := <-c.linkEP.C:
+ var pkt []byte
+ pkt = append(pkt, p.Header...)
+ pkt = append(pkt, p.Payload...)
+ if got, want := len(pkt), header.IPv4MinimumProcessableDatagramSize; got > want {
+ t.Fatalf("got an ICMP packet of size: %d, want: sz <= %d", got, want)
+ }
+
+ hdr := header.IPv4(pkt)
+ checker.IPv4(t, hdr, checker.ICMPv4(
+ checker.ICMPv4Type(header.ICMPv4DstUnreachable),
+ checker.ICMPv4Code(header.ICMPv4PortUnreachable)))
+
+ icmpPkt := header.ICMPv4(hdr.Payload())
+ payloadIPHeader := header.IPv4(icmpPkt.Payload())
+ wantLen := len(payload)
+ if tc.largePayload {
+ wantLen = header.IPv4MinimumProcessableDatagramSize - header.IPv4MinimumSize*2 - header.ICMPv4MinimumSize - header.UDPMinimumSize
+ }
+
+ // In case of large payloads the IP packet may be truncated. Update
+ // the length field before retrieving the udp datagram payload.
+ payloadIPHeader.SetTotalLength(uint16(wantLen + header.UDPMinimumSize + header.IPv4MinimumSize))
+
+ origDgram := header.UDP(payloadIPHeader.Payload())
+ if got, want := len(origDgram.Payload()), wantLen; got != want {
+ t.Fatalf("unexpected payload length got: %d, want: %d", got, want)
+ }
+ if got, want := origDgram.Payload(), payload[:wantLen]; !bytes.Equal(got, want) {
+ t.Fatalf("unexpected payload got: %d, want: %d", got, want)
+ }
+ case <-time.After(1 * time.Second):
+ t.Fatalf("packet wasn't written out")
+ }
+ })
+ }
+}
+
+// TestV6UnknownDestination verifies that we generate an ICMPv6 Destination
+// Unreachable message when a udp datagram is received on ports for which there
+// is no bound udp socket.
+func TestV6UnknownDestination(t *testing.T) {
+ c := newDualTestContext(t, defaultMTU)
+ defer c.cleanup()
+
+ testCases := []struct {
+ flow testFlow
+ icmpRequired bool
+ // largePayload if true will result in a payload large enough to
+ // create an IPv6 packet > header.IPv6MinimumMTU bytes.
+ largePayload bool
+ }{
+ {unicastV6, true, false},
+ {unicastV6, true, true},
+ {multicastV6, false, false},
+ {multicastV6, false, true},
+ }
+ for _, tc := range testCases {
+ t.Run(fmt.Sprintf("flow:%s icmpRequired:%t largePayload:%t", tc.flow, tc.icmpRequired, tc.largePayload), func(t *testing.T) {
+ payload := newPayload()
+ if tc.largePayload {
+ payload = newMinPayload(1280)
+ }
+ c.injectPacket(tc.flow, payload)
+ if !tc.icmpRequired {
+ select {
+ case p := <-c.linkEP.C:
+ t.Fatalf("unexpected packet received: %+v", p)
+ case <-time.After(1 * time.Second):
+ return
+ }
+ }
+
+ select {
+ case p := <-c.linkEP.C:
+ var pkt []byte
+ pkt = append(pkt, p.Header...)
+ pkt = append(pkt, p.Payload...)
+ if got, want := len(pkt), header.IPv6MinimumMTU; got > want {
+ t.Fatalf("got an ICMP packet of size: %d, want: sz <= %d", got, want)
+ }
+
+ hdr := header.IPv6(pkt)
+ checker.IPv6(t, hdr, checker.ICMPv6(
+ checker.ICMPv6Type(header.ICMPv6DstUnreachable),
+ checker.ICMPv6Code(header.ICMPv6PortUnreachable)))
+
+ icmpPkt := header.ICMPv6(hdr.Payload())
+ payloadIPHeader := header.IPv6(icmpPkt.Payload())
+ wantLen := len(payload)
+ if tc.largePayload {
+ wantLen = header.IPv6MinimumMTU - header.IPv6MinimumSize*2 - header.ICMPv6MinimumSize - header.UDPMinimumSize
+ }
+ // In case of large payloads the IP packet may be truncated. Update
+ // the length field before retrieving the udp datagram payload.
+ payloadIPHeader.SetPayloadLength(uint16(wantLen + header.UDPMinimumSize))
+
+ origDgram := header.UDP(payloadIPHeader.Payload())
+ if got, want := len(origDgram.Payload()), wantLen; got != want {
+ t.Fatalf("unexpected payload length got: %d, want: %d", got, want)
+ }
+ if got, want := origDgram.Payload(), payload[:wantLen]; !bytes.Equal(got, want) {
+ t.Fatalf("unexpected payload got: %v, want: %v", got, want)
+ }
+ case <-time.After(1 * time.Second):
+ t.Fatalf("packet wasn't written out")
+ }
+ })
+ }
+}
diff --git a/pkg/tmutex/BUILD b/pkg/tmutex/BUILD
index 98d51cc69..6afdb29b7 100644
--- a/pkg/tmutex/BUILD
+++ b/pkg/tmutex/BUILD
@@ -1,4 +1,5 @@
-load("//tools/go_stateify:defs.bzl", "go_library", "go_test")
+load("//tools/go_stateify:defs.bzl", "go_library")
+load("@io_bazel_rules_go//go:def.bzl", "go_test")
package(licenses = ["notice"])
diff --git a/pkg/unet/BUILD b/pkg/unet/BUILD
index cbd92fc05..8f6f180e5 100644
--- a/pkg/unet/BUILD
+++ b/pkg/unet/BUILD
@@ -1,4 +1,5 @@
-load("//tools/go_stateify:defs.bzl", "go_library", "go_test")
+load("//tools/go_stateify:defs.bzl", "go_library")
+load("@io_bazel_rules_go//go:def.bzl", "go_test")
package(licenses = ["notice"])
diff --git a/pkg/urpc/BUILD b/pkg/urpc/BUILD
index b7f505a84..b6bbb0ea2 100644
--- a/pkg/urpc/BUILD
+++ b/pkg/urpc/BUILD
@@ -1,4 +1,5 @@
-load("//tools/go_stateify:defs.bzl", "go_library", "go_test")
+load("//tools/go_stateify:defs.bzl", "go_library")
+load("@io_bazel_rules_go//go:def.bzl", "go_test")
package(licenses = ["notice"])
diff --git a/pkg/waiter/BUILD b/pkg/waiter/BUILD
index 9173dfd0f..8dc88becb 100644
--- a/pkg/waiter/BUILD
+++ b/pkg/waiter/BUILD
@@ -1,7 +1,9 @@
+load("@io_bazel_rules_go//go:def.bzl", "go_test")
+
package(licenses = ["notice"])
load("//tools/go_generics:defs.bzl", "go_template_instance")
-load("//tools/go_stateify:defs.bzl", "go_library", "go_test")
+load("//tools/go_stateify:defs.bzl", "go_library")
go_template_instance(
name = "waiter_list",
diff --git a/runsc/BUILD b/runsc/BUILD
index 6b8c92706..e4e8e64a3 100644
--- a/runsc/BUILD
+++ b/runsc/BUILD
@@ -1,7 +1,7 @@
package(licenses = ["notice"]) # Apache 2.0
load("@io_bazel_rules_go//go:def.bzl", "go_binary")
-load("@bazel_tools//tools/build_defs/pkg:pkg.bzl", "pkg_deb", "pkg_tar")
+load("@rules_pkg//:pkg.bzl", "pkg_deb", "pkg_tar")
go_binary(
name = "runsc",
@@ -13,9 +13,10 @@ go_binary(
visibility = [
"//visibility:public",
],
- x_defs = {"main.version": "{VERSION}"},
+ x_defs = {"main.version": "{STABLE_VERSION}"},
deps = [
"//pkg/log",
+ "//pkg/refs",
"//pkg/sentry/platform",
"//runsc/boot",
"//runsc/cmd",
@@ -45,9 +46,10 @@ go_binary(
visibility = [
"//visibility:public",
],
- x_defs = {"main.version": "{VERSION}"},
+ x_defs = {"main.version": "{STABLE_VERSION}"},
deps = [
"//pkg/log",
+ "//pkg/refs",
"//pkg/sentry/platform",
"//runsc/boot",
"//runsc/cmd",
@@ -65,19 +67,10 @@ pkg_tar(
)
pkg_tar(
- name = "runsc-tools",
- srcs = ["//runsc/tools/dockercfg"],
- mode = "0755",
- package_dir = "/usr/libexec/runsc",
- strip_prefix = "/runsc/tools/dockercfg/linux_amd64_stripped",
-)
-
-pkg_tar(
name = "debian-data",
extension = "tar.gz",
deps = [
":runsc-bin",
- ":runsc-tools",
],
)
@@ -98,13 +91,15 @@ pkg_deb(
maintainer = "The gVisor Authors <gvisor-dev@googlegroups.com>",
package = "runsc",
postinst = "debian/postinst.sh",
- tags = [
- # TODO(b/135475885): pkg_deb requires python2:
- # https://github.com/bazelbuild/bazel/issues/8443
- "manual",
- ],
version_file = ":version.txt",
visibility = [
"//visibility:public",
],
)
+
+sh_test(
+ name = "version_test",
+ size = "small",
+ srcs = ["version_test.sh"],
+ data = [":runsc"],
+)
diff --git a/runsc/boot/BUILD b/runsc/boot/BUILD
index 588bb8851..d90381c0f 100644
--- a/runsc/boot/BUILD
+++ b/runsc/boot/BUILD
@@ -80,6 +80,7 @@ go_library(
"//pkg/tcpip/network/ipv6",
"//pkg/tcpip/stack",
"//pkg/tcpip/transport/icmp",
+ "//pkg/tcpip/transport/raw",
"//pkg/tcpip/transport/tcp",
"//pkg/tcpip/transport/udp",
"//pkg/urpc",
@@ -109,6 +110,7 @@ go_test(
"//pkg/sentry/arch:registers_go_proto",
"//pkg/sentry/context/contexttest",
"//pkg/sentry/fs",
+ "//pkg/sentry/kernel/auth",
"//pkg/unet",
"//runsc/fsgofer",
"@com_github_opencontainers_runtime-spec//specs-go:go_default_library",
diff --git a/runsc/boot/config.go b/runsc/boot/config.go
index 7ae0dd05d..38278d0a2 100644
--- a/runsc/boot/config.go
+++ b/runsc/boot/config.go
@@ -19,6 +19,7 @@ import (
"strconv"
"strings"
+ "gvisor.dev/gvisor/pkg/refs"
"gvisor.dev/gvisor/pkg/sentry/watchdog"
)
@@ -112,6 +113,34 @@ func MakeWatchdogAction(s string) (watchdog.Action, error) {
}
}
+// MakeRefsLeakMode converts type from string.
+func MakeRefsLeakMode(s string) (refs.LeakMode, error) {
+ switch strings.ToLower(s) {
+ case "disabled":
+ return refs.NoLeakChecking, nil
+ case "log-names":
+ return refs.LeaksLogWarning, nil
+ case "log-traces":
+ return refs.LeaksLogTraces, nil
+ default:
+ return 0, fmt.Errorf("invalid refs leakmode %q", s)
+ }
+}
+
+func refsLeakModeToString(mode refs.LeakMode) string {
+ switch mode {
+ // If not set, default it to disabled.
+ case refs.UninitializedLeakChecking, refs.NoLeakChecking:
+ return "disabled"
+ case refs.LeaksLogWarning:
+ return "log-names"
+ case refs.LeaksLogTraces:
+ return "log-traces"
+ default:
+ panic(fmt.Sprintf("Invalid leakmode: %d", mode))
+ }
+}
+
// Config holds configuration that is not part of the runtime spec.
type Config struct {
// RootDir is the runtime root directory.
@@ -138,6 +167,9 @@ type Config struct {
// Overlay is whether to wrap the root filesystem in an overlay.
Overlay bool
+ // FSGoferHostUDS enables the gofer to mount a host UDS.
+ FSGoferHostUDS bool
+
// Network indicates what type of network to use.
Network NetworkType
@@ -182,12 +214,6 @@ type Config struct {
// RestoreFile is the path to the saved container image
RestoreFile string
- // TestOnlyAllowRunAsCurrentUserWithoutChroot should only be used in
- // tests. It allows runsc to start the sandbox process as the current
- // user, and without chrooting the sandbox process. This can be
- // necessary in test environments that have limited capabilities.
- TestOnlyAllowRunAsCurrentUserWithoutChroot bool
-
// NumNetworkChannels controls the number of AF_PACKET sockets that map
// to the same underlying network device. This allows netstack to better
// scale for high throughput use cases.
@@ -201,6 +227,22 @@ type Config struct {
// AlsoLogToStderr allows to send log messages to stderr.
AlsoLogToStderr bool
+
+ // ReferenceLeakMode sets reference leak check mode
+ ReferenceLeakMode refs.LeakMode
+
+ // TestOnlyAllowRunAsCurrentUserWithoutChroot should only be used in
+ // tests. It allows runsc to start the sandbox process as the current
+ // user, and without chrooting the sandbox process. This can be
+ // necessary in test environments that have limited capabilities.
+ TestOnlyAllowRunAsCurrentUserWithoutChroot bool
+
+ // TestOnlyTestNameEnv should only be used in tests. It looks up for the
+ // test name in the container environment variables and adds it to the debug
+ // log file name. This is done to help identify the log with the test when
+ // multiple tests are run in parallel, since there is no way to pass
+ // parameters to the runtime from docker.
+ TestOnlyTestNameEnv string
}
// ToFlags returns a slice of flags that correspond to the given Config.
@@ -214,6 +256,7 @@ func (c *Config) ToFlags() []string {
"--debug-log-format=" + c.DebugLogFormat,
"--file-access=" + c.FileAccess.String(),
"--overlay=" + strconv.FormatBool(c.Overlay),
+ "--fsgofer-host-uds=" + strconv.FormatBool(c.FSGoferHostUDS),
"--network=" + c.Network.String(),
"--log-packets=" + strconv.FormatBool(c.LogPackets),
"--platform=" + c.Platform,
@@ -227,10 +270,14 @@ func (c *Config) ToFlags() []string {
"--num-network-channels=" + strconv.Itoa(c.NumNetworkChannels),
"--rootless=" + strconv.FormatBool(c.Rootless),
"--alsologtostderr=" + strconv.FormatBool(c.AlsoLogToStderr),
+ "--ref-leak-mode=" + refsLeakModeToString(c.ReferenceLeakMode),
}
+ // Only include these if set since it is never to be used by users.
if c.TestOnlyAllowRunAsCurrentUserWithoutChroot {
- // Only include if set since it is never to be used by users.
- f = append(f, "-TESTONLY-unsafe-nonroot=true")
+ f = append(f, "--TESTONLY-unsafe-nonroot=true")
+ }
+ if len(c.TestOnlyTestNameEnv) != 0 {
+ f = append(f, "--TESTONLY-test-name-env="+c.TestOnlyTestNameEnv)
}
return f
}
diff --git a/runsc/boot/filter/config.go b/runsc/boot/filter/config.go
index 7ca776b3a..a2ecc6bcb 100644
--- a/runsc/boot/filter/config.go
+++ b/runsc/boot/filter/config.go
@@ -88,14 +88,24 @@ var allowedSyscalls = seccomp.SyscallRules{
seccomp.AllowValue(linux.FUTEX_WAIT | linux.FUTEX_PRIVATE_FLAG),
seccomp.AllowAny{},
seccomp.AllowAny{},
- seccomp.AllowValue(0),
},
{
seccomp.AllowAny{},
seccomp.AllowValue(linux.FUTEX_WAKE | linux.FUTEX_PRIVATE_FLAG),
seccomp.AllowAny{},
+ },
+ // Non-private variants are included for flipcall support. They are otherwise
+ // unncessary, as the sentry will use only private futexes internally.
+ {
+ seccomp.AllowAny{},
+ seccomp.AllowValue(linux.FUTEX_WAIT),
+ seccomp.AllowAny{},
+ seccomp.AllowAny{},
+ },
+ {
+ seccomp.AllowAny{},
+ seccomp.AllowValue(linux.FUTEX_WAKE),
seccomp.AllowAny{},
- seccomp.AllowValue(0),
},
},
syscall.SYS_GETPID: {},
diff --git a/runsc/boot/fs.go b/runsc/boot/fs.go
index b6eeacf98..34c674840 100644
--- a/runsc/boot/fs.go
+++ b/runsc/boot/fs.go
@@ -25,19 +25,21 @@ import (
// Include filesystem types that OCI spec might mount.
_ "gvisor.dev/gvisor/pkg/sentry/fs/dev"
- "gvisor.dev/gvisor/pkg/sentry/fs/gofer"
_ "gvisor.dev/gvisor/pkg/sentry/fs/host"
_ "gvisor.dev/gvisor/pkg/sentry/fs/proc"
- "gvisor.dev/gvisor/pkg/sentry/fs/ramfs"
_ "gvisor.dev/gvisor/pkg/sentry/fs/sys"
_ "gvisor.dev/gvisor/pkg/sentry/fs/tmpfs"
_ "gvisor.dev/gvisor/pkg/sentry/fs/tty"
specs "github.com/opencontainers/runtime-spec/specs-go"
+ "gvisor.dev/gvisor/pkg/abi/linux"
"gvisor.dev/gvisor/pkg/log"
"gvisor.dev/gvisor/pkg/sentry/context"
"gvisor.dev/gvisor/pkg/sentry/fs"
+ "gvisor.dev/gvisor/pkg/sentry/fs/gofer"
+ "gvisor.dev/gvisor/pkg/sentry/fs/ramfs"
"gvisor.dev/gvisor/pkg/sentry/kernel"
+ "gvisor.dev/gvisor/pkg/sentry/kernel/auth"
"gvisor.dev/gvisor/pkg/syserror"
"gvisor.dev/gvisor/runsc/specutils"
)
@@ -261,6 +263,18 @@ func subtargets(root string, mnts []specs.Mount) []string {
return targets
}
+func setupContainerFS(ctx context.Context, conf *Config, mntr *containerMounter, procArgs *kernel.CreateProcessArgs) error {
+ mns, err := mntr.setupFS(conf, procArgs)
+ if err != nil {
+ return err
+ }
+
+ // Set namespace here so that it can be found in ctx.
+ procArgs.MountNamespace = mns
+
+ return setExecutablePath(ctx, procArgs)
+}
+
// setExecutablePath sets the procArgs.Filename by searching the PATH for an
// executable matching the procArgs.Argv[0].
func setExecutablePath(ctx context.Context, procArgs *kernel.CreateProcessArgs) error {
@@ -500,73 +514,95 @@ func newContainerMounter(spec *specs.Spec, goferFDs []int, k *kernel.Kernel, hin
}
}
-// setupChildContainer is used to set up the file system for non-root containers
-// and amend the procArgs accordingly. This is the main entry point for this
-// rest of functions in this file. procArgs are passed by reference and the
-// FDMap field is modified. It dups stdioFDs.
-func (c *containerMounter) setupChildContainer(conf *Config, procArgs *kernel.CreateProcessArgs) error {
- // Setup a child container.
- log.Infof("Creating new process in child container.")
-
- // Create a new root inode and mount namespace for the container.
- rootCtx := c.k.SupervisorContext()
- rootInode, err := c.createRootMount(rootCtx, conf)
- if err != nil {
- return fmt.Errorf("creating filesystem for container: %v", err)
+// processHints processes annotations that container hints about how volumes
+// should be mounted (e.g. a volume shared between containers). It must be
+// called for the root container only.
+func (c *containerMounter) processHints(conf *Config) error {
+ ctx := c.k.SupervisorContext()
+ for _, hint := range c.hints.mounts {
+ log.Infof("Mounting master of shared mount %q from %q type %q", hint.name, hint.mount.Source, hint.mount.Type)
+ inode, err := c.mountSharedMaster(ctx, conf, hint)
+ if err != nil {
+ return fmt.Errorf("mounting shared master %q: %v", hint.name, err)
+ }
+ hint.root = inode
}
- mns, err := fs.NewMountNamespace(rootCtx, rootInode)
+ return nil
+}
+
+// setupFS is used to set up the file system for all containers. This is the
+// main entry point method, with most of the other being internal only. It
+// returns the mount namespace that is created for the container.
+func (c *containerMounter) setupFS(conf *Config, procArgs *kernel.CreateProcessArgs) (*fs.MountNamespace, error) {
+ log.Infof("Configuring container's file system")
+
+ // Create context with root credentials to mount the filesystem (the current
+ // user may not be privileged enough).
+ rootProcArgs := *procArgs
+ rootProcArgs.WorkingDirectory = "/"
+ rootProcArgs.Credentials = auth.NewRootCredentials(procArgs.Credentials.UserNamespace)
+ rootProcArgs.Umask = 0022
+ rootProcArgs.MaxSymlinkTraversals = linux.MaxSymlinkTraversals
+ rootCtx := rootProcArgs.NewContext(c.k)
+
+ mns, err := c.createMountNamespace(rootCtx, conf)
if err != nil {
- return fmt.Errorf("creating new mount namespace for container: %v", err)
+ return nil, err
}
- procArgs.MountNamespace = mns
- root := mns.Root()
- defer root.DecRef()
- // Mount all submounts.
- if err := c.mountSubmounts(rootCtx, conf, mns, root); err != nil {
- return err
+ // Set namespace here so that it can be found in rootCtx.
+ rootProcArgs.MountNamespace = mns
+
+ if err := c.mountSubmounts(rootCtx, conf, mns); err != nil {
+ return nil, err
}
- return c.checkDispenser()
+ return mns, nil
}
-func (c *containerMounter) checkDispenser() error {
- if !c.fds.empty() {
- return fmt.Errorf("not all gofer FDs were consumed, remaining: %v", c.fds)
+func (c *containerMounter) createMountNamespace(ctx context.Context, conf *Config) (*fs.MountNamespace, error) {
+ rootInode, err := c.createRootMount(ctx, conf)
+ if err != nil {
+ return nil, fmt.Errorf("creating filesystem for container: %v", err)
}
- return nil
+ mns, err := fs.NewMountNamespace(ctx, rootInode)
+ if err != nil {
+ return nil, fmt.Errorf("creating new mount namespace for container: %v", err)
+ }
+ return mns, nil
}
-// setupRootContainer creates a mount namespace containing the root filesystem
-// and all mounts. 'rootCtx' is used to walk directories to find mount points.
-// The 'setMountNS' callback is called after the mount namespace is created and
-// will get a reference on that namespace. The callback must ensure that the
-// rootCtx has the provided mount namespace.
-func (c *containerMounter) setupRootContainer(userCtx context.Context, rootCtx context.Context, conf *Config, setMountNS func(*fs.MountNamespace)) error {
- for _, hint := range c.hints.mounts {
- log.Infof("Mounting master of shared mount %q from %q type %q", hint.name, hint.mount.Source, hint.mount.Type)
- inode, err := c.mountSharedMaster(rootCtx, conf, hint)
- if err != nil {
- return fmt.Errorf("mounting shared master %q: %v", hint.name, err)
+func (c *containerMounter) mountSubmounts(ctx context.Context, conf *Config, mns *fs.MountNamespace) error {
+ root := mns.Root()
+ defer root.DecRef()
+
+ for _, m := range c.mounts {
+ log.Debugf("Mounting %q to %q, type: %s, options: %s", m.Source, m.Destination, m.Type, m.Options)
+ if hint := c.hints.findMount(m); hint != nil && hint.isSupported() {
+ if err := c.mountSharedSubmount(ctx, mns, root, m, hint); err != nil {
+ return fmt.Errorf("mount shared mount %q to %q: %v", hint.name, m.Destination, err)
+ }
+ } else {
+ if err := c.mountSubmount(ctx, conf, mns, root, m); err != nil {
+ return fmt.Errorf("mount submount %q: %v", m.Destination, err)
+ }
}
- hint.root = inode
}
- rootInode, err := c.createRootMount(rootCtx, conf)
- if err != nil {
- return fmt.Errorf("creating root mount: %v", err)
+ if err := c.mountTmp(ctx, conf, mns, root); err != nil {
+ return fmt.Errorf("mount submount %q: %v", "tmp", err)
}
- mns, err := fs.NewMountNamespace(userCtx, rootInode)
- if err != nil {
- return fmt.Errorf("creating root mount namespace: %v", err)
+
+ if err := c.checkDispenser(); err != nil {
+ return err
}
- setMountNS(mns)
+ return nil
+}
- root := mns.Root()
- defer root.DecRef()
- if err := c.mountSubmounts(rootCtx, conf, mns, root); err != nil {
- return fmt.Errorf("mounting submounts: %v", err)
+func (c *containerMounter) checkDispenser() error {
+ if !c.fds.empty() {
+ return fmt.Errorf("not all gofer FDs were consumed, remaining: %v", c.fds)
}
- return c.checkDispenser()
+ return nil
}
// mountSharedMaster mounts the master of a volume that is shared among
@@ -684,25 +720,6 @@ func (c *containerMounter) getMountNameAndOptions(conf *Config, m specs.Mount) (
return fsName, opts, useOverlay, err
}
-func (c *containerMounter) mountSubmounts(ctx context.Context, conf *Config, mns *fs.MountNamespace, root *fs.Dirent) error {
- for _, m := range c.mounts {
- if hint := c.hints.findMount(m); hint != nil && hint.isSupported() {
- if err := c.mountSharedSubmount(ctx, mns, root, m, hint); err != nil {
- return fmt.Errorf("mount shared mount %q to %q: %v", hint.name, m.Destination, err)
- }
- } else {
- if err := c.mountSubmount(ctx, conf, mns, root, m); err != nil {
- return fmt.Errorf("mount submount %q: %v", m.Destination, err)
- }
- }
- }
-
- if err := c.mountTmp(ctx, conf, mns, root); err != nil {
- return fmt.Errorf("mount submount %q: %v", "tmp", err)
- }
- return nil
-}
-
// mountSubmount mounts volumes inside the container's root. Because mounts may
// be readonly, a lower ramfs overlay is added to create the mount point dir.
// Another overlay is added with tmpfs on top if Config.Overlay is true.
diff --git a/runsc/boot/loader.go b/runsc/boot/loader.go
index f91158027..adf345490 100644
--- a/runsc/boot/loader.go
+++ b/runsc/boot/loader.go
@@ -20,7 +20,6 @@ import (
mrand "math/rand"
"os"
"runtime"
- "strings"
"sync"
"sync/atomic"
"syscall"
@@ -33,7 +32,6 @@ import (
"gvisor.dev/gvisor/pkg/log"
"gvisor.dev/gvisor/pkg/memutil"
"gvisor.dev/gvisor/pkg/rand"
- "gvisor.dev/gvisor/pkg/refs"
"gvisor.dev/gvisor/pkg/sentry/arch"
"gvisor.dev/gvisor/pkg/sentry/control"
"gvisor.dev/gvisor/pkg/sentry/fs"
@@ -56,6 +54,7 @@ import (
"gvisor.dev/gvisor/pkg/tcpip/network/ipv6"
"gvisor.dev/gvisor/pkg/tcpip/stack"
"gvisor.dev/gvisor/pkg/tcpip/transport/icmp"
+ "gvisor.dev/gvisor/pkg/tcpip/transport/raw"
"gvisor.dev/gvisor/pkg/tcpip/transport/tcp"
"gvisor.dev/gvisor/pkg/tcpip/transport/udp"
"gvisor.dev/gvisor/runsc/boot/filter"
@@ -527,34 +526,21 @@ func (l *Loader) run() error {
// Setup the root container file system.
l.startGoferMonitor(l.sandboxID, l.goferFDs)
+
mntr := newContainerMounter(l.spec, l.goferFDs, l.k, l.mountHints)
- if err := mntr.setupRootContainer(ctx, ctx, l.conf, func(mns *fs.MountNamespace) {
- l.rootProcArgs.MountNamespace = mns
- }); err != nil {
+ if err := mntr.processHints(l.conf); err != nil {
return err
}
-
- if err := setExecutablePath(ctx, &l.rootProcArgs); err != nil {
+ if err := setupContainerFS(ctx, l.conf, mntr, &l.rootProcArgs); err != nil {
return err
}
- // Read /etc/passwd for the user's HOME directory and set the HOME
- // environment variable as required by POSIX if it is not overridden by
- // the user.
- hasHomeEnvv := false
- for _, envv := range l.rootProcArgs.Envv {
- if strings.HasPrefix(envv, "HOME=") {
- hasHomeEnvv = true
- }
- }
- if !hasHomeEnvv {
- homeDir, err := getExecUserHome(ctx, l.rootProcArgs.MountNamespace, uint32(l.rootProcArgs.Credentials.RealKUID))
- if err != nil {
- return fmt.Errorf("error reading exec user: %v", err)
- }
-
- l.rootProcArgs.Envv = append(l.rootProcArgs.Envv, "HOME="+homeDir)
+ // Add the HOME enviroment variable if it is not already set.
+ envv, err := maybeAddExecUserHome(ctx, l.rootProcArgs.MountNamespace, l.rootProcArgs.Credentials.RealKUID, l.rootProcArgs.Envv)
+ if err != nil {
+ return err
}
+ l.rootProcArgs.Envv = envv
// Create the root container init task. It will begin running
// when the kernel is started.
@@ -687,13 +673,10 @@ func (l *Loader) startContainer(spec *specs.Spec, conf *Config, cid string, file
// Setup the child container file system.
l.startGoferMonitor(cid, goferFDs)
- mntr := newContainerMounter(spec, goferFDs, l.k, l.mountHints)
- if err := mntr.setupChildContainer(conf, &procArgs); err != nil {
- return fmt.Errorf("configuring container FS: %v", err)
- }
- if err := setExecutablePath(ctx, &procArgs); err != nil {
- return fmt.Errorf("setting executable path for %+v: %v", procArgs, err)
+ mntr := newContainerMounter(spec, goferFDs, l.k, l.mountHints)
+ if err := setupContainerFS(ctx, conf, mntr, &procArgs); err != nil {
+ return err
}
// Create and start the new process.
@@ -766,26 +749,34 @@ func (l *Loader) destroyContainer(cid string) error {
if err := l.signalAllProcesses(cid, int32(linux.SIGKILL)); err != nil {
return fmt.Errorf("sending SIGKILL to all container processes: %v", err)
}
+ // Wait for all processes that belong to the container to exit (including
+ // exec'd processes).
+ for _, t := range l.k.TaskSet().Root.Tasks() {
+ if t.ContainerID() == cid {
+ t.ThreadGroup().WaitExited()
+ }
+ }
+
+ // At this point, all processes inside of the container have exited,
+ // releasing all references to the container's MountNamespace and
+ // causing all submounts and overlays to be unmounted.
+ //
+ // Since the container's MountNamespace has been released,
+ // MountNamespace.destroy() will have executed, but that function may
+ // trigger async close operations. We must wait for those to complete
+ // before returning, otherwise the caller may kill the gofer before
+ // they complete, causing a cascade of failing RPCs.
+ fs.AsyncBarrier()
}
- // Remove all container thread groups from the map.
+ // No more failure from this point on. Remove all container thread groups
+ // from the map.
for key := range l.processes {
if key.cid == cid {
delete(l.processes, key)
}
}
- // At this point, all processes inside of the container have exited,
- // releasing all references to the container's MountNamespace and
- // causing all submounts and overlays to be unmounted.
- //
- // Since the container's MountNamespace has been released,
- // MountNamespace.destroy() will have executed, but that function may
- // trigger async close operations. We must wait for those to complete
- // before returning, otherwise the caller may kill the gofer before
- // they complete, causing a cascade of failing RPCs.
- fs.AsyncBarrier()
-
log.Debugf("Container destroyed %q", cid)
return nil
}
@@ -813,6 +804,16 @@ func (l *Loader) executeAsync(args *control.ExecArgs) (kernel.ThreadID, error) {
})
defer args.MountNamespace.DecRef()
+ // Add the HOME enviroment varible if it is not already set.
+ root := args.MountNamespace.Root()
+ defer root.DecRef()
+ ctx := fs.WithRoot(l.k.SupervisorContext(), root)
+ envv, err := maybeAddExecUserHome(ctx, args.MountNamespace, args.KUID, args.Envv)
+ if err != nil {
+ return 0, err
+ }
+ args.Envv = envv
+
// Start the process.
proc := control.Proc{Kernel: l.k}
args.PIDNamespace = tg.PIDNamespace()
@@ -911,15 +912,17 @@ func newEmptyNetworkStack(conf *Config, clock tcpip.Clock) (inet.Stack, error) {
case NetworkNone, NetworkSandbox:
// NetworkNone sets up loopback using netstack.
- netProtos := []string{ipv4.ProtocolName, ipv6.ProtocolName, arp.ProtocolName}
- protoNames := []string{tcp.ProtocolName, udp.ProtocolName, icmp.ProtocolName4}
- s := epsocket.Stack{stack.New(netProtos, protoNames, stack.Options{
- Clock: clock,
- Stats: epsocket.Metrics,
- HandleLocal: true,
+ netProtos := []stack.NetworkProtocol{ipv4.NewProtocol(), ipv6.NewProtocol(), arp.NewProtocol()}
+ transProtos := []stack.TransportProtocol{tcp.NewProtocol(), udp.NewProtocol(), icmp.NewProtocol4()}
+ s := epsocket.Stack{stack.New(stack.Options{
+ NetworkProtocols: netProtos,
+ TransportProtocols: transProtos,
+ Clock: clock,
+ Stats: epsocket.Metrics,
+ HandleLocal: true,
// Enable raw sockets for users with sufficient
// privileges.
- Raw: true,
+ UnassociatedFactory: raw.EndpointFactory{},
})}
// Enable SACK Recovery.
@@ -1043,21 +1046,8 @@ func (l *Loader) signalAllProcesses(cid string, signo int32) error {
// the signal is delivered. This prevents process leaks when SIGKILL is
// sent to the entire container.
l.k.Pause()
- if err := l.k.SendContainerSignal(cid, &arch.SignalInfo{Signo: signo}); err != nil {
- l.k.Unpause()
- return err
- }
- l.k.Unpause()
-
- // If SIGKILLing all processes, wait for them to exit.
- if linux.Signal(signo) == linux.SIGKILL {
- for _, t := range l.k.TaskSet().Root.Tasks() {
- if t.ContainerID() == cid {
- t.ThreadGroup().WaitExited()
- }
- }
- }
- return nil
+ defer l.k.Unpause()
+ return l.k.SendContainerSignal(cid, &arch.SignalInfo{Signo: signo})
}
// threadGroupFromID same as threadGroupFromIDLocked except that it acquires
@@ -1090,8 +1080,3 @@ func (l *Loader) threadGroupFromIDLocked(key execID) (*kernel.ThreadGroup, *host
}
return ep.tg, ep.tty, true, nil
}
-
-func init() {
- // TODO(gvisor.dev/issue/365): Make this configurable.
- refs.SetLeakMode(refs.NoLeakChecking)
-}
diff --git a/runsc/boot/loader_test.go b/runsc/boot/loader_test.go
index e0e32b9d5..147ff7703 100644
--- a/runsc/boot/loader_test.go
+++ b/runsc/boot/loader_test.go
@@ -401,17 +401,16 @@ func TestCreateMountNamespace(t *testing.T) {
}
defer cleanup()
- // setupRootContainer needs to find root from the context after the
- // namespace is created.
- var mns *fs.MountNamespace
- setMountNS := func(m *fs.MountNamespace) {
- mns = m
- ctx.(*contexttest.TestContext).RegisterValue(fs.CtxRoot, mns.Root())
- }
mntr := newContainerMounter(&tc.spec, []int{sandEnd}, nil, &podMountHints{})
- if err := mntr.setupRootContainer(ctx, ctx, conf, setMountNS); err != nil {
- t.Fatalf("createMountNamespace test case %q failed: %v", tc.name, err)
+ mns, err := mntr.createMountNamespace(ctx, conf)
+ if err != nil {
+ t.Fatalf("failed to create mount namespace: %v", err)
}
+ ctx = fs.WithRoot(ctx, mns.Root())
+ if err := mntr.mountSubmounts(ctx, conf, mns); err != nil {
+ t.Fatalf("failed to create mount namespace: %v", err)
+ }
+
root := mns.Root()
defer root.DecRef()
for _, p := range tc.expectedPaths {
diff --git a/runsc/boot/network.go b/runsc/boot/network.go
index ea0d9f790..32cba5ac1 100644
--- a/runsc/boot/network.go
+++ b/runsc/boot/network.go
@@ -121,10 +121,10 @@ func (n *Network) CreateLinksAndRoutes(args *CreateLinksAndRoutesArgs, _ *struct
nicID++
nicids[link.Name] = nicID
- linkEP := loopback.New()
+ ep := loopback.New()
log.Infof("Enabling loopback interface %q with id %d on addresses %+v", link.Name, nicID, link.Addresses)
- if err := n.createNICWithAddrs(nicID, link.Name, linkEP, link.Addresses, true /* loopback */); err != nil {
+ if err := n.createNICWithAddrs(nicID, link.Name, ep, link.Addresses, true /* loopback */); err != nil {
return err
}
@@ -156,7 +156,7 @@ func (n *Network) CreateLinksAndRoutes(args *CreateLinksAndRoutesArgs, _ *struct
}
mac := tcpip.LinkAddress(link.LinkAddress)
- linkEP, err := fdbased.New(&fdbased.Options{
+ ep, err := fdbased.New(&fdbased.Options{
FDs: FDs,
MTU: uint32(link.MTU),
EthernetHeader: true,
@@ -170,7 +170,7 @@ func (n *Network) CreateLinksAndRoutes(args *CreateLinksAndRoutesArgs, _ *struct
}
log.Infof("Enabling interface %q with id %d on addresses %+v (%v) w/ %d channels", link.Name, nicID, link.Addresses, mac, link.NumChannels)
- if err := n.createNICWithAddrs(nicID, link.Name, linkEP, link.Addresses, false /* loopback */); err != nil {
+ if err := n.createNICWithAddrs(nicID, link.Name, ep, link.Addresses, false /* loopback */); err != nil {
return err
}
@@ -203,14 +203,14 @@ func (n *Network) CreateLinksAndRoutes(args *CreateLinksAndRoutesArgs, _ *struct
// createNICWithAddrs creates a NIC in the network stack and adds the given
// addresses.
-func (n *Network) createNICWithAddrs(id tcpip.NICID, name string, linkEP tcpip.LinkEndpointID, addrs []net.IP, loopback bool) error {
+func (n *Network) createNICWithAddrs(id tcpip.NICID, name string, ep stack.LinkEndpoint, addrs []net.IP, loopback bool) error {
if loopback {
- if err := n.Stack.CreateNamedLoopbackNIC(id, name, sniffer.New(linkEP)); err != nil {
- return fmt.Errorf("CreateNamedLoopbackNIC(%v, %v, %v) failed: %v", id, name, linkEP, err)
+ if err := n.Stack.CreateNamedLoopbackNIC(id, name, sniffer.New(ep)); err != nil {
+ return fmt.Errorf("CreateNamedLoopbackNIC(%v, %v) failed: %v", id, name, err)
}
} else {
- if err := n.Stack.CreateNamedNIC(id, name, sniffer.New(linkEP)); err != nil {
- return fmt.Errorf("CreateNamedNIC(%v, %v, %v) failed: %v", id, name, linkEP, err)
+ if err := n.Stack.CreateNamedNIC(id, name, sniffer.New(ep)); err != nil {
+ return fmt.Errorf("CreateNamedNIC(%v, %v) failed: %v", id, name, err)
}
}
diff --git a/runsc/boot/user.go b/runsc/boot/user.go
index d1d423a5c..56cc12ee0 100644
--- a/runsc/boot/user.go
+++ b/runsc/boot/user.go
@@ -16,6 +16,7 @@ package boot
import (
"bufio"
+ "fmt"
"io"
"strconv"
"strings"
@@ -23,6 +24,7 @@ import (
"gvisor.dev/gvisor/pkg/abi/linux"
"gvisor.dev/gvisor/pkg/sentry/context"
"gvisor.dev/gvisor/pkg/sentry/fs"
+ "gvisor.dev/gvisor/pkg/sentry/kernel/auth"
"gvisor.dev/gvisor/pkg/sentry/usermem"
)
@@ -42,7 +44,7 @@ func (r *fileReader) Read(buf []byte) (int, error) {
// getExecUserHome returns the home directory of the executing user read from
// /etc/passwd as read from the container filesystem.
-func getExecUserHome(ctx context.Context, rootMns *fs.MountNamespace, uid uint32) (string, error) {
+func getExecUserHome(ctx context.Context, rootMns *fs.MountNamespace, uid auth.KUID) (string, error) {
// The default user home directory to return if no user matching the user
// if found in the /etc/passwd found in the image.
const defaultHome = "/"
@@ -82,7 +84,7 @@ func getExecUserHome(ctx context.Context, rootMns *fs.MountNamespace, uid uint32
File: f,
}
- homeDir, err := findHomeInPasswd(uid, r, defaultHome)
+ homeDir, err := findHomeInPasswd(uint32(uid), r, defaultHome)
if err != nil {
return "", err
}
@@ -90,6 +92,28 @@ func getExecUserHome(ctx context.Context, rootMns *fs.MountNamespace, uid uint32
return homeDir, nil
}
+// maybeAddExecUserHome returns a new slice with the HOME enviroment variable
+// set if the slice does not already contain it, otherwise it returns the
+// original slice unmodified.
+func maybeAddExecUserHome(ctx context.Context, mns *fs.MountNamespace, uid auth.KUID, envv []string) ([]string, error) {
+ // Check if the envv already contains HOME.
+ for _, env := range envv {
+ if strings.HasPrefix(env, "HOME=") {
+ // We have it. Return the original slice unmodified.
+ return envv, nil
+ }
+ }
+
+ // Read /etc/passwd for the user's HOME directory and set the HOME
+ // environment variable as required by POSIX if it is not overridden by
+ // the user.
+ homeDir, err := getExecUserHome(ctx, mns, uid)
+ if err != nil {
+ return nil, fmt.Errorf("error reading exec user: %v", err)
+ }
+ return append(envv, "HOME="+homeDir), nil
+}
+
// findHomeInPasswd parses a passwd file and returns the given user's home
// directory. This function does it's best to replicate the runc's behavior.
func findHomeInPasswd(uid uint32, passwd io.Reader, defaultHome string) (string, error) {
diff --git a/runsc/boot/user_test.go b/runsc/boot/user_test.go
index 01f666507..9aee2ad07 100644
--- a/runsc/boot/user_test.go
+++ b/runsc/boot/user_test.go
@@ -25,6 +25,7 @@ import (
specs "github.com/opencontainers/runtime-spec/specs-go"
"gvisor.dev/gvisor/pkg/sentry/context/contexttest"
"gvisor.dev/gvisor/pkg/sentry/fs"
+ "gvisor.dev/gvisor/pkg/sentry/kernel/auth"
)
func setupTempDir() (string, error) {
@@ -68,7 +69,7 @@ func setupPasswd(contents string, perms os.FileMode) func() (string, error) {
// TestGetExecUserHome tests the getExecUserHome function.
func TestGetExecUserHome(t *testing.T) {
tests := map[string]struct {
- uid uint32
+ uid auth.KUID
createRoot func() (string, error)
expected string
}{
@@ -164,13 +165,13 @@ func TestGetExecUserHome(t *testing.T) {
},
}
- var mns *fs.MountNamespace
- setMountNS := func(m *fs.MountNamespace) {
- mns = m
- ctx.(*contexttest.TestContext).RegisterValue(fs.CtxRoot, mns.Root())
- }
mntr := newContainerMounter(spec, []int{sandEnd}, nil, &podMountHints{})
- if err := mntr.setupRootContainer(ctx, ctx, conf, setMountNS); err != nil {
+ mns, err := mntr.createMountNamespace(ctx, conf)
+ if err != nil {
+ t.Fatalf("failed to create mount namespace: %v", err)
+ }
+ ctx = fs.WithRoot(ctx, mns.Root())
+ if err := mntr.mountSubmounts(ctx, conf, mns); err != nil {
t.Fatalf("failed to create mount namespace: %v", err)
}
diff --git a/runsc/cgroup/BUILD b/runsc/cgroup/BUILD
index ab2387614..d6165f9e5 100644
--- a/runsc/cgroup/BUILD
+++ b/runsc/cgroup/BUILD
@@ -6,9 +6,7 @@ go_library(
name = "cgroup",
srcs = ["cgroup.go"],
importpath = "gvisor.dev/gvisor/runsc/cgroup",
- visibility = [
- "//runsc:__subpackages__",
- ],
+ visibility = ["//:sandbox"],
deps = [
"//pkg/log",
"//runsc/specutils",
diff --git a/runsc/cmd/BUILD b/runsc/cmd/BUILD
index 5223b9972..250845ad7 100644
--- a/runsc/cmd/BUILD
+++ b/runsc/cmd/BUILD
@@ -19,6 +19,7 @@ go_library(
"exec.go",
"gofer.go",
"help.go",
+ "install.go",
"kill.go",
"list.go",
"path.go",
@@ -81,7 +82,7 @@ go_test(
"//runsc/boot",
"//runsc/container",
"//runsc/specutils",
- "//runsc/test/testutil",
+ "//runsc/testutil",
"@com_github_google_go-cmp//cmp:go_default_library",
"@com_github_google_go-cmp//cmp/cmpopts:go_default_library",
"@com_github_opencontainers_runtime-spec//specs-go:go_default_library",
diff --git a/runsc/cmd/capability_test.go b/runsc/cmd/capability_test.go
index 3ae25a257..0c27f7313 100644
--- a/runsc/cmd/capability_test.go
+++ b/runsc/cmd/capability_test.go
@@ -15,6 +15,7 @@
package cmd
import (
+ "flag"
"fmt"
"os"
"testing"
@@ -25,7 +26,7 @@ import (
"gvisor.dev/gvisor/runsc/boot"
"gvisor.dev/gvisor/runsc/container"
"gvisor.dev/gvisor/runsc/specutils"
- "gvisor.dev/gvisor/runsc/test/testutil"
+ "gvisor.dev/gvisor/runsc/testutil"
)
func init() {
@@ -121,6 +122,7 @@ func TestCapabilities(t *testing.T) {
}
func TestMain(m *testing.M) {
+ flag.Parse()
specutils.MaybeRunAsRoot()
os.Exit(m.Run())
}
diff --git a/runsc/cmd/exec.go b/runsc/cmd/exec.go
index e817eff77..d1e99243b 100644
--- a/runsc/cmd/exec.go
+++ b/runsc/cmd/exec.go
@@ -105,11 +105,11 @@ func (ex *Exec) SetFlags(f *flag.FlagSet) {
// Execute implements subcommands.Command.Execute. It starts a process in an
// already created container.
func (ex *Exec) Execute(_ context.Context, f *flag.FlagSet, args ...interface{}) subcommands.ExitStatus {
- e, id, err := ex.parseArgs(f)
+ conf := args[0].(*boot.Config)
+ e, id, err := ex.parseArgs(f, conf.EnableRaw)
if err != nil {
Fatalf("parsing process spec: %v", err)
}
- conf := args[0].(*boot.Config)
waitStatus := args[1].(*syscall.WaitStatus)
c, err := container.Load(conf.RootDir, id)
@@ -117,6 +117,9 @@ func (ex *Exec) Execute(_ context.Context, f *flag.FlagSet, args ...interface{})
Fatalf("loading sandbox: %v", err)
}
+ log.Debugf("Exec arguments: %+v", e)
+ log.Debugf("Exec capablities: %+v", e.Capabilities)
+
// Replace empty settings with defaults from container.
if e.WorkingDirectory == "" {
e.WorkingDirectory = c.Spec.Process.Cwd
@@ -127,15 +130,13 @@ func (ex *Exec) Execute(_ context.Context, f *flag.FlagSet, args ...interface{})
Fatalf("getting environment variables: %v", err)
}
}
+
if e.Capabilities == nil {
- // enableRaw is set to true to prevent the filtering out of
- // CAP_NET_RAW. This is the opposite of Create() because exec
- // requires the capability to be set explicitly, while 'docker
- // run' sets it by default.
- e.Capabilities, err = specutils.Capabilities(true /* enableRaw */, c.Spec.Process.Capabilities)
+ e.Capabilities, err = specutils.Capabilities(conf.EnableRaw, c.Spec.Process.Capabilities)
if err != nil {
Fatalf("creating capabilities: %v", err)
}
+ log.Infof("Using exec capabilities from container: %+v", e.Capabilities)
}
// containerd expects an actual process to represent the container being
@@ -282,14 +283,14 @@ func (ex *Exec) execChildAndWait(waitStatus *syscall.WaitStatus) subcommands.Exi
// parseArgs parses exec information from the command line or a JSON file
// depending on whether the --process flag was used. Returns an ExecArgs and
// the ID of the container to be used.
-func (ex *Exec) parseArgs(f *flag.FlagSet) (*control.ExecArgs, string, error) {
+func (ex *Exec) parseArgs(f *flag.FlagSet, enableRaw bool) (*control.ExecArgs, string, error) {
if ex.processPath == "" {
// Requires at least a container ID and command.
if f.NArg() < 2 {
f.Usage()
return nil, "", fmt.Errorf("both a container-id and command are required")
}
- e, err := ex.argsFromCLI(f.Args()[1:])
+ e, err := ex.argsFromCLI(f.Args()[1:], enableRaw)
return e, f.Arg(0), err
}
// Requires only the container ID.
@@ -297,11 +298,11 @@ func (ex *Exec) parseArgs(f *flag.FlagSet) (*control.ExecArgs, string, error) {
f.Usage()
return nil, "", fmt.Errorf("a container-id is required")
}
- e, err := ex.argsFromProcessFile()
+ e, err := ex.argsFromProcessFile(enableRaw)
return e, f.Arg(0), err
}
-func (ex *Exec) argsFromCLI(argv []string) (*control.ExecArgs, error) {
+func (ex *Exec) argsFromCLI(argv []string, enableRaw bool) (*control.ExecArgs, error) {
extraKGIDs := make([]auth.KGID, 0, len(ex.extraKGIDs))
for _, s := range ex.extraKGIDs {
kgid, err := strconv.Atoi(s)
@@ -314,7 +315,7 @@ func (ex *Exec) argsFromCLI(argv []string) (*control.ExecArgs, error) {
var caps *auth.TaskCapabilities
if len(ex.caps) > 0 {
var err error
- caps, err = capabilities(ex.caps)
+ caps, err = capabilities(ex.caps, enableRaw)
if err != nil {
return nil, fmt.Errorf("capabilities error: %v", err)
}
@@ -332,7 +333,7 @@ func (ex *Exec) argsFromCLI(argv []string) (*control.ExecArgs, error) {
}, nil
}
-func (ex *Exec) argsFromProcessFile() (*control.ExecArgs, error) {
+func (ex *Exec) argsFromProcessFile(enableRaw bool) (*control.ExecArgs, error) {
f, err := os.Open(ex.processPath)
if err != nil {
return nil, fmt.Errorf("error opening process file: %s, %v", ex.processPath, err)
@@ -342,21 +343,21 @@ func (ex *Exec) argsFromProcessFile() (*control.ExecArgs, error) {
if err := json.NewDecoder(f).Decode(&p); err != nil {
return nil, fmt.Errorf("error parsing process file: %s, %v", ex.processPath, err)
}
- return argsFromProcess(&p)
+ return argsFromProcess(&p, enableRaw)
}
// argsFromProcess performs all the non-IO conversion from the Process struct
// to ExecArgs.
-func argsFromProcess(p *specs.Process) (*control.ExecArgs, error) {
+func argsFromProcess(p *specs.Process, enableRaw bool) (*control.ExecArgs, error) {
// Create capabilities.
var caps *auth.TaskCapabilities
if p.Capabilities != nil {
var err error
- // enableRaw is set to true to prevent the filtering out of
- // CAP_NET_RAW. This is the opposite of Create() because exec
- // requires the capability to be set explicitly, while 'docker
- // run' sets it by default.
- caps, err = specutils.Capabilities(true /* enableRaw */, p.Capabilities)
+ // Starting from Docker 19, capabilities are explicitly set for exec (instead
+ // of nil like before). So we can't distinguish 'exec' from
+ // 'exec --privileged', as both specify CAP_NET_RAW. Therefore, filter
+ // CAP_NET_RAW in the same way as container start.
+ caps, err = specutils.Capabilities(enableRaw, p.Capabilities)
if err != nil {
return nil, fmt.Errorf("error creating capabilities: %v", err)
}
@@ -409,7 +410,7 @@ func resolveEnvs(envs ...[]string) ([]string, error) {
// capabilities takes a list of capabilities as strings and returns an
// auth.TaskCapabilities struct with those capabilities in every capability set.
// This mimics runc's behavior.
-func capabilities(cs []string) (*auth.TaskCapabilities, error) {
+func capabilities(cs []string, enableRaw bool) (*auth.TaskCapabilities, error) {
var specCaps specs.LinuxCapabilities
for _, cap := range cs {
specCaps.Ambient = append(specCaps.Ambient, cap)
@@ -418,11 +419,11 @@ func capabilities(cs []string) (*auth.TaskCapabilities, error) {
specCaps.Inheritable = append(specCaps.Inheritable, cap)
specCaps.Permitted = append(specCaps.Permitted, cap)
}
- // enableRaw is set to true to prevent the filtering out of
- // CAP_NET_RAW. This is the opposite of Create() because exec requires
- // the capability to be set explicitly, while 'docker run' sets it by
- // default.
- return specutils.Capabilities(true /* enableRaw */, &specCaps)
+ // Starting from Docker 19, capabilities are explicitly set for exec (instead
+ // of nil like before). So we can't distinguish 'exec' from
+ // 'exec --privileged', as both specify CAP_NET_RAW. Therefore, filter
+ // CAP_NET_RAW in the same way as container start.
+ return specutils.Capabilities(enableRaw, &specCaps)
}
// stringSlice allows a flag to be used multiple times, where each occurrence
diff --git a/runsc/cmd/exec_test.go b/runsc/cmd/exec_test.go
index eb38a431f..a1e980d08 100644
--- a/runsc/cmd/exec_test.go
+++ b/runsc/cmd/exec_test.go
@@ -91,7 +91,7 @@ func TestCLIArgs(t *testing.T) {
}
for _, tc := range testCases {
- e, err := tc.ex.argsFromCLI(tc.argv)
+ e, err := tc.ex.argsFromCLI(tc.argv, true)
if err != nil {
t.Errorf("argsFromCLI(%+v): got error: %+v", tc.ex, err)
} else if !cmp.Equal(*e, tc.expected, cmpopts.IgnoreUnexported(os.File{})) {
@@ -144,7 +144,7 @@ func TestJSONArgs(t *testing.T) {
}
for _, tc := range testCases {
- e, err := argsFromProcess(&tc.p)
+ e, err := argsFromProcess(&tc.p, true)
if err != nil {
t.Errorf("argsFromProcess(%+v): got error: %+v", tc.p, err)
} else if !cmp.Equal(*e, tc.expected, cmpopts.IgnoreUnexported(os.File{})) {
diff --git a/runsc/cmd/gofer.go b/runsc/cmd/gofer.go
index 9faabf494..fbd579fb8 100644
--- a/runsc/cmd/gofer.go
+++ b/runsc/cmd/gofer.go
@@ -182,6 +182,7 @@ func (g *Gofer) Execute(_ context.Context, f *flag.FlagSet, args ...interface{})
cfg := fsgofer.Config{
ROMount: isReadonlyMount(m.Options),
PanicOnWrite: g.panicOnWrite,
+ HostUDS: conf.FSGoferHostUDS,
}
ap, err := fsgofer.NewAttachPoint(m.Destination, cfg)
if err != nil {
@@ -200,6 +201,10 @@ func (g *Gofer) Execute(_ context.Context, f *flag.FlagSet, args ...interface{})
Fatalf("too many FDs passed for mounts. mounts: %d, FDs: %d", mountIdx, len(g.ioFDs))
}
+ if conf.FSGoferHostUDS {
+ filter.InstallUDSFilters()
+ }
+
if err := filter.Install(); err != nil {
Fatalf("installing seccomp filters: %v", err)
}
diff --git a/runsc/cmd/install.go b/runsc/cmd/install.go
new file mode 100644
index 000000000..441c1db0d
--- /dev/null
+++ b/runsc/cmd/install.go
@@ -0,0 +1,210 @@
+// Copyright 2019 The gVisor Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+package cmd
+
+import (
+ "context"
+ "encoding/json"
+ "fmt"
+ "io/ioutil"
+ "log"
+ "os"
+ "path"
+
+ "flag"
+ "github.com/google/subcommands"
+)
+
+// Install implements subcommands.Command.
+type Install struct {
+ ConfigFile string
+ Runtime string
+ Experimental bool
+}
+
+// Name implements subcommands.Command.Name.
+func (*Install) Name() string {
+ return "install"
+}
+
+// Synopsis implements subcommands.Command.Synopsis.
+func (*Install) Synopsis() string {
+ return "adds a runtime to docker daemon configuration"
+}
+
+// Usage implements subcommands.Command.Usage.
+func (*Install) Usage() string {
+ return `install [flags] <name> [-- [args...]] -- if provided, args are passed to the runtime
+`
+}
+
+// SetFlags implements subcommands.Command.SetFlags.
+func (i *Install) SetFlags(fs *flag.FlagSet) {
+ fs.StringVar(&i.ConfigFile, "config_file", "/etc/docker/daemon.json", "path to Docker daemon config file")
+ fs.StringVar(&i.Runtime, "runtime", "runsc", "runtime name")
+ fs.BoolVar(&i.Experimental, "experimental", false, "enable experimental features")
+}
+
+// Execute implements subcommands.Command.Execute.
+func (i *Install) Execute(_ context.Context, f *flag.FlagSet, args ...interface{}) subcommands.ExitStatus {
+ // Grab the name and arguments.
+ runtimeArgs := f.Args()
+
+ // Extract the executable.
+ path, err := os.Executable()
+ if err != nil {
+ log.Fatalf("Error reading current exectuable: %v", err)
+ }
+
+ // Load the configuration file.
+ c, err := readConfig(i.ConfigFile)
+ if err != nil {
+ log.Fatalf("Error reading config file %q: %v", i.ConfigFile, err)
+ }
+
+ // Add the given runtime.
+ var rts map[string]interface{}
+ if i, ok := c["runtimes"]; ok {
+ rts = i.(map[string]interface{})
+ } else {
+ rts = make(map[string]interface{})
+ c["runtimes"] = rts
+ }
+ rts[i.Runtime] = struct {
+ Path string `json:"path,omitempty"`
+ RuntimeArgs []string `json:"runtimeArgs,omitempty"`
+ }{
+ Path: path,
+ RuntimeArgs: runtimeArgs,
+ }
+
+ // Set experimental if required.
+ if i.Experimental {
+ c["experimental"] = true
+ }
+
+ // Write out the runtime.
+ if err := writeConfig(c, i.ConfigFile); err != nil {
+ log.Fatalf("Error writing config file %q: %v", i.ConfigFile, err)
+ }
+
+ // Success.
+ log.Printf("Added runtime %q with arguments %v to %q.", i.Runtime, runtimeArgs, i.ConfigFile)
+ return subcommands.ExitSuccess
+}
+
+// Uninstall implements subcommands.Command.
+type Uninstall struct {
+ ConfigFile string
+ Runtime string
+}
+
+// Name implements subcommands.Command.Name.
+func (*Uninstall) Name() string {
+ return "uninstall"
+}
+
+// Synopsis implements subcommands.Command.Synopsis.
+func (*Uninstall) Synopsis() string {
+ return "removes a runtime from docker daemon configuration"
+}
+
+// Usage implements subcommands.Command.Usage.
+func (*Uninstall) Usage() string {
+ return `uninstall [flags] <name>
+`
+}
+
+// SetFlags implements subcommands.Command.SetFlags.
+func (u *Uninstall) SetFlags(fs *flag.FlagSet) {
+ fs.StringVar(&u.ConfigFile, "config_file", "/etc/docker/daemon.json", "path to Docker daemon config file")
+ fs.StringVar(&u.Runtime, "runtime", "runsc", "runtime name")
+}
+
+// Execute implements subcommands.Command.Execute.
+func (u *Uninstall) Execute(_ context.Context, f *flag.FlagSet, args ...interface{}) subcommands.ExitStatus {
+ log.Printf("Removing runtime %q from %q.", u.Runtime, u.ConfigFile)
+
+ c, err := readConfig(u.ConfigFile)
+ if err != nil {
+ log.Fatalf("Error reading config file %q: %v", u.ConfigFile, err)
+ }
+
+ var rts map[string]interface{}
+ if i, ok := c["runtimes"]; ok {
+ rts = i.(map[string]interface{})
+ } else {
+ log.Fatalf("runtime %q not found", u.Runtime)
+ }
+ if _, ok := rts[u.Runtime]; !ok {
+ log.Fatalf("runtime %q not found", u.Runtime)
+ }
+ delete(rts, u.Runtime)
+
+ if err := writeConfig(c, u.ConfigFile); err != nil {
+ log.Fatalf("Error writing config file %q: %v", u.ConfigFile, err)
+ }
+ return subcommands.ExitSuccess
+}
+
+func readConfig(path string) (map[string]interface{}, error) {
+ // Read the configuration data.
+ configBytes, err := ioutil.ReadFile(path)
+ if err != nil && !os.IsNotExist(err) {
+ return nil, err
+ }
+
+ // Unmarshal the configuration.
+ c := make(map[string]interface{})
+ if len(configBytes) > 0 {
+ if err := json.Unmarshal(configBytes, &c); err != nil {
+ return nil, err
+ }
+ }
+
+ return c, nil
+}
+
+func writeConfig(c map[string]interface{}, filename string) error {
+ // Marshal the configuration.
+ b, err := json.MarshalIndent(c, "", " ")
+ if err != nil {
+ return err
+ }
+
+ // Copy the old configuration.
+ old, err := ioutil.ReadFile(filename)
+ if err != nil {
+ if !os.IsNotExist(err) {
+ return fmt.Errorf("error reading config file %q: %v", filename, err)
+ }
+ } else {
+ if err := ioutil.WriteFile(filename+"~", old, 0644); err != nil {
+ return fmt.Errorf("error backing up config file %q: %v", filename, err)
+ }
+ }
+
+ // Make the necessary directories.
+ if err := os.MkdirAll(path.Dir(filename), 0755); err != nil {
+ return fmt.Errorf("error creating config directory for %q: %v", filename, err)
+ }
+
+ // Write the new configuration.
+ if err := ioutil.WriteFile(filename, b, 0644); err != nil {
+ return fmt.Errorf("error writing config file %q: %v", filename, err)
+ }
+
+ return nil
+}
diff --git a/runsc/container/BUILD b/runsc/container/BUILD
index de8202bb1..26d1cd5ab 100644
--- a/runsc/container/BUILD
+++ b/runsc/container/BUILD
@@ -47,6 +47,7 @@ go_test(
],
deps = [
"//pkg/abi/linux",
+ "//pkg/bits",
"//pkg/log",
"//pkg/sentry/control",
"//pkg/sentry/kernel",
@@ -56,7 +57,7 @@ go_test(
"//runsc/boot",
"//runsc/boot/platforms",
"//runsc/specutils",
- "//runsc/test/testutil",
+ "//runsc/testutil",
"@com_github_cenkalti_backoff//:go_default_library",
"@com_github_kr_pty//:go_default_library",
"@com_github_opencontainers_runtime-spec//specs-go:go_default_library",
diff --git a/runsc/container/console_test.go b/runsc/container/console_test.go
index e9372989f..7d67c3a75 100644
--- a/runsc/container/console_test.go
+++ b/runsc/container/console_test.go
@@ -30,7 +30,7 @@ import (
"gvisor.dev/gvisor/pkg/sentry/control"
"gvisor.dev/gvisor/pkg/unet"
"gvisor.dev/gvisor/pkg/urpc"
- "gvisor.dev/gvisor/runsc/test/testutil"
+ "gvisor.dev/gvisor/runsc/testutil"
)
// socketPath creates a path inside bundleDir and ensures that the returned
diff --git a/runsc/container/container.go b/runsc/container/container.go
index bbb364214..a721c1c31 100644
--- a/runsc/container/container.go
+++ b/runsc/container/container.go
@@ -513,9 +513,16 @@ func (c *Container) Start(conf *boot.Config) error {
return err
}
- // Adjust the oom_score_adj for sandbox and gofers. This must be done after
+ // Adjust the oom_score_adj for sandbox. This must be done after
// save().
- return c.adjustOOMScoreAdj(conf)
+ err = adjustSandboxOOMScoreAdj(c.Sandbox, c.RootContainerDir, false)
+ if err != nil {
+ return err
+ }
+
+ // Set container's oom_score_adj to the gofer since it is dedicated to
+ // the container, in case the gofer uses up too much memory.
+ return c.adjustGoferOOMScoreAdj()
}
// Restore takes a container and replaces its kernel and file system
@@ -782,6 +789,9 @@ func (c *Container) Destroy() error {
}
defer unlock()
+ // Stored for later use as stop() sets c.Sandbox to nil.
+ sb := c.Sandbox
+
if err := c.stop(); err != nil {
err = fmt.Errorf("stopping container: %v", err)
log.Warningf("%v", err)
@@ -796,6 +806,16 @@ func (c *Container) Destroy() error {
c.changeStatus(Stopped)
+ // Adjust oom_score_adj for the sandbox. This must be done after the
+ // container is stopped and the directory at c.Root is removed.
+ // We must test if the sandbox is nil because Destroy should be
+ // idempotent.
+ if sb != nil {
+ if err := adjustSandboxOOMScoreAdj(sb, c.RootContainerDir, true); err != nil {
+ errs = append(errs, err.Error())
+ }
+ }
+
// "If any poststop hook fails, the runtime MUST log a warning, but the
// remaining hooks and lifecycle continue as if the hook had succeeded" -OCI spec.
// Based on the OCI, "The post-stop hooks MUST be called after the container is
@@ -926,7 +946,14 @@ func (c *Container) createGoferProcess(spec *specs.Spec, conf *boot.Config, bund
}
if conf.DebugLog != "" {
- debugLogFile, err := specutils.DebugLogFile(conf.DebugLog, "gofer")
+ test := ""
+ if len(conf.TestOnlyTestNameEnv) != 0 {
+ // Fetch test name if one is provided and the test only flag was set.
+ if t, ok := specutils.EnvVar(spec.Process.Env, conf.TestOnlyTestNameEnv); ok {
+ test = t
+ }
+ }
+ debugLogFile, err := specutils.DebugLogFile(conf.DebugLog, "gofer", test)
if err != nil {
return nil, nil, fmt.Errorf("opening debug log file in %q: %v", conf.DebugLog, err)
}
@@ -1139,35 +1166,82 @@ func runInCgroup(cg *cgroup.Cgroup, fn func() error) error {
return fn()
}
-// adjustOOMScoreAdj sets the oom_score_adj for the sandbox and all gofers.
+// adjustGoferOOMScoreAdj sets the oom_store_adj for the container's gofer.
+func (c *Container) adjustGoferOOMScoreAdj() error {
+ if c.GoferPid != 0 && c.Spec.Process.OOMScoreAdj != nil {
+ if err := setOOMScoreAdj(c.GoferPid, *c.Spec.Process.OOMScoreAdj); err != nil {
+ return fmt.Errorf("setting gofer oom_score_adj for container %q: %v", c.ID, err)
+ }
+ }
+
+ return nil
+}
+
+// adjustSandboxOOMScoreAdj sets the oom_score_adj for the sandbox.
// oom_score_adj is set to the lowest oom_score_adj among the containers
// running in the sandbox.
//
// TODO(gvisor.dev/issue/512): This call could race with other containers being
// created at the same time and end up setting the wrong oom_score_adj to the
// sandbox.
-func (c *Container) adjustOOMScoreAdj(conf *boot.Config) error {
- // If this container's OOMScoreAdj is nil then we can exit early as no
- // change should be made to oom_score_adj for the sandbox.
- if c.Spec.Process.OOMScoreAdj == nil {
- return nil
- }
-
- containers, err := loadSandbox(conf.RootDir, c.Sandbox.ID)
+func adjustSandboxOOMScoreAdj(s *sandbox.Sandbox, rootDir string, destroy bool) error {
+ containers, err := loadSandbox(rootDir, s.ID)
if err != nil {
return fmt.Errorf("loading sandbox containers: %v", err)
}
+ // Do nothing if the sandbox has been terminated.
+ if len(containers) == 0 {
+ return nil
+ }
+
// Get the lowest score for all containers.
var lowScore int
scoreFound := false
- for _, container := range containers {
- if container.Spec.Process.OOMScoreAdj != nil && (!scoreFound || *container.Spec.Process.OOMScoreAdj < lowScore) {
+ if len(containers) == 1 && len(containers[0].Spec.Annotations[specutils.ContainerdContainerTypeAnnotation]) == 0 {
+ // This is a single-container sandbox. Set the oom_score_adj to
+ // the value specified in the OCI bundle.
+ if containers[0].Spec.Process.OOMScoreAdj != nil {
scoreFound = true
- lowScore = *container.Spec.Process.OOMScoreAdj
+ lowScore = *containers[0].Spec.Process.OOMScoreAdj
+ }
+ } else {
+ for _, container := range containers {
+ // Special multi-container support for CRI. Ignore the root
+ // container when calculating oom_score_adj for the sandbox because
+ // it is the infrastructure (pause) container and always has a very
+ // low oom_score_adj.
+ //
+ // We will use OOMScoreAdj in the single-container case where the
+ // containerd container-type annotation is not present.
+ if container.Spec.Annotations[specutils.ContainerdContainerTypeAnnotation] == specutils.ContainerdContainerTypeSandbox {
+ continue
+ }
+
+ if container.Spec.Process.OOMScoreAdj != nil && (!scoreFound || *container.Spec.Process.OOMScoreAdj < lowScore) {
+ scoreFound = true
+ lowScore = *container.Spec.Process.OOMScoreAdj
+ }
}
}
+ // If the container is destroyed and remaining containers have no
+ // oomScoreAdj specified then we must revert to the oom_score_adj of the
+ // parent process.
+ if !scoreFound && destroy {
+ ppid, err := specutils.GetParentPid(s.Pid)
+ if err != nil {
+ return fmt.Errorf("getting parent pid of sandbox pid %d: %v", s.Pid, err)
+ }
+ pScore, err := specutils.GetOOMScoreAdj(ppid)
+ if err != nil {
+ return fmt.Errorf("getting oom_score_adj of parent %d: %v", ppid, err)
+ }
+
+ scoreFound = true
+ lowScore = pScore
+ }
+
// Only set oom_score_adj if one of the containers has oom_score_adj set
// in the OCI bundle. If not, we need to inherit the parent process's
// oom_score_adj.
@@ -1177,15 +1251,10 @@ func (c *Container) adjustOOMScoreAdj(conf *boot.Config) error {
}
// Set the lowest of all containers oom_score_adj to the sandbox.
- if err := setOOMScoreAdj(c.Sandbox.Pid, lowScore); err != nil {
- return fmt.Errorf("setting oom_score_adj for sandbox %q: %v", c.Sandbox.ID, err)
+ if err := setOOMScoreAdj(s.Pid, lowScore); err != nil {
+ return fmt.Errorf("setting oom_score_adj for sandbox %q: %v", s.ID, err)
}
- // Set container's oom_score_adj to the gofer since it is dedicated to the
- // container, in case the gofer uses up too much memory.
- if err := setOOMScoreAdj(c.GoferPid, *c.Spec.Process.OOMScoreAdj); err != nil {
- return fmt.Errorf("setting gofer oom_score_adj for container %q: %v", c.ID, err)
- }
return nil
}
diff --git a/runsc/container/container_test.go b/runsc/container/container_test.go
index af128bf1c..519f5ed9b 100644
--- a/runsc/container/container_test.go
+++ b/runsc/container/container_test.go
@@ -16,6 +16,7 @@ package container
import (
"bytes"
+ "flag"
"fmt"
"io"
"io/ioutil"
@@ -33,13 +34,14 @@ import (
"github.com/cenkalti/backoff"
specs "github.com/opencontainers/runtime-spec/specs-go"
"gvisor.dev/gvisor/pkg/abi/linux"
+ "gvisor.dev/gvisor/pkg/bits"
"gvisor.dev/gvisor/pkg/log"
"gvisor.dev/gvisor/pkg/sentry/control"
"gvisor.dev/gvisor/pkg/sentry/kernel/auth"
"gvisor.dev/gvisor/runsc/boot"
"gvisor.dev/gvisor/runsc/boot/platforms"
"gvisor.dev/gvisor/runsc/specutils"
- "gvisor.dev/gvisor/runsc/test/testutil"
+ "gvisor.dev/gvisor/runsc/testutil"
)
// waitForProcessList waits for the given process list to show up in the container.
@@ -155,12 +157,7 @@ func waitForFile(f *os.File) error {
return nil
}
- timeout := 5 * time.Second
- if testutil.RaceEnabled {
- // Race makes slow things even slow, so bump the timeout.
- timeout = 3 * timeout
- }
- return testutil.Poll(op, timeout)
+ return testutil.Poll(op, 30*time.Second)
}
// readOutputNum reads a file at given filepath and returns the int at the
@@ -254,10 +251,6 @@ func configs(opts ...configOption) []*boot.Config {
// TODO(b/112165693): KVM tests are flaky. Disable until fixed.
continue
- // TODO(b/68787993): KVM doesn't work with --race.
- if testutil.RaceEnabled {
- continue
- }
c.Platform = platforms.KVM
case nonExclusiveFS:
c.FileAccess = boot.FileAccessShared
@@ -1310,10 +1303,13 @@ func TestRunNonRoot(t *testing.T) {
t.Logf("Running test with conf: %+v", conf)
spec := testutil.NewSpecWithArgs("/bin/true")
+
+ // Set a random user/group with no access to "blocked" dir.
spec.Process.User.UID = 343
spec.Process.User.GID = 2401
+ spec.Process.Capabilities = nil
- // User that container runs as can't list '$TMP/blocked' and would fail to
+ // User running inside container can't list '$TMP/blocked' and would fail to
// mount it.
dir, err := ioutil.TempDir(testutil.TmpDir(), "blocked")
if err != nil {
@@ -1327,6 +1323,17 @@ func TestRunNonRoot(t *testing.T) {
t.Fatalf("os.MkDir(%q) failed: %v", dir, err)
}
+ src, err := ioutil.TempDir(testutil.TmpDir(), "src")
+ if err != nil {
+ t.Fatalf("ioutil.TempDir() failed: %v", err)
+ }
+
+ spec.Mounts = append(spec.Mounts, specs.Mount{
+ Destination: dir,
+ Source: src,
+ Type: "bind",
+ })
+
if err := run(spec, conf); err != nil {
t.Fatalf("error running sandbox: %v", err)
}
@@ -1637,22 +1644,27 @@ func TestGoferExits(t *testing.T) {
}
func TestRootNotMount(t *testing.T) {
- if testutil.RaceEnabled {
- // Requires statically linked binary, since it's mapping the root to a
- // random dir, libs cannot be located.
- t.Skip("race makes test_app not statically linked")
- }
-
appSym, err := testutil.FindFile("runsc/container/test_app/test_app")
if err != nil {
t.Fatal("error finding test_app:", err)
}
+
app, err := filepath.EvalSymlinks(appSym)
if err != nil {
t.Fatalf("error resolving %q symlink: %v", appSym, err)
}
log.Infof("App path %q is a symlink to %q", appSym, app)
+ static, err := testutil.IsStatic(app)
+ if err != nil {
+ t.Fatalf("error reading application binary: %v", err)
+ }
+ if !static {
+ // This happens during race builds; we cannot map in shared
+ // libraries also, so we need to skip the test.
+ t.Skip()
+ }
+
root := filepath.Dir(app)
exe := "/" + filepath.Base(app)
log.Infof("Executing %q in %q", exe, root)
@@ -2038,6 +2050,30 @@ func TestMountSymlink(t *testing.T) {
}
}
+// Check that --net-raw disables the CAP_NET_RAW capability.
+func TestNetRaw(t *testing.T) {
+ capNetRaw := strconv.FormatUint(bits.MaskOf64(int(linux.CAP_NET_RAW)), 10)
+ app, err := testutil.FindFile("runsc/container/test_app/test_app")
+ if err != nil {
+ t.Fatal("error finding test_app:", err)
+ }
+
+ for _, enableRaw := range []bool{true, false} {
+ conf := testutil.TestConfig()
+ conf.EnableRaw = enableRaw
+
+ test := "--enabled"
+ if !enableRaw {
+ test = "--disabled"
+ }
+
+ spec := testutil.NewSpecWithArgs(app, "capability", test, capNetRaw)
+ if err := run(spec, conf); err != nil {
+ t.Fatalf("Error running container: %v", err)
+ }
+ }
+}
+
// executeSync synchronously executes a new process.
func (cont *Container) executeSync(args *control.ExecArgs) (syscall.WaitStatus, error) {
pid, err := cont.Execute(args)
@@ -2053,10 +2089,10 @@ func (cont *Container) executeSync(args *control.ExecArgs) (syscall.WaitStatus,
func TestMain(m *testing.M) {
log.SetLevel(log.Debug)
+ flag.Parse()
if err := testutil.ConfigureExePath(); err != nil {
panic(err.Error())
}
specutils.MaybeRunAsRoot()
-
os.Exit(m.Run())
}
diff --git a/runsc/container/multi_container_test.go b/runsc/container/multi_container_test.go
index 2d51fecc6..bd45a5118 100644
--- a/runsc/container/multi_container_test.go
+++ b/runsc/container/multi_container_test.go
@@ -32,7 +32,7 @@ import (
"gvisor.dev/gvisor/pkg/sentry/kernel"
"gvisor.dev/gvisor/runsc/boot"
"gvisor.dev/gvisor/runsc/specutils"
- "gvisor.dev/gvisor/runsc/test/testutil"
+ "gvisor.dev/gvisor/runsc/testutil"
)
func createSpecs(cmds ...[]string) ([]*specs.Spec, []string) {
@@ -549,10 +549,16 @@ func TestMultiContainerDestroy(t *testing.T) {
t.Logf("Running test with conf: %+v", conf)
// First container will remain intact while the second container is killed.
- specs, ids := createSpecs(
- []string{app, "reaper"},
+ podSpecs, ids := createSpecs(
+ []string{"sleep", "100"},
[]string{app, "fork-bomb"})
- containers, cleanup, err := startContainers(conf, specs, ids)
+
+ // Run the fork bomb in a PID namespace to prevent processes to be
+ // re-parented to PID=1 in the root container.
+ podSpecs[1].Linux = &specs.Linux{
+ Namespaces: []specs.LinuxNamespace{{Type: "pid"}},
+ }
+ containers, cleanup, err := startContainers(conf, podSpecs, ids)
if err != nil {
t.Fatalf("error starting containers: %v", err)
}
@@ -580,7 +586,7 @@ func TestMultiContainerDestroy(t *testing.T) {
if err != nil {
t.Fatalf("error getting process data from sandbox: %v", err)
}
- expectedPL := []*control.Process{{PID: 1, Cmd: "test_app"}}
+ expectedPL := []*control.Process{{PID: 1, Cmd: "sleep"}}
if !procListsEqual(pss, expectedPL) {
t.Errorf("container got process list: %s, want: %s", procListToString(pss), procListToString(expectedPL))
}
@@ -1485,3 +1491,58 @@ func TestMultiContainerLoadSandbox(t *testing.T) {
t.Errorf("containers not found: %v", wantIDs)
}
}
+
+// TestMultiContainerRunNonRoot checks that child container can be configured
+// when running as non-privileged user.
+func TestMultiContainerRunNonRoot(t *testing.T) {
+ cmdRoot := []string{"/bin/sleep", "100"}
+ cmdSub := []string{"/bin/true"}
+ podSpecs, ids := createSpecs(cmdRoot, cmdSub)
+
+ // User running inside container can't list '$TMP/blocked' and would fail to
+ // mount it.
+ blocked, err := ioutil.TempDir(testutil.TmpDir(), "blocked")
+ if err != nil {
+ t.Fatalf("ioutil.TempDir() failed: %v", err)
+ }
+ if err := os.Chmod(blocked, 0700); err != nil {
+ t.Fatalf("os.MkDir(%q) failed: %v", blocked, err)
+ }
+ dir := path.Join(blocked, "test")
+ if err := os.Mkdir(dir, 0755); err != nil {
+ t.Fatalf("os.MkDir(%q) failed: %v", dir, err)
+ }
+
+ src, err := ioutil.TempDir(testutil.TmpDir(), "src")
+ if err != nil {
+ t.Fatalf("ioutil.TempDir() failed: %v", err)
+ }
+
+ // Set a random user/group with no access to "blocked" dir.
+ podSpecs[1].Process.User.UID = 343
+ podSpecs[1].Process.User.GID = 2401
+ podSpecs[1].Process.Capabilities = nil
+
+ podSpecs[1].Mounts = append(podSpecs[1].Mounts, specs.Mount{
+ Destination: dir,
+ Source: src,
+ Type: "bind",
+ })
+
+ conf := testutil.TestConfig()
+ pod, cleanup, err := startContainers(conf, podSpecs, ids)
+ if err != nil {
+ t.Fatalf("error starting containers: %v", err)
+ }
+ defer cleanup()
+
+ // Once all containers are started, wait for the child container to exit.
+ // This means that the volume was mounted properly.
+ ws, err := pod[1].Wait()
+ if err != nil {
+ t.Fatalf("running child container: %v", err)
+ }
+ if !ws.Exited() || ws.ExitStatus() != 0 {
+ t.Fatalf("child container failed, waitStatus: %v", ws)
+ }
+}
diff --git a/runsc/container/shared_volume_test.go b/runsc/container/shared_volume_test.go
index 1f90d2462..dc4194134 100644
--- a/runsc/container/shared_volume_test.go
+++ b/runsc/container/shared_volume_test.go
@@ -25,7 +25,7 @@ import (
"gvisor.dev/gvisor/pkg/sentry/control"
"gvisor.dev/gvisor/pkg/sentry/kernel/auth"
"gvisor.dev/gvisor/runsc/boot"
- "gvisor.dev/gvisor/runsc/test/testutil"
+ "gvisor.dev/gvisor/runsc/testutil"
)
// TestSharedVolume checks that modifications to a volume mount are propagated
diff --git a/runsc/container/test_app/BUILD b/runsc/container/test_app/BUILD
index 82dbd54d2..9bf9e6e9d 100644
--- a/runsc/container/test_app/BUILD
+++ b/runsc/container/test_app/BUILD
@@ -13,7 +13,7 @@ go_binary(
visibility = ["//runsc/container:__pkg__"],
deps = [
"//pkg/unet",
- "//runsc/test/testutil",
+ "//runsc/testutil",
"@com_github_google_subcommands//:go_default_library",
],
)
diff --git a/runsc/container/test_app/fds.go b/runsc/container/test_app/fds.go
index c12809cab..a90cc1662 100644
--- a/runsc/container/test_app/fds.go
+++ b/runsc/container/test_app/fds.go
@@ -24,7 +24,7 @@ import (
"flag"
"github.com/google/subcommands"
"gvisor.dev/gvisor/pkg/unet"
- "gvisor.dev/gvisor/runsc/test/testutil"
+ "gvisor.dev/gvisor/runsc/testutil"
)
const fileContents = "foobarbaz"
@@ -60,7 +60,7 @@ func (fds *fdSender) Execute(ctx context.Context, f *flag.FlagSet, args ...inter
log.Fatalf("socket flag must be set")
}
- dir, err := ioutil.TempDir(testutil.TmpDir(), "")
+ dir, err := ioutil.TempDir("", "")
if err != nil {
log.Fatalf("TempDir failed: %v", err)
}
diff --git a/runsc/container/test_app/test_app.go b/runsc/container/test_app/test_app.go
index 6578c7b41..913d781c6 100644
--- a/runsc/container/test_app/test_app.go
+++ b/runsc/container/test_app/test_app.go
@@ -19,22 +19,25 @@ package main
import (
"context"
"fmt"
+ "io/ioutil"
"log"
"net"
"os"
"os/exec"
+ "regexp"
"strconv"
sys "syscall"
"time"
"flag"
"github.com/google/subcommands"
- "gvisor.dev/gvisor/runsc/test/testutil"
+ "gvisor.dev/gvisor/runsc/testutil"
)
func main() {
subcommands.Register(subcommands.HelpCommand(), "")
subcommands.Register(subcommands.FlagsCommand(), "")
+ subcommands.Register(new(capability), "")
subcommands.Register(new(fdReceiver), "")
subcommands.Register(new(fdSender), "")
subcommands.Register(new(forkBomb), "")
@@ -287,3 +290,65 @@ func (s *syscall) Execute(ctx context.Context, f *flag.FlagSet, args ...interfac
}
return subcommands.ExitSuccess
}
+
+type capability struct {
+ enabled uint64
+ disabled uint64
+}
+
+// Name implements subcommands.Command.
+func (*capability) Name() string {
+ return "capability"
+}
+
+// Synopsis implements subcommands.Command.
+func (*capability) Synopsis() string {
+ return "checks if effective capabilities are set/unset"
+}
+
+// Usage implements subcommands.Command.
+func (*capability) Usage() string {
+ return "capability [--enabled=number] [--disabled=number]"
+}
+
+// SetFlags implements subcommands.Command.
+func (c *capability) SetFlags(f *flag.FlagSet) {
+ f.Uint64Var(&c.enabled, "enabled", 0, "")
+ f.Uint64Var(&c.disabled, "disabled", 0, "")
+}
+
+// Execute implements subcommands.Command.
+func (c *capability) Execute(ctx context.Context, f *flag.FlagSet, args ...interface{}) subcommands.ExitStatus {
+ if c.enabled == 0 && c.disabled == 0 {
+ fmt.Println("One of the flags must be set")
+ return subcommands.ExitUsageError
+ }
+
+ status, err := ioutil.ReadFile("/proc/self/status")
+ if err != nil {
+ fmt.Printf("Error reading %q: %v\n", "proc/self/status", err)
+ return subcommands.ExitFailure
+ }
+ re := regexp.MustCompile("CapEff:\t([0-9a-f]+)\n")
+ matches := re.FindStringSubmatch(string(status))
+ if matches == nil || len(matches) != 2 {
+ fmt.Printf("Effective capabilities not found in\n%s\n", status)
+ return subcommands.ExitFailure
+ }
+ caps, err := strconv.ParseUint(matches[1], 16, 64)
+ if err != nil {
+ fmt.Printf("failed to convert capabilities %q: %v\n", matches[1], err)
+ return subcommands.ExitFailure
+ }
+
+ if c.enabled != 0 && (caps&c.enabled) != c.enabled {
+ fmt.Printf("Missing capabilities, want: %#x: got: %#x\n", c.enabled, caps)
+ return subcommands.ExitFailure
+ }
+ if c.disabled != 0 && (caps&c.disabled) != 0 {
+ fmt.Printf("Extra capabilities found, dont_want: %#x: got: %#x\n", c.disabled, caps)
+ return subcommands.ExitFailure
+ }
+
+ return subcommands.ExitSuccess
+}
diff --git a/runsc/criutil/BUILD b/runsc/criutil/BUILD
new file mode 100644
index 000000000..558133a0e
--- /dev/null
+++ b/runsc/criutil/BUILD
@@ -0,0 +1,12 @@
+load("@io_bazel_rules_go//go:def.bzl", "go_library")
+
+package(licenses = ["notice"])
+
+go_library(
+ name = "criutil",
+ testonly = 1,
+ srcs = ["criutil.go"],
+ importpath = "gvisor.dev/gvisor/runsc/criutil",
+ visibility = ["//:sandbox"],
+ deps = ["//runsc/testutil"],
+)
diff --git a/runsc/test/testutil/crictl.go b/runsc/criutil/criutil.go
index 4f9ee0c05..c8ddf5a9a 100644
--- a/runsc/test/testutil/crictl.go
+++ b/runsc/criutil/criutil.go
@@ -12,7 +12,10 @@
// See the License for the specific language governing permissions and
// limitations under the License.
-package testutil
+// Package criutil contains utility functions for interacting with the
+// Container Runtime Interface (CRI), principally via the crictl command line
+// tool. This requires critools to be installed on the local system.
+package criutil
import (
"encoding/json"
@@ -21,6 +24,8 @@ import (
"os/exec"
"strings"
"time"
+
+ "gvisor.dev/gvisor/runsc/testutil"
)
const endpointPrefix = "unix://"
@@ -160,11 +165,11 @@ func (cc *Crictl) StartPodAndContainer(image, sbSpec, contSpec string) (string,
}
// Write the specs to files that can be read by crictl.
- sbSpecFile, err := WriteTmpFile("sbSpec", sbSpec)
+ sbSpecFile, err := testutil.WriteTmpFile("sbSpec", sbSpec)
if err != nil {
return "", "", fmt.Errorf("failed to write sandbox spec: %v", err)
}
- contSpecFile, err := WriteTmpFile("contSpec", contSpec)
+ contSpecFile, err := testutil.WriteTmpFile("contSpec", contSpec)
if err != nil {
return "", "", fmt.Errorf("failed to write container spec: %v", err)
}
@@ -233,7 +238,7 @@ func (cc *Crictl) run(args ...string) (string, error) {
case err := <-errCh:
return "", err
case <-time.After(cc.timeout):
- if err := KillCommand(cmd); err != nil {
+ if err := testutil.KillCommand(cmd); err != nil {
return "", fmt.Errorf("timed out, then couldn't kill process %+v: %v", cmd, err)
}
return "", fmt.Errorf("timed out: %+v", cmd)
diff --git a/runsc/debian/postinst.sh b/runsc/debian/postinst.sh
index 03a5ff524..dc7aeee87 100755
--- a/runsc/debian/postinst.sh
+++ b/runsc/debian/postinst.sh
@@ -15,10 +15,10 @@
# limitations under the License.
if [ "$1" != configure ]; then
- exit 0
+ exit 0
fi
if [ -f /etc/docker/daemon.json ]; then
- /usr/libexec/runsc/dockercfg runtime-add runsc /usr/bin/runsc
- systemctl restart docker
+ runsc install
+ systemctl restart docker || echo "unable to restart docker; you must do so manually." >&2
fi
diff --git a/runsc/dockerutil/BUILD b/runsc/dockerutil/BUILD
new file mode 100644
index 000000000..0e0423504
--- /dev/null
+++ b/runsc/dockerutil/BUILD
@@ -0,0 +1,15 @@
+load("@io_bazel_rules_go//go:def.bzl", "go_library")
+
+package(licenses = ["notice"])
+
+go_library(
+ name = "dockerutil",
+ testonly = 1,
+ srcs = ["dockerutil.go"],
+ importpath = "gvisor.dev/gvisor/runsc/dockerutil",
+ visibility = ["//:sandbox"],
+ deps = [
+ "//runsc/testutil",
+ "@com_github_kr_pty//:go_default_library",
+ ],
+)
diff --git a/runsc/test/testutil/docker.go b/runsc/dockerutil/dockerutil.go
index 3f3e191b0..57f6ae8de 100644
--- a/runsc/test/testutil/docker.go
+++ b/runsc/dockerutil/dockerutil.go
@@ -12,9 +12,12 @@
// See the License for the specific language governing permissions and
// limitations under the License.
-package testutil
+// Package dockerutil is a collection of utility functions, primarily for
+// testing.
+package dockerutil
import (
+ "encoding/json"
"flag"
"fmt"
"io/ioutil"
@@ -29,26 +32,13 @@ import (
"time"
"github.com/kr/pty"
+ "gvisor.dev/gvisor/runsc/testutil"
)
-var runtimeType = flag.String("runtime-type", "", "specify which runtime to use: kvm, hostnet, overlay")
-
-func getRuntime() string {
- r, ok := os.LookupEnv("RUNSC_RUNTIME")
- if !ok {
- r = "runsc-test"
- }
- if *runtimeType != "" {
- r += "-" + *runtimeType
- }
- return r
-}
-
-// IsPauseResumeSupported returns true if Pause/Resume is supported by runtime.
-func IsPauseResumeSupported() bool {
- // Native host network stack can't be saved.
- return !strings.Contains(getRuntime(), "hostnet")
-}
+var (
+ runtime = flag.String("runtime", "runsc", "specify which runtime to use")
+ config = flag.String("config_path", "/etc/docker/daemon.json", "configuration file for reading paths")
+)
// EnsureSupportedDockerVersion checks if correct docker is installed.
func EnsureSupportedDockerVersion() {
@@ -69,6 +59,48 @@ func EnsureSupportedDockerVersion() {
}
}
+// RuntimePath returns the binary path for the current runtime.
+func RuntimePath() (string, error) {
+ // Read the configuration data; the file must exist.
+ configBytes, err := ioutil.ReadFile(*config)
+ if err != nil {
+ return "", err
+ }
+
+ // Unmarshal the configuration.
+ c := make(map[string]interface{})
+ if err := json.Unmarshal(configBytes, &c); err != nil {
+ return "", err
+ }
+
+ // Decode the expected configuration.
+ r, ok := c["runtimes"]
+ if !ok {
+ return "", fmt.Errorf("no runtimes declared: %v", c)
+ }
+ rs, ok := r.(map[string]interface{})
+ if !ok {
+ // The runtimes are not a map.
+ return "", fmt.Errorf("unexpected format: %v", c)
+ }
+ r, ok = rs[*runtime]
+ if !ok {
+ // The expected runtime is not declared.
+ return "", fmt.Errorf("runtime %q not found: %v", *runtime, c)
+ }
+ rs, ok = r.(map[string]interface{})
+ if !ok {
+ // The runtime is not a map.
+ return "", fmt.Errorf("unexpected format: %v", c)
+ }
+ p, ok := rs["path"].(string)
+ if !ok {
+ // The runtime does not declare a path.
+ return "", fmt.Errorf("unexpected format: %v", c)
+ }
+ return p, nil
+}
+
// MountMode describes if the mount should be ro or rw.
type MountMode int
@@ -113,7 +145,7 @@ func PrepareFiles(names ...string) (string, error) {
for _, name := range names {
src := getLocalPath(name)
dst := path.Join(dir, name)
- if err := Copy(src, dst); err != nil {
+ if err := testutil.Copy(src, dst); err != nil {
return "", fmt.Errorf("testutil.Copy(%q, %q) failed: %v", src, dst, err)
}
}
@@ -163,7 +195,10 @@ type Docker struct {
// MakeDocker sets up the struct for a Docker container.
// Names of containers will be unique.
func MakeDocker(namePrefix string) Docker {
- return Docker{Name: RandomName(namePrefix), Runtime: getRuntime()}
+ return Docker{
+ Name: testutil.RandomName(namePrefix),
+ Runtime: *runtime,
+ }
}
// logDockerID logs a container id, which is needed to find container runsc logs.
@@ -205,7 +240,7 @@ func (d *Docker) Stop() error {
// Run calls 'docker run' with the arguments provided. The container starts
// running in the background and the call returns immediately.
func (d *Docker) Run(args ...string) error {
- a := []string{"run", "--runtime", d.Runtime, "--name", d.Name, "-d"}
+ a := d.runArgs("-d")
a = append(a, args...)
_, err := do(a...)
if err == nil {
@@ -216,7 +251,7 @@ func (d *Docker) Run(args ...string) error {
// RunWithPty is like Run but with an attached pty.
func (d *Docker) RunWithPty(args ...string) (*exec.Cmd, *os.File, error) {
- a := []string{"run", "--runtime", d.Runtime, "--name", d.Name, "-it"}
+ a := d.runArgs("-it")
a = append(a, args...)
return doWithPty(a...)
}
@@ -224,8 +259,7 @@ func (d *Docker) RunWithPty(args ...string) (*exec.Cmd, *os.File, error) {
// RunFg calls 'docker run' with the arguments provided in the foreground. It
// blocks until the container exits and returns the output.
func (d *Docker) RunFg(args ...string) (string, error) {
- a := []string{"run", "--runtime", d.Runtime, "--name", d.Name}
- a = append(a, args...)
+ a := d.runArgs(args...)
out, err := do(a...)
if err == nil {
d.logDockerID()
@@ -233,6 +267,14 @@ func (d *Docker) RunFg(args ...string) (string, error) {
return string(out), err
}
+func (d *Docker) runArgs(args ...string) []string {
+ // Environment variable RUNSC_TEST_NAME is picked up by the runtime and added
+ // to the log name, so one can easily identify the corresponding logs for
+ // this test.
+ rv := []string{"run", "--runtime", d.Runtime, "--name", d.Name, "-e", "RUNSC_TEST_NAME=" + d.Name}
+ return append(rv, args...)
+}
+
// Logs calls 'docker logs'.
func (d *Docker) Logs() (string, error) {
return do("logs", d.Name)
@@ -240,7 +282,22 @@ func (d *Docker) Logs() (string, error) {
// Exec calls 'docker exec' with the arguments provided.
func (d *Docker) Exec(args ...string) (string, error) {
- a := []string{"exec", d.Name}
+ return d.ExecWithFlags(nil, args...)
+}
+
+// ExecWithFlags calls 'docker exec <flags> name <args>'.
+func (d *Docker) ExecWithFlags(flags []string, args ...string) (string, error) {
+ a := []string{"exec"}
+ a = append(a, flags...)
+ a = append(a, d.Name)
+ a = append(a, args...)
+ return do(a...)
+}
+
+// ExecAsUser calls 'docker exec' as the given user with the arguments
+// provided.
+func (d *Docker) ExecAsUser(user string, args ...string) (string, error) {
+ a := []string{"exec", "--user", user, d.Name}
a = append(a, args...)
return do(a...)
}
@@ -297,7 +354,11 @@ func (d *Docker) Remove() error {
func (d *Docker) CleanUp() {
d.logDockerID()
if _, err := do("kill", d.Name); err != nil {
- log.Printf("error killing container %q: %v", d.Name, err)
+ if strings.Contains(err.Error(), "is not running") {
+ // Nothing to kill. Don't log the error in this case.
+ } else {
+ log.Printf("error killing container %q: %v", d.Name, err)
+ }
}
if err := d.Remove(); err != nil {
log.Print(err)
diff --git a/runsc/fsgofer/filter/BUILD b/runsc/fsgofer/filter/BUILD
index e2318a978..02168ad1b 100644
--- a/runsc/fsgofer/filter/BUILD
+++ b/runsc/fsgofer/filter/BUILD
@@ -17,6 +17,7 @@ go_library(
],
deps = [
"//pkg/abi/linux",
+ "//pkg/flipcall",
"//pkg/log",
"//pkg/seccomp",
"@org_golang_x_sys//unix:go_default_library",
diff --git a/runsc/fsgofer/filter/config.go b/runsc/fsgofer/filter/config.go
index 8ddfa77d6..c7922b54f 100644
--- a/runsc/fsgofer/filter/config.go
+++ b/runsc/fsgofer/filter/config.go
@@ -83,6 +83,11 @@ var allowedSyscalls = seccomp.SyscallRules{
seccomp.AllowAny{},
seccomp.AllowValue(syscall.F_GETFD),
},
+ // Used by flipcall.PacketWindowAllocator.Init().
+ {
+ seccomp.AllowAny{},
+ seccomp.AllowValue(unix.F_ADD_SEALS),
+ },
},
syscall.SYS_FSTAT: {},
syscall.SYS_FSTATFS: {},
@@ -103,6 +108,19 @@ var allowedSyscalls = seccomp.SyscallRules{
seccomp.AllowAny{},
seccomp.AllowValue(0),
},
+ // Non-private futex used for flipcall.
+ seccomp.Rule{
+ seccomp.AllowAny{},
+ seccomp.AllowValue(linux.FUTEX_WAIT),
+ seccomp.AllowAny{},
+ seccomp.AllowAny{},
+ },
+ seccomp.Rule{
+ seccomp.AllowAny{},
+ seccomp.AllowValue(linux.FUTEX_WAKE),
+ seccomp.AllowAny{},
+ seccomp.AllowAny{},
+ },
},
syscall.SYS_GETDENTS64: {},
syscall.SYS_GETPID: {},
@@ -112,6 +130,7 @@ var allowedSyscalls = seccomp.SyscallRules{
syscall.SYS_LINKAT: {},
syscall.SYS_LSEEK: {},
syscall.SYS_MADVISE: {},
+ unix.SYS_MEMFD_CREATE: {}, /// Used by flipcall.PacketWindowAllocator.Init().
syscall.SYS_MKDIRAT: {},
syscall.SYS_MMAP: []seccomp.Rule{
{
@@ -160,6 +179,13 @@ var allowedSyscalls = seccomp.SyscallRules{
syscall.SYS_RT_SIGPROCMASK: {},
syscall.SYS_SCHED_YIELD: {},
syscall.SYS_SENDMSG: []seccomp.Rule{
+ // Used by fdchannel.Endpoint.SendFD().
+ {
+ seccomp.AllowAny{},
+ seccomp.AllowAny{},
+ seccomp.AllowValue(0),
+ },
+ // Used by unet.SocketWriter.WriteVec().
{
seccomp.AllowAny{},
seccomp.AllowAny{},
@@ -170,7 +196,15 @@ var allowedSyscalls = seccomp.SyscallRules{
{seccomp.AllowAny{}, seccomp.AllowValue(syscall.SHUT_RDWR)},
},
syscall.SYS_SIGALTSTACK: {},
- syscall.SYS_SYMLINKAT: {},
+ // Used by fdchannel.NewConnectedSockets().
+ syscall.SYS_SOCKETPAIR: {
+ {
+ seccomp.AllowValue(syscall.AF_UNIX),
+ seccomp.AllowValue(syscall.SOCK_SEQPACKET | syscall.SOCK_CLOEXEC),
+ seccomp.AllowValue(0),
+ },
+ },
+ syscall.SYS_SYMLINKAT: {},
syscall.SYS_TGKILL: []seccomp.Rule{
{
seccomp.AllowValue(uint64(os.Getpid())),
@@ -180,3 +214,16 @@ var allowedSyscalls = seccomp.SyscallRules{
syscall.SYS_UTIMENSAT: {},
syscall.SYS_WRITE: {},
}
+
+var udsSyscalls = seccomp.SyscallRules{
+ syscall.SYS_SOCKET: []seccomp.Rule{
+ {
+ seccomp.AllowValue(syscall.AF_UNIX),
+ },
+ },
+ syscall.SYS_CONNECT: []seccomp.Rule{
+ {
+ seccomp.AllowAny{},
+ },
+ },
+}
diff --git a/runsc/fsgofer/filter/filter.go b/runsc/fsgofer/filter/filter.go
index 65053415f..289886720 100644
--- a/runsc/fsgofer/filter/filter.go
+++ b/runsc/fsgofer/filter/filter.go
@@ -23,11 +23,16 @@ import (
// Install installs seccomp filters.
func Install() error {
- s := allowedSyscalls
-
// Set of additional filters used by -race and -msan. Returns empty
// when not enabled.
- s.Merge(instrumentationFilters())
+ allowedSyscalls.Merge(instrumentationFilters())
+
+ return seccomp.Install(allowedSyscalls)
+}
- return seccomp.Install(s)
+// InstallUDSFilters extends the allowed syscalls to include those necessary for
+// connecting to a host UDS.
+func InstallUDSFilters() {
+ // Add additional filters required for connecting to the host's sockets.
+ allowedSyscalls.Merge(udsSyscalls)
}
diff --git a/runsc/fsgofer/fsgofer.go b/runsc/fsgofer/fsgofer.go
index 7c4d2b94e..29a82138e 100644
--- a/runsc/fsgofer/fsgofer.go
+++ b/runsc/fsgofer/fsgofer.go
@@ -21,6 +21,7 @@
package fsgofer
import (
+ "errors"
"fmt"
"io"
"math"
@@ -54,6 +55,7 @@ const (
regular fileType = iota
directory
symlink
+ socket
unknown
)
@@ -66,6 +68,8 @@ func (f fileType) String() string {
return "directory"
case symlink:
return "symlink"
+ case socket:
+ return "socket"
}
return "unknown"
}
@@ -82,6 +86,9 @@ type Config struct {
// PanicOnWrite panics on attempts to write to RO mounts.
PanicOnWrite bool
+
+ // HostUDS signals whether the gofer can mount a host's UDS.
+ HostUDS bool
}
type attachPoint struct {
@@ -124,24 +131,50 @@ func (a *attachPoint) Attach() (p9.File, error) {
if err != nil {
return nil, fmt.Errorf("stat file %q, err: %v", a.prefix, err)
}
- mode := syscall.O_RDWR
- if a.conf.ROMount || (stat.Mode&syscall.S_IFMT) == syscall.S_IFDIR {
- mode = syscall.O_RDONLY
- }
-
- // Open the root directory.
- f, err := fd.Open(a.prefix, openFlags|mode, 0)
- if err != nil {
- return nil, fmt.Errorf("unable to open file %q, err: %v", a.prefix, err)
- }
+ // Acquire the attach point lock.
a.attachedMu.Lock()
defer a.attachedMu.Unlock()
+
if a.attached {
- f.Close()
return nil, fmt.Errorf("attach point already attached, prefix: %s", a.prefix)
}
+ // Hold the file descriptor we are converting into a p9.File.
+ var f *fd.FD
+
+ // Apply the S_IFMT bitmask so we can detect file type appropriately.
+ switch fmtStat := stat.Mode & syscall.S_IFMT; fmtStat {
+ case syscall.S_IFSOCK:
+ // Check to see if the CLI option has been set to allow the UDS mount.
+ if !a.conf.HostUDS {
+ return nil, errors.New("host UDS support is disabled")
+ }
+
+ // Attempt to open a connection. Bubble up the failures.
+ f, err = fd.DialUnix(a.prefix)
+ if err != nil {
+ return nil, err
+ }
+
+ default:
+ // Default to Read/Write permissions.
+ mode := syscall.O_RDWR
+
+ // If the configuration is Read Only or the mount point is a directory,
+ // set the mode to Read Only.
+ if a.conf.ROMount || fmtStat == syscall.S_IFDIR {
+ mode = syscall.O_RDONLY
+ }
+
+ // Open the mount point & capture the FD.
+ f, err = fd.Open(a.prefix, openFlags|mode, 0)
+ if err != nil {
+ return nil, fmt.Errorf("unable to open file %q, err: %v", a.prefix, err)
+ }
+ }
+
+ // Return a localFile object to the caller with the UDS FD included.
rv, err := newLocalFile(a, f, a.prefix, stat)
if err != nil {
return nil, err
@@ -295,7 +328,7 @@ func openAnyFile(path string, fn func(mode int) (*fd.FD, error)) (*fd.FD, error)
return file, nil
}
-func getSupportedFileType(stat syscall.Stat_t) (fileType, error) {
+func getSupportedFileType(stat syscall.Stat_t, permitSocket bool) (fileType, error) {
var ft fileType
switch stat.Mode & syscall.S_IFMT {
case syscall.S_IFREG:
@@ -304,6 +337,11 @@ func getSupportedFileType(stat syscall.Stat_t) (fileType, error) {
ft = directory
case syscall.S_IFLNK:
ft = symlink
+ case syscall.S_IFSOCK:
+ if !permitSocket {
+ return unknown, syscall.EPERM
+ }
+ ft = socket
default:
return unknown, syscall.EPERM
}
@@ -311,7 +349,7 @@ func getSupportedFileType(stat syscall.Stat_t) (fileType, error) {
}
func newLocalFile(a *attachPoint, file *fd.FD, path string, stat syscall.Stat_t) (*localFile, error) {
- ft, err := getSupportedFileType(stat)
+ ft, err := getSupportedFileType(stat, a.conf.HostUDS)
if err != nil {
return nil, err
}
@@ -1026,7 +1064,11 @@ func (l *localFile) Flush() error {
// Connect implements p9.File.
func (l *localFile) Connect(p9.ConnectFlags) (*fd.FD, error) {
- return nil, syscall.ECONNREFUSED
+ // Check to see if the CLI option has been set to allow the UDS mount.
+ if !l.attachPoint.conf.HostUDS {
+ return nil, syscall.ECONNREFUSED
+ }
+ return fd.DialUnix(l.hostPath)
}
// Close implements p9.File.
diff --git a/runsc/fsgofer/fsgofer_test.go b/runsc/fsgofer/fsgofer_test.go
index c86beaef1..05af7e397 100644
--- a/runsc/fsgofer/fsgofer_test.go
+++ b/runsc/fsgofer/fsgofer_test.go
@@ -635,7 +635,15 @@ func TestAttachInvalidType(t *testing.T) {
t.Fatalf("Mkfifo(%q): %v", fifo, err)
}
- socket := filepath.Join(dir, "socket")
+ dirFile, err := os.Open(dir)
+ if err != nil {
+ t.Fatalf("Open(%s): %v", dir, err)
+ }
+ defer dirFile.Close()
+
+ // Bind a socket via /proc to be sure that a length of a socket path
+ // is less than UNIX_PATH_MAX.
+ socket := filepath.Join(fmt.Sprintf("/proc/self/fd/%d", dirFile.Fd()), "socket")
l, err := net.Listen("unix", socket)
if err != nil {
t.Fatalf("net.Listen(unix, %q): %v", socket, err)
@@ -657,7 +665,7 @@ func TestAttachInvalidType(t *testing.T) {
}
f, err := a.Attach()
if f != nil || err == nil {
- t.Fatalf("Attach should have failed, got (%v, nil)", f)
+ t.Fatalf("Attach should have failed, got (%v, %v)", f, err)
}
})
}
diff --git a/runsc/main.go b/runsc/main.go
index c61583441..7dce9dc00 100644
--- a/runsc/main.go
+++ b/runsc/main.go
@@ -31,6 +31,7 @@ import (
"github.com/google/subcommands"
"gvisor.dev/gvisor/pkg/log"
+ "gvisor.dev/gvisor/pkg/refs"
"gvisor.dev/gvisor/pkg/sentry/platform"
"gvisor.dev/gvisor/runsc/boot"
"gvisor.dev/gvisor/runsc/cmd"
@@ -67,6 +68,7 @@ var (
network = flag.String("network", "sandbox", "specifies which network to use: sandbox (default), host, none. Using network inside the sandbox is more secure because it's isolated from the host network.")
gso = flag.Bool("gso", true, "enable generic segmenation offload")
fileAccess = flag.String("file-access", "exclusive", "specifies which filesystem to use for the root mount: exclusive (default), shared. Volume mounts are always shared.")
+ fsGoferHostUDS = flag.Bool("fsgofer-host-uds", false, "Allow the gofer to mount Unix Domain Sockets.")
overlay = flag.Bool("overlay", false, "wrap filesystem mounts with writable overlay. All modifications are stored in memory inside the sandbox.")
watchdogAction = flag.String("watchdog-action", "log", "sets what action the watchdog takes when triggered: log (default), panic.")
panicSignal = flag.Int("panic-signal", -1, "register signal handling that panics. Usually set to SIGUSR2(12) to troubleshoot hangs. -1 disables it.")
@@ -74,9 +76,11 @@ var (
netRaw = flag.Bool("net-raw", false, "enable raw sockets. When false, raw sockets are disabled by removing CAP_NET_RAW from containers (`runsc exec` will still be able to utilize raw sockets). Raw sockets allow malicious containers to craft packets and potentially attack the network.")
numNetworkChannels = flag.Int("num-network-channels", 1, "number of underlying channels(FDs) to use for network link endpoints.")
rootless = flag.Bool("rootless", false, "it allows the sandbox to be started with a user that is not root. Sandbox and Gofer processes may run with same privileges as current user.")
+ referenceLeakMode = flag.String("ref-leak-mode", "disabled", "sets reference leak check mode: disabled (default), log-names, log-traces.")
// Test flags, not to be used outside tests, ever.
testOnlyAllowRunAsCurrentUserWithoutChroot = flag.Bool("TESTONLY-unsafe-nonroot", false, "TEST ONLY; do not ever use! This skips many security measures that isolate the host from the sandbox.")
+ testOnlyTestNameEnv = flag.String("TESTONLY-test-name-env", "", "TEST ONLY; do not ever use! Used for automated tests to improve logging.")
)
func main() {
@@ -86,6 +90,11 @@ func main() {
subcommands.Register(help, "")
subcommands.Register(subcommands.FlagsCommand(), "")
+ // Installation helpers.
+ const helperGroup = "helpers"
+ subcommands.Register(new(cmd.Install), helperGroup)
+ subcommands.Register(new(cmd.Uninstall), helperGroup)
+
// Register user-facing runsc commands.
subcommands.Register(new(cmd.Checkpoint), "")
subcommands.Register(new(cmd.Create), "")
@@ -169,6 +178,15 @@ func main() {
cmd.Fatalf("num_network_channels must be > 0, got: %d", *numNetworkChannels)
}
+ refsLeakMode, err := boot.MakeRefsLeakMode(*referenceLeakMode)
+ if err != nil {
+ cmd.Fatalf("%v", err)
+ }
+
+ // Sets the reference leak check mode. Also set it in config below to
+ // propagate it to child processes.
+ refs.SetLeakMode(refsLeakMode)
+
// Create a new Config from the flags.
conf := &boot.Config{
RootDir: *rootDir,
@@ -178,6 +196,7 @@ func main() {
DebugLog: *debugLog,
DebugLogFormat: *debugLogFormat,
FileAccess: fsAccess,
+ FSGoferHostUDS: *fsGoferHostUDS,
Overlay: *overlay,
Network: netType,
GSO: *gso,
@@ -192,8 +211,10 @@ func main() {
NumNetworkChannels: *numNetworkChannels,
Rootless: *rootless,
AlsoLogToStderr: *alsoLogToStderr,
+ ReferenceLeakMode: refsLeakMode,
TestOnlyAllowRunAsCurrentUserWithoutChroot: *testOnlyAllowRunAsCurrentUserWithoutChroot,
+ TestOnlyTestNameEnv: *testOnlyTestNameEnv,
}
if len(*straceSyscalls) != 0 {
conf.StraceSyscalls = strings.Split(*straceSyscalls, ",")
@@ -220,14 +241,14 @@ func main() {
// want with them. Since Docker and Containerd both eat boot's stderr, we
// dup our stderr to the provided log FD so that panics will appear in the
// logs, rather than just disappear.
- if err := syscall.Dup2(int(f.Fd()), int(os.Stderr.Fd())); err != nil {
+ if err := syscall.Dup3(int(f.Fd()), int(os.Stderr.Fd()), 0); err != nil {
cmd.Fatalf("error dup'ing fd %d to stderr: %v", f.Fd(), err)
}
e = newEmitter(*debugLogFormat, f)
} else if *debugLog != "" {
- f, err := specutils.DebugLogFile(*debugLog, subcommand)
+ f, err := specutils.DebugLogFile(*debugLog, subcommand, "" /* name */)
if err != nil {
cmd.Fatalf("error opening debug log file in %q: %v", *debugLog, err)
}
diff --git a/runsc/sandbox/sandbox.go b/runsc/sandbox/sandbox.go
index df3c0c5ef..ee9327fc8 100644
--- a/runsc/sandbox/sandbox.go
+++ b/runsc/sandbox/sandbox.go
@@ -351,7 +351,15 @@ func (s *Sandbox) createSandboxProcess(conf *boot.Config, args *Args, startSyncF
nextFD++
}
if conf.DebugLog != "" {
- debugLogFile, err := specutils.DebugLogFile(conf.DebugLog, "boot")
+ test := ""
+ if len(conf.TestOnlyTestNameEnv) != 0 {
+ // Fetch test name if one is provided and the test only flag was set.
+ if t, ok := specutils.EnvVar(args.Spec.Process.Env, conf.TestOnlyTestNameEnv); ok {
+ test = t
+ }
+ }
+
+ debugLogFile, err := specutils.DebugLogFile(conf.DebugLog, "boot", test)
if err != nil {
return fmt.Errorf("opening debug log file in %q: %v", conf.DebugLog, err)
}
diff --git a/runsc/specutils/BUILD b/runsc/specutils/BUILD
index fbfb8e2f8..fa58313a0 100644
--- a/runsc/specutils/BUILD
+++ b/runsc/specutils/BUILD
@@ -13,6 +13,7 @@ go_library(
visibility = ["//:sandbox"],
deps = [
"//pkg/abi/linux",
+ "//pkg/bits",
"//pkg/log",
"//pkg/sentry/kernel/auth",
"@com_github_cenkalti_backoff//:go_default_library",
diff --git a/runsc/specutils/namespace.go b/runsc/specutils/namespace.go
index d441419cb..c7dd3051c 100644
--- a/runsc/specutils/namespace.go
+++ b/runsc/specutils/namespace.go
@@ -33,19 +33,19 @@ import (
func nsCloneFlag(nst specs.LinuxNamespaceType) uintptr {
switch nst {
case specs.IPCNamespace:
- return syscall.CLONE_NEWIPC
+ return unix.CLONE_NEWIPC
case specs.MountNamespace:
- return syscall.CLONE_NEWNS
+ return unix.CLONE_NEWNS
case specs.NetworkNamespace:
- return syscall.CLONE_NEWNET
+ return unix.CLONE_NEWNET
case specs.PIDNamespace:
- return syscall.CLONE_NEWPID
+ return unix.CLONE_NEWPID
case specs.UTSNamespace:
- return syscall.CLONE_NEWUTS
+ return unix.CLONE_NEWUTS
case specs.UserNamespace:
- return syscall.CLONE_NEWUSER
+ return unix.CLONE_NEWUSER
case specs.CgroupNamespace:
- panic("cgroup namespace has no associated clone flag")
+ return unix.CLONE_NEWCGROUP
default:
panic(fmt.Sprintf("unknown namespace %v", nst))
}
diff --git a/runsc/specutils/specutils.go b/runsc/specutils/specutils.go
index 2eec92349..591abe458 100644
--- a/runsc/specutils/specutils.go
+++ b/runsc/specutils/specutils.go
@@ -23,6 +23,7 @@ import (
"os"
"path"
"path/filepath"
+ "strconv"
"strings"
"syscall"
"time"
@@ -30,6 +31,7 @@ import (
"github.com/cenkalti/backoff"
specs "github.com/opencontainers/runtime-spec/specs-go"
"gvisor.dev/gvisor/pkg/abi/linux"
+ "gvisor.dev/gvisor/pkg/bits"
"gvisor.dev/gvisor/pkg/log"
"gvisor.dev/gvisor/pkg/sentry/kernel/auth"
)
@@ -240,6 +242,15 @@ func AllCapabilities() *specs.LinuxCapabilities {
}
}
+// AllCapabilitiesUint64 returns a bitmask containing all capabilities set.
+func AllCapabilitiesUint64() uint64 {
+ var rv uint64
+ for _, cap := range capFromName {
+ rv |= bits.MaskOf64(int(cap))
+ }
+ return rv
+}
+
var capFromName = map[string]linux.Capability{
"CAP_CHOWN": linux.CAP_CHOWN,
"CAP_DAC_OVERRIDE": linux.CAP_DAC_OVERRIDE,
@@ -398,13 +409,15 @@ func WaitForReady(pid int, timeout time.Duration, ready func() (bool, error)) er
// - %TIMESTAMP%: is replaced with a timestamp using the following format:
// <yyyymmdd-hhmmss.uuuuuu>
// - %COMMAND%: is replaced with 'command'
-func DebugLogFile(logPattern, command string) (*os.File, error) {
+// - %TEST%: is replaced with 'test' (omitted by default)
+func DebugLogFile(logPattern, command, test string) (*os.File, error) {
if strings.HasSuffix(logPattern, "/") {
// Default format: <debug-log>/runsc.log.<yyyymmdd-hhmmss.uuuuuu>.<command>
logPattern += "runsc.log.%TIMESTAMP%.%COMMAND%"
}
logPattern = strings.Replace(logPattern, "%TIMESTAMP%", time.Now().Format("20060102-150405.000000"), -1)
logPattern = strings.Replace(logPattern, "%COMMAND%", command, -1)
+ logPattern = strings.Replace(logPattern, "%TEST%", test, -1)
dir := filepath.Dir(logPattern)
if err := os.MkdirAll(dir, 0775); err != nil {
@@ -503,3 +516,53 @@ func RetryEintr(f func() (uintptr, uintptr, error)) (uintptr, uintptr, error) {
}
}
}
+
+// GetOOMScoreAdj reads the given process' oom_score_adj
+func GetOOMScoreAdj(pid int) (int, error) {
+ data, err := ioutil.ReadFile(fmt.Sprintf("/proc/%d/oom_score_adj", pid))
+ if err != nil {
+ return 0, err
+ }
+ return strconv.Atoi(strings.TrimSpace(string(data)))
+}
+
+// GetParentPid gets the parent process ID of the specified PID.
+func GetParentPid(pid int) (int, error) {
+ data, err := ioutil.ReadFile(fmt.Sprintf("/proc/%d/stat", pid))
+ if err != nil {
+ return 0, err
+ }
+
+ var cpid string
+ var name string
+ var state string
+ var ppid int
+ // Parse after the binary name.
+ _, err = fmt.Sscanf(string(data),
+ "%v %v %v %d",
+ // cpid is ignored.
+ &cpid,
+ // name is ignored.
+ &name,
+ // state is ignored.
+ &state,
+ &ppid)
+
+ if err != nil {
+ return 0, err
+ }
+
+ return ppid, nil
+}
+
+// EnvVar looks for a varible value in the env slice assuming the following
+// format: "NAME=VALUE".
+func EnvVar(env []string, name string) (string, bool) {
+ prefix := name + "="
+ for _, e := range env {
+ if strings.HasPrefix(e, prefix) {
+ return strings.TrimPrefix(e, prefix), true
+ }
+ }
+ return "", false
+}
diff --git a/runsc/test/BUILD b/runsc/test/BUILD
deleted file mode 100644
index e69de29bb..000000000
--- a/runsc/test/BUILD
+++ /dev/null
diff --git a/runsc/test/README.md b/runsc/test/README.md
deleted file mode 100644
index f22a8e017..000000000
--- a/runsc/test/README.md
+++ /dev/null
@@ -1,24 +0,0 @@
-# Tests
-
-The tests defined under this path are verifying functionality beyond what unit
-tests can cover, e.g. integration and end to end tests. Due to their nature,
-they may need extra setup in the test machine and extra configuration to run.
-
-- **integration:** defines integration tests that uses `docker run` to test
- functionality.
-- **image:** basic end to end test for popular images.
-- **root:** tests that require to be run as root.
-- **testutil:** utilities library to support the tests.
-
-The following setup steps are required in order to run these tests:
-
- `./runsc/test/install.sh [--runtime <name>]`
-
-The tests expect the runtime name to be provided in the `RUNSC_RUNTIME`
-environment variable (default: `runsc-test`). To run the tests execute:
-
-```
-bazel test --test_env=RUNSC_RUNTIME=runsc-test \
- //runsc/test/image:image_test \
- //runsc/test/integration:integration_test
-```
diff --git a/runsc/test/build_defs.bzl b/runsc/test/build_defs.bzl
deleted file mode 100644
index ac28cc037..000000000
--- a/runsc/test/build_defs.bzl
+++ /dev/null
@@ -1,19 +0,0 @@
-"""Defines a rule for runsc test targets."""
-
-load("@io_bazel_rules_go//go:def.bzl", _go_test = "go_test")
-
-# runtime_test is a macro that will create targets to run the given test target
-# with different runtime options.
-def runtime_test(**kwargs):
- """Runs the given test target with different runtime options."""
- name = kwargs["name"]
- _go_test(**kwargs)
- kwargs["name"] = name + "_hostnet"
- kwargs["args"] = ["--runtime-type=hostnet"]
- _go_test(**kwargs)
- kwargs["name"] = name + "_kvm"
- kwargs["args"] = ["--runtime-type=kvm"]
- _go_test(**kwargs)
- kwargs["name"] = name + "_overlay"
- kwargs["args"] = ["--runtime-type=overlay"]
- _go_test(**kwargs)
diff --git a/runsc/test/install.sh b/runsc/test/install.sh
deleted file mode 100755
index 8f05dea20..000000000
--- a/runsc/test/install.sh
+++ /dev/null
@@ -1,93 +0,0 @@
-#!/bin/bash
-
-# Copyright 2018 The gVisor Authors.
-#
-# Licensed under the Apache License, Version 2.0 (the "License");
-# you may not use this file except in compliance with the License.
-# You may obtain a copy of the License at
-#
-# http://www.apache.org/licenses/LICENSE-2.0
-#
-# Unless required by applicable law or agreed to in writing, software
-# distributed under the License is distributed on an "AS IS" BASIS,
-# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-# See the License for the specific language governing permissions and
-# limitations under the License.
-
-# Fail on any error
-set -e
-
-# Defaults
-declare runtime=runsc-test
-declare uninstall=0
-
-function findExe() {
- local exe=${1}
-
- local path=$(find bazel-bin/runsc -type f -executable -name "${exe}" | head -n1)
- if [[ "${path}" == "" ]]; then
- echo "Location of ${exe} not found in bazel-bin" >&2
- exit 1
- fi
- echo "${path}"
-}
-
-while [[ $# -gt 0 ]]; do
- case "$1" in
- --runtime)
- shift
- [ "$#" -le 0 ] && echo "No runtime provided" && exit 1
- runtime=$1
- ;;
- -u)
- uninstall=1
- ;;
- *)
- echo "Unknown option: ${1}"
- echo ""
- echo "Usage: ${0} [--runtime <name>] [-u]"
- echo " --runtime sets the runtime name, default: runsc-test"
- echo " -u uninstall the runtime"
- exit 1
- esac
- shift
-done
-
-# Find location of executables.
-declare -r dockercfg=$(findExe dockercfg)
-[[ "${dockercfg}" == "" ]] && exit 1
-
-declare runsc=$(findExe runsc)
-[[ "${runsc}" == "" ]] && exit 1
-
-if [[ ${uninstall} == 0 ]]; then
- rm -rf /tmp/${runtime}
- mkdir -p /tmp/${runtime}
- cp "${runsc}" /tmp/${runtime}/runsc
- runsc=/tmp/${runtime}/runsc
-
- # Make tmp dir and runsc binary readable and executable to all users, since it
- # will run in an empty user namespace.
- chmod a+rx "${runsc}" $(dirname "${runsc}")
-
- # Make log dir executable and writable to all users for the same reason.
- declare logdir=/tmp/"${runtime?}/logs"
- mkdir -p "${logdir}"
- sudo -n chmod a+wx "${logdir}"
-
- declare -r args="--debug-log '${logdir}/' --debug --strace --log-packets"
- # experimental is needed to checkpoint/restore.
- sudo -n "${dockercfg}" --experimental=true runtime-add "${runtime}" "${runsc}" ${args}
- sudo -n "${dockercfg}" runtime-add "${runtime}"-kvm "${runsc}" --platform=kvm ${args}
- sudo -n "${dockercfg}" runtime-add "${runtime}"-hostnet "${runsc}" --network=host ${args}
- sudo -n "${dockercfg}" runtime-add "${runtime}"-overlay "${runsc}" --overlay ${args}
-
-else
- sudo -n "${dockercfg}" runtime-rm "${runtime}"
- sudo -n "${dockercfg}" runtime-rm "${runtime}"-kvm
- sudo -n "${dockercfg}" runtime-rm "${runtime}"-hostnet
- sudo -n "${dockercfg}" runtime-rm "${runtime}"-overlay
-fi
-
-echo "Restarting docker service..."
-sudo -n /etc/init.d/docker restart
diff --git a/runsc/test/integration/exec_test.go b/runsc/test/integration/exec_test.go
deleted file mode 100644
index 993136f96..000000000
--- a/runsc/test/integration/exec_test.go
+++ /dev/null
@@ -1,161 +0,0 @@
-// Copyright 2018 The gVisor Authors.
-//
-// Licensed under the Apache License, Version 2.0 (the "License");
-// you may not use this file except in compliance with the License.
-// You may obtain a copy of the License at
-//
-// http://www.apache.org/licenses/LICENSE-2.0
-//
-// Unless required by applicable law or agreed to in writing, software
-// distributed under the License is distributed on an "AS IS" BASIS,
-// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-// See the License for the specific language governing permissions and
-// limitations under the License.
-
-// Package image provides end-to-end integration tests for runsc. These tests require
-// docker and runsc to be installed on the machine. To set it up, run:
-//
-// ./runsc/test/install.sh [--runtime <name>]
-//
-// The tests expect the runtime name to be provided in the RUNSC_RUNTIME
-// environment variable (default: runsc-test).
-//
-// Each test calls docker commands to start up a container, and tests that it is
-// behaving properly, with various runsc commands. The container is killed and deleted
-// at the end.
-
-package integration
-
-import (
- "fmt"
- "strconv"
- "strings"
- "syscall"
- "testing"
- "time"
-
- "gvisor.dev/gvisor/pkg/abi/linux"
- "gvisor.dev/gvisor/runsc/test/testutil"
-)
-
-func TestExecCapabilities(t *testing.T) {
- if err := testutil.Pull("alpine"); err != nil {
- t.Fatalf("docker pull failed: %v", err)
- }
- d := testutil.MakeDocker("exec-test")
-
- // Start the container.
- if err := d.Run("alpine", "sh", "-c", "cat /proc/self/status; sleep 100"); err != nil {
- t.Fatalf("docker run failed: %v", err)
- }
- defer d.CleanUp()
-
- matches, err := d.WaitForOutputSubmatch("CapEff:\t([0-9a-f]+)\n", 5*time.Second)
- if err != nil {
- t.Fatalf("WaitForOutputSubmatch() timeout: %v", err)
- }
- if len(matches) != 2 {
- t.Fatalf("There should be a match for the whole line and the capability bitmask")
- }
- capString := matches[1]
- t.Log("Root capabilities:", capString)
-
- // CAP_NET_RAW was in the capability set for the container, but was
- // removed. However, `exec` does not remove it. Verify that it's not
- // set in the container, then re-add it for comparison.
- caps, err := strconv.ParseUint(capString, 16, 64)
- if err != nil {
- t.Fatalf("failed to convert capabilities %q: %v", capString, err)
- }
- if caps&(1<<uint64(linux.CAP_NET_RAW)) != 0 {
- t.Fatalf("CAP_NET_RAW should be filtered, but is set in the container: %x", caps)
- }
- caps |= 1 << uint64(linux.CAP_NET_RAW)
- want := fmt.Sprintf("CapEff:\t%016x\n", caps)
-
- // Now check that exec'd process capabilities match the root.
- got, err := d.Exec("grep", "CapEff:", "/proc/self/status")
- if err != nil {
- t.Fatalf("docker exec failed: %v", err)
- }
- if got != want {
- t.Errorf("wrong capabilities, got: %q, want: %q", got, want)
- }
-}
-
-func TestExecJobControl(t *testing.T) {
- if err := testutil.Pull("alpine"); err != nil {
- t.Fatalf("docker pull failed: %v", err)
- }
- d := testutil.MakeDocker("exec-job-control-test")
-
- // Start the container.
- if err := d.Run("alpine", "sleep", "1000"); err != nil {
- t.Fatalf("docker run failed: %v", err)
- }
- defer d.CleanUp()
-
- // Exec 'sh' with an attached pty.
- cmd, ptmx, err := d.ExecWithTerminal("sh")
- if err != nil {
- t.Fatalf("docker exec failed: %v", err)
- }
- defer ptmx.Close()
-
- // Call "sleep 100 | cat" in the shell. We pipe to cat so that there
- // will be two processes in the foreground process group.
- if _, err := ptmx.Write([]byte("sleep 100 | cat\n")); err != nil {
- t.Fatalf("error writing to pty: %v", err)
- }
-
- // Give shell a few seconds to start executing the sleep.
- time.Sleep(2 * time.Second)
-
- // Send a ^C to the pty, which should kill sleep and cat, but not the
- // shell. \x03 is ASCII "end of text", which is the same as ^C.
- if _, err := ptmx.Write([]byte{'\x03'}); err != nil {
- t.Fatalf("error writing to pty: %v", err)
- }
-
- // The shell should still be alive at this point. Sleep should have
- // exited with code 2+128=130. We'll exit with 10 plus that number, so
- // that we can be sure that the shell did not get signalled.
- if _, err := ptmx.Write([]byte("exit $(expr $? + 10)\n")); err != nil {
- t.Fatalf("error writing to pty: %v", err)
- }
-
- // Exec process should exit with code 10+130=140.
- ps, err := cmd.Process.Wait()
- if err != nil {
- t.Fatalf("error waiting for exec process: %v", err)
- }
- ws := ps.Sys().(syscall.WaitStatus)
- if !ws.Exited() {
- t.Errorf("ws.Exited got false, want true")
- }
- if got, want := ws.ExitStatus(), 140; got != want {
- t.Errorf("ws.ExitedStatus got %d, want %d", got, want)
- }
-}
-
-// Test that failure to exec returns proper error message.
-func TestExecError(t *testing.T) {
- if err := testutil.Pull("alpine"); err != nil {
- t.Fatalf("docker pull failed: %v", err)
- }
- d := testutil.MakeDocker("exec-error-test")
-
- // Start the container.
- if err := d.Run("alpine", "sleep", "1000"); err != nil {
- t.Fatalf("docker run failed: %v", err)
- }
- defer d.CleanUp()
-
- _, err := d.Exec("no_can_find")
- if err == nil {
- t.Fatalf("docker exec didn't fail")
- }
- if want := `error finding executable "no_can_find" in PATH`; !strings.Contains(err.Error(), want) {
- t.Fatalf("docker exec wrong error, got: %s, want: .*%s.*", err.Error(), want)
- }
-}
diff --git a/runsc/test/testutil/BUILD b/runsc/testutil/BUILD
index 327e7ca4d..d44ebc906 100644
--- a/runsc/test/testutil/BUILD
+++ b/runsc/testutil/BUILD
@@ -4,19 +4,14 @@ package(licenses = ["notice"])
go_library(
name = "testutil",
- srcs = [
- "crictl.go",
- "docker.go",
- "testutil.go",
- "testutil_race.go",
- ],
- importpath = "gvisor.dev/gvisor/runsc/test/testutil",
+ testonly = 1,
+ srcs = ["testutil.go"],
+ importpath = "gvisor.dev/gvisor/runsc/testutil",
visibility = ["//:sandbox"],
deps = [
"//runsc/boot",
"//runsc/specutils",
"@com_github_cenkalti_backoff//:go_default_library",
- "@com_github_kr_pty//:go_default_library",
"@com_github_opencontainers_runtime-spec//specs-go:go_default_library",
],
)
diff --git a/runsc/test/testutil/testutil.go b/runsc/testutil/testutil.go
index 4a3dfa0e3..edf8b126c 100644
--- a/runsc/test/testutil/testutil.go
+++ b/runsc/testutil/testutil.go
@@ -18,18 +18,22 @@ package testutil
import (
"bufio"
"context"
+ "debug/elf"
"encoding/base32"
"encoding/json"
+ "flag"
"fmt"
"io"
"io/ioutil"
"log"
+ "math"
"math/rand"
"net/http"
"os"
"os/exec"
"os/signal"
"path/filepath"
+ "strconv"
"strings"
"sync"
"sync/atomic"
@@ -42,12 +46,18 @@ import (
"gvisor.dev/gvisor/runsc/specutils"
)
+var (
+ checkpoint = flag.Bool("checkpoint", true, "control checkpoint/restore support")
+)
+
func init() {
rand.Seed(time.Now().UnixNano())
}
-// RaceEnabled is set to true if it was built with '--race' option.
-var RaceEnabled = false
+// IsCheckpointSupported returns the relevant command line flag.
+func IsCheckpointSupported() bool {
+ return *checkpoint
+}
// TmpDir returns the absolute path to a writable directory that can be used as
// scratch by the test.
@@ -191,14 +201,11 @@ func SetupRootDir() (string, error) {
// SetupContainer creates a bundle and root dir for the container, generates a
// test config, and writes the spec to config.json in the bundle dir.
func SetupContainer(spec *specs.Spec, conf *boot.Config) (rootDir, bundleDir string, err error) {
- // Setup root dir if one hasn't been provided.
- if len(conf.RootDir) == 0 {
- rootDir, err = SetupRootDir()
- if err != nil {
- return "", "", err
- }
- conf.RootDir = rootDir
+ rootDir, err = SetupRootDir()
+ if err != nil {
+ return "", "", err
}
+ conf.RootDir = rootDir
bundleDir, err = SetupBundleDir(spec)
return rootDir, bundleDir, err
}
@@ -419,3 +426,58 @@ func WriteTmpFile(pattern, text string) (string, error) {
func RandomName(prefix string) string {
return fmt.Sprintf("%s-%06d", prefix, rand.Int31n(1000000))
}
+
+// IsStatic returns true iff the given file is a static binary.
+func IsStatic(filename string) (bool, error) {
+ f, err := elf.Open(filename)
+ if err != nil {
+ return false, err
+ }
+ for _, prog := range f.Progs {
+ if prog.Type == elf.PT_INTERP {
+ return false, nil // Has interpreter.
+ }
+ }
+ return true, nil
+}
+
+// TestBoundsForShard calculates the beginning and end indices for the test
+// based on the TEST_SHARD_INDEX and TEST_TOTAL_SHARDS environment vars. The
+// returned ints are the beginning (inclusive) and end (exclusive) of the
+// subslice corresponding to the shard. If either of the env vars are not
+// present, then the function will return bounds that include all tests. If
+// there are more shards than there are tests, then the returned list may be
+// empty.
+func TestBoundsForShard(numTests int) (int, int, error) {
+ var (
+ begin = 0
+ end = numTests
+ )
+ indexStr, totalStr := os.Getenv("TEST_SHARD_INDEX"), os.Getenv("TEST_TOTAL_SHARDS")
+ if indexStr == "" || totalStr == "" {
+ return begin, end, nil
+ }
+
+ // Parse index and total to ints.
+ shardIndex, err := strconv.Atoi(indexStr)
+ if err != nil {
+ return 0, 0, fmt.Errorf("invalid TEST_SHARD_INDEX %q: %v", indexStr, err)
+ }
+ shardTotal, err := strconv.Atoi(totalStr)
+ if err != nil {
+ return 0, 0, fmt.Errorf("invalid TEST_TOTAL_SHARDS %q: %v", totalStr, err)
+ }
+
+ // Calculate!
+ shardSize := int(math.Ceil(float64(numTests) / float64(shardTotal)))
+ begin = shardIndex * shardSize
+ end = ((shardIndex + 1) * shardSize)
+ if begin > numTests {
+ // Nothing to run.
+ return 0, 0, nil
+ }
+ if end > numTests {
+ end = numTests
+ }
+ return begin, end, nil
+}
diff --git a/runsc/tools/dockercfg/BUILD b/runsc/tools/dockercfg/BUILD
deleted file mode 100644
index 5cff917ed..000000000
--- a/runsc/tools/dockercfg/BUILD
+++ /dev/null
@@ -1,10 +0,0 @@
-load("@io_bazel_rules_go//go:def.bzl", "go_binary")
-
-package(licenses = ["notice"])
-
-go_binary(
- name = "dockercfg",
- srcs = ["dockercfg.go"],
- visibility = ["//visibility:public"],
- deps = ["@com_github_google_subcommands//:go_default_library"],
-)
diff --git a/runsc/tools/dockercfg/dockercfg.go b/runsc/tools/dockercfg/dockercfg.go
deleted file mode 100644
index eb9dbd421..000000000
--- a/runsc/tools/dockercfg/dockercfg.go
+++ /dev/null
@@ -1,193 +0,0 @@
-// Copyright 2018 The gVisor Authors.
-//
-// Licensed under the Apache License, Version 2.0 (the "License");
-// you may not use this file except in compliance with the License.
-// You may obtain a copy of the License at
-//
-// http://www.apache.org/licenses/LICENSE-2.0
-//
-// Unless required by applicable law or agreed to in writing, software
-// distributed under the License is distributed on an "AS IS" BASIS,
-// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-// See the License for the specific language governing permissions and
-// limitations under the License.
-
-// Helper tool to configure Docker daemon.
-package main
-
-import (
- "context"
- "encoding/json"
- "fmt"
- "io/ioutil"
- "log"
- "os"
-
- "flag"
- "github.com/google/subcommands"
-)
-
-var (
- configFile = flag.String("config_file", "/etc/docker/daemon.json", "path to Docker daemon config file")
- experimental = flag.Bool("experimental", false, "enable experimental features")
-)
-
-func main() {
- subcommands.Register(subcommands.HelpCommand(), "")
- subcommands.Register(subcommands.FlagsCommand(), "")
- subcommands.Register(&runtimeAdd{}, "")
- subcommands.Register(&runtimeRemove{}, "")
-
- // All subcommands must be registered before flag parsing.
- flag.Parse()
-
- exitCode := subcommands.Execute(context.Background())
- os.Exit(int(exitCode))
-}
-
-type runtime struct {
- Path string `json:"path,omitempty"`
- RuntimeArgs []string `json:"runtimeArgs,omitempty"`
-}
-
-// runtimeAdd implements subcommands.Command.
-type runtimeAdd struct {
-}
-
-// Name implements subcommands.Command.Name.
-func (*runtimeAdd) Name() string {
- return "runtime-add"
-}
-
-// Synopsis implements subcommands.Command.Synopsis.
-func (*runtimeAdd) Synopsis() string {
- return "adds a runtime to docker daemon configuration"
-}
-
-// Usage implements subcommands.Command.Usage.
-func (*runtimeAdd) Usage() string {
- return `runtime-add [flags] <name> <path> [args...] -- if provided, args are passed as arguments to the runtime
-`
-}
-
-// SetFlags implements subcommands.Command.SetFlags.
-func (*runtimeAdd) SetFlags(*flag.FlagSet) {
-}
-
-// Execute implements subcommands.Command.Execute.
-func (r *runtimeAdd) Execute(_ context.Context, f *flag.FlagSet, args ...interface{}) subcommands.ExitStatus {
- if f.NArg() < 2 {
- f.Usage()
- return subcommands.ExitUsageError
- }
- name := f.Arg(0)
- path := f.Arg(1)
- runtimeArgs := f.Args()[2:]
-
- fmt.Printf("Adding runtime %q to file %q\n", name, *configFile)
- c, err := readConfig(*configFile)
- if err != nil {
- log.Fatalf("Error reading config file %q: %v", *configFile, err)
- }
-
- var rts map[string]interface{}
- if i, ok := c["runtimes"]; ok {
- rts = i.(map[string]interface{})
- } else {
- rts = make(map[string]interface{})
- c["runtimes"] = rts
- }
- if *experimental {
- c["experimental"] = true
- }
- rts[name] = runtime{Path: path, RuntimeArgs: runtimeArgs}
-
- if err := writeConfig(c, *configFile); err != nil {
- log.Fatalf("Error writing config file %q: %v", *configFile, err)
- }
- return subcommands.ExitSuccess
-}
-
-// runtimeRemove implements subcommands.Command.
-type runtimeRemove struct {
-}
-
-// Name implements subcommands.Command.Name.
-func (*runtimeRemove) Name() string {
- return "runtime-rm"
-}
-
-// Synopsis implements subcommands.Command.Synopsis.
-func (*runtimeRemove) Synopsis() string {
- return "removes a runtime from docker daemon configuration"
-}
-
-// Usage implements subcommands.Command.Usage.
-func (*runtimeRemove) Usage() string {
- return `runtime-rm [flags] <name>
-`
-}
-
-// SetFlags implements subcommands.Command.SetFlags.
-func (*runtimeRemove) SetFlags(*flag.FlagSet) {
-}
-
-// Execute implements subcommands.Command.Execute.
-func (r *runtimeRemove) Execute(_ context.Context, f *flag.FlagSet, args ...interface{}) subcommands.ExitStatus {
- if f.NArg() != 1 {
- f.Usage()
- return subcommands.ExitUsageError
- }
- name := f.Arg(0)
-
- fmt.Printf("Removing runtime %q from file %q\n", name, *configFile)
- c, err := readConfig(*configFile)
- if err != nil {
- log.Fatalf("Error reading config file %q: %v", *configFile, err)
- }
-
- var rts map[string]interface{}
- if i, ok := c["runtimes"]; ok {
- rts = i.(map[string]interface{})
- } else {
- log.Fatalf("runtime %q not found", name)
- }
- if _, ok := rts[name]; !ok {
- log.Fatalf("runtime %q not found", name)
- }
- delete(rts, name)
-
- if err := writeConfig(c, *configFile); err != nil {
- log.Fatalf("Error writing config file %q: %v", *configFile, err)
- }
- return subcommands.ExitSuccess
-}
-
-func readConfig(path string) (map[string]interface{}, error) {
- configBytes, err := ioutil.ReadFile(path)
- if err != nil && !os.IsNotExist(err) {
- return nil, err
- }
- c := make(map[string]interface{})
- if len(configBytes) > 0 {
- if err := json.Unmarshal(configBytes, &c); err != nil {
- return nil, err
- }
- }
- return c, nil
-}
-
-func writeConfig(c map[string]interface{}, path string) error {
- b, err := json.MarshalIndent(c, "", " ")
- if err != nil {
- return err
- }
-
- if err := os.Rename(path, path+"~"); err != nil && !os.IsNotExist(err) {
- return fmt.Errorf("error renaming config file %q: %v", path, err)
- }
- if err := ioutil.WriteFile(path, b, 0644); err != nil {
- return fmt.Errorf("error writing config file %q: %v", path, err)
- }
- return nil
-}
diff --git a/runsc/version.go b/runsc/version.go
index ce0573a9b..ab9194b9d 100644
--- a/runsc/version.go
+++ b/runsc/version.go
@@ -15,4 +15,4 @@
package main
// version is set during linking.
-var version = ""
+var version = "VERSION_MISSING"
diff --git a/runsc/version_test.sh b/runsc/version_test.sh
new file mode 100755
index 000000000..cc0ca3f05
--- /dev/null
+++ b/runsc/version_test.sh
@@ -0,0 +1,36 @@
+#!/bin/bash
+
+# Copyright 2018 The gVisor Authors.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+set -euf -x -o pipefail
+
+readonly runsc="${TEST_SRCDIR}/__main__/runsc/linux_amd64_pure_stripped/runsc"
+readonly version=$($runsc --version)
+
+# Version should should not match VERSION, which is the default and which will
+# also appear if something is wrong with workspace_status.sh script.
+if [[ $version =~ "VERSION" ]]; then
+ echo "FAIL: Got bad version $version"
+ exit 1
+fi
+
+# Version should contain at least one number.
+if [[ ! $version =~ [0-9] ]]; then
+ echo "FAIL: Got bad version $version"
+ exit 1
+fi
+
+echo "PASS: Got OK version $version"
+exit 0
diff --git a/scripts/build.sh b/scripts/build.sh
new file mode 100755
index 000000000..0b3d1b316
--- /dev/null
+++ b/scripts/build.sh
@@ -0,0 +1,79 @@
+#!/bin/bash
+
+# Copyright 2018 The gVisor Authors.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+source $(dirname $0)/common.sh
+
+# Install required packages for make_repository.sh et al.
+sudo apt-get update && sudo apt-get install -y dpkg-sig coreutils apt-utils
+
+# Build runsc.
+runsc=$(build -c opt //runsc)
+
+# Build packages.
+pkg=$(build -c opt //runsc:runsc-debian)
+
+# Build a repository, if the key is available.
+if [[ -v KOKORO_REPO_KEY ]]; then
+ repo=$(tools/make_repository.sh "${KOKORO_KEYSTORE_DIR}/${KOKORO_REPO_KEY}" gvisor-bot@google.com main ${pkg})
+fi
+
+# Install installs artifacts.
+install() {
+ local -r binaries_dir="$1"
+ local -r repo_dir="$2"
+ mkdir -p "${binaries_dir}"
+ cp -f "${runsc}" "${binaries_dir}"/runsc
+ sha512sum "${binaries_dir}"/runsc | awk '{print $1 " runsc"}' > "${binaries_dir}"/runsc.sha512
+ if [[ -v repo ]]; then
+ rm -rf "${repo_dir}" && mkdir -p "$(dirname "${repo_dir}")"
+ cp -a "${repo}" "${repo_dir}"
+ fi
+}
+
+# Move the runsc binary into "latest" directory, and also a directory with the
+# current date. If the current commit happens to correpond to a tag, then we
+# will also move everything into a directory named after the given tag.
+if [[ -v KOKORO_ARTIFACTS_DIR ]]; then
+ if [[ "${KOKORO_BUILD_NIGHTLY:-false}" == "true" ]]; then
+ # The "latest" directory and current date.
+ stamp="$(date -Idate)"
+ install "${KOKORO_ARTIFACTS_DIR}/nightly/latest" \
+ "${KOKORO_ARTIFACTS_DIR}/dists/nightly/latest"
+ install "${KOKORO_ARTIFACTS_DIR}/nightly/${stamp}" \
+ "${KOKORO_ARTIFACTS_DIR}/dists/nightly/${stamp}"
+ else
+ # Is it a tagged release? Build that instead. In that case, we also try to
+ # update the base release directory, in case this is an update. Finally, we
+ # update the "release" directory, which has the last released version.
+ tags="$(git tag --points-at HEAD)"
+ if ! [[ -z "${tags}" ]]; then
+ # Note that a given commit can match any number of tags. We have to
+ # iterate through all possible tags and produce associated artifacts.
+ for tag in ${tags}; do
+ name=$(echo "${tag}" | cut -d'-' -f2)
+ base=$(echo "${name}" | cut -d'.' -f1)
+ install "${KOKORO_ARTIFACTS_DIR}/release/${name}" \
+ "${KOKORO_ARTIFACTS_DIR}/dists/${name}"
+ if [[ "${base}" != "${tag}" ]]; then
+ install "${KOKORO_ARTIFACTS_DIR}/release/${base}" \
+ "${KOKORO_ARTIFACTS_DIR}/dists/${base}"
+ fi
+ install "${KOKORO_ARTIFACTS_DIR}/release/latest" \
+ "${KOKORO_ARTIFACTS_DIR}/dists/latest"
+ done
+ fi
+ fi
+fi
diff --git a/scripts/common.sh b/scripts/common.sh
new file mode 100755
index 000000000..6dabad141
--- /dev/null
+++ b/scripts/common.sh
@@ -0,0 +1,80 @@
+#!/bin/bash
+
+# Copyright 2019 The gVisor Authors.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+set -xeou pipefail
+
+if [[ -f $(dirname $0)/common_google.sh ]]; then
+ source $(dirname $0)/common_google.sh
+else
+ source $(dirname $0)/common_bazel.sh
+fi
+
+# Ensure it attempts to collect logs in all cases.
+trap collect_logs EXIT
+
+function set_runtime() {
+ RUNTIME=${1:-runsc}
+ RUNSC_BIN=/tmp/"${RUNTIME}"/runsc
+ RUNSC_LOGS_DIR="$(dirname ${RUNSC_BIN})"/logs
+ RUNSC_LOGS="${RUNSC_LOGS_DIR}"/runsc.log.%TEST%.%TIMESTAMP%.%COMMAND%
+}
+
+function test_runsc() {
+ test --test_arg=--runtime=${RUNTIME} "$@"
+}
+
+function install_runsc_for_test() {
+ local -r test_name=$1
+ shift
+ if [[ -z "${test_name}" ]]; then
+ echo "Missing mandatory test name"
+ exit 1
+ fi
+
+ # Add test to the name, so it doesn't conflict with other runtimes.
+ set_runtime $(find_branch_name)_"${test_name}"
+
+ # ${RUNSC_TEST_NAME} is set by tests (see dockerutil) to pass the test name
+ # down to the runtime.
+ install_runsc "${RUNTIME}" \
+ --TESTONLY-test-name-env=RUNSC_TEST_NAME \
+ --debug \
+ --strace \
+ --log-packets \
+ "$@"
+}
+
+# Installs the runsc with given runtime name. set_runtime must have been called
+# to set runtime and logs location.
+function install_runsc() {
+ local -r runtime=$1
+ shift
+
+ # Prepare the runtime binary.
+ local -r output=$(build //runsc)
+ mkdir -p "$(dirname ${RUNSC_BIN})"
+ cp -f "${output}" "${RUNSC_BIN}"
+ chmod 0755 "${RUNSC_BIN}"
+
+ # Install the runtime.
+ sudo "${RUNSC_BIN}" install --experimental=true --runtime="${runtime}" -- --debug-log "${RUNSC_LOGS}" "$@"
+
+ # Clear old logs files that may exist.
+ sudo rm -f "${RUNSC_LOGS_DIR}"/*
+
+ # Restart docker to pick up the new runtime configuration.
+ sudo systemctl restart docker
+}
diff --git a/scripts/common_bazel.sh b/scripts/common_bazel.sh
new file mode 100755
index 000000000..f8ec967b1
--- /dev/null
+++ b/scripts/common_bazel.sh
@@ -0,0 +1,99 @@
+#!/bin/bash
+
+# Copyright 2019 The gVisor Authors.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+# Install the latest version of Bazel and log the version.
+(which use_bazel.sh && use_bazel.sh latest) || which bazel
+bazel version
+
+# Switch into the workspace; only necessary if run with kokoro.
+if [[ -v KOKORO_GIT_COMMIT ]] && [[ -d git/repo ]]; then
+ cd git/repo
+elif [[ -v KOKORO_GIT_COMMIT ]] && [[ -d github/repo ]]; then
+ cd github/repo
+fi
+
+# Set the standard bazel flags.
+declare -r BAZEL_FLAGS=(
+ "--show_timestamps"
+ "--test_output=errors"
+ "--keep_going"
+ "--verbose_failures=true"
+)
+if [[ -v KOKORO_BAZEL_AUTH_CREDENTIAL ]] || [[ -v RBE_PROJECT_ID ]]; then
+ declare -r RBE_PROJECT_ID="${RBE_PROJECT_ID:-gvisor-rbe}"
+ declare -r BAZEL_RBE_FLAGS=(
+ "--config=remote"
+ "--project_id=${RBE_PROJECT_ID}"
+ "--remote_instance_name=projects/${RBE_PROJECT_ID}/instances/default_instance"
+ )
+fi
+if [[ -v KOKORO_BAZEL_AUTH_CREDENTIAL ]]; then
+ declare -r BAZEL_RBE_AUTH_FLAGS=(
+ "--auth_credentials=${KOKORO_BAZEL_AUTH_CREDENTIAL}"
+ )
+fi
+
+# Wrap bazel.
+function build() {
+ bazel build "${BAZEL_RBE_FLAGS[@]}" "${BAZEL_RBE_AUTH_FLAGS[@]}" "${BAZEL_FLAGS[@]}" "$@" 2>&1 |
+ tee /dev/fd/2 | grep -E '^ bazel-bin/' | awk '{ print $1; }'
+}
+
+function test() {
+ bazel test "${BAZEL_RBE_FLAGS[@]}" "${BAZEL_RBE_AUTH_FLAGS[@]}" "${BAZEL_FLAGS[@]}" "$@"
+}
+
+function run() {
+ local binary=$1
+ shift
+ bazel run "${binary}" -- "$@"
+}
+
+function run_as_root() {
+ local binary=$1
+ shift
+ bazel run --run_under="sudo" "${binary}" -- "$@"
+}
+
+function collect_logs() {
+ # Zip out everything into a convenient form.
+ if [[ -v KOKORO_ARTIFACTS_DIR ]] && [[ -e bazel-testlogs ]]; then
+ # Move test logs to Kokoro directory. tar is used to conveniently perform
+ # renames while moving files.
+ find -L "bazel-testlogs" -name "test.xml" -o -name "test.log" -o -name "outputs.zip" |
+ tar --create --files-from - --transform 's/test\./sponge_log./' |
+ tar --extract --directory ${KOKORO_ARTIFACTS_DIR}
+
+ # Collect sentry logs, if any.
+ if [[ -v RUNSC_LOGS_DIR ]] && [[ -d "${RUNSC_LOGS_DIR}" ]]; then
+ # Check if the directory is empty or not (only the first line it needed).
+ local -r logs=$(ls "${RUNSC_LOGS_DIR}" | head -n1)
+ if [[ "${logs}" ]]; then
+ local -r archive=runsc_logs_"${RUNTIME}".tar.gz
+ if [[ -v KOKORO_BUILD_ARTIFACTS_SUBDIR ]]; then
+ echo "runsc logs will be uploaded to:"
+ echo " gsutil cp gs://gvisor/logs/${KOKORO_BUILD_ARTIFACTS_SUBDIR}/${archive} /tmp"
+ echo " https://storage.cloud.google.com/gvisor/logs/${KOKORO_BUILD_ARTIFACTS_SUBDIR}/${archive}"
+ fi
+ tar --create --gzip --file="${KOKORO_ARTIFACTS_DIR}/${archive}" -C "${RUNSC_LOGS_DIR}" .
+ fi
+ fi
+ fi
+}
+
+function find_branch_name() {
+ git branch --show-current || git rev-parse HEAD || bazel info workspace | xargs basename
+}
diff --git a/scripts/dev.sh b/scripts/dev.sh
new file mode 100755
index 000000000..c67003018
--- /dev/null
+++ b/scripts/dev.sh
@@ -0,0 +1,73 @@
+#!/bin/bash
+
+# Copyright 2019 The gVisor Authors.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+source $(dirname $0)/common.sh
+
+# common.sh sets '-x', but it's annoying to see so much output.
+set +x
+
+# Defaults
+declare -i REFRESH=0
+declare NAME=$(find_branch_name)
+
+while [[ $# -gt 0 ]]; do
+ case "$1" in
+ --refresh)
+ REFRESH=1
+ ;;
+ --help)
+ echo "Use this script to build and install runsc with Docker."
+ echo
+ echo "usage: $0 [--refresh] [runtime_name]"
+ exit 1
+ ;;
+ *)
+ NAME=$1
+ ;;
+ esac
+ shift
+done
+
+set_runtime "${NAME}"
+echo
+echo "Using runtime=${RUNTIME}"
+echo
+
+echo Building runsc...
+# Build first and fail on error. $() prevents "set -e" from reporting errors.
+build //runsc
+declare OUTPUT="$(build //runsc)"
+
+if [[ ${REFRESH} -eq 0 ]]; then
+ install_runsc "${RUNTIME}" --net-raw
+ install_runsc "${RUNTIME}-d" --net-raw --debug --strace --log-packets
+
+ echo
+ echo "Runtimes ${RUNTIME} and ${RUNTIME}-d (debug enabled) setup."
+ echo "Use --runtime="${RUNTIME}" with your Docker command."
+ echo " docker run --rm --runtime="${RUNTIME}" hello-world"
+ echo
+ echo "If you rebuild, use $0 --refresh."
+
+else
+ mkdir -p "$(dirname ${RUNSC_BIN})"
+ cp -f ${OUTPUT} "${RUNSC_BIN}"
+
+ echo
+ echo "Runtime ${RUNTIME} refreshed."
+fi
+
+echo "Logs are in: ${RUNSC_LOGS_DIR}"
diff --git a/scripts/do_tests.sh b/scripts/do_tests.sh
new file mode 100755
index 000000000..a3a387c37
--- /dev/null
+++ b/scripts/do_tests.sh
@@ -0,0 +1,27 @@
+#!/bin/bash
+
+# Copyright 2019 The gVisor Authors.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+source $(dirname $0)/common.sh
+
+# Build runsc.
+build //runsc
+
+# run runsc do without root privileges.
+run //runsc --rootless do true
+run //runsc --rootless --network=none do true
+
+# run runsc do with root privileges.
+run_as_root //runsc do true
diff --git a/scripts/docker_tests.sh b/scripts/docker_tests.sh
new file mode 100755
index 000000000..72ba05260
--- /dev/null
+++ b/scripts/docker_tests.sh
@@ -0,0 +1,20 @@
+#!/bin/bash
+
+# Copyright 2019 The gVisor Authors.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+source $(dirname $0)/common.sh
+
+install_runsc_for_test docker
+test_runsc //test/image:image_test //test/e2e:integration_test
diff --git a/scripts/go.sh b/scripts/go.sh
new file mode 100755
index 000000000..0dbfb7747
--- /dev/null
+++ b/scripts/go.sh
@@ -0,0 +1,43 @@
+#!/bin/bash
+
+# Copyright 2019 The gVisor Authors.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+source $(dirname $0)/common.sh
+
+# Build the go path.
+build :gopath
+
+# Build the synthetic branch.
+tools/go_branch.sh
+
+# Checkout the new branch.
+git checkout go && git clean -f
+
+# Build everything.
+go build ./...
+
+# Push, if required.
+if [[ -v KOKORO_GO_PUSH ]] && [[ "${KOKORO_GO_PUSH}" == "true" ]]; then
+ if [[ -v KOKORO_GITHUB_ACCESS_TOKEN ]]; then
+ git config --global credential.helper cache
+ git credential approve <<EOF
+protocol=https
+host=github.com
+username=$(cat "${KOKORO_KEYSTORE_DIR}/${KOKORO_GITHUB_ACCESS_TOKEN}")
+password=x-oauth-basic
+EOF
+ fi
+ git push origin go:go
+fi
diff --git a/scripts/hostnet_tests.sh b/scripts/hostnet_tests.sh
new file mode 100755
index 000000000..41298293d
--- /dev/null
+++ b/scripts/hostnet_tests.sh
@@ -0,0 +1,21 @@
+#!/bin/bash
+
+# Copyright 2019 The gVisor Authors.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+source $(dirname $0)/common.sh
+
+# Install the runtime and perform basic tests.
+install_runsc_for_test hostnet --network=host
+test_runsc --test_arg=-checkpoint=false //test/image:image_test //test/e2e:integration_test
diff --git a/scripts/kvm_tests.sh b/scripts/kvm_tests.sh
new file mode 100755
index 000000000..5662401df
--- /dev/null
+++ b/scripts/kvm_tests.sh
@@ -0,0 +1,28 @@
+#!/bin/bash
+
+# Copyright 2019 The gVisor Authors.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+source $(dirname $0)/common.sh
+
+# Ensure that KVM is loaded, and we can use it.
+(lsmod | grep -E '^(kvm_intel|kvm_amd)') || sudo modprobe kvm
+sudo chmod a+rw /dev/kvm
+
+# Run all KVM platform tests (locally).
+run_as_root //pkg/sentry/platform/kvm:kvm_test
+
+# Install the KVM runtime and run all integration tests.
+install_runsc_for_test kvm --platform=kvm
+test_runsc //test/image:image_test //test/e2e:integration_test
diff --git a/scripts/make_tests.sh b/scripts/make_tests.sh
new file mode 100755
index 000000000..79426756d
--- /dev/null
+++ b/scripts/make_tests.sh
@@ -0,0 +1,25 @@
+#!/bin/bash
+
+# Copyright 2019 The gVisor Authors.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+source $(dirname $0)/common.sh
+
+top_level=$(git rev-parse --show-toplevel 2>/dev/null)
+[[ $? -eq 0 ]] && cd "${top_level}" || exit 1
+
+make
+make runsc
+make BAZEL_OPTIONS="build //..." bazel
+make bazel-shutdown
diff --git a/scripts/overlay_tests.sh b/scripts/overlay_tests.sh
new file mode 100755
index 000000000..2a1f12c0b
--- /dev/null
+++ b/scripts/overlay_tests.sh
@@ -0,0 +1,21 @@
+#!/bin/bash
+
+# Copyright 2019 The gVisor Authors.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+source $(dirname $0)/common.sh
+
+# Install the runtime and perform basic tests.
+install_runsc_for_test overlay --overlay
+test_runsc //test/image:image_test //test/e2e:integration_test
diff --git a/scripts/release.sh b/scripts/release.sh
new file mode 100755
index 000000000..b936bcc77
--- /dev/null
+++ b/scripts/release.sh
@@ -0,0 +1,38 @@
+#!/bin/bash
+
+# Copyright 2018 The gVisor Authors.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+source $(dirname $0)/common.sh
+
+# Tag a release only if provided.
+if ! [[ -v KOKORO_RELEASE_COMMIT ]]; then
+ echo "No KOKORO_RELEASE_COMMIT provided." >&2
+ exit 1
+fi
+if ! [[ -v KOKORO_RELEASE_TAG ]]; then
+ echo "No KOKORO_RELEASE_TAG provided." >&2
+ exit 1
+fi
+
+# Unless an explicit releaser is provided, use the bot e-mail.
+declare -r KOKORO_RELEASE_AUTHOR=${KOKORO_RELEASE_AUTHOR:-gvisor-bot}
+declare -r EMAIL=${EMAIL:-${KOKORO_RELEASE_AUTHOR}@google.com}
+
+# Ensure we have an appropriate configuration for the tag.
+git config --get user.name || git config user.name "gVisor-bot"
+git config --get user.email || git config user.email "${EMAIL}"
+
+# Run the release tool, which pushes to the origin repository.
+tools/tag_release.sh "${KOKORO_RELEASE_COMMIT}" "${KOKORO_RELEASE_TAG}"
diff --git a/scripts/root_tests.sh b/scripts/root_tests.sh
new file mode 100755
index 000000000..4e4fcc76b
--- /dev/null
+++ b/scripts/root_tests.sh
@@ -0,0 +1,31 @@
+#!/bin/bash
+
+# Copyright 2019 The gVisor Authors.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+source $(dirname $0)/common.sh
+
+# Reinstall the latest containerd shim.
+declare -r base="https://storage.googleapis.com/cri-containerd-staging/gvisor-containerd-shim"
+declare -r latest=$(mktemp --tmpdir gvisor-containerd-shim-latest.XXXXXX)
+declare -r shim_path=$(mktemp --tmpdir gvisor-containerd-shim.XXXXXX)
+wget --no-verbose "${base}"/latest -O ${latest}
+wget --no-verbose "${base}"/gvisor-containerd-shim-$(cat ${latest}) -O ${shim_path}
+chmod +x ${shim_path}
+sudo mv ${shim_path} /usr/local/bin/gvisor-containerd-shim
+
+# Run the tests that require root.
+install_runsc_for_test root
+run_as_root //test/root:root_test --runtime=${RUNTIME}
+
diff --git a/scripts/simple_tests.sh b/scripts/simple_tests.sh
new file mode 100755
index 000000000..585216aae
--- /dev/null
+++ b/scripts/simple_tests.sh
@@ -0,0 +1,20 @@
+#!/bin/bash
+
+# Copyright 2019 The gVisor Authors.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+source $(dirname $0)/common.sh
+
+# Run all simple tests (locally).
+test //pkg/... //runsc/... //tools/...
diff --git a/scripts/syscall_tests.sh b/scripts/syscall_tests.sh
new file mode 100755
index 000000000..a131b2d50
--- /dev/null
+++ b/scripts/syscall_tests.sh
@@ -0,0 +1,20 @@
+#!/bin/bash
+
+# Copyright 2019 The gVisor Authors.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+source $(dirname $0)/common.sh
+
+# Run all ptrace-variants of the system call tests.
+test --test_tag_filters=runsc_ptrace //test/syscalls/...
diff --git a/test/README.md b/test/README.md
new file mode 100644
index 000000000..09c36b461
--- /dev/null
+++ b/test/README.md
@@ -0,0 +1,18 @@
+# Tests
+
+The tests defined under this path are verifying functionality beyond what unit
+tests can cover, e.g. integration and end to end tests. Due to their nature,
+they may need extra setup in the test machine and extra configuration to run.
+
+- **syscalls**: system call tests use a local runner, and do not require
+ additional configuration in the machine.
+- **integration:** defines integration tests that uses `docker run` to test
+ functionality.
+- **image:** basic end to end test for popular images. These require the same
+ setup as integration tests.
+- **root:** tests that require to be run as root.
+- **util:** utilities library to support the tests.
+
+For the above noted cases, the relevant runtime must be installed via `runsc
+install` before running. This is handled automatically by the test scripts in
+the `kokoro` directory.
diff --git a/runsc/test/integration/BUILD b/test/e2e/BUILD
index 12065617c..4fe03a220 100644
--- a/runsc/test/integration/BUILD
+++ b/test/e2e/BUILD
@@ -1,9 +1,8 @@
-load("@io_bazel_rules_go//go:def.bzl", "go_library")
-load("//runsc/test:build_defs.bzl", "runtime_test")
+load("@io_bazel_rules_go//go:def.bzl", "go_library", "go_test")
package(licenses = ["notice"])
-runtime_test(
+go_test(
name = "integration_test",
size = "large",
srcs = [
@@ -17,14 +16,18 @@ runtime_test(
"manual",
"local",
],
+ visibility = ["//:sandbox"],
deps = [
"//pkg/abi/linux",
- "//runsc/test/testutil",
+ "//pkg/bits",
+ "//runsc/dockerutil",
+ "//runsc/specutils",
+ "//runsc/testutil",
],
)
go_library(
name = "integration",
srcs = ["integration.go"],
- importpath = "gvisor.dev/gvisor/runsc/test/integration",
+ importpath = "gvisor.dev/gvisor/test/integration",
)
diff --git a/test/e2e/exec_test.go b/test/e2e/exec_test.go
new file mode 100644
index 000000000..88d26e865
--- /dev/null
+++ b/test/e2e/exec_test.go
@@ -0,0 +1,256 @@
+// Copyright 2018 The gVisor Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+// Package integration provides end-to-end integration tests for runsc. These
+// tests require docker and runsc to be installed on the machine.
+//
+// Each test calls docker commands to start up a container, and tests that it
+// is behaving properly, with various runsc commands. The container is killed
+// and deleted at the end.
+
+package integration
+
+import (
+ "fmt"
+ "strconv"
+ "strings"
+ "syscall"
+ "testing"
+ "time"
+
+ "gvisor.dev/gvisor/pkg/abi/linux"
+ "gvisor.dev/gvisor/pkg/bits"
+ "gvisor.dev/gvisor/runsc/dockerutil"
+ "gvisor.dev/gvisor/runsc/specutils"
+)
+
+// Test that exec uses the exact same capability set as the container.
+func TestExecCapabilities(t *testing.T) {
+ if err := dockerutil.Pull("alpine"); err != nil {
+ t.Fatalf("docker pull failed: %v", err)
+ }
+ d := dockerutil.MakeDocker("exec-capabilities-test")
+
+ // Start the container.
+ if err := d.Run("alpine", "sh", "-c", "cat /proc/self/status; sleep 100"); err != nil {
+ t.Fatalf("docker run failed: %v", err)
+ }
+ defer d.CleanUp()
+
+ matches, err := d.WaitForOutputSubmatch("CapEff:\t([0-9a-f]+)\n", 5*time.Second)
+ if err != nil {
+ t.Fatalf("WaitForOutputSubmatch() timeout: %v", err)
+ }
+ if len(matches) != 2 {
+ t.Fatalf("There should be a match for the whole line and the capability bitmask")
+ }
+ want := fmt.Sprintf("CapEff:\t%s\n", matches[1])
+ t.Log("Root capabilities:", want)
+
+ // Now check that exec'd process capabilities match the root.
+ got, err := d.Exec("grep", "CapEff:", "/proc/self/status")
+ if err != nil {
+ t.Fatalf("docker exec failed: %v", err)
+ }
+ t.Logf("CapEff: %v", got)
+ if got != want {
+ t.Errorf("wrong capabilities, got: %q, want: %q", got, want)
+ }
+}
+
+// Test that 'exec --privileged' adds all capabilities, except for CAP_NET_RAW
+// which is removed from the container when --net-raw=false.
+func TestExecPrivileged(t *testing.T) {
+ if err := dockerutil.Pull("alpine"); err != nil {
+ t.Fatalf("docker pull failed: %v", err)
+ }
+ d := dockerutil.MakeDocker("exec-privileged-test")
+
+ // Start the container with all capabilities dropped.
+ if err := d.Run("--cap-drop=all", "alpine", "sh", "-c", "cat /proc/self/status; sleep 100"); err != nil {
+ t.Fatalf("docker run failed: %v", err)
+ }
+ defer d.CleanUp()
+
+ // Check that all capabilities where dropped from container.
+ matches, err := d.WaitForOutputSubmatch("CapEff:\t([0-9a-f]+)\n", 5*time.Second)
+ if err != nil {
+ t.Fatalf("WaitForOutputSubmatch() timeout: %v", err)
+ }
+ if len(matches) != 2 {
+ t.Fatalf("There should be a match for the whole line and the capability bitmask")
+ }
+ containerCaps, err := strconv.ParseUint(matches[1], 16, 64)
+ if err != nil {
+ t.Fatalf("failed to convert capabilities %q: %v", matches[1], err)
+ }
+ t.Logf("Container capabilities: %#x", containerCaps)
+ if containerCaps != 0 {
+ t.Fatalf("Container should have no capabilities: %x", containerCaps)
+ }
+
+ // Check that 'exec --privileged' adds all capabilities, except
+ // for CAP_NET_RAW.
+ got, err := d.ExecWithFlags([]string{"--privileged"}, "grep", "CapEff:", "/proc/self/status")
+ if err != nil {
+ t.Fatalf("docker exec failed: %v", err)
+ }
+ t.Logf("Exec CapEff: %v", got)
+ want := fmt.Sprintf("CapEff:\t%016x\n", specutils.AllCapabilitiesUint64()&^bits.MaskOf64(int(linux.CAP_NET_RAW)))
+ if got != want {
+ t.Errorf("wrong capabilities, got: %q, want: %q", got, want)
+ }
+}
+
+func TestExecJobControl(t *testing.T) {
+ if err := dockerutil.Pull("alpine"); err != nil {
+ t.Fatalf("docker pull failed: %v", err)
+ }
+ d := dockerutil.MakeDocker("exec-job-control-test")
+
+ // Start the container.
+ if err := d.Run("alpine", "sleep", "1000"); err != nil {
+ t.Fatalf("docker run failed: %v", err)
+ }
+ defer d.CleanUp()
+
+ // Exec 'sh' with an attached pty.
+ cmd, ptmx, err := d.ExecWithTerminal("sh")
+ if err != nil {
+ t.Fatalf("docker exec failed: %v", err)
+ }
+ defer ptmx.Close()
+
+ // Call "sleep 100 | cat" in the shell. We pipe to cat so that there
+ // will be two processes in the foreground process group.
+ if _, err := ptmx.Write([]byte("sleep 100 | cat\n")); err != nil {
+ t.Fatalf("error writing to pty: %v", err)
+ }
+
+ // Give shell a few seconds to start executing the sleep.
+ time.Sleep(2 * time.Second)
+
+ // Send a ^C to the pty, which should kill sleep and cat, but not the
+ // shell. \x03 is ASCII "end of text", which is the same as ^C.
+ if _, err := ptmx.Write([]byte{'\x03'}); err != nil {
+ t.Fatalf("error writing to pty: %v", err)
+ }
+
+ // The shell should still be alive at this point. Sleep should have
+ // exited with code 2+128=130. We'll exit with 10 plus that number, so
+ // that we can be sure that the shell did not get signalled.
+ if _, err := ptmx.Write([]byte("exit $(expr $? + 10)\n")); err != nil {
+ t.Fatalf("error writing to pty: %v", err)
+ }
+
+ // Exec process should exit with code 10+130=140.
+ ps, err := cmd.Process.Wait()
+ if err != nil {
+ t.Fatalf("error waiting for exec process: %v", err)
+ }
+ ws := ps.Sys().(syscall.WaitStatus)
+ if !ws.Exited() {
+ t.Errorf("ws.Exited got false, want true")
+ }
+ if got, want := ws.ExitStatus(), 140; got != want {
+ t.Errorf("ws.ExitedStatus got %d, want %d", got, want)
+ }
+}
+
+// Test that failure to exec returns proper error message.
+func TestExecError(t *testing.T) {
+ if err := dockerutil.Pull("alpine"); err != nil {
+ t.Fatalf("docker pull failed: %v", err)
+ }
+ d := dockerutil.MakeDocker("exec-error-test")
+
+ // Start the container.
+ if err := d.Run("alpine", "sleep", "1000"); err != nil {
+ t.Fatalf("docker run failed: %v", err)
+ }
+ defer d.CleanUp()
+
+ _, err := d.Exec("no_can_find")
+ if err == nil {
+ t.Fatalf("docker exec didn't fail")
+ }
+ if want := `error finding executable "no_can_find" in PATH`; !strings.Contains(err.Error(), want) {
+ t.Fatalf("docker exec wrong error, got: %s, want: .*%s.*", err.Error(), want)
+ }
+}
+
+// Test that exec inherits environment from run.
+func TestExecEnv(t *testing.T) {
+ if err := dockerutil.Pull("alpine"); err != nil {
+ t.Fatalf("docker pull failed: %v", err)
+ }
+ d := dockerutil.MakeDocker("exec-env-test")
+
+ // Start the container with env FOO=BAR.
+ if err := d.Run("-e", "FOO=BAR", "alpine", "sleep", "1000"); err != nil {
+ t.Fatalf("docker run failed: %v", err)
+ }
+ defer d.CleanUp()
+
+ // Exec "echo $FOO".
+ got, err := d.Exec("/bin/sh", "-c", "echo $FOO")
+ if err != nil {
+ t.Fatalf("docker exec failed: %v", err)
+ }
+ if want := "BAR"; !strings.Contains(got, want) {
+ t.Errorf("wanted exec output to contain %q, got %q", want, got)
+ }
+}
+
+// Test that exec always has HOME environment set, even when not set in run.
+func TestExecEnvHasHome(t *testing.T) {
+ // Base alpine image does not have any environment variables set.
+ if err := dockerutil.Pull("alpine"); err != nil {
+ t.Fatalf("docker pull failed: %v", err)
+ }
+ d := dockerutil.MakeDocker("exec-env-home-test")
+
+ // We will check that HOME is set for root user, and also for a new
+ // non-root user we will create.
+ newUID := 1234
+ newHome := "/foo/bar"
+
+ // Create a new user with a home directory, and then sleep.
+ script := fmt.Sprintf(`
+ mkdir -p -m 777 %s && \
+ adduser foo -D -u %d -h %s && \
+ sleep 1000`, newHome, newUID, newHome)
+ if err := d.Run("alpine", "/bin/sh", "-c", script); err != nil {
+ t.Fatalf("docker run failed: %v", err)
+ }
+ defer d.CleanUp()
+
+ // Exec "echo $HOME", and expect to see "/root".
+ got, err := d.Exec("/bin/sh", "-c", "echo $HOME")
+ if err != nil {
+ t.Fatalf("docker exec failed: %v", err)
+ }
+ if want := "/root"; !strings.Contains(got, want) {
+ t.Errorf("wanted exec output to contain %q, got %q", want, got)
+ }
+
+ // Execute the same as uid 123 and expect newHome.
+ got, err = d.ExecAsUser(strconv.Itoa(newUID), "/bin/sh", "-c", "echo $HOME")
+ if err != nil {
+ t.Fatalf("docker exec failed: %v", err)
+ }
+ if want := newHome; !strings.Contains(got, want) {
+ t.Errorf("wanted exec output to contain %q, got %q", want, got)
+ }
+}
diff --git a/runsc/test/integration/integration.go b/test/e2e/integration.go
index 4cd5f6c24..4cd5f6c24 100644
--- a/runsc/test/integration/integration.go
+++ b/test/e2e/integration.go
diff --git a/runsc/test/integration/integration_test.go b/test/e2e/integration_test.go
index 7cef4b9dd..7cc0de129 100644
--- a/runsc/test/integration/integration_test.go
+++ b/test/e2e/integration_test.go
@@ -18,10 +18,11 @@
// behaving properly, with various runsc commands. The container is killed and
// deleted at the end.
//
-// Setup instruction in runsc/test/README.md.
+// Setup instruction in test/README.md.
package integration
import (
+ "flag"
"fmt"
"net"
"net/http"
@@ -32,7 +33,8 @@ import (
"testing"
"time"
- "gvisor.dev/gvisor/runsc/test/testutil"
+ "gvisor.dev/gvisor/runsc/dockerutil"
+ "gvisor.dev/gvisor/runsc/testutil"
)
// httpRequestSucceeds sends a request to a given url and checks that the status is OK.
@@ -51,10 +53,10 @@ func httpRequestSucceeds(client http.Client, server string, port int) error {
// TestLifeCycle tests a basic Create/Start/Stop docker container life cycle.
func TestLifeCycle(t *testing.T) {
- if err := testutil.Pull("nginx"); err != nil {
+ if err := dockerutil.Pull("nginx"); err != nil {
t.Fatal("docker pull failed:", err)
}
- d := testutil.MakeDocker("lifecycle-test")
+ d := dockerutil.MakeDocker("lifecycle-test")
if err := d.Create("-p", "80", "nginx"); err != nil {
t.Fatal("docker create failed:", err)
}
@@ -87,15 +89,15 @@ func TestLifeCycle(t *testing.T) {
func TestPauseResume(t *testing.T) {
const img = "gcr.io/gvisor-presubmit/python-hello"
- if !testutil.IsPauseResumeSupported() {
- t.Log("Pause/resume is not supported, skipping test.")
+ if !testutil.IsCheckpointSupported() {
+ t.Log("Checkpoint is not supported, skipping test.")
return
}
- if err := testutil.Pull(img); err != nil {
+ if err := dockerutil.Pull(img); err != nil {
t.Fatal("docker pull failed:", err)
}
- d := testutil.MakeDocker("pause-resume-test")
+ d := dockerutil.MakeDocker("pause-resume-test")
if err := d.Run("-p", "8080", img); err != nil {
t.Fatalf("docker run failed: %v", err)
}
@@ -151,14 +153,15 @@ func TestPauseResume(t *testing.T) {
func TestCheckpointRestore(t *testing.T) {
const img = "gcr.io/gvisor-presubmit/python-hello"
- if !testutil.IsPauseResumeSupported() {
+ if !testutil.IsCheckpointSupported() {
t.Log("Pause/resume is not supported, skipping test.")
return
}
- if err := testutil.Pull(img); err != nil {
+
+ if err := dockerutil.Pull(img); err != nil {
t.Fatal("docker pull failed:", err)
}
- d := testutil.MakeDocker("save-restore-test")
+ d := dockerutil.MakeDocker("save-restore-test")
if err := d.Run("-p", "8080", img); err != nil {
t.Fatalf("docker run failed: %v", err)
}
@@ -196,7 +199,7 @@ func TestCheckpointRestore(t *testing.T) {
// Create client and server that talk to each other using the local IP.
func TestConnectToSelf(t *testing.T) {
- d := testutil.MakeDocker("connect-to-self-test")
+ d := dockerutil.MakeDocker("connect-to-self-test")
// Creates server that replies "server" and exists. Sleeps at the end because
// 'docker exec' gets killed if the init process exists before it can finish.
@@ -228,10 +231,10 @@ func TestConnectToSelf(t *testing.T) {
}
func TestMemLimit(t *testing.T) {
- if err := testutil.Pull("alpine"); err != nil {
+ if err := dockerutil.Pull("alpine"); err != nil {
t.Fatal("docker pull failed:", err)
}
- d := testutil.MakeDocker("cgroup-test")
+ d := dockerutil.MakeDocker("cgroup-test")
cmd := "cat /proc/meminfo | grep MemTotal: | awk '{print $2}'"
out, err := d.RunFg("--memory=500MB", "alpine", "sh", "-c", cmd)
if err != nil {
@@ -258,10 +261,10 @@ func TestMemLimit(t *testing.T) {
}
func TestNumCPU(t *testing.T) {
- if err := testutil.Pull("alpine"); err != nil {
+ if err := dockerutil.Pull("alpine"); err != nil {
t.Fatal("docker pull failed:", err)
}
- d := testutil.MakeDocker("cgroup-test")
+ d := dockerutil.MakeDocker("cgroup-test")
cmd := "cat /proc/cpuinfo | grep 'processor.*:' | wc -l"
out, err := d.RunFg("--cpuset-cpus=0", "alpine", "sh", "-c", cmd)
if err != nil {
@@ -280,10 +283,10 @@ func TestNumCPU(t *testing.T) {
// TestJobControl tests that job control characters are handled properly.
func TestJobControl(t *testing.T) {
- if err := testutil.Pull("alpine"); err != nil {
+ if err := dockerutil.Pull("alpine"); err != nil {
t.Fatalf("docker pull failed: %v", err)
}
- d := testutil.MakeDocker("job-control-test")
+ d := dockerutil.MakeDocker("job-control-test")
// Start the container with an attached PTY.
_, ptmx, err := d.RunWithPty("alpine", "sh")
@@ -328,10 +331,10 @@ func TestJobControl(t *testing.T) {
// TestTmpFile checks that files inside '/tmp' are not overridden. In addition,
// it checks that working dir is created if it doesn't exit.
func TestTmpFile(t *testing.T) {
- if err := testutil.Pull("alpine"); err != nil {
+ if err := dockerutil.Pull("alpine"); err != nil {
t.Fatal("docker pull failed:", err)
}
- d := testutil.MakeDocker("tmp-file-test")
+ d := dockerutil.MakeDocker("tmp-file-test")
if err := d.Run("-w=/tmp/foo/bar", "--read-only", "alpine", "touch", "/tmp/foo/bar/file"); err != nil {
t.Fatal("docker run failed:", err)
}
@@ -339,6 +342,7 @@ func TestTmpFile(t *testing.T) {
}
func TestMain(m *testing.M) {
- testutil.EnsureSupportedDockerVersion()
+ dockerutil.EnsureSupportedDockerVersion()
+ flag.Parse()
os.Exit(m.Run())
}
diff --git a/runsc/test/integration/regression_test.go b/test/e2e/regression_test.go
index fb68dda99..2488be383 100644
--- a/runsc/test/integration/regression_test.go
+++ b/test/e2e/regression_test.go
@@ -18,7 +18,7 @@ import (
"strings"
"testing"
- "gvisor.dev/gvisor/runsc/test/testutil"
+ "gvisor.dev/gvisor/runsc/dockerutil"
)
// Test that UDS can be created using overlay when parent directory is in lower
@@ -27,10 +27,10 @@ import (
// Prerequisite: the directory where the socket file is created must not have
// been open for write before bind(2) is called.
func TestBindOverlay(t *testing.T) {
- if err := testutil.Pull("ubuntu:trusty"); err != nil {
+ if err := dockerutil.Pull("ubuntu:trusty"); err != nil {
t.Fatal("docker pull failed:", err)
}
- d := testutil.MakeDocker("bind-overlay-test")
+ d := dockerutil.MakeDocker("bind-overlay-test")
cmd := "nc -l -U /var/run/sock & p=$! && sleep 1 && echo foobar-asdf | nc -U /var/run/sock && wait $p"
got, err := d.RunFg("ubuntu:trusty", "bash", "-c", cmd)
diff --git a/runsc/test/image/BUILD b/test/image/BUILD
index 58758fde5..09b0a0ad5 100644
--- a/runsc/test/image/BUILD
+++ b/test/image/BUILD
@@ -1,9 +1,8 @@
-load("@io_bazel_rules_go//go:def.bzl", "go_library")
-load("//runsc/test:build_defs.bzl", "runtime_test")
+load("@io_bazel_rules_go//go:def.bzl", "go_library", "go_test")
package(licenses = ["notice"])
-runtime_test(
+go_test(
name = "image_test",
size = "large",
srcs = [
@@ -21,11 +20,15 @@ runtime_test(
"manual",
"local",
],
- deps = ["//runsc/test/testutil"],
+ visibility = ["//:sandbox"],
+ deps = [
+ "//runsc/dockerutil",
+ "//runsc/testutil",
+ ],
)
go_library(
name = "image",
srcs = ["image.go"],
- importpath = "gvisor.dev/gvisor/runsc/test/image",
+ importpath = "gvisor.dev/gvisor/test/image",
)
diff --git a/runsc/test/image/image.go b/test/image/image.go
index 297f1ab92..297f1ab92 100644
--- a/runsc/test/image/image.go
+++ b/test/image/image.go
diff --git a/runsc/test/image/image_test.go b/test/image/image_test.go
index ddaa2c13b..d0dcb1861 100644
--- a/runsc/test/image/image_test.go
+++ b/test/image/image_test.go
@@ -14,14 +14,15 @@
// Package image provides end-to-end image tests for runsc.
-// Each test calls docker commands to start up a container, and tests that it is
-// behaving properly, like connecting to a port or looking at the output. The
-// container is killed and deleted at the end.
+// Each test calls docker commands to start up a container, and tests that it
+// is behaving properly, like connecting to a port or looking at the output.
+// The container is killed and deleted at the end.
//
-// Setup instruction in runsc/test/README.md.
+// Setup instruction in test/README.md.
package image
import (
+ "flag"
"fmt"
"io/ioutil"
"log"
@@ -32,11 +33,12 @@ import (
"testing"
"time"
- "gvisor.dev/gvisor/runsc/test/testutil"
+ "gvisor.dev/gvisor/runsc/dockerutil"
+ "gvisor.dev/gvisor/runsc/testutil"
)
func TestHelloWorld(t *testing.T) {
- d := testutil.MakeDocker("hello-test")
+ d := dockerutil.MakeDocker("hello-test")
if err := d.Run("hello-world"); err != nil {
t.Fatalf("docker run failed: %v", err)
}
@@ -100,18 +102,18 @@ func testHTTPServer(t *testing.T, port int) {
}
func TestHttpd(t *testing.T) {
- if err := testutil.Pull("httpd"); err != nil {
+ if err := dockerutil.Pull("httpd"); err != nil {
t.Fatalf("docker pull failed: %v", err)
}
- d := testutil.MakeDocker("http-test")
+ d := dockerutil.MakeDocker("http-test")
- dir, err := testutil.PrepareFiles("latin10k.txt")
+ dir, err := dockerutil.PrepareFiles("latin10k.txt")
if err != nil {
t.Fatalf("PrepareFiles() failed: %v", err)
}
// Start the container.
- mountArg := testutil.MountArg(dir, "/usr/local/apache2/htdocs", testutil.ReadOnly)
+ mountArg := dockerutil.MountArg(dir, "/usr/local/apache2/htdocs", dockerutil.ReadOnly)
if err := d.Run("-p", "80", mountArg, "httpd"); err != nil {
t.Fatalf("docker run failed: %v", err)
}
@@ -132,18 +134,18 @@ func TestHttpd(t *testing.T) {
}
func TestNginx(t *testing.T) {
- if err := testutil.Pull("nginx"); err != nil {
+ if err := dockerutil.Pull("nginx"); err != nil {
t.Fatalf("docker pull failed: %v", err)
}
- d := testutil.MakeDocker("net-test")
+ d := dockerutil.MakeDocker("net-test")
- dir, err := testutil.PrepareFiles("latin10k.txt")
+ dir, err := dockerutil.PrepareFiles("latin10k.txt")
if err != nil {
t.Fatalf("PrepareFiles() failed: %v", err)
}
// Start the container.
- mountArg := testutil.MountArg(dir, "/usr/share/nginx/html", testutil.ReadOnly)
+ mountArg := dockerutil.MountArg(dir, "/usr/share/nginx/html", dockerutil.ReadOnly)
if err := d.Run("-p", "80", mountArg, "nginx"); err != nil {
t.Fatalf("docker run failed: %v", err)
}
@@ -164,10 +166,10 @@ func TestNginx(t *testing.T) {
}
func TestMysql(t *testing.T) {
- if err := testutil.Pull("mysql"); err != nil {
+ if err := dockerutil.Pull("mysql"); err != nil {
t.Fatalf("docker pull failed: %v", err)
}
- d := testutil.MakeDocker("mysql-test")
+ d := dockerutil.MakeDocker("mysql-test")
// Start the container.
if err := d.Run("-e", "MYSQL_ROOT_PASSWORD=foobar123", "mysql"); err != nil {
@@ -180,8 +182,8 @@ func TestMysql(t *testing.T) {
t.Fatalf("docker.WaitForOutput() timeout: %v", err)
}
- client := testutil.MakeDocker("mysql-client-test")
- dir, err := testutil.PrepareFiles("mysql.sql")
+ client := dockerutil.MakeDocker("mysql-client-test")
+ dir, err := dockerutil.PrepareFiles("mysql.sql")
if err != nil {
t.Fatalf("PrepareFiles() failed: %v", err)
}
@@ -189,8 +191,8 @@ func TestMysql(t *testing.T) {
// Tell mysql client to connect to the server and execute the file in verbose
// mode to verify the output.
args := []string{
- testutil.LinkArg(&d, "mysql"),
- testutil.MountArg(dir, "/sql", testutil.ReadWrite),
+ dockerutil.LinkArg(&d, "mysql"),
+ dockerutil.MountArg(dir, "/sql", dockerutil.ReadWrite),
"mysql",
"mysql", "-hmysql", "-uroot", "-pfoobar123", "-v", "-e", "source /sql/mysql.sql",
}
@@ -212,10 +214,10 @@ func TestPythonHello(t *testing.T) {
// TODO(b/136503277): Once we have more complete python runtime tests,
// we can drop this one.
const img = "gcr.io/gvisor-presubmit/python-hello"
- if err := testutil.Pull(img); err != nil {
+ if err := dockerutil.Pull(img); err != nil {
t.Fatalf("docker pull failed: %v", err)
}
- d := testutil.MakeDocker("python-hello-test")
+ d := dockerutil.MakeDocker("python-hello-test")
if err := d.Run("-p", "8080", img); err != nil {
t.Fatalf("docker run failed: %v", err)
}
@@ -244,10 +246,10 @@ func TestPythonHello(t *testing.T) {
}
func TestTomcat(t *testing.T) {
- if err := testutil.Pull("tomcat:8.0"); err != nil {
+ if err := dockerutil.Pull("tomcat:8.0"); err != nil {
t.Fatalf("docker pull failed: %v", err)
}
- d := testutil.MakeDocker("tomcat-test")
+ d := dockerutil.MakeDocker("tomcat-test")
if err := d.Run("-p", "8080", "tomcat:8.0"); err != nil {
t.Fatalf("docker run failed: %v", err)
}
@@ -276,12 +278,12 @@ func TestTomcat(t *testing.T) {
}
func TestRuby(t *testing.T) {
- if err := testutil.Pull("ruby"); err != nil {
+ if err := dockerutil.Pull("ruby"); err != nil {
t.Fatalf("docker pull failed: %v", err)
}
- d := testutil.MakeDocker("ruby-test")
+ d := dockerutil.MakeDocker("ruby-test")
- dir, err := testutil.PrepareFiles("ruby.rb", "ruby.sh")
+ dir, err := dockerutil.PrepareFiles("ruby.rb", "ruby.sh")
if err != nil {
t.Fatalf("PrepareFiles() failed: %v", err)
}
@@ -289,7 +291,7 @@ func TestRuby(t *testing.T) {
t.Fatalf("os.Chmod(%q, 0333) failed: %v", dir, err)
}
- if err := d.Run("-p", "8080", testutil.MountArg(dir, "/src", testutil.ReadOnly), "ruby", "/src/ruby.sh"); err != nil {
+ if err := d.Run("-p", "8080", dockerutil.MountArg(dir, "/src", dockerutil.ReadOnly), "ruby", "/src/ruby.sh"); err != nil {
t.Fatalf("docker run failed: %v", err)
}
defer d.CleanUp()
@@ -324,10 +326,10 @@ func TestRuby(t *testing.T) {
}
func TestStdio(t *testing.T) {
- if err := testutil.Pull("alpine"); err != nil {
+ if err := dockerutil.Pull("alpine"); err != nil {
t.Fatalf("docker pull failed: %v", err)
}
- d := testutil.MakeDocker("stdio-test")
+ d := dockerutil.MakeDocker("stdio-test")
wantStdout := "hello stdout"
wantStderr := "bonjour stderr"
@@ -345,6 +347,7 @@ func TestStdio(t *testing.T) {
}
func TestMain(m *testing.M) {
- testutil.EnsureSupportedDockerVersion()
+ dockerutil.EnsureSupportedDockerVersion()
+ flag.Parse()
os.Exit(m.Run())
}
diff --git a/runsc/test/image/latin10k.txt b/test/image/latin10k.txt
index 61341e00b..61341e00b 100644
--- a/runsc/test/image/latin10k.txt
+++ b/test/image/latin10k.txt
diff --git a/runsc/test/image/mysql.sql b/test/image/mysql.sql
index 51554b98d..51554b98d 100644
--- a/runsc/test/image/mysql.sql
+++ b/test/image/mysql.sql
diff --git a/runsc/test/image/ruby.rb b/test/image/ruby.rb
index aced49c6d..aced49c6d 100644
--- a/runsc/test/image/ruby.rb
+++ b/test/image/ruby.rb
diff --git a/runsc/test/image/ruby.sh b/test/image/ruby.sh
index ebe8d5b0e..ebe8d5b0e 100644
--- a/runsc/test/image/ruby.sh
+++ b/test/image/ruby.sh
diff --git a/runsc/test/root/BUILD b/test/root/BUILD
index 500ef7b8e..d5dd9bca2 100644
--- a/runsc/test/root/BUILD
+++ b/test/root/BUILD
@@ -5,7 +5,7 @@ package(licenses = ["notice"])
go_library(
name = "root",
srcs = ["root.go"],
- importpath = "gvisor.dev/gvisor/runsc/test/root",
+ importpath = "gvisor.dev/gvisor/test/root",
)
go_test(
@@ -15,6 +15,11 @@ go_test(
"cgroup_test.go",
"chroot_test.go",
"crictl_test.go",
+ "main_test.go",
+ "oom_score_adj_test.go",
+ ],
+ data = [
+ "//runsc",
],
embed = [":root"],
tags = [
@@ -23,11 +28,17 @@ go_test(
"manual",
"local",
],
+ visibility = ["//:sandbox"],
deps = [
+ "//runsc/boot",
"//runsc/cgroup",
+ "//runsc/container",
+ "//runsc/criutil",
+ "//runsc/dockerutil",
"//runsc/specutils",
- "//runsc/test/root/testdata",
- "//runsc/test/testutil",
+ "//runsc/testutil",
+ "//test/root/testdata",
+ "@com_github_opencontainers_runtime-spec//specs-go:go_default_library",
"@com_github_syndtr_gocapability//capability:go_default_library",
],
)
diff --git a/runsc/test/root/cgroup_test.go b/test/root/cgroup_test.go
index 5392dc6e0..76f1e4f2a 100644
--- a/runsc/test/root/cgroup_test.go
+++ b/test/root/cgroup_test.go
@@ -26,7 +26,8 @@ import (
"testing"
"gvisor.dev/gvisor/runsc/cgroup"
- "gvisor.dev/gvisor/runsc/test/testutil"
+ "gvisor.dev/gvisor/runsc/dockerutil"
+ "gvisor.dev/gvisor/runsc/testutil"
)
func verifyPid(pid int, path string) error {
@@ -56,11 +57,17 @@ func verifyPid(pid int, path string) error {
// TestCgroup sets cgroup options and checks that cgroup was properly configured.
func TestCgroup(t *testing.T) {
- if err := testutil.Pull("alpine"); err != nil {
+ if err := dockerutil.Pull("alpine"); err != nil {
t.Fatal("docker pull failed:", err)
}
- d := testutil.MakeDocker("cgroup-test")
+ d := dockerutil.MakeDocker("cgroup-test")
+ // This is not a comprehensive list of attributes.
+ //
+ // Note that we are specifically missing cpusets, which fail if specified.
+ // In any case, it's unclear if cpusets can be reliably tested here: these
+ // are often run on a single core virtual machine, and there is only a single
+ // CPU available in our current set, and every container's set.
attrs := []struct {
arg string
ctrl string
@@ -87,18 +94,6 @@ func TestCgroup(t *testing.T) {
want: "3000",
},
{
- arg: "--cpuset-cpus=0",
- ctrl: "cpuset",
- file: "cpuset.cpus",
- want: "0",
- },
- {
- arg: "--cpuset-mems=0",
- ctrl: "cpuset",
- file: "cpuset.mems",
- want: "0",
- },
- {
arg: "--kernel-memory=100MB",
ctrl: "memory",
file: "memory.kmem.limit_in_bytes",
@@ -197,10 +192,10 @@ func TestCgroup(t *testing.T) {
}
func TestCgroupParent(t *testing.T) {
- if err := testutil.Pull("alpine"); err != nil {
+ if err := dockerutil.Pull("alpine"); err != nil {
t.Fatal("docker pull failed:", err)
}
- d := testutil.MakeDocker("cgroup-test")
+ d := dockerutil.MakeDocker("cgroup-test")
parent := testutil.RandomName("runsc")
if err := d.Run("--cgroup-parent", parent, "alpine", "sleep", "10000"); err != nil {
diff --git a/runsc/test/root/chroot_test.go b/test/root/chroot_test.go
index d0f236580..be0f63d18 100644
--- a/runsc/test/root/chroot_test.go
+++ b/test/root/chroot_test.go
@@ -12,33 +12,25 @@
// See the License for the specific language governing permissions and
// limitations under the License.
-// Package root is used for tests that requires sysadmin privileges run. First,
-// follow the setup instruction in runsc/test/README.md. To run these tests:
-//
-// bazel build //runsc/test/root:root_test
-// root_test=$(find -L ./bazel-bin/ -executable -type f -name root_test | grep __main__)
-// sudo RUNSC_RUNTIME=runsc-test ${root_test}
+// Package root is used for tests that requires sysadmin privileges run.
package root
import (
"fmt"
"io/ioutil"
- "os"
"os/exec"
"path/filepath"
"strconv"
"strings"
"testing"
- "github.com/syndtr/gocapability/capability"
- "gvisor.dev/gvisor/runsc/specutils"
- "gvisor.dev/gvisor/runsc/test/testutil"
+ "gvisor.dev/gvisor/runsc/dockerutil"
)
// TestChroot verifies that the sandbox is chroot'd and that mounts are cleaned
// up after the sandbox is destroyed.
func TestChroot(t *testing.T) {
- d := testutil.MakeDocker("chroot-test")
+ d := dockerutil.MakeDocker("chroot-test")
if err := d.Run("alpine", "sleep", "10000"); err != nil {
t.Fatalf("docker run failed: %v", err)
}
@@ -84,7 +76,7 @@ func TestChroot(t *testing.T) {
}
func TestChrootGofer(t *testing.T) {
- d := testutil.MakeDocker("chroot-test")
+ d := dockerutil.MakeDocker("chroot-test")
if err := d.Run("alpine", "sleep", "10000"); err != nil {
t.Fatalf("docker run failed: %v", err)
}
@@ -148,14 +140,3 @@ func TestChrootGofer(t *testing.T) {
}
}
}
-
-func TestMain(m *testing.M) {
- testutil.EnsureSupportedDockerVersion()
-
- if !specutils.HasCapabilities(capability.CAP_SYS_ADMIN, capability.CAP_DAC_OVERRIDE) {
- fmt.Println("Test requires sysadmin privileges to run. Try again with sudo.")
- os.Exit(1)
- }
-
- os.Exit(m.Run())
-}
diff --git a/runsc/test/root/crictl_test.go b/test/root/crictl_test.go
index 515ae2df1..d597664f5 100644
--- a/runsc/test/root/crictl_test.go
+++ b/test/root/crictl_test.go
@@ -29,14 +29,17 @@ import (
"testing"
"time"
+ "gvisor.dev/gvisor/runsc/criutil"
+ "gvisor.dev/gvisor/runsc/dockerutil"
"gvisor.dev/gvisor/runsc/specutils"
- "gvisor.dev/gvisor/runsc/test/root/testdata"
- "gvisor.dev/gvisor/runsc/test/testutil"
+ "gvisor.dev/gvisor/runsc/testutil"
+ "gvisor.dev/gvisor/test/root/testdata"
)
// Tests for crictl have to be run as root (rather than in a user namespace)
// because crictl creates named network namespaces in /var/run/netns/.
+// TestCrictlSanity refers to b/112433158.
func TestCrictlSanity(t *testing.T) {
// Setup containerd and crictl.
crictl, cleanup, err := setup(t)
@@ -60,6 +63,7 @@ func TestCrictlSanity(t *testing.T) {
}
}
+// TestMountPaths refers to b/117635704.
func TestMountPaths(t *testing.T) {
// Setup containerd and crictl.
crictl, cleanup, err := setup(t)
@@ -83,6 +87,7 @@ func TestMountPaths(t *testing.T) {
}
}
+// TestMountPaths refers to b/118728671.
func TestMountOverSymlinks(t *testing.T) {
// Setup containerd and crictl.
crictl, cleanup, err := setup(t)
@@ -125,7 +130,7 @@ func TestMountOverSymlinks(t *testing.T) {
// * Creates directories and a socket for containerd to utilize.
// * Runs containerd and waits for it to reach a "ready" state for testing.
// * Returns a cleanup function that should be called at the end of the test.
-func setup(t *testing.T) (*testutil.Crictl, func(), error) {
+func setup(t *testing.T) (*criutil.Crictl, func(), error) {
var cleanups []func()
cleanupFunc := func() {
for i := len(cleanups) - 1; i >= 0; i-- {
@@ -149,12 +154,19 @@ func setup(t *testing.T) (*testutil.Crictl, func(), error) {
cleanups = append(cleanups, func() { os.RemoveAll(containerdState) })
sockAddr := filepath.Join(testutil.TmpDir(), "containerd-test.sock")
- // Start containerd.
- config, err := testutil.WriteTmpFile("containerd-config", testdata.ContainerdConfig(getRunsc()))
+ // We rewrite a configuration. This is based on the current docker
+ // configuration for the runtime under test.
+ runtime, err := dockerutil.RuntimePath()
+ if err != nil {
+ t.Fatalf("error discovering runtime path: %v", err)
+ }
+ config, err := testutil.WriteTmpFile("containerd-config", testdata.ContainerdConfig(runtime))
if err != nil {
t.Fatalf("failed to write containerd config")
}
cleanups = append(cleanups, func() { os.RemoveAll(config) })
+
+ // Start containerd.
containerd := exec.Command(getContainerd(),
"--config", config,
"--log-level", "debug",
@@ -191,11 +203,11 @@ func setup(t *testing.T) (*testutil.Crictl, func(), error) {
})
cleanup.Release()
- return testutil.NewCrictl(20*time.Second, sockAddr), cleanupFunc, nil
+ return criutil.NewCrictl(20*time.Second, sockAddr), cleanupFunc, nil
}
// httpGet GETs the contents of a file served from a pod on port 80.
-func httpGet(crictl *testutil.Crictl, podID, filePath string) error {
+func httpGet(crictl *criutil.Crictl, podID, filePath string) error {
// Get the IP of the httpd server.
ip, err := crictl.PodIP(podID)
if err != nil {
@@ -222,21 +234,9 @@ func httpGet(crictl *testutil.Crictl, podID, filePath string) error {
}
func getContainerd() string {
- // Bazel doesn't pass PATH through, assume the location of containerd
- // unless specified by environment variable.
- c := os.Getenv("CONTAINERD_PATH")
- if c == "" {
+ // Use the local path if it exists, otherwise, use the system one.
+ if _, err := os.Stat("/usr/local/bin/containerd"); err == nil {
return "/usr/local/bin/containerd"
}
- return c
-}
-
-func getRunsc() string {
- // Bazel doesn't pass PATH through, assume the location of runsc unless
- // specified by environment variable.
- c := os.Getenv("RUNSC_EXEC")
- if c == "" {
- return "/tmp/runsc-test/runsc"
- }
- return c
+ return "/usr/bin/containerd"
}
diff --git a/test/root/main_test.go b/test/root/main_test.go
new file mode 100644
index 000000000..d74dec85f
--- /dev/null
+++ b/test/root/main_test.go
@@ -0,0 +1,49 @@
+// Copyright 2018 The gVisor Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+package root
+
+import (
+ "flag"
+ "fmt"
+ "os"
+ "testing"
+
+ "github.com/syndtr/gocapability/capability"
+ "gvisor.dev/gvisor/runsc/dockerutil"
+ "gvisor.dev/gvisor/runsc/specutils"
+)
+
+// TestMain is the main function for root tests. This function checks the
+// supported docker version, required capabilities, and configures the executable
+// path for runsc.
+func TestMain(m *testing.M) {
+ flag.Parse()
+
+ if !specutils.HasCapabilities(capability.CAP_SYS_ADMIN, capability.CAP_DAC_OVERRIDE) {
+ fmt.Println("Test requires sysadmin privileges to run. Try again with sudo.")
+ os.Exit(1)
+ }
+
+ dockerutil.EnsureSupportedDockerVersion()
+
+ // Configure exe for tests.
+ path, err := dockerutil.RuntimePath()
+ if err != nil {
+ panic(err.Error())
+ }
+ specutils.ExePath = path
+
+ os.Exit(m.Run())
+}
diff --git a/test/root/oom_score_adj_test.go b/test/root/oom_score_adj_test.go
new file mode 100644
index 000000000..6cd378a1b
--- /dev/null
+++ b/test/root/oom_score_adj_test.go
@@ -0,0 +1,376 @@
+// Copyright 2018 The gVisor Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+package root
+
+import (
+ "fmt"
+ "os"
+ "testing"
+
+ specs "github.com/opencontainers/runtime-spec/specs-go"
+ "gvisor.dev/gvisor/runsc/boot"
+ "gvisor.dev/gvisor/runsc/container"
+ "gvisor.dev/gvisor/runsc/specutils"
+ "gvisor.dev/gvisor/runsc/testutil"
+)
+
+var (
+ maxOOMScoreAdj = 1000
+ highOOMScoreAdj = 500
+ lowOOMScoreAdj = -500
+ minOOMScoreAdj = -1000
+)
+
+// Tests for oom_score_adj have to be run as root (rather than in a user
+// namespace) because we need to adjust oom_score_adj for PIDs other than our
+// own and test values below 0.
+
+// TestOOMScoreAdjSingle tests that oom_score_adj is set properly in a
+// single container sandbox.
+func TestOOMScoreAdjSingle(t *testing.T) {
+ ppid, err := specutils.GetParentPid(os.Getpid())
+ if err != nil {
+ t.Fatalf("getting parent pid: %v", err)
+ }
+ parentOOMScoreAdj, err := specutils.GetOOMScoreAdj(ppid)
+ if err != nil {
+ t.Fatalf("getting parent oom_score_adj: %v", err)
+ }
+
+ testCases := []struct {
+ Name string
+
+ // OOMScoreAdj is the oom_score_adj set to the OCI spec. If nil then
+ // no value is set.
+ OOMScoreAdj *int
+ }{
+ {
+ Name: "max",
+ OOMScoreAdj: &maxOOMScoreAdj,
+ },
+ {
+ Name: "high",
+ OOMScoreAdj: &highOOMScoreAdj,
+ },
+ {
+ Name: "low",
+ OOMScoreAdj: &lowOOMScoreAdj,
+ },
+ {
+ Name: "min",
+ OOMScoreAdj: &minOOMScoreAdj,
+ },
+ {
+ Name: "nil",
+ OOMScoreAdj: &parentOOMScoreAdj,
+ },
+ }
+
+ for _, testCase := range testCases {
+ t.Run(testCase.Name, func(t *testing.T) {
+ id := testutil.UniqueContainerID()
+ s := testutil.NewSpecWithArgs("sleep", "1000")
+ s.Process.OOMScoreAdj = testCase.OOMScoreAdj
+
+ conf := testutil.TestConfig()
+ containers, cleanup, err := startContainers(conf, []*specs.Spec{s}, []string{id})
+ if err != nil {
+ t.Fatalf("error starting containers: %v", err)
+ }
+ defer cleanup()
+
+ c := containers[0]
+
+ // Verify the gofer's oom_score_adj
+ if testCase.OOMScoreAdj != nil {
+ goferScore, err := specutils.GetOOMScoreAdj(c.GoferPid)
+ if err != nil {
+ t.Fatalf("error reading gofer oom_score_adj: %v", err)
+ }
+ if goferScore != *testCase.OOMScoreAdj {
+ t.Errorf("gofer oom_score_adj got: %d, want: %d", goferScore, *testCase.OOMScoreAdj)
+ }
+
+ // Verify the sandbox's oom_score_adj.
+ //
+ // The sandbox should be the same for all containers so just use
+ // the first one.
+ sandboxPid := c.Sandbox.Pid
+ sandboxScore, err := specutils.GetOOMScoreAdj(sandboxPid)
+ if err != nil {
+ t.Fatalf("error reading sandbox oom_score_adj: %v", err)
+ }
+ if sandboxScore != *testCase.OOMScoreAdj {
+ t.Errorf("sandbox oom_score_adj got: %d, want: %d", sandboxScore, *testCase.OOMScoreAdj)
+ }
+ }
+ })
+ }
+}
+
+// TestOOMScoreAdjMulti tests that oom_score_adj is set properly in a
+// multi-container sandbox.
+func TestOOMScoreAdjMulti(t *testing.T) {
+ ppid, err := specutils.GetParentPid(os.Getpid())
+ if err != nil {
+ t.Fatalf("getting parent pid: %v", err)
+ }
+ parentOOMScoreAdj, err := specutils.GetOOMScoreAdj(ppid)
+ if err != nil {
+ t.Fatalf("getting parent oom_score_adj: %v", err)
+ }
+
+ testCases := []struct {
+ Name string
+
+ // OOMScoreAdj is the oom_score_adj set to the OCI spec. If nil then
+ // no value is set. One value for each container. The first value is the
+ // root container.
+ OOMScoreAdj []*int
+
+ // Expected is the expected oom_score_adj of the sandbox. If nil, then
+ // this value is ignored.
+ Expected *int
+
+ // Remove is a set of container indexes to remove from the sandbox.
+ Remove []int
+
+ // ExpectedAfterRemove is the expected oom_score_adj of the sandbox
+ // after containers are removed. Ignored if nil.
+ ExpectedAfterRemove *int
+ }{
+ // A single container CRI test case. This should not happen in
+ // practice as there should be at least one container besides the pause
+ // container. However, we include a test case to ensure sane behavior.
+ {
+ Name: "single",
+ OOMScoreAdj: []*int{&highOOMScoreAdj},
+ Expected: &parentOOMScoreAdj,
+ },
+ {
+ Name: "multi_no_value",
+ OOMScoreAdj: []*int{nil, nil, nil},
+ Expected: &parentOOMScoreAdj,
+ },
+ {
+ Name: "multi_non_nil_root",
+ OOMScoreAdj: []*int{&minOOMScoreAdj, nil, nil},
+ Expected: &parentOOMScoreAdj,
+ },
+ {
+ Name: "multi_value",
+ OOMScoreAdj: []*int{&minOOMScoreAdj, &highOOMScoreAdj, &lowOOMScoreAdj},
+ // The lowest value excluding the root container is expected.
+ Expected: &lowOOMScoreAdj,
+ },
+ {
+ Name: "multi_min_value",
+ OOMScoreAdj: []*int{&minOOMScoreAdj, &lowOOMScoreAdj},
+ // The lowest value excluding the root container is expected.
+ Expected: &lowOOMScoreAdj,
+ },
+ {
+ Name: "multi_max_value",
+ OOMScoreAdj: []*int{&minOOMScoreAdj, &maxOOMScoreAdj, &highOOMScoreAdj},
+ // The lowest value excluding the root container is expected.
+ Expected: &highOOMScoreAdj,
+ },
+ {
+ Name: "remove_adjusted",
+ OOMScoreAdj: []*int{&minOOMScoreAdj, &maxOOMScoreAdj, &highOOMScoreAdj},
+ // The lowest value excluding the root container is expected.
+ Expected: &highOOMScoreAdj,
+ // Remove highOOMScoreAdj container.
+ Remove: []int{2},
+ ExpectedAfterRemove: &maxOOMScoreAdj,
+ },
+ {
+ // This test removes all non-root sandboxes with a specified oomScoreAdj.
+ Name: "remove_to_nil",
+ OOMScoreAdj: []*int{&minOOMScoreAdj, nil, &lowOOMScoreAdj},
+ Expected: &lowOOMScoreAdj,
+ // Remove lowOOMScoreAdj container.
+ Remove: []int{2},
+ // The oom_score_adj expected after remove is that of the parent process.
+ ExpectedAfterRemove: &parentOOMScoreAdj,
+ },
+ {
+ Name: "remove_no_effect",
+ OOMScoreAdj: []*int{&minOOMScoreAdj, &maxOOMScoreAdj, &highOOMScoreAdj},
+ // The lowest value excluding the root container is expected.
+ Expected: &highOOMScoreAdj,
+ // Remove the maxOOMScoreAdj container.
+ Remove: []int{1},
+ ExpectedAfterRemove: &highOOMScoreAdj,
+ },
+ }
+
+ for _, testCase := range testCases {
+ t.Run(testCase.Name, func(t *testing.T) {
+ var cmds [][]string
+ var oomScoreAdj []*int
+ var toRemove []string
+
+ for _, oomScore := range testCase.OOMScoreAdj {
+ oomScoreAdj = append(oomScoreAdj, oomScore)
+ cmds = append(cmds, []string{"sleep", "100"})
+ }
+
+ specs, ids := createSpecs(cmds...)
+ for i, spec := range specs {
+ // Ensure the correct value is set, including no value.
+ spec.Process.OOMScoreAdj = oomScoreAdj[i]
+
+ for _, j := range testCase.Remove {
+ if i == j {
+ toRemove = append(toRemove, ids[i])
+ }
+ }
+ }
+
+ conf := testutil.TestConfig()
+ containers, cleanup, err := startContainers(conf, specs, ids)
+ if err != nil {
+ t.Fatalf("error starting containers: %v", err)
+ }
+ defer cleanup()
+
+ for i, c := range containers {
+ if oomScoreAdj[i] != nil {
+ // Verify the gofer's oom_score_adj
+ score, err := specutils.GetOOMScoreAdj(c.GoferPid)
+ if err != nil {
+ t.Fatalf("error reading gofer oom_score_adj: %v", err)
+ }
+ if score != *oomScoreAdj[i] {
+ t.Errorf("gofer oom_score_adj got: %d, want: %d", score, *oomScoreAdj[i])
+ }
+ }
+ }
+
+ // Verify the sandbox's oom_score_adj.
+ //
+ // The sandbox should be the same for all containers so just use
+ // the first one.
+ sandboxPid := containers[0].Sandbox.Pid
+ if testCase.Expected != nil {
+ score, err := specutils.GetOOMScoreAdj(sandboxPid)
+ if err != nil {
+ t.Fatalf("error reading sandbox oom_score_adj: %v", err)
+ }
+ if score != *testCase.Expected {
+ t.Errorf("sandbox oom_score_adj got: %d, want: %d", score, *testCase.Expected)
+ }
+ }
+
+ if len(toRemove) == 0 {
+ return
+ }
+
+ // Remove containers.
+ for _, removeID := range toRemove {
+ for _, c := range containers {
+ if c.ID == removeID {
+ c.Destroy()
+ }
+ }
+ }
+
+ // Check the new adjusted oom_score_adj.
+ if testCase.ExpectedAfterRemove != nil {
+ scoreAfterRemove, err := specutils.GetOOMScoreAdj(sandboxPid)
+ if err != nil {
+ t.Fatalf("error reading sandbox oom_score_adj: %v", err)
+ }
+ if scoreAfterRemove != *testCase.ExpectedAfterRemove {
+ t.Errorf("sandbox oom_score_adj got: %d, want: %d", scoreAfterRemove, *testCase.ExpectedAfterRemove)
+ }
+ }
+ })
+ }
+}
+
+func createSpecs(cmds ...[]string) ([]*specs.Spec, []string) {
+ var specs []*specs.Spec
+ var ids []string
+ rootID := testutil.UniqueContainerID()
+
+ for i, cmd := range cmds {
+ spec := testutil.NewSpecWithArgs(cmd...)
+ if i == 0 {
+ spec.Annotations = map[string]string{
+ specutils.ContainerdContainerTypeAnnotation: specutils.ContainerdContainerTypeSandbox,
+ }
+ ids = append(ids, rootID)
+ } else {
+ spec.Annotations = map[string]string{
+ specutils.ContainerdContainerTypeAnnotation: specutils.ContainerdContainerTypeContainer,
+ specutils.ContainerdSandboxIDAnnotation: rootID,
+ }
+ ids = append(ids, testutil.UniqueContainerID())
+ }
+ specs = append(specs, spec)
+ }
+ return specs, ids
+}
+
+func startContainers(conf *boot.Config, specs []*specs.Spec, ids []string) ([]*container.Container, func(), error) {
+ // Setup root dir if one hasn't been provided.
+ if len(conf.RootDir) == 0 {
+ rootDir, err := testutil.SetupRootDir()
+ if err != nil {
+ return nil, nil, fmt.Errorf("error creating root dir: %v", err)
+ }
+ conf.RootDir = rootDir
+ }
+
+ var containers []*container.Container
+ var bundles []string
+ cleanup := func() {
+ for _, c := range containers {
+ c.Destroy()
+ }
+ for _, b := range bundles {
+ os.RemoveAll(b)
+ }
+ os.RemoveAll(conf.RootDir)
+ }
+ for i, spec := range specs {
+ bundleDir, err := testutil.SetupBundleDir(spec)
+ if err != nil {
+ cleanup()
+ return nil, nil, fmt.Errorf("error setting up container: %v", err)
+ }
+ bundles = append(bundles, bundleDir)
+
+ args := container.Args{
+ ID: ids[i],
+ Spec: spec,
+ BundleDir: bundleDir,
+ }
+ cont, err := container.New(conf, args)
+ if err != nil {
+ cleanup()
+ return nil, nil, fmt.Errorf("error creating container: %v", err)
+ }
+ containers = append(containers, cont)
+
+ if err := cont.Start(conf); err != nil {
+ cleanup()
+ return nil, nil, fmt.Errorf("error starting container: %v", err)
+ }
+ }
+ return containers, cleanup, nil
+}
diff --git a/test/root/root.go b/test/root/root.go
new file mode 100644
index 000000000..0f1d29faf
--- /dev/null
+++ b/test/root/root.go
@@ -0,0 +1,21 @@
+// Copyright 2018 The gVisor Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+// Package root is used for tests that requires sysadmin privileges run. First,
+// follow the setup instruction in runsc/test/README.md. You should also have
+// docker, containerd, and crictl installed. To run these tests from the
+// project root directory:
+//
+// ./scripts/root_tests.sh
+package root
diff --git a/runsc/test/root/testdata/BUILD b/test/root/testdata/BUILD
index 80dc5f214..14c19ef1e 100644
--- a/runsc/test/root/testdata/BUILD
+++ b/test/root/testdata/BUILD
@@ -11,7 +11,7 @@ go_library(
"httpd_mount_paths.go",
"sandbox.go",
],
- importpath = "gvisor.dev/gvisor/runsc/test/root/testdata",
+ importpath = "gvisor.dev/gvisor/test/root/testdata",
visibility = [
"//visibility:public",
],
diff --git a/runsc/test/root/testdata/busybox.go b/test/root/testdata/busybox.go
index e4dbd2843..e4dbd2843 100644
--- a/runsc/test/root/testdata/busybox.go
+++ b/test/root/testdata/busybox.go
diff --git a/runsc/test/root/testdata/containerd_config.go b/test/root/testdata/containerd_config.go
index e12f1ec88..e12f1ec88 100644
--- a/runsc/test/root/testdata/containerd_config.go
+++ b/test/root/testdata/containerd_config.go
diff --git a/runsc/test/root/testdata/httpd.go b/test/root/testdata/httpd.go
index 45d5e33d4..45d5e33d4 100644
--- a/runsc/test/root/testdata/httpd.go
+++ b/test/root/testdata/httpd.go
diff --git a/runsc/test/root/testdata/httpd_mount_paths.go b/test/root/testdata/httpd_mount_paths.go
index ac3f4446a..ac3f4446a 100644
--- a/runsc/test/root/testdata/httpd_mount_paths.go
+++ b/test/root/testdata/httpd_mount_paths.go
diff --git a/runsc/test/root/testdata/sandbox.go b/test/root/testdata/sandbox.go
index 0db210370..0db210370 100644
--- a/runsc/test/root/testdata/sandbox.go
+++ b/test/root/testdata/sandbox.go
diff --git a/test/runtimes/BUILD b/test/runtimes/BUILD
index e85804a83..1cde74cfc 100644
--- a/test/runtimes/BUILD
+++ b/test/runtimes/BUILD
@@ -1,25 +1,42 @@
# These packages are used to run language runtime tests inside gVisor sandboxes.
-load("@io_bazel_rules_go//go:def.bzl", "go_library")
-load("//runsc/test:build_defs.bzl", "runtime_test")
+load("@io_bazel_rules_go//go:def.bzl", "go_binary")
+load("//test/runtimes:build_defs.bzl", "runtime_test")
package(licenses = ["notice"])
-go_library(
- name = "runtimes",
- srcs = ["runtimes.go"],
- importpath = "gvisor.dev/gvisor/test/runtimes",
+go_binary(
+ name = "runner",
+ testonly = 1,
+ srcs = ["runner.go"],
+ deps = [
+ "//runsc/dockerutil",
+ "//runsc/testutil",
+ ],
)
runtime_test(
- name = "runtimes_test",
- size = "small",
- srcs = ["runtimes_test.go"],
- embed = [":runtimes"],
- tags = [
- # Requires docker and runsc to be configured before the test runs.
- "manual",
- "local",
- ],
- deps = ["//runsc/test/testutil"],
+ image = "gcr.io/gvisor-presubmit/go1.12",
+ lang = "go",
+)
+
+runtime_test(
+ image = "gcr.io/gvisor-presubmit/java11",
+ lang = "java",
+)
+
+runtime_test(
+ blacklist_file = "blacklist_nodejs12.4.0.csv",
+ image = "gcr.io/gvisor-presubmit/nodejs12.4.0",
+ lang = "nodejs",
+)
+
+runtime_test(
+ image = "gcr.io/gvisor-presubmit/php7.3.6",
+ lang = "php",
+)
+
+runtime_test(
+ image = "gcr.io/gvisor-presubmit/python3.7.3",
+ lang = "python",
)
diff --git a/test/runtimes/README.md b/test/runtimes/README.md
index 34d3507be..e41e78f77 100644
--- a/test/runtimes/README.md
+++ b/test/runtimes/README.md
@@ -16,10 +16,11 @@ The following runtimes are currently supported:
1) [Install and configure Docker](https://docs.docker.com/install/)
-2) Build each Docker container from the runtimes directory:
+2) Build each Docker container from the runtimes/images directory:
```bash
-$ docker build -f $LANG/Dockerfile [-t $NAME] .
+$ cd images
+$ docker build -f Dockerfile_$LANG [-t $NAME] .
```
### Testing:
diff --git a/test/runtimes/blacklist_nodejs12.4.0.csv b/test/runtimes/blacklist_nodejs12.4.0.csv
new file mode 100644
index 000000000..9135d763c
--- /dev/null
+++ b/test/runtimes/blacklist_nodejs12.4.0.csv
@@ -0,0 +1,47 @@
+test name,bug id,comment
+benchmark/test-benchmark-fs.js,,
+benchmark/test-benchmark-module.js,,
+benchmark/test-benchmark-napi.js,,
+doctool/test-make-doc.js,b/68848110,Expected to fail.
+fixtures/test-error-first-line-offset.js,,
+fixtures/test-fs-readfile-error.js,,
+fixtures/test-fs-stat-sync-overflow.js,,
+internet/test-dgram-broadcast-multi-process.js,,
+internet/test-dgram-multicast-multi-process.js,,
+internet/test-dgram-multicast-set-interface-lo.js,,
+parallel/test-cluster-dgram-reuse.js,b/64024294,
+parallel/test-dgram-bind-fd.js,b/132447356,
+parallel/test-dgram-create-socket-handle-fd.js,b/132447238,
+parallel/test-dgram-createSocket-type.js,b/68847739,
+parallel/test-dgram-socket-buffer-size.js,b/68847921,
+parallel/test-fs-access.js,,
+parallel/test-fs-write-stream-double-close.js
+parallel/test-fs-write-stream-throw-type-error.js,b/110226209,
+parallel/test-fs-write-stream.js,,
+parallel/test-http2-respond-file-error-pipe-offset.js,,
+parallel/test-os.js,,
+parallel/test-process-uid-gid.js,,
+pseudo-tty/test-assert-colors.js,,
+pseudo-tty/test-assert-no-color.js,,
+pseudo-tty/test-assert-position-indicator.js,,
+pseudo-tty/test-async-wrap-getasyncid-tty.js,,
+pseudo-tty/test-fatal-error.js,,
+pseudo-tty/test-handle-wrap-isrefed-tty.js,,
+pseudo-tty/test-readable-tty-keepalive.js,,
+pseudo-tty/test-set-raw-mode-reset-process-exit.js,,
+pseudo-tty/test-set-raw-mode-reset-signal.js,,
+pseudo-tty/test-set-raw-mode-reset.js,,
+pseudo-tty/test-stderr-stdout-handle-sigwinch.js,,
+pseudo-tty/test-stdout-read.js,,
+pseudo-tty/test-tty-color-support.js,,
+pseudo-tty/test-tty-isatty.js,,
+pseudo-tty/test-tty-stdin-call-end.js,,
+pseudo-tty/test-tty-stdin-end.js,,
+pseudo-tty/test-stdin-write.js,,
+pseudo-tty/test-tty-stdout-end.js,,
+pseudo-tty/test-tty-stdout-resize.js,,
+pseudo-tty/test-tty-stream-constructors.js,,
+pseudo-tty/test-tty-window-size.js,,
+pseudo-tty/test-tty-wrap.js,,
+pummel/test-net-pingpong.js,,
+pummel/test-vm-memleak.js,,
diff --git a/test/runtimes/build_defs.bzl b/test/runtimes/build_defs.bzl
new file mode 100644
index 000000000..7edd12c17
--- /dev/null
+++ b/test/runtimes/build_defs.bzl
@@ -0,0 +1,42 @@
+"""Defines a rule for runsc test targets."""
+
+# runtime_test is a macro that will create targets to run the given test target
+# with different runtime options.
+def runtime_test(
+ lang,
+ image,
+ shard_count = 50,
+ size = "enormous",
+ blacklist_file = ""):
+ args = [
+ "--lang",
+ lang,
+ "--image",
+ image,
+ ]
+ data = [
+ ":runner",
+ ]
+ if blacklist_file != "":
+ args += ["--blacklist_file", "test/runtimes/" + blacklist_file]
+ data += [blacklist_file]
+
+ sh_test(
+ name = lang + "_test",
+ srcs = ["runner.sh"],
+ args = args,
+ data = data,
+ size = size,
+ shard_count = shard_count,
+ tags = [
+ # Requires docker and runsc to be configured before the test runs.
+ "manual",
+ "local",
+ ],
+ )
+
+def sh_test(**kwargs):
+ """Wraps the standard sh_test."""
+ native.sh_test(
+ **kwargs
+ )
diff --git a/test/runtimes/common/BUILD b/test/runtimes/common/BUILD
deleted file mode 100644
index 1b39606b8..000000000
--- a/test/runtimes/common/BUILD
+++ /dev/null
@@ -1,20 +0,0 @@
-load("@io_bazel_rules_go//go:def.bzl", "go_library", "go_test")
-
-package(licenses = ["notice"])
-
-go_library(
- name = "common",
- srcs = ["common.go"],
- importpath = "gvisor.dev/gvisor/test/runtimes/common",
- visibility = ["//:sandbox"],
-)
-
-go_test(
- name = "common_test",
- size = "small",
- srcs = ["common_test.go"],
- deps = [
- ":common",
- "//runsc/test/testutil",
- ],
-)
diff --git a/test/runtimes/common/common.go b/test/runtimes/common/common.go
deleted file mode 100644
index 0ff87fa8b..000000000
--- a/test/runtimes/common/common.go
+++ /dev/null
@@ -1,114 +0,0 @@
-// Copyright 2019 The gVisor Authors.
-//
-// Licensed under the Apache License, Version 2.0 (the "License");
-// you may not use this file except in compliance with the License.
-// You may obtain a copy of the License at
-//
-// http://www.apache.org/licenses/LICENSE-2.0
-//
-// Unless required by applicable law or agreed to in writing, software
-// distributed under the License is distributed on an "AS IS" BASIS,
-// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-// See the License for the specific language governing permissions and
-// limitations under the License.
-
-// Package common executes functions for proctor binaries.
-package common
-
-import (
- "flag"
- "fmt"
- "os"
- "path/filepath"
- "regexp"
-)
-
-var (
- list = flag.Bool("list", false, "list all available tests")
- test = flag.String("test", "", "run a single test from the list of available tests")
- version = flag.Bool("v", false, "print out the version of node that is installed")
-)
-
-// TestRunner is an interface to be implemented in each proctor binary.
-type TestRunner interface {
- // ListTests returns a string slice of tests available to run.
- ListTests() ([]string, error)
-
- // RunTest runs a single test.
- RunTest(test string) error
-}
-
-// LaunchFunc parses flags passed by a proctor binary and calls the requested behavior.
-func LaunchFunc(tr TestRunner) error {
- flag.Parse()
-
- if *list && *test != "" {
- flag.PrintDefaults()
- return fmt.Errorf("cannot specify 'list' and 'test' flags simultaneously")
- }
- if *list {
- tests, err := tr.ListTests()
- if err != nil {
- return fmt.Errorf("failed to list tests: %v", err)
- }
- for _, test := range tests {
- fmt.Println(test)
- }
- return nil
- }
- if *version {
- fmt.Println(os.Getenv("LANG_NAME"), "version:", os.Getenv("LANG_VER"), "is installed.")
- return nil
- }
- if *test != "" {
- if err := tr.RunTest(*test); err != nil {
- return fmt.Errorf("test %q failed to run: %v", *test, err)
- }
- return nil
- }
-
- if err := runAllTests(tr); err != nil {
- return fmt.Errorf("error running all tests: %v", err)
- }
- return nil
-}
-
-// Search uses filepath.Walk to perform a search of the disk for test files
-// and returns a string slice of tests.
-func Search(root string, testFilter *regexp.Regexp) ([]string, error) {
- var testSlice []string
-
- err := filepath.Walk(root, func(path string, info os.FileInfo, err error) error {
- name := filepath.Base(path)
-
- if info.IsDir() || !testFilter.MatchString(name) {
- return nil
- }
-
- relPath, err := filepath.Rel(root, path)
- if err != nil {
- return err
- }
- testSlice = append(testSlice, relPath)
- return nil
- })
-
- if err != nil {
- return nil, fmt.Errorf("walking %q: %v", root, err)
- }
-
- return testSlice, nil
-}
-
-func runAllTests(tr TestRunner) error {
- tests, err := tr.ListTests()
- if err != nil {
- return fmt.Errorf("failed to list tests: %v", err)
- }
- for _, test := range tests {
- if err := tr.RunTest(test); err != nil {
- return fmt.Errorf("test %q failed to run: %v", test, err)
- }
- }
- return nil
-}
diff --git a/test/runtimes/go/BUILD b/test/runtimes/go/BUILD
deleted file mode 100644
index ce971ee9d..000000000
--- a/test/runtimes/go/BUILD
+++ /dev/null
@@ -1,9 +0,0 @@
-load("@io_bazel_rules_go//go:def.bzl", "go_binary")
-
-package(licenses = ["notice"])
-
-go_binary(
- name = "proctor-go",
- srcs = ["proctor-go.go"],
- deps = ["//test/runtimes/common"],
-)
diff --git a/test/runtimes/go/Dockerfile b/test/runtimes/go/Dockerfile
deleted file mode 100644
index 1d5202b70..000000000
--- a/test/runtimes/go/Dockerfile
+++ /dev/null
@@ -1,34 +0,0 @@
-FROM ubuntu:bionic
-ENV LANG_VER=1.12.5
-ENV LANG_NAME=Go
-
-RUN apt-get update && apt-get install -y \
- curl \
- gcc \
- git
-
-WORKDIR /root
-
-# Download Go 1.4 to use as a bootstrap for building Go from the source.
-RUN curl -o go1.4.linux-amd64.tar.gz https://dl.google.com/go/go1.4.linux-amd64.tar.gz
-RUN curl -LJO https://github.com/golang/go/archive/go${LANG_VER}.tar.gz
-RUN mkdir bootstr
-RUN tar -C bootstr -xzf go1.4.linux-amd64.tar.gz
-RUN tar -xzf go-go${LANG_VER}.tar.gz
-RUN mv go-go${LANG_VER} go
-
-ENV GOROOT=/root/go
-ENV GOROOT_BOOTSTRAP=/root/bootstr/go
-ENV LANG_DIR=${GOROOT}
-
-WORKDIR ${LANG_DIR}/src
-RUN ./make.bash
-# Pre-compile the tests for faster execution
-RUN ["/root/go/bin/go", "tool", "dist", "test", "-compile-only"]
-
-WORKDIR ${LANG_DIR}
-
-COPY common /root/go/src/gvisor.dev/gvisor/test/runtimes/common/common
-COPY go/proctor-go.go ${LANG_DIR}
-
-ENTRYPOINT ["/root/go/bin/go", "run", "proctor-go.go"]
diff --git a/test/runtimes/images/Dockerfile_go1.12 b/test/runtimes/images/Dockerfile_go1.12
new file mode 100644
index 000000000..ab9d6abf3
--- /dev/null
+++ b/test/runtimes/images/Dockerfile_go1.12
@@ -0,0 +1,10 @@
+# Go is easy, since we already have everything we need to compile the proctor
+# binary and run the tests in the golang Docker image.
+FROM golang:1.12
+ADD ["proctor/", "/go/src/proctor/"]
+RUN ["go", "build", "-o", "/proctor", "/go/src/proctor"]
+
+# Pre-compile the tests so we don't need to do so in each test run.
+RUN ["go", "tool", "dist", "test", "-compile-only"]
+
+ENTRYPOINT ["/proctor", "--runtime=go"]
diff --git a/test/runtimes/images/Dockerfile_java11 b/test/runtimes/images/Dockerfile_java11
new file mode 100644
index 000000000..9b7c3d5a3
--- /dev/null
+++ b/test/runtimes/images/Dockerfile_java11
@@ -0,0 +1,30 @@
+# Compile the proctor binary.
+FROM golang:1.12 AS golang
+ADD ["proctor/", "/go/src/proctor/"]
+RUN ["go", "build", "-o", "/proctor", "/go/src/proctor"]
+
+FROM ubuntu:bionic
+RUN apt-get update && apt-get install -y \
+ autoconf \
+ build-essential \
+ curl \
+ make \
+ openjdk-11-jdk \
+ unzip \
+ zip
+
+# Download the JDK test library.
+WORKDIR /root
+RUN set -ex \
+ && curl -fsSL --retry 10 -o /tmp/jdktests.tar.gz http://hg.openjdk.java.net/jdk/jdk11/archive/76072a077ee1.tar.gz/test \
+ && tar -xzf /tmp/jdktests.tar.gz \
+ && mv jdk11-76072a077ee1/test test \
+ && rm -f /tmp/jdktests.tar.gz
+
+# Install jtreg and add to PATH.
+RUN curl -o jtreg.tar.gz https://ci.adoptopenjdk.net/view/Dependencies/job/jtreg/lastSuccessfulBuild/artifact/jtreg-4.2.0-tip.tar.gz
+RUN tar -xzf jtreg.tar.gz
+ENV PATH="/root/jtreg/bin:$PATH"
+
+COPY --from=golang /proctor /proctor
+ENTRYPOINT ["/proctor", "--runtime=java"]
diff --git a/test/runtimes/images/Dockerfile_nodejs12.4.0 b/test/runtimes/images/Dockerfile_nodejs12.4.0
new file mode 100644
index 000000000..26f68b487
--- /dev/null
+++ b/test/runtimes/images/Dockerfile_nodejs12.4.0
@@ -0,0 +1,28 @@
+# Compile the proctor binary.
+FROM golang:1.12 AS golang
+ADD ["proctor/", "/go/src/proctor/"]
+RUN ["go", "build", "-o", "/proctor", "/go/src/proctor"]
+
+FROM ubuntu:bionic
+RUN apt-get update && apt-get install -y \
+ curl \
+ dumb-init \
+ g++ \
+ make \
+ python
+
+WORKDIR /root
+ARG VERSION=v12.4.0
+RUN curl -o node-${VERSION}.tar.gz https://nodejs.org/dist/${VERSION}/node-${VERSION}.tar.gz
+RUN tar -zxf node-${VERSION}.tar.gz
+
+WORKDIR /root/node-${VERSION}
+RUN ./configure
+RUN make
+RUN make test-build
+
+COPY --from=golang /proctor /proctor
+
+# Including dumb-init emulates the Linux "init" process, preventing the failure
+# of tests involving worker processes.
+ENTRYPOINT ["/usr/bin/dumb-init", "/proctor", "--runtime=nodejs"]
diff --git a/test/runtimes/images/Dockerfile_php7.3.6 b/test/runtimes/images/Dockerfile_php7.3.6
new file mode 100644
index 000000000..e6b4c6329
--- /dev/null
+++ b/test/runtimes/images/Dockerfile_php7.3.6
@@ -0,0 +1,27 @@
+# Compile the proctor binary.
+FROM golang:1.12 AS golang
+ADD ["proctor/", "/go/src/proctor/"]
+RUN ["go", "build", "-o", "/proctor", "/go/src/proctor"]
+
+FROM ubuntu:bionic
+RUN apt-get update && apt-get install -y \
+ autoconf \
+ automake \
+ bison \
+ build-essential \
+ curl \
+ libtool \
+ libxml2-dev \
+ re2c
+
+WORKDIR /root
+ARG VERSION=7.3.6
+RUN curl -o php-${VERSION}.tar.gz https://www.php.net/distributions/php-${VERSION}.tar.gz
+RUN tar -zxf php-${VERSION}.tar.gz
+
+WORKDIR /root/php-${VERSION}
+RUN ./configure
+RUN make
+
+COPY --from=golang /proctor /proctor
+ENTRYPOINT ["/proctor", "--runtime=php"]
diff --git a/test/runtimes/images/Dockerfile_python3.7.3 b/test/runtimes/images/Dockerfile_python3.7.3
new file mode 100644
index 000000000..905cd22d7
--- /dev/null
+++ b/test/runtimes/images/Dockerfile_python3.7.3
@@ -0,0 +1,30 @@
+# Compile the proctor binary.
+FROM golang:1.12 AS golang
+ADD ["proctor/", "/go/src/proctor/"]
+RUN ["go", "build", "-o", "/proctor", "/go/src/proctor"]
+
+FROM ubuntu:bionic
+
+RUN apt-get update && apt-get install -y \
+ curl \
+ gcc \
+ libbz2-dev \
+ libffi-dev \
+ liblzma-dev \
+ libreadline-dev \
+ libssl-dev \
+ make \
+ zlib1g-dev
+
+# Use flags -LJO to follow the html redirect and download .tar.gz.
+WORKDIR /root
+ARG VERSION=3.7.3
+RUN curl -LJO https://github.com/python/cpython/archive/v${VERSION}.tar.gz
+RUN tar -zxf cpython-${VERSION}.tar.gz
+
+WORKDIR /root/cpython-${VERSION}
+RUN ./configure --with-pydebug
+RUN make -s -j2
+
+COPY --from=golang /proctor /proctor
+ENTRYPOINT ["/proctor", "--runtime=python"]
diff --git a/test/runtimes/images/proctor/BUILD b/test/runtimes/images/proctor/BUILD
new file mode 100644
index 000000000..09dc6c42f
--- /dev/null
+++ b/test/runtimes/images/proctor/BUILD
@@ -0,0 +1,26 @@
+load("@io_bazel_rules_go//go:def.bzl", "go_binary", "go_test")
+
+package(licenses = ["notice"])
+
+go_binary(
+ name = "proctor",
+ srcs = [
+ "go.go",
+ "java.go",
+ "nodejs.go",
+ "php.go",
+ "proctor.go",
+ "python.go",
+ ],
+ visibility = ["//test/runtimes/images:__subpackages__"],
+)
+
+go_test(
+ name = "proctor_test",
+ size = "small",
+ srcs = ["proctor_test.go"],
+ embed = [":proctor"],
+ deps = [
+ "//runsc/testutil",
+ ],
+)
diff --git a/test/runtimes/go/proctor-go.go b/test/runtimes/images/proctor/go.go
index 3eb24576e..3e2d5d8db 100644
--- a/test/runtimes/go/proctor-go.go
+++ b/test/runtimes/images/proctor/go.go
@@ -12,50 +12,42 @@
// See the License for the specific language governing permissions and
// limitations under the License.
-// Binary proctor-go is a utility that facilitates language testing for Go.
-
-// There are two types of Go tests: "Go tool tests" and "Go tests on disk".
-// "Go tool tests" are found and executed using `go tool dist test`.
-// "Go tests on disk" are found in the /test directory and are
-// executed using `go run run.go`.
package main
import (
"fmt"
- "log"
"os"
"os/exec"
- "path/filepath"
"regexp"
"strings"
-
- "gvisor.dev/gvisor/test/runtimes/common"
)
var (
- dir = os.Getenv("LANG_DIR")
- goBin = filepath.Join(dir, "bin/go")
- testDir = filepath.Join(dir, "test")
- testRegEx = regexp.MustCompile(`^.+\.go$`)
+ goTestRegEx = regexp.MustCompile(`^.+\.go$`)
// Directories with .dir contain helper files for tests.
// Exclude benchmarks and stress tests.
- dirFilter = regexp.MustCompile(`^(bench|stress)\/.+$|^.+\.dir.+$`)
+ goDirFilter = regexp.MustCompile(`^(bench|stress)\/.+$|^.+\.dir.+$`)
)
-type goRunner struct {
-}
+// Location of Go tests on disk.
+const goTestDir = "/usr/local/go/test"
-func main() {
- if err := common.LaunchFunc(goRunner{}); err != nil {
- log.Fatalf("Failed to start: %v", err)
- }
-}
+// goRunner implements TestRunner for Go.
+//
+// There are two types of Go tests: "Go tool tests" and "Go tests on disk".
+// "Go tool tests" are found and executed using `go tool dist test`. "Go tests
+// on disk" are found in the /usr/local/go/test directory and are executed
+// using `go run run.go`.
+type goRunner struct{}
+
+var _ TestRunner = goRunner{}
-func (g goRunner) ListTests() ([]string, error) {
+// ListTests implements TestRunner.ListTests.
+func (goRunner) ListTests() ([]string, error) {
// Go tool dist test tests.
args := []string{"tool", "dist", "test", "-list"}
- cmd := exec.Command(filepath.Join(dir, "bin/go"), args...)
+ cmd := exec.Command("go", args...)
cmd.Stderr = os.Stderr
out, err := cmd.Output()
if err != nil {
@@ -67,14 +59,14 @@ func (g goRunner) ListTests() ([]string, error) {
}
// Go tests on disk.
- diskSlice, err := common.Search(testDir, testRegEx)
+ diskSlice, err := search(goTestDir, goTestRegEx)
if err != nil {
return nil, err
}
// Remove items from /bench/, /stress/ and .dir files
diskFiltered := diskSlice[:0]
for _, file := range diskSlice {
- if !dirFilter.MatchString(file) {
+ if !goDirFilter.MatchString(file) {
diskFiltered = append(diskFiltered, file)
}
}
@@ -82,24 +74,17 @@ func (g goRunner) ListTests() ([]string, error) {
return append(toolSlice, diskFiltered...), nil
}
-func (g goRunner) RunTest(test string) error {
+// TestCmd implements TestRunner.TestCmd.
+func (goRunner) TestCmd(test string) *exec.Cmd {
// Check if test exists on disk by searching for file of the same name.
// This will determine whether or not it is a Go test on disk.
if strings.HasSuffix(test, ".go") {
// Test has suffix ".go" which indicates a disk test, run it as such.
- cmd := exec.Command(goBin, "run", "run.go", "-v", "--", test)
- cmd.Dir = testDir
- cmd.Stdout, cmd.Stderr = os.Stdout, os.Stderr
- if err := cmd.Run(); err != nil {
- return fmt.Errorf("failed to run test: %v", err)
- }
- } else {
- // No ".go" suffix, run as a tool test.
- cmd := exec.Command(goBin, "tool", "dist", "test", "-run", test)
- cmd.Stdout, cmd.Stderr = os.Stdout, os.Stderr
- if err := cmd.Run(); err != nil {
- return fmt.Errorf("failed to run test: %v", err)
- }
+ cmd := exec.Command("go", "run", "run.go", "-v", "--", test)
+ cmd.Dir = goTestDir
+ return cmd
}
- return nil
+
+ // No ".go" suffix, run as a tool test.
+ return exec.Command("go", "tool", "dist", "test", "-run", test)
}
diff --git a/test/runtimes/java/proctor-java.go b/test/runtimes/images/proctor/java.go
index 7f6a66f4f..8b362029d 100644
--- a/test/runtimes/java/proctor-java.go
+++ b/test/runtimes/images/proctor/java.go
@@ -12,40 +12,31 @@
// See the License for the specific language governing permissions and
// limitations under the License.
-// Binary proctor-java is a utility that facilitates language testing for Java.
package main
import (
"fmt"
- "log"
"os"
"os/exec"
- "path/filepath"
"regexp"
"strings"
-
- "gvisor.dev/gvisor/test/runtimes/common"
)
-var (
- dir = os.Getenv("LANG_DIR")
- hash = os.Getenv("LANG_HASH")
- jtreg = filepath.Join(dir, "jtreg/bin/jtreg")
- exclDirs = regexp.MustCompile(`(^(sun\/security)|(java\/util\/stream)|(java\/time)| )`)
-)
+// Directories to exclude from tests.
+var javaExclDirs = regexp.MustCompile(`(^(sun\/security)|(java\/util\/stream)|(java\/time)| )`)
-type javaRunner struct {
-}
+// Location of java tests.
+const javaTestDir = "/root/test/jdk"
-func main() {
- if err := common.LaunchFunc(javaRunner{}); err != nil {
- log.Fatalf("Failed to start: %v", err)
- }
-}
+// javaRunner implements TestRunner for Java.
+type javaRunner struct{}
+
+var _ TestRunner = javaRunner{}
-func (j javaRunner) ListTests() ([]string, error) {
+// ListTests implements TestRunner.ListTests.
+func (javaRunner) ListTests() ([]string, error) {
args := []string{
- "-dir:/root/jdk11-" + hash + "/test/jdk",
+ "-dir:" + javaTestDir,
"-ignore:quiet",
"-a",
"-listtests",
@@ -54,7 +45,7 @@ func (j javaRunner) ListTests() ([]string, error) {
":jdk_sound",
":jdk_imageio",
}
- cmd := exec.Command(jtreg, args...)
+ cmd := exec.Command("jtreg", args...)
cmd.Stderr = os.Stderr
out, err := cmd.Output()
if err != nil {
@@ -62,19 +53,19 @@ func (j javaRunner) ListTests() ([]string, error) {
}
var testSlice []string
for _, test := range strings.Split(string(out), "\n") {
- if !exclDirs.MatchString(test) {
+ if !javaExclDirs.MatchString(test) {
testSlice = append(testSlice, test)
}
}
return testSlice, nil
}
-func (j javaRunner) RunTest(test string) error {
- args := []string{"-noreport", "-dir:/root/jdk11-" + hash + "/test/jdk", test}
- cmd := exec.Command(jtreg, args...)
- cmd.Stdout, cmd.Stderr = os.Stdout, os.Stderr
- if err := cmd.Run(); err != nil {
- return fmt.Errorf("failed to run: %v", err)
+// TestCmd implements TestRunner.TestCmd.
+func (javaRunner) TestCmd(test string) *exec.Cmd {
+ args := []string{
+ "-noreport",
+ "-dir:" + javaTestDir,
+ test,
}
- return nil
+ return exec.Command("jtreg", args...)
}
diff --git a/test/runtimes/images/proctor/nodejs.go b/test/runtimes/images/proctor/nodejs.go
new file mode 100644
index 000000000..bd57db444
--- /dev/null
+++ b/test/runtimes/images/proctor/nodejs.go
@@ -0,0 +1,46 @@
+// Copyright 2019 The gVisor Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+package main
+
+import (
+ "os/exec"
+ "path/filepath"
+ "regexp"
+)
+
+var nodejsTestRegEx = regexp.MustCompile(`^test-[^-].+\.js$`)
+
+// Location of nodejs tests relative to working dir.
+const nodejsTestDir = "test"
+
+// nodejsRunner implements TestRunner for NodeJS.
+type nodejsRunner struct{}
+
+var _ TestRunner = nodejsRunner{}
+
+// ListTests implements TestRunner.ListTests.
+func (nodejsRunner) ListTests() ([]string, error) {
+ testSlice, err := search(nodejsTestDir, nodejsTestRegEx)
+ if err != nil {
+ return nil, err
+ }
+ return testSlice, nil
+}
+
+// TestCmd implements TestRunner.TestCmd.
+func (nodejsRunner) TestCmd(test string) *exec.Cmd {
+ args := []string{filepath.Join("tools", "test.py"), test}
+ return exec.Command("/usr/bin/python", args...)
+}
diff --git a/test/runtimes/php/proctor-php.go b/test/runtimes/images/proctor/php.go
index e6c5fabdf..9115040e1 100644
--- a/test/runtimes/php/proctor-php.go
+++ b/test/runtimes/images/proctor/php.go
@@ -12,47 +12,31 @@
// See the License for the specific language governing permissions and
// limitations under the License.
-// Binary proctor-php is a utility that facilitates language testing for PHP.
package main
import (
- "fmt"
- "log"
- "os"
"os/exec"
"regexp"
-
- "gvisor.dev/gvisor/test/runtimes/common"
)
-var (
- dir = os.Getenv("LANG_DIR")
- testRegEx = regexp.MustCompile(`^.+\.phpt$`)
-)
+var phpTestRegEx = regexp.MustCompile(`^.+\.phpt$`)
-type phpRunner struct {
-}
+// phpRunner implements TestRunner for PHP.
+type phpRunner struct{}
-func main() {
- if err := common.LaunchFunc(phpRunner{}); err != nil {
- log.Fatalf("Failed to start: %v", err)
- }
-}
+var _ TestRunner = phpRunner{}
-func (p phpRunner) ListTests() ([]string, error) {
- testSlice, err := common.Search(dir, testRegEx)
+// ListTests implements TestRunner.ListTests.
+func (phpRunner) ListTests() ([]string, error) {
+ testSlice, err := search(".", phpTestRegEx)
if err != nil {
return nil, err
}
return testSlice, nil
}
-func (p phpRunner) RunTest(test string) error {
+// TestCmd implements TestRunner.TestCmd.
+func (phpRunner) TestCmd(test string) *exec.Cmd {
args := []string{"test", "TESTS=" + test}
- cmd := exec.Command("make", args...)
- cmd.Stdout, cmd.Stderr = os.Stdout, os.Stderr
- if err := cmd.Run(); err != nil {
- return fmt.Errorf("failed to run: %v", err)
- }
- return nil
+ return exec.Command("make", args...)
}
diff --git a/test/runtimes/images/proctor/proctor.go b/test/runtimes/images/proctor/proctor.go
new file mode 100644
index 000000000..e6178e82b
--- /dev/null
+++ b/test/runtimes/images/proctor/proctor.go
@@ -0,0 +1,154 @@
+// Copyright 2019 The gVisor Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+// Binary proctor runs the test for a particular runtime. It is meant to be
+// included in Docker images for all runtime tests.
+package main
+
+import (
+ "flag"
+ "fmt"
+ "log"
+ "os"
+ "os/exec"
+ "os/signal"
+ "path/filepath"
+ "regexp"
+ "syscall"
+)
+
+// TestRunner is an interface that must be implemented for each runtime
+// integrated with proctor.
+type TestRunner interface {
+ // ListTests returns a string slice of tests available to run.
+ ListTests() ([]string, error)
+
+ // TestCmd returns an *exec.Cmd that will run the given test.
+ TestCmd(test string) *exec.Cmd
+}
+
+var (
+ runtime = flag.String("runtime", "", "name of runtime")
+ list = flag.Bool("list", false, "list all available tests")
+ test = flag.String("test", "", "run a single test from the list of available tests")
+ pause = flag.Bool("pause", false, "cause container to pause indefinitely, reaping any zombie children")
+)
+
+func main() {
+ flag.Parse()
+
+ if *pause {
+ pauseAndReap()
+ panic("pauseAndReap should never return")
+ }
+
+ if *runtime == "" {
+ log.Fatalf("runtime flag must be provided")
+ }
+
+ tr, err := testRunnerForRuntime(*runtime)
+ if err != nil {
+ log.Fatalf("%v", err)
+ }
+
+ // List tests.
+ if *list {
+ tests, err := tr.ListTests()
+ if err != nil {
+ log.Fatalf("failed to list tests: %v", err)
+ }
+ for _, test := range tests {
+ fmt.Println(test)
+ }
+ return
+ }
+
+ // Run a single test.
+ if *test == "" {
+ log.Fatalf("test flag must be provided")
+ }
+ cmd := tr.TestCmd(*test)
+ cmd.Stdout, cmd.Stderr = os.Stdout, os.Stderr
+ if err := cmd.Run(); err != nil {
+ log.Fatalf("FAIL: %v", err)
+ }
+}
+
+// testRunnerForRuntime returns a new TestRunner for the given runtime.
+func testRunnerForRuntime(runtime string) (TestRunner, error) {
+ switch runtime {
+ case "go":
+ return goRunner{}, nil
+ case "java":
+ return javaRunner{}, nil
+ case "nodejs":
+ return nodejsRunner{}, nil
+ case "php":
+ return phpRunner{}, nil
+ case "python":
+ return pythonRunner{}, nil
+ }
+ return nil, fmt.Errorf("invalid runtime %q", runtime)
+}
+
+// pauseAndReap is like init. It runs forever and reaps any children.
+func pauseAndReap() {
+ // Get notified of any new children.
+ ch := make(chan os.Signal, 1)
+ signal.Notify(ch, syscall.SIGCHLD)
+
+ for {
+ if _, ok := <-ch; !ok {
+ // Channel closed. This should not happen.
+ panic("signal channel closed")
+ }
+
+ // Reap the child.
+ for {
+ if cpid, _ := syscall.Wait4(-1, nil, syscall.WNOHANG, nil); cpid < 1 {
+ break
+ }
+ }
+ }
+}
+
+// search is a helper function to find tests in the given directory that match
+// the regex.
+func search(root string, testFilter *regexp.Regexp) ([]string, error) {
+ var testSlice []string
+
+ err := filepath.Walk(root, func(path string, info os.FileInfo, err error) error {
+ if err != nil {
+ return err
+ }
+
+ name := filepath.Base(path)
+
+ if info.IsDir() || !testFilter.MatchString(name) {
+ return nil
+ }
+
+ relPath, err := filepath.Rel(root, path)
+ if err != nil {
+ return err
+ }
+ testSlice = append(testSlice, relPath)
+ return nil
+ })
+ if err != nil {
+ return nil, fmt.Errorf("walking %q: %v", root, err)
+ }
+
+ return testSlice, nil
+}
diff --git a/test/runtimes/common/common_test.go b/test/runtimes/images/proctor/proctor_test.go
index 4fb1e482a..6bb61d142 100644
--- a/test/runtimes/common/common_test.go
+++ b/test/runtimes/images/proctor/proctor_test.go
@@ -12,7 +12,7 @@
// See the License for the specific language governing permissions and
// limitations under the License.
-package common_test
+package main
import (
"io/ioutil"
@@ -23,8 +23,7 @@ import (
"strings"
"testing"
- "gvisor.dev/gvisor/runsc/test/testutil"
- "gvisor.dev/gvisor/test/runtimes/common"
+ "gvisor.dev/gvisor/runsc/testutil"
)
func touch(t *testing.T, name string) {
@@ -48,9 +47,9 @@ func TestSearchEmptyDir(t *testing.T) {
var want []string
testFilter := regexp.MustCompile(`^test-[^-].+\.tc$`)
- got, err := common.Search(td, testFilter)
+ got, err := search(td, testFilter)
if err != nil {
- t.Errorf("Search error: %v", err)
+ t.Errorf("search error: %v", err)
}
if !reflect.DeepEqual(got, want) {
@@ -117,9 +116,9 @@ func TestSearch(t *testing.T) {
}
testFilter := regexp.MustCompile(`^test-[^-].+\.tc$`)
- got, err := common.Search(td, testFilter)
+ got, err := search(td, testFilter)
if err != nil {
- t.Errorf("Search error: %v", err)
+ t.Errorf("search error: %v", err)
}
if !reflect.DeepEqual(got, want) {
diff --git a/test/runtimes/python/proctor-python.go b/test/runtimes/images/proctor/python.go
index 35e28a7df..b9e0fbe6f 100644
--- a/test/runtimes/python/proctor-python.go
+++ b/test/runtimes/images/proctor/python.go
@@ -12,36 +12,24 @@
// See the License for the specific language governing permissions and
// limitations under the License.
-// Binary proctor-python is a utility that facilitates language testing for Pyhton.
package main
import (
"fmt"
- "log"
"os"
"os/exec"
- "path/filepath"
"strings"
-
- "gvisor.dev/gvisor/test/runtimes/common"
)
-var (
- dir = os.Getenv("LANG_DIR")
-)
+// pythonRunner implements TestRunner for Python.
+type pythonRunner struct{}
-type pythonRunner struct {
-}
+var _ TestRunner = pythonRunner{}
-func main() {
- if err := common.LaunchFunc(pythonRunner{}); err != nil {
- log.Fatalf("Failed to start: %v", err)
- }
-}
-
-func (p pythonRunner) ListTests() ([]string, error) {
+// ListTests implements TestRunner.ListTests.
+func (pythonRunner) ListTests() ([]string, error) {
args := []string{"-m", "test", "--list-tests"}
- cmd := exec.Command(filepath.Join(dir, "python"), args...)
+ cmd := exec.Command("./python", args...)
cmd.Stderr = os.Stderr
out, err := cmd.Output()
if err != nil {
@@ -54,12 +42,8 @@ func (p pythonRunner) ListTests() ([]string, error) {
return toolSlice, nil
}
-func (p pythonRunner) RunTest(test string) error {
+// TestCmd implements TestRunner.TestCmd.
+func (pythonRunner) TestCmd(test string) *exec.Cmd {
args := []string{"-m", "test", test}
- cmd := exec.Command(filepath.Join(dir, "python"), args...)
- cmd.Stdout, cmd.Stderr = os.Stdout, os.Stderr
- if err := cmd.Run(); err != nil {
- return fmt.Errorf("failed to run: %v", err)
- }
- return nil
+ return exec.Command("./python", args...)
}
diff --git a/test/runtimes/java/BUILD b/test/runtimes/java/BUILD
deleted file mode 100644
index 8c39d39ec..000000000
--- a/test/runtimes/java/BUILD
+++ /dev/null
@@ -1,9 +0,0 @@
-load("@io_bazel_rules_go//go:def.bzl", "go_binary")
-
-package(licenses = ["notice"])
-
-go_binary(
- name = "proctor-java",
- srcs = ["proctor-java.go"],
- deps = ["//test/runtimes/common"],
-)
diff --git a/test/runtimes/java/Dockerfile b/test/runtimes/java/Dockerfile
deleted file mode 100644
index b9132b575..000000000
--- a/test/runtimes/java/Dockerfile
+++ /dev/null
@@ -1,35 +0,0 @@
-FROM ubuntu:bionic
-# This hash is associated with a specific JDK release and needed for ensuring
-# the same version is downloaded every time.
-ENV LANG_HASH=76072a077ee1
-ENV LANG_VER=11
-ENV LANG_NAME=Java
-
-RUN apt-get update && apt-get install -y \
- autoconf \
- build-essential \
- curl \
- make \
- openjdk-${LANG_VER}-jdk \
- unzip \
- zip
-
-WORKDIR /root
-RUN curl -o go.tar.gz https://dl.google.com/go/go1.12.6.linux-amd64.tar.gz
-RUN tar -zxf go.tar.gz
-
-# Download the JDK test library.
-RUN set -ex \
- && curl -fsSL --retry 10 -o /tmp/jdktests.tar.gz http://hg.openjdk.java.net/jdk/jdk${LANG_VER}/archive/${LANG_HASH}.tar.gz/test \
- && tar -xzf /tmp/jdktests.tar.gz -C /root \
- && rm -f /tmp/jdktests.tar.gz
-
-RUN curl -o jtreg.tar.gz https://ci.adoptopenjdk.net/view/Dependencies/job/jtreg/lastSuccessfulBuild/artifact/jtreg-4.2.0-tip.tar.gz
-RUN tar -xzf jtreg.tar.gz
-
-ENV LANG_DIR=/root
-
-COPY common /root/go/src/gvisor.dev/gvisor/test/runtimes/common/common
-COPY java/proctor-java.go ${LANG_DIR}
-
-ENTRYPOINT ["/root/go/bin/go", "run", "proctor-java.go"]
diff --git a/test/runtimes/nodejs/BUILD b/test/runtimes/nodejs/BUILD
deleted file mode 100644
index 0594c250b..000000000
--- a/test/runtimes/nodejs/BUILD
+++ /dev/null
@@ -1,9 +0,0 @@
-load("@io_bazel_rules_go//go:def.bzl", "go_binary")
-
-package(licenses = ["notice"])
-
-go_binary(
- name = "proctor-nodejs",
- srcs = ["proctor-nodejs.go"],
- deps = ["//test/runtimes/common"],
-)
diff --git a/test/runtimes/nodejs/Dockerfile b/test/runtimes/nodejs/Dockerfile
deleted file mode 100644
index aba30d2ee..000000000
--- a/test/runtimes/nodejs/Dockerfile
+++ /dev/null
@@ -1,30 +0,0 @@
-FROM ubuntu:bionic
-ENV LANG_VER=12.4.0
-ENV LANG_NAME=Node
-
-RUN apt-get update && apt-get install -y \
- curl \
- dumb-init \
- g++ \
- make \
- python
-
-WORKDIR /root
-RUN curl -o go.tar.gz https://dl.google.com/go/go1.12.6.linux-amd64.tar.gz
-RUN tar -zxf go.tar.gz
-
-RUN curl -o node-v${LANG_VER}.tar.gz https://nodejs.org/dist/v${LANG_VER}/node-v${LANG_VER}.tar.gz
-RUN tar -zxf node-v${LANG_VER}.tar.gz
-ENV LANG_DIR=/root/node-v${LANG_VER}
-
-WORKDIR ${LANG_DIR}
-RUN ./configure
-RUN make
-RUN make test-build
-
-COPY common /root/go/src/gvisor.dev/gvisor/test/runtimes/common/common
-COPY nodejs/proctor-nodejs.go ${LANG_DIR}
-
-# Including dumb-init emulates the Linux "init" process, preventing the failure
-# of tests involving worker processes.
-ENTRYPOINT ["/usr/bin/dumb-init", "/root/go/bin/go", "run", "proctor-nodejs.go"]
diff --git a/test/runtimes/nodejs/proctor-nodejs.go b/test/runtimes/nodejs/proctor-nodejs.go
deleted file mode 100644
index 0624f6a0d..000000000
--- a/test/runtimes/nodejs/proctor-nodejs.go
+++ /dev/null
@@ -1,60 +0,0 @@
-// Copyright 2019 The gVisor Authors.
-//
-// Licensed under the Apache License, Version 2.0 (the "License");
-// you may not use this file except in compliance with the License.
-// You may obtain a copy of the License at
-//
-// http://www.apache.org/licenses/LICENSE-2.0
-//
-// Unless required by applicable law or agreed to in writing, software
-// distributed under the License is distributed on an "AS IS" BASIS,
-// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-// See the License for the specific language governing permissions and
-// limitations under the License.
-
-// Binary proctor-nodejs is a utility that facilitates language testing for NodeJS.
-package main
-
-import (
- "fmt"
- "log"
- "os"
- "os/exec"
- "path/filepath"
- "regexp"
-
- "gvisor.dev/gvisor/test/runtimes/common"
-)
-
-var (
- dir = os.Getenv("LANG_DIR")
- testDir = filepath.Join(dir, "test")
- testRegEx = regexp.MustCompile(`^test-[^-].+\.js$`)
-)
-
-type nodejsRunner struct {
-}
-
-func main() {
- if err := common.LaunchFunc(nodejsRunner{}); err != nil {
- log.Fatalf("Failed to start: %v", err)
- }
-}
-
-func (n nodejsRunner) ListTests() ([]string, error) {
- testSlice, err := common.Search(testDir, testRegEx)
- if err != nil {
- return nil, err
- }
- return testSlice, nil
-}
-
-func (n nodejsRunner) RunTest(test string) error {
- args := []string{filepath.Join(dir, "tools", "test.py"), test}
- cmd := exec.Command("/usr/bin/python", args...)
- cmd.Stdout, cmd.Stderr = os.Stdout, os.Stderr
- if err := cmd.Run(); err != nil {
- return fmt.Errorf("failed to run: %v", err)
- }
- return nil
-}
diff --git a/test/runtimes/php/BUILD b/test/runtimes/php/BUILD
deleted file mode 100644
index 31799b77a..000000000
--- a/test/runtimes/php/BUILD
+++ /dev/null
@@ -1,9 +0,0 @@
-load("@io_bazel_rules_go//go:def.bzl", "go_binary")
-
-package(licenses = ["notice"])
-
-go_binary(
- name = "proctor-php",
- srcs = ["proctor-php.go"],
- deps = ["//test/runtimes/common"],
-)
diff --git a/test/runtimes/php/Dockerfile b/test/runtimes/php/Dockerfile
deleted file mode 100644
index 491ab902d..000000000
--- a/test/runtimes/php/Dockerfile
+++ /dev/null
@@ -1,30 +0,0 @@
-FROM ubuntu:bionic
-ENV LANG_VER=7.3.6
-ENV LANG_NAME=PHP
-
-RUN apt-get update && apt-get install -y \
- autoconf \
- automake \
- bison \
- build-essential \
- curl \
- libtool \
- libxml2-dev \
- re2c
-
-WORKDIR /root
-RUN curl -o go.tar.gz https://dl.google.com/go/go1.12.6.linux-amd64.tar.gz
-RUN tar -zxf go.tar.gz
-
-RUN curl -o php-${LANG_VER}.tar.gz https://www.php.net/distributions/php-${LANG_VER}.tar.gz
-RUN tar -zxf php-${LANG_VER}.tar.gz
-ENV LANG_DIR=/root/php-${LANG_VER}
-
-WORKDIR ${LANG_DIR}
-RUN ./configure
-RUN make
-
-COPY common /root/go/src/gvisor.dev/gvisor/test/runtimes/common/common
-COPY php/proctor-php.go ${LANG_DIR}
-
-ENTRYPOINT ["/root/go/bin/go", "run", "proctor-php.go"]
diff --git a/test/runtimes/python/BUILD b/test/runtimes/python/BUILD
deleted file mode 100644
index 37fd6a0f2..000000000
--- a/test/runtimes/python/BUILD
+++ /dev/null
@@ -1,9 +0,0 @@
-load("@io_bazel_rules_go//go:def.bzl", "go_binary")
-
-package(licenses = ["notice"])
-
-go_binary(
- name = "proctor-python",
- srcs = ["proctor-python.go"],
- deps = ["//test/runtimes/common"],
-)
diff --git a/test/runtimes/python/Dockerfile b/test/runtimes/python/Dockerfile
deleted file mode 100644
index 710daee43..000000000
--- a/test/runtimes/python/Dockerfile
+++ /dev/null
@@ -1,32 +0,0 @@
-FROM ubuntu:bionic
-ENV LANG_VER=3.7.3
-ENV LANG_NAME=Python
-
-RUN apt-get update && apt-get install -y \
- curl \
- gcc \
- libbz2-dev \
- libffi-dev \
- liblzma-dev \
- libreadline-dev \
- libssl-dev \
- make \
- zlib1g-dev
-
-WORKDIR /root
-RUN curl -o go.tar.gz https://dl.google.com/go/go1.12.6.linux-amd64.tar.gz
-RUN tar -zxf go.tar.gz
-
-# Use flags -LJO to follow the html redirect and download .tar.gz.
-RUN curl -LJO https://github.com/python/cpython/archive/v${LANG_VER}.tar.gz
-RUN tar -zxf cpython-${LANG_VER}.tar.gz
-ENV LANG_DIR=/root/cpython-${LANG_VER}
-
-WORKDIR ${LANG_DIR}
-RUN ./configure --with-pydebug
-RUN make -s -j2
-
-COPY common /root/go/src/gvisor.dev/gvisor/test/runtimes/common/common
-COPY python/proctor-python.go ${LANG_DIR}
-
-ENTRYPOINT ["/root/go/bin/go", "run", "proctor-python.go"]
diff --git a/test/runtimes/runner.go b/test/runtimes/runner.go
new file mode 100644
index 000000000..bec37c69d
--- /dev/null
+++ b/test/runtimes/runner.go
@@ -0,0 +1,199 @@
+// Copyright 2019 The gVisor Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+// Binary runner runs the runtime tests in a Docker container.
+package main
+
+import (
+ "encoding/csv"
+ "flag"
+ "fmt"
+ "io"
+ "log"
+ "os"
+ "sort"
+ "strings"
+ "testing"
+ "time"
+
+ "gvisor.dev/gvisor/runsc/dockerutil"
+ "gvisor.dev/gvisor/runsc/testutil"
+)
+
+var (
+ lang = flag.String("lang", "", "language runtime to test")
+ image = flag.String("image", "", "docker image with runtime tests")
+ blacklistFile = flag.String("blacklist_file", "", "file containing blacklist of tests to exclude, in CSV format with fields: test name, bug id, comment")
+)
+
+// Wait time for each test to run.
+const timeout = 5 * time.Minute
+
+func main() {
+ flag.Parse()
+ if *lang == "" || *image == "" {
+ fmt.Fprintf(os.Stderr, "lang and image flags must not be empty\n")
+ os.Exit(1)
+ }
+
+ os.Exit(runTests())
+}
+
+// runTests is a helper that is called by main. It exists so that we can run
+// defered functions before exiting. It returns an exit code that should be
+// passed to os.Exit.
+func runTests() int {
+ // Get tests to blacklist.
+ blacklist, err := getBlacklist()
+ if err != nil {
+ fmt.Fprintf(os.Stderr, "Error getting blacklist: %s\n", err.Error())
+ return 1
+ }
+
+ // Create a single docker container that will be used for all tests.
+ d := dockerutil.MakeDocker("gvisor-" + *lang)
+ defer d.CleanUp()
+
+ // Get a slice of tests to run. This will also start a single Docker
+ // container that will be used to run each test. The final test will
+ // stop the Docker container.
+ tests, err := getTests(d, blacklist)
+ if err != nil {
+ fmt.Fprintf(os.Stderr, "%s\n", err.Error())
+ return 1
+ }
+
+ m := testing.MainStart(testDeps{}, tests, nil, nil)
+ return m.Run()
+}
+
+// getTests returns a slice of tests to run, subject to the shard size and
+// index.
+func getTests(d dockerutil.Docker, blacklist map[string]struct{}) ([]testing.InternalTest, error) {
+ // Pull the image.
+ if err := dockerutil.Pull(*image); err != nil {
+ return nil, fmt.Errorf("docker pull %q failed: %v", *image, err)
+ }
+
+ // Run proctor with --pause flag to keep container alive forever.
+ if err := d.Run(*image, "--pause"); err != nil {
+ return nil, fmt.Errorf("docker run failed: %v", err)
+ }
+
+ // Get a list of all tests in the image.
+ list, err := d.Exec("/proctor", "--runtime", *lang, "--list")
+ if err != nil {
+ return nil, fmt.Errorf("docker exec failed: %v", err)
+ }
+
+ // Calculate a subset of tests to run corresponding to the current
+ // shard.
+ tests := strings.Fields(list)
+ sort.Strings(tests)
+ begin, end, err := testutil.TestBoundsForShard(len(tests))
+ if err != nil {
+ return nil, fmt.Errorf("TestsForShard() failed: %v", err)
+ }
+ log.Printf("Got bounds [%d:%d) for shard out of %d total tests", begin, end, len(tests))
+ tests = tests[begin:end]
+
+ var itests []testing.InternalTest
+ for _, tc := range tests {
+ // Capture tc in this scope.
+ tc := tc
+ itests = append(itests, testing.InternalTest{
+ Name: tc,
+ F: func(t *testing.T) {
+ // Is the test blacklisted?
+ if _, ok := blacklist[tc]; ok {
+ t.Skip("SKIP: blacklisted test %q", tc)
+ }
+
+ var (
+ now = time.Now()
+ done = make(chan struct{})
+ output string
+ err error
+ )
+
+ go func() {
+ fmt.Printf("RUNNING %s...\n", tc)
+ output, err = d.Exec("/proctor", "--runtime", *lang, "--test", tc)
+ close(done)
+ }()
+
+ select {
+ case <-done:
+ if err == nil {
+ fmt.Printf("PASS: %s (%v)\n\n", tc, time.Since(now))
+ return
+ }
+ t.Errorf("FAIL: %s (%v):\n%s\n", tc, time.Since(now), output)
+ case <-time.After(timeout):
+ t.Errorf("TIMEOUT: %s (%v):\n%s\n", tc, time.Since(now), output)
+ }
+ },
+ })
+ }
+ return itests, nil
+}
+
+// getBlacklist reads the blacklist file and returns a set of test names to
+// exclude.
+func getBlacklist() (map[string]struct{}, error) {
+ blacklist := make(map[string]struct{})
+ if *blacklistFile == "" {
+ return blacklist, nil
+ }
+ file, err := testutil.FindFile(*blacklistFile)
+ if err != nil {
+ return nil, err
+ }
+ f, err := os.Open(file)
+ if err != nil {
+ return nil, err
+ }
+ defer f.Close()
+
+ r := csv.NewReader(f)
+
+ // First line is header. Skip it.
+ if _, err := r.Read(); err != nil {
+ return nil, err
+ }
+
+ for {
+ record, err := r.Read()
+ if err == io.EOF {
+ break
+ }
+ if err != nil {
+ return nil, err
+ }
+ blacklist[record[0]] = struct{}{}
+ }
+ return blacklist, nil
+}
+
+// testDeps implements testing.testDeps (an unexported interface), and is
+// required to use testing.MainStart.
+type testDeps struct{}
+
+func (f testDeps) MatchString(a, b string) (bool, error) { return a == b, nil }
+func (f testDeps) StartCPUProfile(io.Writer) error { return nil }
+func (f testDeps) StopCPUProfile() {}
+func (f testDeps) WriteProfileTo(string, io.Writer, int) error { return nil }
+func (f testDeps) ImportPath() string { return "" }
+func (f testDeps) StartTestLog(io.Writer) {}
+func (f testDeps) StopTestLog() error { return nil }
diff --git a/test/runtimes/runner.sh b/test/runtimes/runner.sh
new file mode 100755
index 000000000..a8d9a3460
--- /dev/null
+++ b/test/runtimes/runner.sh
@@ -0,0 +1,35 @@
+#!/bin/bash
+
+# Copyright 2018 The gVisor Authors.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+set -euf -x -o pipefail
+
+echo -- "$@"
+
+# Create outputs dir if it does not exist.
+if [[ -n "${TEST_UNDECLARED_OUTPUTS_DIR}" ]]; then
+ mkdir -p "${TEST_UNDECLARED_OUTPUTS_DIR}"
+ chmod a+rwx "${TEST_UNDECLARED_OUTPUTS_DIR}"
+fi
+
+# Update the timestamp on the shard status file. Bazel looks for this.
+touch "${TEST_SHARD_STATUS_FILE}"
+
+# Get location of runner binary.
+readonly runner=$(find "${TEST_SRCDIR}" -name runner)
+
+# Pass the arguments of this script directly to the runner.
+exec "${runner}" "$@"
+
diff --git a/test/runtimes/runtimes_test.go b/test/runtimes/runtimes_test.go
deleted file mode 100644
index 9421021a1..000000000
--- a/test/runtimes/runtimes_test.go
+++ /dev/null
@@ -1,93 +0,0 @@
-// Copyright 2019 The gVisor Authors.
-//
-// Licensed under the Apache License, Version 2.0 (the "License");
-// you may not use this file except in compliance with the License.
-// You may obtain a copy of the License at
-//
-// http://www.apache.org/licenses/LICENSE-2.0
-//
-// Unless required by applicable law or agreed to in writing, software
-// distributed under the License is distributed on an "AS IS" BASIS,
-// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-// See the License for the specific language governing permissions and
-// limitations under the License.
-
-package runtimes
-
-import (
- "strings"
- "testing"
- "time"
-
- "gvisor.dev/gvisor/runsc/test/testutil"
-)
-
-// Wait time for each test to run.
-const timeout = 180 * time.Second
-
-// Helper function to execute the docker container associated with the
-// language passed. Captures the output of the list function and executes
-// each test individually, supplying any errors recieved.
-func testLang(t *testing.T, lang string) {
- t.Helper()
-
- img := "gcr.io/gvisor-presubmit/" + lang
- if err := testutil.Pull(img); err != nil {
- t.Fatalf("docker pull failed: %v", err)
- }
-
- c := testutil.MakeDocker("gvisor-list")
-
- list, err := c.RunFg(img, "--list")
- if err != nil {
- t.Fatalf("docker run failed: %v", err)
- }
- c.CleanUp()
-
- tests := strings.Fields(list)
-
- for _, tc := range tests {
- tc := tc
- t.Run(tc, func(t *testing.T) {
- d := testutil.MakeDocker("gvisor-test")
- if err := d.Run(img, "--test", tc); err != nil {
- t.Fatalf("docker test %q failed to run: %v", tc, err)
- }
- defer d.CleanUp()
-
- status, err := d.Wait(timeout)
- if err != nil {
- t.Fatalf("docker test %q failed to wait: %v", tc, err)
- }
- if status == 0 {
- t.Logf("test %q passed", tc)
- return
- }
- logs, err := d.Logs()
- if err != nil {
- t.Fatalf("docker test %q failed to supply logs: %v", tc, err)
- }
- t.Errorf("test %q failed: %v", tc, logs)
- })
- }
-}
-
-func TestGo(t *testing.T) {
- testLang(t, "go")
-}
-
-func TestJava(t *testing.T) {
- testLang(t, "java")
-}
-
-func TestNodejs(t *testing.T) {
- testLang(t, "nodejs")
-}
-
-func TestPhp(t *testing.T) {
- testLang(t, "php")
-}
-
-func TestPython(t *testing.T) {
- testLang(t, "python")
-}
diff --git a/test/syscalls/BUILD b/test/syscalls/BUILD
index ccae4925f..341e6b252 100644
--- a/test/syscalls/BUILD
+++ b/test/syscalls/BUILD
@@ -305,6 +305,8 @@ syscall_test(
syscall_test(test = "//test/syscalls/linux:proc_pid_uid_gid_map_test")
+syscall_test(test = "//test/syscalls/linux:proc_net_test")
+
syscall_test(
size = "medium",
test = "//test/syscalls/linux:pselect_test",
@@ -344,6 +346,11 @@ syscall_test(
)
syscall_test(
+ add_overlay = True,
+ test = "//test/syscalls/linux:readahead_test",
+)
+
+syscall_test(
size = "medium",
shard_count = 5,
test = "//test/syscalls/linux:readv_socket_test",
@@ -691,8 +698,13 @@ syscall_test(
syscall_test(test = "//test/syscalls/linux:proc_net_unix_test")
+syscall_test(test = "//test/syscalls/linux:proc_net_tcp_test")
+
+syscall_test(test = "//test/syscalls/linux:proc_net_udp_test")
+
go_binary(
name = "syscall_test_runner",
+ testonly = 1,
srcs = ["syscall_test_runner.go"],
data = [
"//runsc",
@@ -700,7 +712,7 @@ go_binary(
deps = [
"//pkg/log",
"//runsc/specutils",
- "//runsc/test/testutil",
+ "//runsc/testutil",
"//test/syscalls/gtest",
"@com_github_opencontainers_runtime-spec//specs-go:go_default_library",
"@org_golang_x_sys//unix:go_default_library",
diff --git a/test/syscalls/build_defs.bzl b/test/syscalls/build_defs.bzl
index 771d171f8..e94ef5602 100644
--- a/test/syscalls/build_defs.bzl
+++ b/test/syscalls/build_defs.bzl
@@ -94,6 +94,7 @@ def _syscall_test(
# more stable.
if platform == "kvm":
tags += ["manual"]
+ tags += ["requires-kvm"]
args = [
# Arguments are passed directly to syscall_test_runner binary.
@@ -122,3 +123,6 @@ def sh_test(**kwargs):
native.sh_test(
**kwargs
)
+
+def select_for_linux(for_linux, for_others = []):
+ return for_linux
diff --git a/test/syscalls/linux/BUILD b/test/syscalls/linux/BUILD
index ca4344139..d5a2b7725 100644
--- a/test/syscalls/linux/BUILD
+++ b/test/syscalls/linux/BUILD
@@ -1,3 +1,6 @@
+load("@rules_cc//cc:defs.bzl", "cc_binary", "cc_library")
+load("//test/syscalls:build_defs.bzl", "select_for_linux")
+
package(
default_visibility = ["//:sandbox"],
licenses = ["notice"],
@@ -108,20 +111,27 @@ cc_library(
cc_library(
name = "socket_test_util",
testonly = 1,
- srcs = ["socket_test_util.cc"],
+ srcs = [
+ "socket_test_util.cc",
+ ] + select_for_linux(
+ [
+ "socket_test_util_impl.cc",
+ ],
+ ),
hdrs = ["socket_test_util.h"],
deps = [
+ "@com_google_googletest//:gtest",
+ "@com_google_absl//absl/memory",
+ "@com_google_absl//absl/strings",
+ "@com_google_absl//absl/strings:str_format",
+ "@com_google_absl//absl/time",
"//test/util:file_descriptor",
"//test/util:posix_error",
"//test/util:temp_path",
"//test/util:test_util",
"//test/util:thread_util",
- "@com_google_absl//absl/memory",
- "@com_google_absl//absl/strings",
- "@com_google_absl//absl/strings:str_format",
- "@com_google_absl//absl/time",
- "@com_google_googletest//:gtest",
- ],
+ ] + select_for_linux([
+ ]),
)
cc_library(
@@ -256,12 +266,15 @@ cc_binary(
],
linkstatic = 1,
deps = [
- # The heap check doesn't handle mremap properly.
+ # The heapchecker doesn't recognize that io_destroy munmaps.
"@com_google_googletest//:gtest",
"@com_google_absl//absl/strings",
"//test/util:cleanup",
"//test/util:file_descriptor",
+ "//test/util:fs_util",
+ "//test/util:memory_util",
"//test/util:posix_error",
+ "//test/util:proc_util",
"//test/util:temp_path",
"//test/util:test_main",
"//test/util:test_util",
@@ -381,6 +394,7 @@ cc_binary(
"//test/util:test_main",
"//test/util:test_util",
"//test/util:thread_util",
+ "@com_google_absl//absl/flags:flag",
"@com_google_absl//absl/synchronization",
"@com_google_googletest//:gtest",
],
@@ -399,6 +413,7 @@ cc_binary(
"//test/util:test_main",
"//test/util:test_util",
"//test/util:thread_util",
+ "@com_google_absl//absl/flags:flag",
"@com_google_googletest//:gtest",
],
)
@@ -715,6 +730,7 @@ cc_binary(
"//test/util:test_util",
"//test/util:timer_util",
"@com_google_absl//absl/base:core_headers",
+ "@com_google_absl//absl/flags:flag",
"@com_google_absl//absl/memory",
"@com_google_absl//absl/strings",
"@com_google_absl//absl/time",
@@ -963,6 +979,7 @@ cc_binary(
"//test/util:test_main",
"//test/util:test_util",
"//test/util:thread_util",
+ "@com_google_absl//absl/flags:flag",
"@com_google_absl//absl/synchronization",
"@com_google_absl//absl/time",
"@com_google_googletest//:gtest",
@@ -983,6 +1000,7 @@ cc_binary(
"//test/util:test_main",
"//test/util:test_util",
"//test/util:thread_util",
+ "@com_google_absl//absl/flags:flag",
"@com_google_absl//absl/strings",
"@com_google_googletest//:gtest",
],
@@ -1204,6 +1222,7 @@ cc_binary(
"//test/util:test_main",
"//test/util:test_util",
"//test/util:thread_util",
+ "@com_google_absl//absl/memory",
"@com_google_absl//absl/strings",
"@com_google_googletest//:gtest",
],
@@ -1411,6 +1430,7 @@ cc_binary(
"//test/util:posix_error",
"//test/util:test_util",
"//test/util:thread_util",
+ "@com_google_absl//absl/flags:flag",
"@com_google_googletest//:gtest",
],
)
@@ -1427,6 +1447,7 @@ cc_binary(
"//test/util:posix_error",
"//test/util:test_util",
"//test/util:thread_util",
+ "@com_google_absl//absl/flags:flag",
"@com_google_googletest//:gtest",
],
)
@@ -1530,6 +1551,7 @@ cc_binary(
srcs = ["proc_net.cc"],
linkstatic = 1,
deps = [
+ "//test/util:capability_util",
"//test/util:file_descriptor",
"//test/util:fs_util",
"//test/util:test_main",
@@ -1609,6 +1631,7 @@ cc_binary(
"//test/util:test_util",
"//test/util:thread_util",
"//test/util:time_util",
+ "@com_google_absl//absl/flags:flag",
"@com_google_absl//absl/time",
"@com_google_googletest//:gtest",
],
@@ -1713,6 +1736,20 @@ cc_binary(
)
cc_binary(
+ name = "readahead_test",
+ testonly = 1,
+ srcs = ["readahead.cc"],
+ linkstatic = 1,
+ deps = [
+ "//test/util:file_descriptor",
+ "//test/util:temp_path",
+ "//test/util:test_main",
+ "//test/util:test_util",
+ "@com_google_googletest//:gtest",
+ ],
+)
+
+cc_binary(
name = "readv_test",
testonly = 1,
srcs = [
@@ -1871,7 +1908,9 @@ cc_binary(
"//test/util:temp_path",
"//test/util:test_main",
"//test/util:test_util",
+ "//test/util:thread_util",
"@com_google_absl//absl/strings",
+ "@com_google_absl//absl/time",
"@com_google_googletest//:gtest",
],
)
@@ -1905,6 +1944,7 @@ cc_binary(
"//test/util:test_util",
"//test/util:thread_util",
"@com_google_absl//absl/strings",
+ "@com_google_absl//absl/time",
"@com_google_googletest//:gtest",
],
)
@@ -1957,6 +1997,24 @@ cc_binary(
)
cc_binary(
+ name = "signalfd_test",
+ testonly = 1,
+ srcs = ["signalfd.cc"],
+ linkstatic = 1,
+ deps = [
+ "//test/util:file_descriptor",
+ "//test/util:logging",
+ "//test/util:posix_error",
+ "//test/util:signal_util",
+ "//test/util:test_main",
+ "//test/util:test_util",
+ "//test/util:thread_util",
+ "@com_google_absl//absl/synchronization",
+ "@com_google_googletest//:gtest",
+ ],
+)
+
+cc_binary(
name = "sigprocmask_test",
testonly = 1,
srcs = ["sigprocmask.cc"],
@@ -1979,6 +2037,7 @@ cc_binary(
"//test/util:posix_error",
"//test/util:test_util",
"//test/util:thread_util",
+ "@com_google_absl//absl/flags:flag",
"@com_google_absl//absl/time",
"@com_google_googletest//:gtest",
],
@@ -2405,6 +2464,63 @@ cc_binary(
)
cc_binary(
+ name = "socket_bind_to_device_test",
+ testonly = 1,
+ srcs = [
+ "socket_bind_to_device.cc",
+ ],
+ linkstatic = 1,
+ deps = [
+ ":ip_socket_test_util",
+ ":socket_bind_to_device_util",
+ ":socket_test_util",
+ "//test/util:capability_util",
+ "//test/util:test_main",
+ "//test/util:test_util",
+ "//test/util:thread_util",
+ "@com_google_googletest//:gtest",
+ ],
+)
+
+cc_binary(
+ name = "socket_bind_to_device_sequence_test",
+ testonly = 1,
+ srcs = [
+ "socket_bind_to_device_sequence.cc",
+ ],
+ linkstatic = 1,
+ deps = [
+ ":ip_socket_test_util",
+ ":socket_bind_to_device_util",
+ ":socket_test_util",
+ "//test/util:capability_util",
+ "//test/util:test_main",
+ "//test/util:test_util",
+ "//test/util:thread_util",
+ "@com_google_googletest//:gtest",
+ ],
+)
+
+cc_binary(
+ name = "socket_bind_to_device_distribution_test",
+ testonly = 1,
+ srcs = [
+ "socket_bind_to_device_distribution.cc",
+ ],
+ linkstatic = 1,
+ deps = [
+ ":ip_socket_test_util",
+ ":socket_bind_to_device_util",
+ ":socket_test_util",
+ "//test/util:capability_util",
+ "//test/util:test_main",
+ "//test/util:test_util",
+ "//test/util:thread_util",
+ "@com_google_googletest//:gtest",
+ ],
+)
+
+cc_binary(
name = "socket_ip_udp_loopback_non_blocking_test",
testonly = 1,
srcs = [
@@ -2681,6 +2797,23 @@ cc_library(
alwayslink = 1,
)
+cc_library(
+ name = "socket_bind_to_device_util",
+ testonly = 1,
+ srcs = [
+ "socket_bind_to_device_util.cc",
+ ],
+ hdrs = [
+ "socket_bind_to_device_util.h",
+ ],
+ deps = [
+ "//test/util:test_util",
+ "@com_google_absl//absl/memory",
+ "@com_google_absl//absl/strings",
+ ],
+ alwayslink = 1,
+)
+
cc_binary(
name = "socket_stream_local_test",
testonly = 1,
@@ -3104,8 +3237,6 @@ cc_binary(
testonly = 1,
srcs = ["timers.cc"],
linkstatic = 1,
- # FIXME(b/136599201)
- tags = ["flaky"],
deps = [
"//test/util:cleanup",
"//test/util:logging",
@@ -3114,6 +3245,7 @@ cc_binary(
"//test/util:signal_util",
"//test/util:test_util",
"//test/util:thread_util",
+ "@com_google_absl//absl/flags:flag",
"@com_google_absl//absl/time",
"@com_google_googletest//:gtest",
],
@@ -3193,6 +3325,8 @@ cc_binary(
"//test/util:test_main",
"//test/util:test_util",
"//test/util:thread_util",
+ "//test/util:uid_util",
+ "@com_google_absl//absl/flags:flag",
"@com_google_absl//absl/strings",
"@com_google_googletest//:gtest",
],
@@ -3284,6 +3418,7 @@ cc_binary(
"//test/util:multiprocess_util",
"//test/util:test_util",
"//test/util:time_util",
+ "@com_google_absl//absl/flags:flag",
"@com_google_absl//absl/time",
"@com_google_googletest//:gtest",
],
@@ -3353,6 +3488,7 @@ cc_binary(
"//test/util:test_util",
"//test/util:thread_util",
"@com_google_absl//absl/base:core_headers",
+ "@com_google_absl//absl/memory",
"@com_google_absl//absl/synchronization",
"@com_google_absl//absl/time",
"@com_google_googletest//:gtest",
@@ -3463,3 +3599,18 @@ cc_binary(
"@com_google_googletest//:gtest",
],
)
+
+cc_binary(
+ name = "proc_net_udp_test",
+ testonly = 1,
+ srcs = ["proc_net_udp.cc"],
+ linkstatic = 1,
+ deps = [
+ ":ip_socket_test_util",
+ "//test/util:file_descriptor",
+ "//test/util:test_main",
+ "//test/util:test_util",
+ "@com_google_absl//absl/strings",
+ "@com_google_googletest//:gtest",
+ ],
+)
diff --git a/test/syscalls/linux/affinity.cc b/test/syscalls/linux/affinity.cc
index f2d8375b6..128364c34 100644
--- a/test/syscalls/linux/affinity.cc
+++ b/test/syscalls/linux/affinity.cc
@@ -13,6 +13,7 @@
// limitations under the License.
#include <sched.h>
+#include <sys/syscall.h>
#include <sys/types.h>
#include <unistd.h>
diff --git a/test/syscalls/linux/aio.cc b/test/syscalls/linux/aio.cc
index 68dc05417..b27d4e10a 100644
--- a/test/syscalls/linux/aio.cc
+++ b/test/syscalls/linux/aio.cc
@@ -14,31 +14,57 @@
#include <fcntl.h>
#include <linux/aio_abi.h>
-#include <string.h>
#include <sys/mman.h>
#include <sys/syscall.h>
#include <sys/types.h>
#include <unistd.h>
+#include <algorithm>
+#include <string>
+
#include "gtest/gtest.h"
#include "test/syscalls/linux/file_base.h"
#include "test/util/cleanup.h"
#include "test/util/file_descriptor.h"
+#include "test/util/fs_util.h"
+#include "test/util/memory_util.h"
+#include "test/util/posix_error.h"
+#include "test/util/proc_util.h"
#include "test/util/temp_path.h"
#include "test/util/test_util.h"
+using ::testing::_;
+
namespace gvisor {
namespace testing {
namespace {
+// Returns the size of the VMA containing the given address.
+PosixErrorOr<size_t> VmaSizeAt(uintptr_t addr) {
+ ASSIGN_OR_RETURN_ERRNO(std::string proc_self_maps,
+ GetContents("/proc/self/maps"));
+ ASSIGN_OR_RETURN_ERRNO(auto entries, ParseProcMaps(proc_self_maps));
+ // Use binary search to find the first VMA that might contain addr.
+ ProcMapsEntry target = {};
+ target.end = addr;
+ auto it =
+ std::upper_bound(entries.begin(), entries.end(), target,
+ [](const ProcMapsEntry& x, const ProcMapsEntry& y) {
+ return x.end < y.end;
+ });
+ // Check that it actually contains addr.
+ if (it == entries.end() || addr < it->start) {
+ return PosixError(ENOENT, absl::StrCat("no VMA contains address ", addr));
+ }
+ return it->end - it->start;
+}
+
constexpr char kData[] = "hello world!";
int SubmitCtx(aio_context_t ctx, long nr, struct iocb** iocbpp) {
return syscall(__NR_io_submit, ctx, nr, iocbpp);
}
-} // namespace
-
class AIOTest : public FileTest {
public:
AIOTest() : ctx_(0) {}
@@ -124,10 +150,10 @@ TEST_F(AIOTest, BasicWrite) {
EXPECT_EQ(events[0].res, strlen(kData));
// Verify that the file contains the contents.
- char verify_buf[32] = {};
- ASSERT_THAT(read(test_file_fd_.get(), &verify_buf[0], strlen(kData)),
- SyscallSucceeds());
- EXPECT_EQ(strcmp(kData, &verify_buf[0]), 0);
+ char verify_buf[sizeof(kData)] = {};
+ ASSERT_THAT(read(test_file_fd_.get(), verify_buf, sizeof(kData)),
+ SyscallSucceedsWithValue(strlen(kData)));
+ EXPECT_STREQ(verify_buf, kData);
}
TEST_F(AIOTest, BadWrite) {
@@ -220,38 +246,25 @@ TEST_F(AIOTest, CloneVm) {
TEST_F(AIOTest, Mremap) {
// Setup a context that is 128 entries deep.
ASSERT_THAT(SetupContext(128), SyscallSucceeds());
+ const size_t ctx_size =
+ ASSERT_NO_ERRNO_AND_VALUE(VmaSizeAt(reinterpret_cast<uintptr_t>(ctx_)));
struct iocb cb = CreateCallback();
struct iocb* cbs[1] = {&cb};
// Reserve address space for the mremap target so we have something safe to
// map over.
- //
- // N.B. We reserve 2 pages because we'll attempt to remap to 2 pages below.
- // That should fail with EFAULT, but will fail with EINVAL if this mmap
- // returns the page immediately below ctx_, as
- // [new_address, new_address+2*kPageSize) overlaps [ctx_, ctx_+kPageSize).
- void* new_address = mmap(nullptr, 2 * kPageSize, PROT_READ,
- MAP_PRIVATE | MAP_ANONYMOUS, -1, 0);
- ASSERT_THAT(reinterpret_cast<intptr_t>(new_address), SyscallSucceeds());
- auto mmap_cleanup = Cleanup([new_address] {
- EXPECT_THAT(munmap(new_address, 2 * kPageSize), SyscallSucceeds());
- });
-
- // Test that remapping to a larger address fails.
- void* res = mremap(reinterpret_cast<void*>(ctx_), kPageSize, 2 * kPageSize,
- MREMAP_FIXED | MREMAP_MAYMOVE, new_address);
- ASSERT_THAT(reinterpret_cast<intptr_t>(res), SyscallFailsWithErrno(EFAULT));
+ Mapping dst =
+ ASSERT_NO_ERRNO_AND_VALUE(MmapAnon(ctx_size, PROT_READ, MAP_PRIVATE));
// Remap context 'handle' to a different address.
- res = mremap(reinterpret_cast<void*>(ctx_), kPageSize, kPageSize,
- MREMAP_FIXED | MREMAP_MAYMOVE, new_address);
- ASSERT_THAT(
- reinterpret_cast<intptr_t>(res),
- SyscallSucceedsWithValue(reinterpret_cast<intptr_t>(new_address)));
- mmap_cleanup.Release();
+ ASSERT_THAT(Mremap(reinterpret_cast<void*>(ctx_), ctx_size, dst.len(),
+ MREMAP_FIXED | MREMAP_MAYMOVE, dst.ptr()),
+ IsPosixErrorOkAndHolds(dst.ptr()));
aio_context_t old_ctx = ctx_;
- ctx_ = reinterpret_cast<aio_context_t>(new_address);
+ ctx_ = reinterpret_cast<aio_context_t>(dst.addr());
+ // io_destroy() will unmap dst now.
+ dst.release();
// Check that submitting the request with the old 'ctx_' fails.
ASSERT_THAT(SubmitCtx(old_ctx, 1, cbs), SyscallFailsWithErrno(EINVAL));
@@ -260,18 +273,12 @@ TEST_F(AIOTest, Mremap) {
ASSERT_THAT(Submit(1, cbs), SyscallSucceedsWithValue(1));
// Remap again.
- new_address =
- mmap(nullptr, kPageSize, PROT_READ, MAP_PRIVATE | MAP_ANONYMOUS, -1, 0);
- ASSERT_THAT(reinterpret_cast<int64_t>(new_address), SyscallSucceeds());
- auto mmap_cleanup2 = Cleanup([new_address] {
- EXPECT_THAT(munmap(new_address, kPageSize), SyscallSucceeds());
- });
- res = mremap(reinterpret_cast<void*>(ctx_), kPageSize, kPageSize,
- MREMAP_FIXED | MREMAP_MAYMOVE, new_address);
- ASSERT_THAT(reinterpret_cast<int64_t>(res),
- SyscallSucceedsWithValue(reinterpret_cast<int64_t>(new_address)));
- mmap_cleanup2.Release();
- ctx_ = reinterpret_cast<aio_context_t>(new_address);
+ dst = ASSERT_NO_ERRNO_AND_VALUE(MmapAnon(ctx_size, PROT_READ, MAP_PRIVATE));
+ ASSERT_THAT(Mremap(reinterpret_cast<void*>(ctx_), ctx_size, dst.len(),
+ MREMAP_FIXED | MREMAP_MAYMOVE, dst.ptr()),
+ IsPosixErrorOkAndHolds(dst.ptr()));
+ ctx_ = reinterpret_cast<aio_context_t>(dst.addr());
+ dst.release();
// Get the reply with yet another 'ctx_' and verify it.
struct io_event events[1];
@@ -281,51 +288,33 @@ TEST_F(AIOTest, Mremap) {
EXPECT_EQ(events[0].res, strlen(kData));
// Verify that the file contains the contents.
- char verify_buf[32] = {};
- ASSERT_THAT(read(test_file_fd_.get(), &verify_buf[0], strlen(kData)),
- SyscallSucceeds());
- EXPECT_EQ(strcmp(kData, &verify_buf[0]), 0);
+ char verify_buf[sizeof(kData)] = {};
+ ASSERT_THAT(read(test_file_fd_.get(), verify_buf, sizeof(kData)),
+ SyscallSucceedsWithValue(strlen(kData)));
+ EXPECT_STREQ(verify_buf, kData);
}
-// Tests that AIO context can be replaced with a different mapping at the same
-// address and continue working. Don't ask why, but Linux allows it.
-TEST_F(AIOTest, MremapOver) {
+// Tests that AIO context cannot be expanded with mremap.
+TEST_F(AIOTest, MremapExpansion) {
// Setup a context that is 128 entries deep.
ASSERT_THAT(SetupContext(128), SyscallSucceeds());
+ const size_t ctx_size =
+ ASSERT_NO_ERRNO_AND_VALUE(VmaSizeAt(reinterpret_cast<uintptr_t>(ctx_)));
- struct iocb cb = CreateCallback();
- struct iocb* cbs[1] = {&cb};
-
- ASSERT_THAT(Submit(1, cbs), SyscallSucceedsWithValue(1));
-
- // Allocate a new VMA, copy 'ctx_' content over, and remap it on top
- // of 'ctx_'.
- void* new_address = mmap(nullptr, kPageSize, PROT_READ | PROT_WRITE,
- MAP_PRIVATE | MAP_ANONYMOUS, -1, 0);
- ASSERT_THAT(reinterpret_cast<int64_t>(new_address), SyscallSucceeds());
- auto mmap_cleanup = Cleanup([new_address] {
- EXPECT_THAT(munmap(new_address, kPageSize), SyscallSucceeds());
- });
-
- memcpy(new_address, reinterpret_cast<void*>(ctx_), kPageSize);
- void* res =
- mremap(new_address, kPageSize, kPageSize, MREMAP_FIXED | MREMAP_MAYMOVE,
- reinterpret_cast<void*>(ctx_));
- ASSERT_THAT(reinterpret_cast<int64_t>(res), SyscallSucceedsWithValue(ctx_));
- mmap_cleanup.Release();
-
- // Everything continues to work just fine.
- struct io_event events[1];
- ASSERT_THAT(GetEvents(1, 1, events, nullptr), SyscallSucceedsWithValue(1));
- EXPECT_EQ(events[0].data, 0x123);
- EXPECT_EQ(events[0].obj, reinterpret_cast<long>(&cb));
- EXPECT_EQ(events[0].res, strlen(kData));
-
- // Verify that the file contains the contents.
- char verify_buf[32] = {};
- ASSERT_THAT(read(test_file_fd_.get(), &verify_buf[0], strlen(kData)),
- SyscallSucceeds());
- EXPECT_EQ(strcmp(kData, &verify_buf[0]), 0);
+ // Reserve address space for the mremap target so we have something safe to
+ // map over.
+ Mapping dst = ASSERT_NO_ERRNO_AND_VALUE(
+ MmapAnon(ctx_size + kPageSize, PROT_NONE, MAP_PRIVATE));
+
+ // Test that remapping to a larger address range fails.
+ ASSERT_THAT(Mremap(reinterpret_cast<void*>(ctx_), ctx_size, dst.len(),
+ MREMAP_FIXED | MREMAP_MAYMOVE, dst.ptr()),
+ PosixErrorIs(EFAULT, _));
+
+ // mm/mremap.c:sys_mremap() => mremap_to() does do_munmap() of the destination
+ // before it hits the VM_DONTEXPAND check in vma_to_resize(), so we should no
+ // longer munmap it (another thread may have created a mapping there).
+ dst.release();
}
// Tests that AIO calls fail if context's address is inaccessible.
@@ -429,5 +418,7 @@ TEST_P(AIOVectorizedParamTest, BadIOVecs) {
INSTANTIATE_TEST_SUITE_P(BadIOVecs, AIOVectorizedParamTest,
::testing::Values(IOCB_CMD_PREADV, IOCB_CMD_PWRITEV));
+} // namespace
+
} // namespace testing
} // namespace gvisor
diff --git a/test/syscalls/linux/base_poll_test.h b/test/syscalls/linux/base_poll_test.h
index 088831f9f..0d4a6701e 100644
--- a/test/syscalls/linux/base_poll_test.h
+++ b/test/syscalls/linux/base_poll_test.h
@@ -56,7 +56,7 @@ class TimerThread {
private:
mutable absl::Mutex mu_;
- bool cancel_ GUARDED_BY(mu_) = false;
+ bool cancel_ ABSL_GUARDED_BY(mu_) = false;
// Must be last to ensure that the destructor for the thread is run before
// any other member of the object is destroyed.
diff --git a/test/syscalls/linux/chown.cc b/test/syscalls/linux/chown.cc
index 2e82f0b3a..7a28b674d 100644
--- a/test/syscalls/linux/chown.cc
+++ b/test/syscalls/linux/chown.cc
@@ -16,10 +16,12 @@
#include <grp.h>
#include <sys/types.h>
#include <unistd.h>
+
#include <vector>
#include "gmock/gmock.h"
#include "gtest/gtest.h"
+#include "absl/flags/flag.h"
#include "absl/synchronization/notification.h"
#include "test/util/capability_util.h"
#include "test/util/file_descriptor.h"
@@ -29,9 +31,9 @@
#include "test/util/test_util.h"
#include "test/util/thread_util.h"
-DEFINE_int32(scratch_uid1, 65534, "first scratch UID");
-DEFINE_int32(scratch_uid2, 65533, "second scratch UID");
-DEFINE_int32(scratch_gid, 65534, "first scratch GID");
+ABSL_FLAG(int32_t, scratch_uid1, 65534, "first scratch UID");
+ABSL_FLAG(int32_t, scratch_uid2, 65533, "second scratch UID");
+ABSL_FLAG(int32_t, scratch_gid, 65534, "first scratch GID");
namespace gvisor {
namespace testing {
@@ -100,10 +102,12 @@ TEST_P(ChownParamTest, ChownFilePermissionDenied) {
// Change EUID and EGID.
//
// See note about POSIX below.
- EXPECT_THAT(syscall(SYS_setresgid, -1, FLAGS_scratch_gid, -1),
- SyscallSucceeds());
- EXPECT_THAT(syscall(SYS_setresuid, -1, FLAGS_scratch_uid1, -1),
- SyscallSucceeds());
+ EXPECT_THAT(
+ syscall(SYS_setresgid, -1, absl::GetFlag(FLAGS_scratch_gid), -1),
+ SyscallSucceeds());
+ EXPECT_THAT(
+ syscall(SYS_setresuid, -1, absl::GetFlag(FLAGS_scratch_uid1), -1),
+ SyscallSucceeds());
EXPECT_THAT(GetParam()(file.path(), geteuid(), getegid()),
PosixErrorIs(EPERM, ::testing::ContainsRegex("chown")));
@@ -125,8 +129,9 @@ TEST_P(ChownParamTest, ChownFileSucceedsAsRoot) {
// setresuid syscall. However, we want this thread to have its own set of
// credentials different from the parent process, so we use the raw
// syscall.
- EXPECT_THAT(syscall(SYS_setresuid, -1, FLAGS_scratch_uid2, -1),
- SyscallSucceeds());
+ EXPECT_THAT(
+ syscall(SYS_setresuid, -1, absl::GetFlag(FLAGS_scratch_uid2), -1),
+ SyscallSucceeds());
// Create file and immediately close it.
FileDescriptor fd =
@@ -143,12 +148,13 @@ TEST_P(ChownParamTest, ChownFileSucceedsAsRoot) {
fileCreated.WaitForNotification();
// Set file's owners to someone different.
- EXPECT_NO_ERRNO(GetParam()(filename, FLAGS_scratch_uid1, FLAGS_scratch_gid));
+ EXPECT_NO_ERRNO(GetParam()(filename, absl::GetFlag(FLAGS_scratch_uid1),
+ absl::GetFlag(FLAGS_scratch_gid)));
struct stat s;
EXPECT_THAT(stat(filename.c_str(), &s), SyscallSucceeds());
- EXPECT_EQ(s.st_uid, FLAGS_scratch_uid1);
- EXPECT_EQ(s.st_gid, FLAGS_scratch_gid);
+ EXPECT_EQ(s.st_uid, absl::GetFlag(FLAGS_scratch_uid1));
+ EXPECT_EQ(s.st_gid, absl::GetFlag(FLAGS_scratch_gid));
fileChowned.Notify();
}
diff --git a/test/syscalls/linux/clock_nanosleep.cc b/test/syscalls/linux/clock_nanosleep.cc
index 52a69d230..b55cddc52 100644
--- a/test/syscalls/linux/clock_nanosleep.cc
+++ b/test/syscalls/linux/clock_nanosleep.cc
@@ -43,7 +43,7 @@ int sys_clock_nanosleep(clockid_t clkid, int flags,
PosixErrorOr<absl::Time> GetTime(clockid_t clk) {
struct timespec ts = {};
- int rc = clock_gettime(clk, &ts);
+ const int rc = clock_gettime(clk, &ts);
MaybeSave();
if (rc < 0) {
return PosixError(errno, "clock_gettime");
@@ -67,31 +67,32 @@ TEST_P(WallClockNanosleepTest, InvalidValues) {
}
TEST_P(WallClockNanosleepTest, SleepOneSecond) {
- absl::Duration const duration = absl::Seconds(1);
- struct timespec dur = absl::ToTimespec(duration);
+ constexpr absl::Duration kSleepDuration = absl::Seconds(1);
+ struct timespec duration = absl::ToTimespec(kSleepDuration);
- absl::Time const before = ASSERT_NO_ERRNO_AND_VALUE(GetTime(GetParam()));
- EXPECT_THAT(RetryEINTR(sys_clock_nanosleep)(GetParam(), 0, &dur, &dur),
- SyscallSucceeds());
- absl::Time const after = ASSERT_NO_ERRNO_AND_VALUE(GetTime(GetParam()));
+ const absl::Time before = ASSERT_NO_ERRNO_AND_VALUE(GetTime(GetParam()));
+ EXPECT_THAT(
+ RetryEINTR(sys_clock_nanosleep)(GetParam(), 0, &duration, &duration),
+ SyscallSucceeds());
+ const absl::Time after = ASSERT_NO_ERRNO_AND_VALUE(GetTime(GetParam()));
- EXPECT_GE(after - before, duration);
+ EXPECT_GE(after - before, kSleepDuration);
}
TEST_P(WallClockNanosleepTest, InterruptedNanosleep) {
- absl::Duration const duration = absl::Seconds(60);
- struct timespec dur = absl::ToTimespec(duration);
+ constexpr absl::Duration kSleepDuration = absl::Seconds(60);
+ struct timespec duration = absl::ToTimespec(kSleepDuration);
// Install no-op signal handler for SIGALRM.
struct sigaction sa = {};
sigfillset(&sa.sa_mask);
sa.sa_handler = +[](int signo) {};
- auto const cleanup_sa =
+ const auto cleanup_sa =
ASSERT_NO_ERRNO_AND_VALUE(ScopedSigaction(SIGALRM, sa));
// Measure time since setting the alarm, since the alarm will interrupt the
// sleep and hence determine how long we sleep.
- absl::Time const before = ASSERT_NO_ERRNO_AND_VALUE(GetTime(GetParam()));
+ const absl::Time before = ASSERT_NO_ERRNO_AND_VALUE(GetTime(GetParam()));
// Set an alarm to go off while sleeping.
struct itimerval timer = {};
@@ -99,26 +100,51 @@ TEST_P(WallClockNanosleepTest, InterruptedNanosleep) {
timer.it_value.tv_usec = 0;
timer.it_interval.tv_sec = 1;
timer.it_interval.tv_usec = 0;
- auto const cleanup =
+ const auto cleanup =
ASSERT_NO_ERRNO_AND_VALUE(ScopedItimer(ITIMER_REAL, timer));
- EXPECT_THAT(sys_clock_nanosleep(GetParam(), 0, &dur, &dur),
+ EXPECT_THAT(sys_clock_nanosleep(GetParam(), 0, &duration, &duration),
SyscallFailsWithErrno(EINTR));
- absl::Time const after = ASSERT_NO_ERRNO_AND_VALUE(GetTime(GetParam()));
+ const absl::Time after = ASSERT_NO_ERRNO_AND_VALUE(GetTime(GetParam()));
- absl::Duration const remaining = absl::DurationFromTimespec(dur);
- EXPECT_GE(after - before + remaining, duration);
+ // Remaining time updated.
+ const absl::Duration remaining = absl::DurationFromTimespec(duration);
+ EXPECT_GE(after - before + remaining, kSleepDuration);
+}
+
+// Remaining time is *not* updated if nanosleep completes uninterrupted.
+TEST_P(WallClockNanosleepTest, UninterruptedNanosleep) {
+ constexpr absl::Duration kSleepDuration = absl::Milliseconds(10);
+ const struct timespec duration = absl::ToTimespec(kSleepDuration);
+
+ while (true) {
+ constexpr int kRemainingMagic = 42;
+ struct timespec remaining;
+ remaining.tv_sec = kRemainingMagic;
+ remaining.tv_nsec = kRemainingMagic;
+
+ int ret = sys_clock_nanosleep(GetParam(), 0, &duration, &remaining);
+ if (ret == EINTR) {
+ // Retry from beginning. We want a single uninterrupted call.
+ continue;
+ }
+
+ EXPECT_THAT(ret, SyscallSucceeds());
+ EXPECT_EQ(remaining.tv_sec, kRemainingMagic);
+ EXPECT_EQ(remaining.tv_nsec, kRemainingMagic);
+ break;
+ }
}
TEST_P(WallClockNanosleepTest, SleepUntil) {
- absl::Time const now = ASSERT_NO_ERRNO_AND_VALUE(GetTime(GetParam()));
- absl::Time const until = now + absl::Seconds(2);
- struct timespec ts = absl::ToTimespec(until);
+ const absl::Time now = ASSERT_NO_ERRNO_AND_VALUE(GetTime(GetParam()));
+ const absl::Time until = now + absl::Seconds(2);
+ const struct timespec ts = absl::ToTimespec(until);
EXPECT_THAT(
RetryEINTR(sys_clock_nanosleep)(GetParam(), TIMER_ABSTIME, &ts, nullptr),
SyscallSucceeds());
- absl::Time const after = ASSERT_NO_ERRNO_AND_VALUE(GetTime(GetParam()));
+ const absl::Time after = ASSERT_NO_ERRNO_AND_VALUE(GetTime(GetParam()));
EXPECT_GE(after, until);
}
@@ -127,8 +153,8 @@ INSTANTIATE_TEST_SUITE_P(Sleepers, WallClockNanosleepTest,
::testing::Values(CLOCK_REALTIME, CLOCK_MONOTONIC));
TEST(ClockNanosleepProcessTest, SleepFiveSeconds) {
- absl::Duration const kDuration = absl::Seconds(5);
- struct timespec dur = absl::ToTimespec(kDuration);
+ const absl::Duration kSleepDuration = absl::Seconds(5);
+ struct timespec duration = absl::ToTimespec(kSleepDuration);
// Ensure that CLOCK_PROCESS_CPUTIME_ID advances.
std::atomic<bool> done(false);
@@ -136,16 +162,16 @@ TEST(ClockNanosleepProcessTest, SleepFiveSeconds) {
while (!done.load()) {
}
});
- auto const cleanup_done = Cleanup([&] { done.store(true); });
+ const auto cleanup_done = Cleanup([&] { done.store(true); });
- absl::Time const before =
+ const absl::Time before =
ASSERT_NO_ERRNO_AND_VALUE(GetTime(CLOCK_PROCESS_CPUTIME_ID));
- EXPECT_THAT(
- RetryEINTR(sys_clock_nanosleep)(CLOCK_PROCESS_CPUTIME_ID, 0, &dur, &dur),
- SyscallSucceeds());
- absl::Time const after =
+ EXPECT_THAT(RetryEINTR(sys_clock_nanosleep)(CLOCK_PROCESS_CPUTIME_ID, 0,
+ &duration, &duration),
+ SyscallSucceeds());
+ const absl::Time after =
ASSERT_NO_ERRNO_AND_VALUE(GetTime(CLOCK_PROCESS_CPUTIME_ID));
- EXPECT_GE(after - before, kDuration);
+ EXPECT_GE(after - before, kSleepDuration);
}
} // namespace
diff --git a/test/syscalls/linux/exec_binary.cc b/test/syscalls/linux/exec_binary.cc
index 91b55015c..0a3931e5a 100644
--- a/test/syscalls/linux/exec_binary.cc
+++ b/test/syscalls/linux/exec_binary.cc
@@ -401,12 +401,17 @@ TEST(ElfTest, DataSegment) {
})));
}
-// Additonal pages beyond filesz are always RW.
+// Additonal pages beyond filesz honor (only) execute protections.
//
-// N.B. Linux uses set_brk -> vm_brk to additional pages beyond filesz (even
-// though start_brk itself will always be beyond memsz). As a result, the
-// segment permissions don't apply; the mapping is always RW.
+// N.B. Linux changed this in 4.11 (16e72e9b30986 "powerpc: do not make the
+// entire heap executable"). Previously, extra pages were always RW.
TEST(ElfTest, ExtraMemPages) {
+ // gVisor has the newer behavior.
+ if (!IsRunningOnGvisor()) {
+ auto version = ASSERT_NO_ERRNO_AND_VALUE(GetKernelVersion());
+ SKIP_IF(version.major < 4 || (version.major == 4 && version.minor < 11));
+ }
+
ElfBinary<64> elf = StandardElf();
// Create a standard ELF, but extend to 1.5 pages. The second page will be the
@@ -415,7 +420,7 @@ TEST(ElfTest, ExtraMemPages) {
decltype(elf)::ElfPhdr phdr = {};
phdr.p_type = PT_LOAD;
- // RWX segment. The extra anon page will be RW anyways.
+ // RWX segment. The extra anon page will also be RWX.
//
// N.B. Linux uses clear_user to clear the end of the file-mapped page, which
// respects the mapping protections. Thus if we map this RO with memsz >
@@ -454,7 +459,7 @@ TEST(ElfTest, ExtraMemPages) {
{0x41000, 0x42000, true, true, true, true, kPageSize, 0, 0, 0,
file.path().c_str()},
// extra page from anon.
- {0x42000, 0x43000, true, true, false, true, 0, 0, 0, 0, ""},
+ {0x42000, 0x43000, true, true, true, true, 0, 0, 0, 0, ""},
})));
}
@@ -469,7 +474,7 @@ TEST(ElfTest, AnonOnlySegment) {
phdr.p_offset = 0;
phdr.p_vaddr = 0x41000;
phdr.p_filesz = 0;
- phdr.p_memsz = kPageSize - 0xe8;
+ phdr.p_memsz = kPageSize;
elf.phdrs.push_back(phdr);
elf.UpdateOffsets();
@@ -854,6 +859,11 @@ TEST(ElfTest, ELFInterpreter) {
// The first segment really needs to start at 0 for a normal PIE binary, and
// thus includes the headers.
uint64_t const offset = interpreter.phdrs[1].p_offset;
+ // N.B. Since Linux 4.10 (0036d1f7eb95b "binfmt_elf: fix calculations for bss
+ // padding"), Linux unconditionally zeroes the remainder of the highest mapped
+ // page in an interpreter, failing if the protections don't allow write. Thus
+ // we must mark this writeable.
+ interpreter.phdrs[1].p_flags = PF_R | PF_W | PF_X;
interpreter.phdrs[1].p_offset = 0x0;
interpreter.phdrs[1].p_vaddr = 0x0;
interpreter.phdrs[1].p_filesz += offset;
@@ -903,15 +913,15 @@ TEST(ElfTest, ELFInterpreter) {
const uint64_t interp_load_addr = regs.rip & ~(kPageSize - 1);
- EXPECT_THAT(child,
- ContainsMappings(std::vector<ProcMapsEntry>({
- // Main binary
- {0x40000, 0x41000, true, false, true, true, 0, 0, 0, 0,
- binary_file.path().c_str()},
- // Interpreter
- {interp_load_addr, interp_load_addr + 0x1000, true, false,
- true, true, 0, 0, 0, 0, interpreter_file.path().c_str()},
- })));
+ EXPECT_THAT(
+ child, ContainsMappings(std::vector<ProcMapsEntry>({
+ // Main binary
+ {0x40000, 0x41000, true, false, true, true, 0, 0, 0, 0,
+ binary_file.path().c_str()},
+ // Interpreter
+ {interp_load_addr, interp_load_addr + 0x1000, true, true, true,
+ true, 0, 0, 0, 0, interpreter_file.path().c_str()},
+ })));
}
// Test parameter to ElfInterpterStaticTest cases. The first item is a suffix to
@@ -928,6 +938,8 @@ TEST_P(ElfInterpreterStaticTest, Test) {
const int expected_errno = std::get<1>(GetParam());
ElfBinary<64> interpreter = StandardElf();
+ // See comment in ElfTest.ELFInterpreter.
+ interpreter.phdrs[1].p_flags = PF_R | PF_W | PF_X;
interpreter.UpdateOffsets();
TempPath interpreter_file =
ASSERT_NO_ERRNO_AND_VALUE(CreateElfWith(interpreter));
@@ -957,7 +969,7 @@ TEST_P(ElfInterpreterStaticTest, Test) {
EXPECT_THAT(child, ContainsMappings(std::vector<ProcMapsEntry>({
// Interpreter.
- {0x40000, 0x41000, true, false, true, true, 0, 0, 0,
+ {0x40000, 0x41000, true, true, true, true, 0, 0, 0,
0, interpreter_file.path().c_str()},
})));
}
@@ -1035,6 +1047,8 @@ TEST(ElfTest, ELFInterpreterRelative) {
// The first segment really needs to start at 0 for a normal PIE binary, and
// thus includes the headers.
uint64_t const offset = interpreter.phdrs[1].p_offset;
+ // See comment in ElfTest.ELFInterpreter.
+ interpreter.phdrs[1].p_flags = PF_R | PF_W | PF_X;
interpreter.phdrs[1].p_offset = 0x0;
interpreter.phdrs[1].p_vaddr = 0x0;
interpreter.phdrs[1].p_filesz += offset;
@@ -1073,15 +1087,15 @@ TEST(ElfTest, ELFInterpreterRelative) {
const uint64_t interp_load_addr = regs.rip & ~(kPageSize - 1);
- EXPECT_THAT(child,
- ContainsMappings(std::vector<ProcMapsEntry>({
- // Main binary
- {0x40000, 0x41000, true, false, true, true, 0, 0, 0, 0,
- binary_file.path().c_str()},
- // Interpreter
- {interp_load_addr, interp_load_addr + 0x1000, true, false,
- true, true, 0, 0, 0, 0, interpreter_file.path().c_str()},
- })));
+ EXPECT_THAT(
+ child, ContainsMappings(std::vector<ProcMapsEntry>({
+ // Main binary
+ {0x40000, 0x41000, true, false, true, true, 0, 0, 0, 0,
+ binary_file.path().c_str()},
+ // Interpreter
+ {interp_load_addr, interp_load_addr + 0x1000, true, true, true,
+ true, 0, 0, 0, 0, interpreter_file.path().c_str()},
+ })));
}
// ELF interpreter architecture doesn't match the binary.
@@ -1095,6 +1109,8 @@ TEST(ElfTest, ELFInterpreterWrongArch) {
// The first segment really needs to start at 0 for a normal PIE binary, and
// thus includes the headers.
uint64_t const offset = interpreter.phdrs[1].p_offset;
+ // See comment in ElfTest.ELFInterpreter.
+ interpreter.phdrs[1].p_flags = PF_R | PF_W | PF_X;
interpreter.phdrs[1].p_offset = 0x0;
interpreter.phdrs[1].p_vaddr = 0x0;
interpreter.phdrs[1].p_filesz += offset;
@@ -1174,6 +1190,8 @@ TEST(ElfTest, ElfInterpreterNoExecute) {
// The first segment really needs to start at 0 for a normal PIE binary, and
// thus includes the headers.
uint64_t const offset = interpreter.phdrs[1].p_offset;
+ // See comment in ElfTest.ELFInterpreter.
+ interpreter.phdrs[1].p_flags = PF_R | PF_W | PF_X;
interpreter.phdrs[1].p_offset = 0x0;
interpreter.phdrs[1].p_vaddr = 0x0;
interpreter.phdrs[1].p_filesz += offset;
diff --git a/test/syscalls/linux/fcntl.cc b/test/syscalls/linux/fcntl.cc
index 2f8e7c9dd..8a45be12a 100644
--- a/test/syscalls/linux/fcntl.cc
+++ b/test/syscalls/linux/fcntl.cc
@@ -17,9 +17,12 @@
#include <syscall.h>
#include <unistd.h>
+#include <string>
+
#include "gtest/gtest.h"
#include "absl/base/macros.h"
#include "absl/base/port.h"
+#include "absl/flags/flag.h"
#include "absl/memory/memory.h"
#include "absl/strings/str_cat.h"
#include "absl/time/clock.h"
@@ -33,18 +36,19 @@
#include "test/util/test_util.h"
#include "test/util/timer_util.h"
-DEFINE_string(child_setlock_on, "",
- "Contains the path to try to set a file lock on.");
-DEFINE_bool(child_setlock_write, false,
- "Whether to set a writable lock (otherwise readable)");
-DEFINE_bool(blocking, false,
- "Whether to set a blocking lock (otherwise non-blocking).");
-DEFINE_bool(retry_eintr, false, "Whether to retry in the subprocess on EINTR.");
-DEFINE_uint64(child_setlock_start, 0, "The value of struct flock start");
-DEFINE_uint64(child_setlock_len, 0, "The value of struct flock len");
-DEFINE_int32(socket_fd, -1,
- "A socket to use for communicating more state back "
- "to the parent.");
+ABSL_FLAG(std::string, child_setlock_on, "",
+ "Contains the path to try to set a file lock on.");
+ABSL_FLAG(bool, child_setlock_write, false,
+ "Whether to set a writable lock (otherwise readable)");
+ABSL_FLAG(bool, blocking, false,
+ "Whether to set a blocking lock (otherwise non-blocking).");
+ABSL_FLAG(bool, retry_eintr, false,
+ "Whether to retry in the subprocess on EINTR.");
+ABSL_FLAG(uint64_t, child_setlock_start, 0, "The value of struct flock start");
+ABSL_FLAG(uint64_t, child_setlock_len, 0, "The value of struct flock len");
+ABSL_FLAG(int32_t, socket_fd, -1,
+ "A socket to use for communicating more state back "
+ "to the parent.");
namespace gvisor {
namespace testing {
@@ -918,25 +922,26 @@ TEST(FcntlTest, GetOwn) {
int main(int argc, char** argv) {
gvisor::testing::TestInit(&argc, &argv);
- if (!FLAGS_child_setlock_on.empty()) {
- int socket_fd = FLAGS_socket_fd;
- int fd = open(FLAGS_child_setlock_on.c_str(), O_RDWR, 0666);
+ const std::string setlock_on = absl::GetFlag(FLAGS_child_setlock_on);
+ if (!setlock_on.empty()) {
+ int socket_fd = absl::GetFlag(FLAGS_socket_fd);
+ int fd = open(setlock_on.c_str(), O_RDWR, 0666);
if (fd == -1 && errno != 0) {
int err = errno;
- std::cerr << "CHILD open " << FLAGS_child_setlock_on << " failed " << err
+ std::cerr << "CHILD open " << setlock_on << " failed " << err
<< std::endl;
exit(err);
}
struct flock fl;
- if (FLAGS_child_setlock_write) {
+ if (absl::GetFlag(FLAGS_child_setlock_write)) {
fl.l_type = F_WRLCK;
} else {
fl.l_type = F_RDLCK;
}
fl.l_whence = SEEK_SET;
- fl.l_start = FLAGS_child_setlock_start;
- fl.l_len = FLAGS_child_setlock_len;
+ fl.l_start = absl::GetFlag(FLAGS_child_setlock_start);
+ fl.l_len = absl::GetFlag(FLAGS_child_setlock_len);
// Test the fcntl, no need to log, the error is unambiguously
// from fcntl at this point.
@@ -946,8 +951,8 @@ int main(int argc, char** argv) {
gvisor::testing::MonotonicTimer timer;
timer.Start();
do {
- ret = fcntl(fd, FLAGS_blocking ? F_SETLKW : F_SETLK, &fl);
- } while (FLAGS_retry_eintr && ret == -1 && errno == EINTR);
+ ret = fcntl(fd, absl::GetFlag(FLAGS_blocking) ? F_SETLKW : F_SETLK, &fl);
+ } while (absl::GetFlag(FLAGS_retry_eintr) && ret == -1 && errno == EINTR);
auto usec = absl::ToInt64Microseconds(timer.Duration());
if (ret == -1 && errno != 0) {
diff --git a/test/syscalls/linux/futex.cc b/test/syscalls/linux/futex.cc
index d2cbbdb49..d3e3f998c 100644
--- a/test/syscalls/linux/futex.cc
+++ b/test/syscalls/linux/futex.cc
@@ -125,6 +125,10 @@ int futex_lock_pi(bool priv, std::atomic<int>* uaddr) {
if (priv) {
op |= FUTEX_PRIVATE_FLAG;
}
+ int zero = 0;
+ if (uaddr->compare_exchange_strong(zero, gettid())) {
+ return 0;
+ }
return RetryEINTR(syscall)(SYS_futex, uaddr, op, nullptr, nullptr);
}
@@ -133,6 +137,10 @@ int futex_trylock_pi(bool priv, std::atomic<int>* uaddr) {
if (priv) {
op |= FUTEX_PRIVATE_FLAG;
}
+ int zero = 0;
+ if (uaddr->compare_exchange_strong(zero, gettid())) {
+ return 0;
+ }
return RetryEINTR(syscall)(SYS_futex, uaddr, op, nullptr, nullptr);
}
@@ -141,6 +149,10 @@ int futex_unlock_pi(bool priv, std::atomic<int>* uaddr) {
if (priv) {
op |= FUTEX_PRIVATE_FLAG;
}
+ int tid = gettid();
+ if (uaddr->compare_exchange_strong(tid, 0)) {
+ return 0;
+ }
return RetryEINTR(syscall)(SYS_futex, uaddr, op, nullptr, nullptr);
}
@@ -692,8 +704,8 @@ TEST_P(PrivateAndSharedFutexTest, PITryLockConcurrency_NoRandomSave) {
std::unique_ptr<ScopedThread> threads[10];
for (size_t i = 0; i < ABSL_ARRAYSIZE(threads); ++i) {
threads[i] = absl::make_unique<ScopedThread>([is_priv, &a] {
- for (size_t j = 0; j < 100;) {
- if (futex_trylock_pi(is_priv, &a) >= 0) {
+ for (size_t j = 0; j < 10;) {
+ if (futex_trylock_pi(is_priv, &a) == 0) {
++j;
EXPECT_EQ(a.load() & FUTEX_TID_MASK, gettid());
SleepSafe(absl::Milliseconds(5));
diff --git a/test/syscalls/linux/ip_socket_test_util.cc b/test/syscalls/linux/ip_socket_test_util.cc
index c73262e72..410b42a47 100644
--- a/test/syscalls/linux/ip_socket_test_util.cc
+++ b/test/syscalls/linux/ip_socket_test_util.cc
@@ -23,6 +23,16 @@
namespace gvisor {
namespace testing {
+uint32_t IPFromInetSockaddr(const struct sockaddr* addr) {
+ auto* in_addr = reinterpret_cast<const struct sockaddr_in*>(addr);
+ return in_addr->sin_addr.s_addr;
+}
+
+uint16_t PortFromInetSockaddr(const struct sockaddr* addr) {
+ auto* in_addr = reinterpret_cast<const struct sockaddr_in*>(addr);
+ return ntohs(in_addr->sin_port);
+}
+
PosixErrorOr<int> InterfaceIndex(std::string name) {
// TODO(igudger): Consider using netlink.
ifreq req = {};
diff --git a/test/syscalls/linux/ip_socket_test_util.h b/test/syscalls/linux/ip_socket_test_util.h
index b498a053d..3d36b9620 100644
--- a/test/syscalls/linux/ip_socket_test_util.h
+++ b/test/syscalls/linux/ip_socket_test_util.h
@@ -26,6 +26,31 @@
namespace gvisor {
namespace testing {
+// Possible values of the "st" field in a /proc/net/{tcp,udp} entry. Source:
+// Linux kernel, include/net/tcp_states.h.
+enum {
+ TCP_ESTABLISHED = 1,
+ TCP_SYN_SENT,
+ TCP_SYN_RECV,
+ TCP_FIN_WAIT1,
+ TCP_FIN_WAIT2,
+ TCP_TIME_WAIT,
+ TCP_CLOSE,
+ TCP_CLOSE_WAIT,
+ TCP_LAST_ACK,
+ TCP_LISTEN,
+ TCP_CLOSING,
+ TCP_NEW_SYN_RECV,
+
+ TCP_MAX_STATES
+};
+
+// Extracts the IP address from an inet sockaddr in network byte order.
+uint32_t IPFromInetSockaddr(const struct sockaddr* addr);
+
+// Extracts the port from an inet sockaddr in host byte order.
+uint16_t PortFromInetSockaddr(const struct sockaddr* addr);
+
// InterfaceIndex returns the index of the named interface.
PosixErrorOr<int> InterfaceIndex(std::string name);
diff --git a/test/syscalls/linux/itimer.cc b/test/syscalls/linux/itimer.cc
index 51ce323b9..930d2b940 100644
--- a/test/syscalls/linux/itimer.cc
+++ b/test/syscalls/linux/itimer.cc
@@ -336,7 +336,9 @@ int main(int argc, char** argv) {
}
if (arg == gvisor::testing::kSIGPROFFairnessIdle) {
MaskSIGPIPE();
- return gvisor::testing::TestSIGPROFFairness(absl::Milliseconds(10));
+ // Sleep time > ClockTick (10ms) exercises sleeping gVisor's
+ // kernel.cpuClockTicker.
+ return gvisor::testing::TestSIGPROFFairness(absl::Milliseconds(25));
}
}
diff --git a/test/syscalls/linux/kill.cc b/test/syscalls/linux/kill.cc
index 18ad923b8..db29bd59c 100644
--- a/test/syscalls/linux/kill.cc
+++ b/test/syscalls/linux/kill.cc
@@ -21,6 +21,7 @@
#include <csignal>
#include "gtest/gtest.h"
+#include "absl/flags/flag.h"
#include "absl/synchronization/mutex.h"
#include "absl/time/clock.h"
#include "absl/time/time.h"
@@ -31,8 +32,8 @@
#include "test/util/test_util.h"
#include "test/util/thread_util.h"
-DEFINE_int32(scratch_uid, 65534, "scratch UID");
-DEFINE_int32(scratch_gid, 65534, "scratch GID");
+ABSL_FLAG(int32_t, scratch_uid, 65534, "scratch UID");
+ABSL_FLAG(int32_t, scratch_gid, 65534, "scratch GID");
using ::testing::Ge;
@@ -255,8 +256,8 @@ TEST(KillTest, ProcessGroups) {
TEST(KillTest, ChildDropsPrivsCannotKill) {
SKIP_IF(!ASSERT_NO_ERRNO_AND_VALUE(HaveCapability(CAP_SETUID)));
- int uid = FLAGS_scratch_uid;
- int gid = FLAGS_scratch_gid;
+ const int uid = absl::GetFlag(FLAGS_scratch_uid);
+ const int gid = absl::GetFlag(FLAGS_scratch_gid);
// Create the child that drops privileges and tries to kill the parent.
pid_t pid = fork();
@@ -331,8 +332,8 @@ TEST(KillTest, CanSIGCONTSameSession) {
EXPECT_TRUE(WIFSTOPPED(status) && WSTOPSIG(status) == SIGSTOP)
<< "status " << status;
- int uid = FLAGS_scratch_uid;
- int gid = FLAGS_scratch_gid;
+ const int uid = absl::GetFlag(FLAGS_scratch_uid);
+ const int gid = absl::GetFlag(FLAGS_scratch_gid);
// Drop privileges only in child process, or else this parent process won't be
// able to open some log files after the test ends.
diff --git a/test/syscalls/linux/link.cc b/test/syscalls/linux/link.cc
index a91703070..dd5352954 100644
--- a/test/syscalls/linux/link.cc
+++ b/test/syscalls/linux/link.cc
@@ -22,6 +22,7 @@
#include <string>
#include "gtest/gtest.h"
+#include "absl/flags/flag.h"
#include "absl/strings/str_cat.h"
#include "test/util/capability_util.h"
#include "test/util/file_descriptor.h"
@@ -31,7 +32,7 @@
#include "test/util/test_util.h"
#include "test/util/thread_util.h"
-DEFINE_int32(scratch_uid, 65534, "scratch UID");
+ABSL_FLAG(int32_t, scratch_uid, 65534, "scratch UID");
namespace gvisor {
namespace testing {
@@ -92,7 +93,8 @@ TEST(LinkTest, PermissionDenied) {
// threads have the same UIDs, so using the setuid wrapper sets all threads'
// real UID.
// Also drops capabilities.
- EXPECT_THAT(syscall(SYS_setuid, FLAGS_scratch_uid), SyscallSucceeds());
+ EXPECT_THAT(syscall(SYS_setuid, absl::GetFlag(FLAGS_scratch_uid)),
+ SyscallSucceeds());
EXPECT_THAT(link(oldfile.path().c_str(), newname.c_str()),
SyscallFailsWithErrno(EPERM));
diff --git a/test/syscalls/linux/mlock.cc b/test/syscalls/linux/mlock.cc
index aee4f7d1a..283c21ed3 100644
--- a/test/syscalls/linux/mlock.cc
+++ b/test/syscalls/linux/mlock.cc
@@ -169,26 +169,24 @@ TEST(MlockallTest, Future) {
// Run this test in a separate (single-threaded) subprocess to ensure that a
// background thread doesn't try to mmap a large amount of memory, fail due
// to hitting RLIMIT_MEMLOCK, and explode the process violently.
- EXPECT_THAT(InForkedProcess([] {
- auto const mapping =
- MmapAnon(kPageSize, PROT_READ | PROT_WRITE, MAP_PRIVATE)
- .ValueOrDie();
- TEST_CHECK(!IsPageMlocked(mapping.addr()));
- TEST_PCHECK(mlockall(MCL_FUTURE) == 0);
- // Ensure that mlockall(MCL_FUTURE) is turned off before the end
- // of the test, as otherwise mmaps may fail unexpectedly.
- Cleanup do_munlockall([] { TEST_PCHECK(munlockall() == 0); });
- auto const mapping2 = ASSERT_NO_ERRNO_AND_VALUE(
- MmapAnon(kPageSize, PROT_READ | PROT_WRITE, MAP_PRIVATE));
- TEST_CHECK(IsPageMlocked(mapping2.addr()));
- // Fire munlockall() and check that it disables
- // mlockall(MCL_FUTURE).
- do_munlockall.Release()();
- auto const mapping3 = ASSERT_NO_ERRNO_AND_VALUE(
- MmapAnon(kPageSize, PROT_READ | PROT_WRITE, MAP_PRIVATE));
- TEST_CHECK(!IsPageMlocked(mapping2.addr()));
- }),
- IsPosixErrorOkAndHolds(0));
+ auto const do_test = [] {
+ auto const mapping =
+ MmapAnon(kPageSize, PROT_READ | PROT_WRITE, MAP_PRIVATE).ValueOrDie();
+ TEST_CHECK(!IsPageMlocked(mapping.addr()));
+ TEST_PCHECK(mlockall(MCL_FUTURE) == 0);
+ // Ensure that mlockall(MCL_FUTURE) is turned off before the end of the
+ // test, as otherwise mmaps may fail unexpectedly.
+ Cleanup do_munlockall([] { TEST_PCHECK(munlockall() == 0); });
+ auto const mapping2 =
+ MmapAnon(kPageSize, PROT_READ | PROT_WRITE, MAP_PRIVATE).ValueOrDie();
+ TEST_CHECK(IsPageMlocked(mapping2.addr()));
+ // Fire munlockall() and check that it disables mlockall(MCL_FUTURE).
+ do_munlockall.Release()();
+ auto const mapping3 =
+ MmapAnon(kPageSize, PROT_READ | PROT_WRITE, MAP_PRIVATE).ValueOrDie();
+ TEST_CHECK(!IsPageMlocked(mapping2.addr()));
+ };
+ EXPECT_THAT(InForkedProcess(do_test), IsPosixErrorOkAndHolds(0));
}
TEST(MunlockallTest, Basic) {
diff --git a/test/syscalls/linux/mremap.cc b/test/syscalls/linux/mremap.cc
index 64e435cb7..f0e5f7d82 100644
--- a/test/syscalls/linux/mremap.cc
+++ b/test/syscalls/linux/mremap.cc
@@ -35,17 +35,6 @@ namespace testing {
namespace {
-// Wrapper for mremap that returns a PosixErrorOr<>, since the return type of
-// void* isn't directly compatible with SyscallSucceeds.
-PosixErrorOr<void*> Mremap(void* old_address, size_t old_size, size_t new_size,
- int flags, void* new_address) {
- void* rv = mremap(old_address, old_size, new_size, flags, new_address);
- if (rv == MAP_FAILED) {
- return PosixError(errno, "mremap failed");
- }
- return rv;
-}
-
// Fixture for mremap tests parameterized by mmap flags.
using MremapParamTest = ::testing::TestWithParam<int>;
diff --git a/test/syscalls/linux/open.cc b/test/syscalls/linux/open.cc
index e0525f386..2b1df52ce 100644
--- a/test/syscalls/linux/open.cc
+++ b/test/syscalls/linux/open.cc
@@ -21,6 +21,7 @@
#include "gmock/gmock.h"
#include "gtest/gtest.h"
+#include "absl/memory/memory.h"
#include "test/syscalls/linux/file_base.h"
#include "test/util/capability_util.h"
#include "test/util/cleanup.h"
diff --git a/test/syscalls/linux/packet_socket.cc b/test/syscalls/linux/packet_socket.cc
index 7a3379b9e..37b4e6575 100644
--- a/test/syscalls/linux/packet_socket.cc
+++ b/test/syscalls/linux/packet_socket.cc
@@ -83,9 +83,15 @@ void SendUDPMessage(int sock) {
// Send an IP packet and make sure ETH_P_<something else> doesn't pick it up.
TEST(BasicCookedPacketTest, WrongType) {
- SKIP_IF(!ASSERT_NO_ERRNO_AND_VALUE(HaveCapability(CAP_NET_RAW)));
+ // (b/129292371): Remove once we support packet sockets.
SKIP_IF(IsRunningOnGvisor());
+ if (!ASSERT_NO_ERRNO_AND_VALUE(HaveCapability(CAP_NET_RAW))) {
+ ASSERT_THAT(socket(AF_PACKET, SOCK_DGRAM, ETH_P_PUP),
+ SyscallFailsWithErrno(EPERM));
+ GTEST_SKIP();
+ }
+
FileDescriptor sock =
ASSERT_NO_ERRNO_AND_VALUE(Socket(AF_PACKET, SOCK_DGRAM, ETH_P_PUP));
@@ -118,18 +124,27 @@ class CookedPacketTest : public ::testing::TestWithParam<int> {
};
void CookedPacketTest::SetUp() {
- SKIP_IF(!ASSERT_NO_ERRNO_AND_VALUE(HaveCapability(CAP_NET_RAW)));
+ // (b/129292371): Remove once we support packet sockets.
SKIP_IF(IsRunningOnGvisor());
+ if (!ASSERT_NO_ERRNO_AND_VALUE(HaveCapability(CAP_NET_RAW))) {
+ ASSERT_THAT(socket(AF_PACKET, SOCK_DGRAM, htons(GetParam())),
+ SyscallFailsWithErrno(EPERM));
+ GTEST_SKIP();
+ }
+
ASSERT_THAT(socket_ = socket(AF_PACKET, SOCK_DGRAM, htons(GetParam())),
SyscallSucceeds());
}
void CookedPacketTest::TearDown() {
- SKIP_IF(!ASSERT_NO_ERRNO_AND_VALUE(HaveCapability(CAP_NET_RAW)));
+ // (b/129292371): Remove once we support packet sockets.
SKIP_IF(IsRunningOnGvisor());
- EXPECT_THAT(close(socket_), SyscallSucceeds());
+ // TearDown will be run even if we skip the test.
+ if (ASSERT_NO_ERRNO_AND_VALUE(HaveCapability(CAP_NET_RAW))) {
+ EXPECT_THAT(close(socket_), SyscallSucceeds());
+ }
}
int CookedPacketTest::GetLoopbackIndex() {
@@ -142,9 +157,6 @@ int CookedPacketTest::GetLoopbackIndex() {
// Receive via a packet socket.
TEST_P(CookedPacketTest, Receive) {
- SKIP_IF(!ASSERT_NO_ERRNO_AND_VALUE(HaveCapability(CAP_NET_RAW)));
- SKIP_IF(IsRunningOnGvisor());
-
// Let's use a simple IP payload: a UDP datagram.
FileDescriptor udp_sock =
ASSERT_NO_ERRNO_AND_VALUE(Socket(AF_INET, SOCK_DGRAM, 0));
@@ -201,9 +213,6 @@ TEST_P(CookedPacketTest, Receive) {
// Send via a packet socket.
TEST_P(CookedPacketTest, Send) {
- SKIP_IF(!ASSERT_NO_ERRNO_AND_VALUE(HaveCapability(CAP_NET_RAW)));
- SKIP_IF(IsRunningOnGvisor());
-
// Let's send a UDP packet and receive it using a regular UDP socket.
FileDescriptor udp_sock =
ASSERT_NO_ERRNO_AND_VALUE(Socket(AF_INET, SOCK_DGRAM, 0));
diff --git a/test/syscalls/linux/packet_socket_raw.cc b/test/syscalls/linux/packet_socket_raw.cc
index 9e96460ee..6491453b6 100644
--- a/test/syscalls/linux/packet_socket_raw.cc
+++ b/test/syscalls/linux/packet_socket_raw.cc
@@ -97,9 +97,15 @@ class RawPacketTest : public ::testing::TestWithParam<int> {
};
void RawPacketTest::SetUp() {
- SKIP_IF(!ASSERT_NO_ERRNO_AND_VALUE(HaveCapability(CAP_NET_RAW)));
+ // (b/129292371): Remove once we support packet sockets.
SKIP_IF(IsRunningOnGvisor());
+ if (!ASSERT_NO_ERRNO_AND_VALUE(HaveCapability(CAP_NET_RAW))) {
+ ASSERT_THAT(socket(AF_PACKET, SOCK_RAW, htons(GetParam())),
+ SyscallFailsWithErrno(EPERM));
+ GTEST_SKIP();
+ }
+
if (!IsRunningOnGvisor()) {
FileDescriptor acceptLocal = ASSERT_NO_ERRNO_AND_VALUE(
Open("/proc/sys/net/ipv4/conf/lo/accept_local", O_RDONLY));
@@ -119,10 +125,13 @@ void RawPacketTest::SetUp() {
}
void RawPacketTest::TearDown() {
- SKIP_IF(!ASSERT_NO_ERRNO_AND_VALUE(HaveCapability(CAP_NET_RAW)));
+ // (b/129292371): Remove once we support packet sockets.
SKIP_IF(IsRunningOnGvisor());
- EXPECT_THAT(close(socket_), SyscallSucceeds());
+ // TearDown will be run even if we skip the test.
+ if (ASSERT_NO_ERRNO_AND_VALUE(HaveCapability(CAP_NET_RAW))) {
+ EXPECT_THAT(close(socket_), SyscallSucceeds());
+ }
}
int RawPacketTest::GetLoopbackIndex() {
@@ -135,9 +144,6 @@ int RawPacketTest::GetLoopbackIndex() {
// Receive via a packet socket.
TEST_P(RawPacketTest, Receive) {
- SKIP_IF(!ASSERT_NO_ERRNO_AND_VALUE(HaveCapability(CAP_NET_RAW)));
- SKIP_IF(IsRunningOnGvisor());
-
// Let's use a simple IP payload: a UDP datagram.
FileDescriptor udp_sock =
ASSERT_NO_ERRNO_AND_VALUE(Socket(AF_INET, SOCK_DGRAM, 0));
@@ -208,9 +214,6 @@ TEST_P(RawPacketTest, Receive) {
// Send via a packet socket.
TEST_P(RawPacketTest, Send) {
- SKIP_IF(!ASSERT_NO_ERRNO_AND_VALUE(HaveCapability(CAP_NET_RAW)));
- SKIP_IF(IsRunningOnGvisor());
-
// Let's send a UDP packet and receive it using a regular UDP socket.
FileDescriptor udp_sock =
ASSERT_NO_ERRNO_AND_VALUE(Socket(AF_INET, SOCK_DGRAM, 0));
diff --git a/test/syscalls/linux/pipe.cc b/test/syscalls/linux/pipe.cc
index 65afb90f3..10e2a6dfc 100644
--- a/test/syscalls/linux/pipe.cc
+++ b/test/syscalls/linux/pipe.cc
@@ -168,6 +168,20 @@ TEST_P(PipeTest, Write) {
EXPECT_EQ(wbuf, rbuf);
}
+TEST_P(PipeTest, WritePage) {
+ SKIP_IF(!CreateBlocking());
+
+ std::vector<char> wbuf(kPageSize);
+ RandomizeBuffer(wbuf.data(), wbuf.size());
+ std::vector<char> rbuf(wbuf.size());
+
+ ASSERT_THAT(write(wfd_.get(), wbuf.data(), wbuf.size()),
+ SyscallSucceedsWithValue(wbuf.size()));
+ ASSERT_THAT(read(rfd_.get(), rbuf.data(), rbuf.size()),
+ SyscallSucceedsWithValue(rbuf.size()));
+ EXPECT_EQ(memcmp(rbuf.data(), wbuf.data(), wbuf.size()), 0);
+}
+
TEST_P(PipeTest, NonBlocking) {
SKIP_IF(!CreateNonBlocking());
diff --git a/test/syscalls/linux/prctl.cc b/test/syscalls/linux/prctl.cc
index bd1779557..d07571a5f 100644
--- a/test/syscalls/linux/prctl.cc
+++ b/test/syscalls/linux/prctl.cc
@@ -21,6 +21,7 @@
#include <string>
#include "gtest/gtest.h"
+#include "absl/flags/flag.h"
#include "test/util/capability_util.h"
#include "test/util/cleanup.h"
#include "test/util/multiprocess_util.h"
@@ -28,9 +29,9 @@
#include "test/util/test_util.h"
#include "test/util/thread_util.h"
-DEFINE_bool(prctl_no_new_privs_test_child, false,
- "If true, exit with the return value of prctl(PR_GET_NO_NEW_PRIVS) "
- "plus an offset (see test source).");
+ABSL_FLAG(bool, prctl_no_new_privs_test_child, false,
+ "If true, exit with the return value of prctl(PR_GET_NO_NEW_PRIVS) "
+ "plus an offset (see test source).");
namespace gvisor {
namespace testing {
@@ -220,7 +221,7 @@ TEST(PrctlTest, RootDumpability) {
int main(int argc, char** argv) {
gvisor::testing::TestInit(&argc, &argv);
- if (FLAGS_prctl_no_new_privs_test_child) {
+ if (absl::GetFlag(FLAGS_prctl_no_new_privs_test_child)) {
exit(gvisor::testing::kPrctlNoNewPrivsTestChildExitBase +
prctl(PR_GET_NO_NEW_PRIVS, 0, 0, 0, 0));
}
diff --git a/test/syscalls/linux/prctl_setuid.cc b/test/syscalls/linux/prctl_setuid.cc
index 00dd6523e..30f0d75b3 100644
--- a/test/syscalls/linux/prctl_setuid.cc
+++ b/test/syscalls/linux/prctl_setuid.cc
@@ -14,9 +14,11 @@
#include <sched.h>
#include <sys/prctl.h>
+
#include <string>
#include "gtest/gtest.h"
+#include "absl/flags/flag.h"
#include "test/util/capability_util.h"
#include "test/util/logging.h"
#include "test/util/multiprocess_util.h"
@@ -24,12 +26,12 @@
#include "test/util/test_util.h"
#include "test/util/thread_util.h"
-DEFINE_int32(scratch_uid, 65534, "scratch UID");
+ABSL_FLAG(int32_t, scratch_uid, 65534, "scratch UID");
// This flag is used to verify that after an exec PR_GET_KEEPCAPS
// returns 0, the return code will be offset by kPrGetKeepCapsExitBase.
-DEFINE_bool(prctl_pr_get_keepcaps, false,
- "If true the test will verify that prctl with pr_get_keepcaps"
- "returns 0. The test will exit with the result of that check.");
+ABSL_FLAG(bool, prctl_pr_get_keepcaps, false,
+ "If true the test will verify that prctl with pr_get_keepcaps"
+ "returns 0. The test will exit with the result of that check.");
// These tests exist seperately from prctl because we need to start
// them as root. Setuid() has the behavior that permissions are fully
@@ -113,10 +115,12 @@ TEST_F(PrctlKeepCapsSetuidTest, SetUidNoKeepCaps) {
// call to only apply to this task. POSIX threads, however, require that
// all threads have the same UIDs, so using the setuid wrapper sets all
// threads' real UID.
- EXPECT_THAT(syscall(SYS_setuid, FLAGS_scratch_uid), SyscallSucceeds());
+ EXPECT_THAT(syscall(SYS_setuid, absl::GetFlag(FLAGS_scratch_uid)),
+ SyscallSucceeds());
// Verify that we changed uid.
- EXPECT_THAT(getuid(), SyscallSucceedsWithValue(FLAGS_scratch_uid));
+ EXPECT_THAT(getuid(),
+ SyscallSucceedsWithValue(absl::GetFlag(FLAGS_scratch_uid)));
// Verify we lost the capability in the effective set, this always happens.
TEST_CHECK(!HaveCapability(CAP_SYS_ADMIN).ValueOrDie());
@@ -157,10 +161,12 @@ TEST_F(PrctlKeepCapsSetuidTest, SetUidKeepCaps) {
// call to only apply to this task. POSIX threads, however, require that
// all threads have the same UIDs, so using the setuid wrapper sets all
// threads' real UID.
- EXPECT_THAT(syscall(SYS_setuid, FLAGS_scratch_uid), SyscallSucceeds());
+ EXPECT_THAT(syscall(SYS_setuid, absl::GetFlag(FLAGS_scratch_uid)),
+ SyscallSucceeds());
// Verify that we changed uid.
- EXPECT_THAT(getuid(), SyscallSucceedsWithValue(FLAGS_scratch_uid));
+ EXPECT_THAT(getuid(),
+ SyscallSucceedsWithValue(absl::GetFlag(FLAGS_scratch_uid)));
// Verify we lost the capability in the effective set, this always happens.
TEST_CHECK(!HaveCapability(CAP_SYS_ADMIN).ValueOrDie());
@@ -253,7 +259,7 @@ TEST_F(PrctlKeepCapsSetuidTest, PrGetKeepCaps) {
int main(int argc, char** argv) {
gvisor::testing::TestInit(&argc, &argv);
- if (FLAGS_prctl_pr_get_keepcaps) {
+ if (absl::GetFlag(FLAGS_prctl_pr_get_keepcaps)) {
return gvisor::testing::kPrGetKeepCapsExitBase +
prctl(PR_GET_KEEPCAPS, 0, 0, 0, 0);
}
diff --git a/test/syscalls/linux/proc.cc b/test/syscalls/linux/proc.cc
index b440ba0df..e4c030bbb 100644
--- a/test/syscalls/linux/proc.cc
+++ b/test/syscalls/linux/proc.cc
@@ -440,6 +440,11 @@ TEST(ProcSelfAuxv, EntryPresence) {
EXPECT_EQ(auxv_entries.count(AT_PHENT), 1);
EXPECT_EQ(auxv_entries.count(AT_PHNUM), 1);
EXPECT_EQ(auxv_entries.count(AT_BASE), 1);
+ EXPECT_EQ(auxv_entries.count(AT_UID), 1);
+ EXPECT_EQ(auxv_entries.count(AT_EUID), 1);
+ EXPECT_EQ(auxv_entries.count(AT_GID), 1);
+ EXPECT_EQ(auxv_entries.count(AT_EGID), 1);
+ EXPECT_EQ(auxv_entries.count(AT_SECURE), 1);
EXPECT_EQ(auxv_entries.count(AT_CLKTCK), 1);
EXPECT_EQ(auxv_entries.count(AT_RANDOM), 1);
EXPECT_EQ(auxv_entries.count(AT_EXECFN), 1);
@@ -1602,9 +1607,9 @@ class BlockingChild {
}
mutable absl::Mutex mu_;
- bool stop_ GUARDED_BY(mu_) = false;
+ bool stop_ ABSL_GUARDED_BY(mu_) = false;
pid_t tid_;
- bool tid_ready_ GUARDED_BY(mu_) = false;
+ bool tid_ready_ ABSL_GUARDED_BY(mu_) = false;
// Must be last to ensure that the destructor for the thread is run before
// any other member of the object is destroyed.
@@ -1882,7 +1887,9 @@ void CheckDuplicatesRecursively(std::string path) {
errno = 0;
DIR* dir = opendir(path.c_str());
if (dir == nullptr) {
- ASSERT_THAT(errno, ::testing::AnyOf(EPERM, EACCES)) << path;
+ // Ignore any directories we can't read or missing directories as the
+ // directory could have been deleted/mutated from the time the parent
+ // directory contents were read.
return;
}
auto dir_closer = Cleanup([&dir]() { closedir(dir); });
diff --git a/test/syscalls/linux/proc_net.cc b/test/syscalls/linux/proc_net.cc
index 03d0665eb..efdaf202b 100644
--- a/test/syscalls/linux/proc_net.cc
+++ b/test/syscalls/linux/proc_net.cc
@@ -14,6 +14,7 @@
#include "gtest/gtest.h"
#include "gtest/gtest.h"
+#include "test/util/capability_util.h"
#include "test/util/file_descriptor.h"
#include "test/util/fs_util.h"
#include "test/util/test_util.h"
@@ -27,7 +28,7 @@ TEST(ProcNetIfInet6, Format) {
EXPECT_THAT(ifinet6,
::testing::MatchesRegex(
// Ex: "00000000000000000000000000000001 01 80 10 80 lo\n"
- "^([a-f\\d]{32}( [a-f\\d]{2}){4} +[a-z][a-z\\d]*\\n)+$"));
+ "^([a-f0-9]{32}( [a-f0-9]{2}){4} +[a-z][a-z0-9]*\n)+$"));
}
TEST(ProcSysNetIpv4Sack, Exists) {
@@ -35,6 +36,8 @@ TEST(ProcSysNetIpv4Sack, Exists) {
}
TEST(ProcSysNetIpv4Sack, CanReadAndWrite) {
+ SKIP_IF(!ASSERT_NO_ERRNO_AND_VALUE(HaveCapability((CAP_DAC_OVERRIDE))));
+
auto const fd =
ASSERT_NO_ERRNO_AND_VALUE(Open("/proc/sys/net/ipv4/tcp_sack", O_RDWR));
diff --git a/test/syscalls/linux/proc_net_tcp.cc b/test/syscalls/linux/proc_net_tcp.cc
index 498f62d9c..f6d7ad0bb 100644
--- a/test/syscalls/linux/proc_net_tcp.cc
+++ b/test/syscalls/linux/proc_net_tcp.cc
@@ -38,25 +38,6 @@ constexpr char kProcNetTCPHeader[] =
"retrnsmt uid timeout inode "
" ";
-// Possible values of the "st" field in a /proc/net/tcp entry. Source: Linux
-// kernel, include/net/tcp_states.h.
-enum {
- TCP_ESTABLISHED = 1,
- TCP_SYN_SENT,
- TCP_SYN_RECV,
- TCP_FIN_WAIT1,
- TCP_FIN_WAIT2,
- TCP_TIME_WAIT,
- TCP_CLOSE,
- TCP_CLOSE_WAIT,
- TCP_LAST_ACK,
- TCP_LISTEN,
- TCP_CLOSING,
- TCP_NEW_SYN_RECV,
-
- TCP_MAX_STATES
-};
-
// TCPEntry represents a single entry from /proc/net/tcp.
struct TCPEntry {
uint32_t local_addr;
@@ -70,42 +51,35 @@ struct TCPEntry {
uint64_t inode;
};
-uint32_t IP(const struct sockaddr* addr) {
- auto* in_addr = reinterpret_cast<const struct sockaddr_in*>(addr);
- return in_addr->sin_addr.s_addr;
-}
-
-uint16_t Port(const struct sockaddr* addr) {
- auto* in_addr = reinterpret_cast<const struct sockaddr_in*>(addr);
- return ntohs(in_addr->sin_port);
-}
-
// Finds the first entry in 'entries' for which 'predicate' returns true.
-// Returns true on match, and sets 'match' to point to the matching entry.
-bool FindBy(std::vector<TCPEntry> entries, TCPEntry* match,
+// Returns true on match, and sets 'match' to a copy of the matching entry. If
+// 'match' is null, it's ignored.
+bool FindBy(const std::vector<TCPEntry>& entries, TCPEntry* match,
std::function<bool(const TCPEntry&)> predicate) {
- for (int i = 0; i < entries.size(); ++i) {
- if (predicate(entries[i])) {
- *match = entries[i];
+ for (const TCPEntry& entry : entries) {
+ if (predicate(entry)) {
+ if (match != nullptr) {
+ *match = entry;
+ }
return true;
}
}
return false;
}
-bool FindByLocalAddr(std::vector<TCPEntry> entries, TCPEntry* match,
+bool FindByLocalAddr(const std::vector<TCPEntry>& entries, TCPEntry* match,
const struct sockaddr* addr) {
- uint32_t host = IP(addr);
- uint16_t port = Port(addr);
+ uint32_t host = IPFromInetSockaddr(addr);
+ uint16_t port = PortFromInetSockaddr(addr);
return FindBy(entries, match, [host, port](const TCPEntry& e) {
return (e.local_addr == host && e.local_port == port);
});
}
-bool FindByRemoteAddr(std::vector<TCPEntry> entries, TCPEntry* match,
+bool FindByRemoteAddr(const std::vector<TCPEntry>& entries, TCPEntry* match,
const struct sockaddr* addr) {
- uint32_t host = IP(addr);
- uint16_t port = Port(addr);
+ uint32_t host = IPFromInetSockaddr(addr);
+ uint16_t port = PortFromInetSockaddr(addr);
return FindBy(entries, match, [host, port](const TCPEntry& e) {
return (e.remote_addr == host && e.remote_port == port);
});
@@ -120,7 +94,7 @@ PosixErrorOr<std::vector<TCPEntry>> ProcNetTCPEntries() {
std::vector<TCPEntry> entries;
std::vector<std::string> lines = StrSplit(content, '\n');
std::cerr << "<contents of /proc/net/tcp>" << std::endl;
- for (std::string line : lines) {
+ for (const std::string& line : lines) {
std::cerr << line << std::endl;
if (!found_header) {
@@ -204,9 +178,8 @@ TEST(ProcNetTCP, BindAcceptConnect) {
EXPECT_EQ(entries.size(), 2);
}
- TCPEntry e;
- EXPECT_TRUE(FindByLocalAddr(entries, &e, sockets->first_addr()));
- EXPECT_TRUE(FindByRemoteAddr(entries, &e, sockets->first_addr()));
+ EXPECT_TRUE(FindByLocalAddr(entries, nullptr, sockets->first_addr()));
+ EXPECT_TRUE(FindByRemoteAddr(entries, nullptr, sockets->first_addr()));
}
TEST(ProcNetTCP, InodeReasonable) {
@@ -261,8 +234,8 @@ TEST(ProcNetTCP, State) {
FileDescriptor accepted =
ASSERT_NO_ERRNO_AND_VALUE(Accept(server->get(), nullptr, nullptr));
- const uint32_t accepted_local_host = IP(&addr);
- const uint16_t accepted_local_port = Port(&addr);
+ const uint32_t accepted_local_host = IPFromInetSockaddr(&addr);
+ const uint16_t accepted_local_port = PortFromInetSockaddr(&addr);
entries = ASSERT_NO_ERRNO_AND_VALUE(ProcNetTCPEntries());
TCPEntry accepted_entry;
diff --git a/test/syscalls/linux/proc_net_udp.cc b/test/syscalls/linux/proc_net_udp.cc
new file mode 100644
index 000000000..369df8e0e
--- /dev/null
+++ b/test/syscalls/linux/proc_net_udp.cc
@@ -0,0 +1,309 @@
+// Copyright 2019 Google LLC
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+#include <sys/socket.h>
+#include <sys/stat.h>
+#include <sys/types.h>
+#include <unistd.h>
+
+#include "gtest/gtest.h"
+#include "gtest/gtest.h"
+#include "absl/strings/numbers.h"
+#include "absl/strings/str_join.h"
+#include "absl/strings/str_split.h"
+#include "test/syscalls/linux/ip_socket_test_util.h"
+#include "test/util/file_descriptor.h"
+#include "test/util/test_util.h"
+
+namespace gvisor {
+namespace testing {
+namespace {
+
+using absl::StrCat;
+using absl::StrFormat;
+using absl::StrSplit;
+
+constexpr char kProcNetUDPHeader[] =
+ " sl local_address rem_address st tx_queue rx_queue tr tm->when "
+ "retrnsmt uid timeout inode ref pointer drops ";
+
+// UDPEntry represents a single entry from /proc/net/udp.
+struct UDPEntry {
+ uint32_t local_addr;
+ uint16_t local_port;
+
+ uint32_t remote_addr;
+ uint16_t remote_port;
+
+ uint64_t state;
+ uint64_t uid;
+ uint64_t inode;
+};
+
+std::string DescribeFirstInetSocket(const SocketPair& sockets) {
+ const struct sockaddr* addr = sockets.first_addr();
+ return StrFormat("First test socket: fd:%d %8X:%4X", sockets.first_fd(),
+ IPFromInetSockaddr(addr), PortFromInetSockaddr(addr));
+}
+
+std::string DescribeSecondInetSocket(const SocketPair& sockets) {
+ const struct sockaddr* addr = sockets.second_addr();
+ return StrFormat("Second test socket fd:%d %8X:%4X", sockets.second_fd(),
+ IPFromInetSockaddr(addr), PortFromInetSockaddr(addr));
+}
+
+// Finds the first entry in 'entries' for which 'predicate' returns true.
+// Returns true on match, and set 'match' to a copy of the matching entry. If
+// 'match' is null, it's ignored.
+bool FindBy(const std::vector<UDPEntry>& entries, UDPEntry* match,
+ std::function<bool(const UDPEntry&)> predicate) {
+ for (const UDPEntry& entry : entries) {
+ if (predicate(entry)) {
+ if (match != nullptr) {
+ *match = entry;
+ }
+ return true;
+ }
+ }
+ return false;
+}
+
+bool FindByLocalAddr(const std::vector<UDPEntry>& entries, UDPEntry* match,
+ const struct sockaddr* addr) {
+ uint32_t host = IPFromInetSockaddr(addr);
+ uint16_t port = PortFromInetSockaddr(addr);
+ return FindBy(entries, match, [host, port](const UDPEntry& e) {
+ return (e.local_addr == host && e.local_port == port);
+ });
+}
+
+bool FindByRemoteAddr(const std::vector<UDPEntry>& entries, UDPEntry* match,
+ const struct sockaddr* addr) {
+ uint32_t host = IPFromInetSockaddr(addr);
+ uint16_t port = PortFromInetSockaddr(addr);
+ return FindBy(entries, match, [host, port](const UDPEntry& e) {
+ return (e.remote_addr == host && e.remote_port == port);
+ });
+}
+
+PosixErrorOr<uint64_t> InodeFromSocketFD(int fd) {
+ ASSIGN_OR_RETURN_ERRNO(struct stat s, Fstat(fd));
+ if (!S_ISSOCK(s.st_mode)) {
+ return PosixError(EINVAL, StrFormat("FD %d is not a socket", fd));
+ }
+ return s.st_ino;
+}
+
+PosixErrorOr<bool> FindByFD(const std::vector<UDPEntry>& entries,
+ UDPEntry* match, int fd) {
+ ASSIGN_OR_RETURN_ERRNO(uint64_t inode, InodeFromSocketFD(fd));
+ return FindBy(entries, match,
+ [inode](const UDPEntry& e) { return (e.inode == inode); });
+}
+
+// Returns a parsed representation of /proc/net/udp entries.
+PosixErrorOr<std::vector<UDPEntry>> ProcNetUDPEntries() {
+ std::string content;
+ RETURN_IF_ERRNO(GetContents("/proc/net/udp", &content));
+
+ bool found_header = false;
+ std::vector<UDPEntry> entries;
+ std::vector<std::string> lines = StrSplit(content, '\n');
+ std::cerr << "<contents of /proc/net/udp>" << std::endl;
+ for (const std::string& line : lines) {
+ std::cerr << line << std::endl;
+
+ if (!found_header) {
+ EXPECT_EQ(line, kProcNetUDPHeader);
+ found_header = true;
+ continue;
+ }
+ if (line.empty()) {
+ continue;
+ }
+
+ // Parse a single entry from /proc/net/udp.
+ //
+ // Example entries:
+ //
+ // clang-format off
+ //
+ // sl local_address rem_address st tx_queue rx_queue tr tm->when retrnsmt uid timeout inode ref pointer drops
+ // 3503: 0100007F:0035 00000000:0000 07 00000000:00000000 00:00000000 00000000 0 0 33317 2 0000000000000000 0
+ // 3518: 00000000:0044 00000000:0000 07 00000000:00000000 00:00000000 00000000 0 0 40394 2 0000000000000000 0
+ // ^ ^ ^ ^ ^ ^ ^ ^ ^ ^ ^ ^ ^ ^ ^ ^ ^
+ // 0 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16
+ //
+ // clang-format on
+
+ UDPEntry entry;
+ std::vector<std::string> fields =
+ StrSplit(line, absl::ByAnyChar(": "), absl::SkipEmpty());
+
+ ASSIGN_OR_RETURN_ERRNO(entry.local_addr, AtoiBase(fields[1], 16));
+ ASSIGN_OR_RETURN_ERRNO(entry.local_port, AtoiBase(fields[2], 16));
+
+ ASSIGN_OR_RETURN_ERRNO(entry.remote_addr, AtoiBase(fields[3], 16));
+ ASSIGN_OR_RETURN_ERRNO(entry.remote_port, AtoiBase(fields[4], 16));
+
+ ASSIGN_OR_RETURN_ERRNO(entry.state, AtoiBase(fields[5], 16));
+ ASSIGN_OR_RETURN_ERRNO(entry.uid, Atoi<uint64_t>(fields[11]));
+ ASSIGN_OR_RETURN_ERRNO(entry.inode, Atoi<uint64_t>(fields[13]));
+
+ // Linux shares internal data structures between TCP and UDP sockets. The
+ // proc entries for UDP sockets share some fields with TCP sockets, but
+ // these fields should always be zero as they're not meaningful for UDP
+ // sockets.
+ EXPECT_EQ(fields[8], "00") << StrFormat("sl:%s, tr", fields[0]);
+ EXPECT_EQ(fields[9], "00000000") << StrFormat("sl:%s, tm->when", fields[0]);
+ EXPECT_EQ(fields[10], "00000000")
+ << StrFormat("sl:%s, retrnsmt", fields[0]);
+ EXPECT_EQ(fields[12], "0") << StrFormat("sl:%s, timeout", fields[0]);
+
+ entries.push_back(entry);
+ }
+ std::cerr << "<end of /proc/net/udp>" << std::endl;
+
+ return entries;
+}
+
+TEST(ProcNetUDP, Exists) {
+ const std::string content =
+ ASSERT_NO_ERRNO_AND_VALUE(GetContents("/proc/net/udp"));
+ const std::string header_line = StrCat(kProcNetUDPHeader, "\n");
+ EXPECT_THAT(content, ::testing::StartsWith(header_line));
+}
+
+TEST(ProcNetUDP, EntryUID) {
+ auto sockets =
+ ASSERT_NO_ERRNO_AND_VALUE(IPv4UDPBidirectionalBindSocketPair(0).Create());
+ std::vector<UDPEntry> entries =
+ ASSERT_NO_ERRNO_AND_VALUE(ProcNetUDPEntries());
+ UDPEntry e;
+ ASSERT_TRUE(FindByLocalAddr(entries, &e, sockets->first_addr()))
+ << DescribeFirstInetSocket(*sockets);
+ EXPECT_EQ(e.uid, geteuid());
+ ASSERT_TRUE(FindByRemoteAddr(entries, &e, sockets->first_addr()))
+ << DescribeSecondInetSocket(*sockets);
+ EXPECT_EQ(e.uid, geteuid());
+}
+
+TEST(ProcNetUDP, FindMutualEntries) {
+ auto sockets =
+ ASSERT_NO_ERRNO_AND_VALUE(IPv4UDPBidirectionalBindSocketPair(0).Create());
+ std::vector<UDPEntry> entries =
+ ASSERT_NO_ERRNO_AND_VALUE(ProcNetUDPEntries());
+
+ EXPECT_TRUE(FindByLocalAddr(entries, nullptr, sockets->first_addr()))
+ << DescribeFirstInetSocket(*sockets);
+ EXPECT_TRUE(FindByRemoteAddr(entries, nullptr, sockets->first_addr()))
+ << DescribeSecondInetSocket(*sockets);
+
+ EXPECT_TRUE(FindByLocalAddr(entries, nullptr, sockets->second_addr()))
+ << DescribeSecondInetSocket(*sockets);
+ EXPECT_TRUE(FindByRemoteAddr(entries, nullptr, sockets->second_addr()))
+ << DescribeFirstInetSocket(*sockets);
+}
+
+TEST(ProcNetUDP, EntriesRemovedOnClose) {
+ auto sockets =
+ ASSERT_NO_ERRNO_AND_VALUE(IPv4UDPBidirectionalBindSocketPair(0).Create());
+ std::vector<UDPEntry> entries =
+ ASSERT_NO_ERRNO_AND_VALUE(ProcNetUDPEntries());
+
+ EXPECT_TRUE(FindByLocalAddr(entries, nullptr, sockets->first_addr()))
+ << DescribeFirstInetSocket(*sockets);
+ EXPECT_TRUE(FindByLocalAddr(entries, nullptr, sockets->second_addr()))
+ << DescribeSecondInetSocket(*sockets);
+
+ EXPECT_THAT(close(sockets->release_first_fd()), SyscallSucceeds());
+ entries = ASSERT_NO_ERRNO_AND_VALUE(ProcNetUDPEntries());
+ // First socket's entry should be gone, but the second socket's entry should
+ // still exist.
+ EXPECT_FALSE(FindByLocalAddr(entries, nullptr, sockets->first_addr()))
+ << DescribeFirstInetSocket(*sockets);
+ EXPECT_TRUE(FindByLocalAddr(entries, nullptr, sockets->second_addr()))
+ << DescribeSecondInetSocket(*sockets);
+
+ EXPECT_THAT(close(sockets->release_second_fd()), SyscallSucceeds());
+ entries = ASSERT_NO_ERRNO_AND_VALUE(ProcNetUDPEntries());
+ // Both entries should be gone.
+ EXPECT_FALSE(FindByLocalAddr(entries, nullptr, sockets->first_addr()))
+ << DescribeFirstInetSocket(*sockets);
+ EXPECT_FALSE(FindByLocalAddr(entries, nullptr, sockets->second_addr()))
+ << DescribeSecondInetSocket(*sockets);
+}
+
+PosixErrorOr<std::unique_ptr<FileDescriptor>> BoundUDPSocket() {
+ ASSIGN_OR_RETURN_ERRNO(std::unique_ptr<FileDescriptor> socket,
+ IPv4UDPUnboundSocket(0).Create());
+ struct sockaddr_in addr;
+ addr.sin_family = AF_INET;
+ addr.sin_addr.s_addr = htonl(INADDR_ANY);
+ addr.sin_port = 0;
+
+ int res = bind(socket->get(), reinterpret_cast<const struct sockaddr*>(&addr),
+ sizeof(addr));
+ if (res) {
+ return PosixError(errno, "bind()");
+ }
+ return socket;
+}
+
+TEST(ProcNetUDP, BoundEntry) {
+ std::unique_ptr<FileDescriptor> socket =
+ ASSERT_NO_ERRNO_AND_VALUE(BoundUDPSocket());
+ struct sockaddr addr;
+ socklen_t len = sizeof(addr);
+ ASSERT_THAT(getsockname(socket->get(), &addr, &len), SyscallSucceeds());
+ uint16_t port = PortFromInetSockaddr(&addr);
+
+ std::vector<UDPEntry> entries =
+ ASSERT_NO_ERRNO_AND_VALUE(ProcNetUDPEntries());
+ UDPEntry e;
+ ASSERT_TRUE(ASSERT_NO_ERRNO_AND_VALUE(FindByFD(entries, &e, socket->get())));
+ EXPECT_EQ(e.local_port, port);
+ EXPECT_EQ(e.remote_addr, 0);
+ EXPECT_EQ(e.remote_port, 0);
+}
+
+TEST(ProcNetUDP, BoundSocketStateClosed) {
+ std::unique_ptr<FileDescriptor> socket =
+ ASSERT_NO_ERRNO_AND_VALUE(BoundUDPSocket());
+ std::vector<UDPEntry> entries =
+ ASSERT_NO_ERRNO_AND_VALUE(ProcNetUDPEntries());
+ UDPEntry e;
+ ASSERT_TRUE(ASSERT_NO_ERRNO_AND_VALUE(FindByFD(entries, &e, socket->get())));
+ EXPECT_EQ(e.state, TCP_CLOSE);
+}
+
+TEST(ProcNetUDP, ConnectedSocketStateEstablished) {
+ auto sockets =
+ ASSERT_NO_ERRNO_AND_VALUE(IPv4UDPBidirectionalBindSocketPair(0).Create());
+ std::vector<UDPEntry> entries =
+ ASSERT_NO_ERRNO_AND_VALUE(ProcNetUDPEntries());
+
+ UDPEntry e;
+ ASSERT_TRUE(FindByLocalAddr(entries, &e, sockets->first_addr()))
+ << DescribeFirstInetSocket(*sockets);
+ EXPECT_EQ(e.state, TCP_ESTABLISHED);
+
+ ASSERT_TRUE(FindByLocalAddr(entries, &e, sockets->second_addr()))
+ << DescribeSecondInetSocket(*sockets);
+ EXPECT_EQ(e.state, TCP_ESTABLISHED);
+}
+
+} // namespace
+} // namespace testing
+} // namespace gvisor
diff --git a/test/syscalls/linux/proc_net_unix.cc b/test/syscalls/linux/proc_net_unix.cc
index 9b9be66ff..83dbd1364 100644
--- a/test/syscalls/linux/proc_net_unix.cc
+++ b/test/syscalls/linux/proc_net_unix.cc
@@ -56,19 +56,44 @@ struct UnixEntry {
std::string path;
};
+// Abstract socket paths can have either trailing null bytes or '@'s as padding
+// at the end, depending on the linux version. This function strips any such
+// padding.
+void StripAbstractPathPadding(std::string* s) {
+ const char pad_char = s->back();
+ if (pad_char != '\0' && pad_char != '@') {
+ return;
+ }
+
+ const auto last_pos = s->find_last_not_of(pad_char);
+ if (last_pos != std::string::npos) {
+ s->resize(last_pos + 1);
+ }
+}
+
+// Precondition: addr must be a unix socket address (i.e. sockaddr_un) and
+// addr->sun_path must be null-terminated. This is always the case if addr comes
+// from Linux:
+//
+// Per man unix(7):
+//
+// "When the address of a pathname socket is returned (by [getsockname(2)]), its
+// length is
+//
+// offsetof(struct sockaddr_un, sun_path) + strlen(sun_path) + 1
+//
+// and sun_path contains the null-terminated pathname."
std::string ExtractPath(const struct sockaddr* addr) {
const char* path =
reinterpret_cast<const struct sockaddr_un*>(addr)->sun_path;
// Note: sockaddr_un.sun_path is an embedded character array of length
// UNIX_PATH_MAX, so we can always safely dereference the first 2 bytes below.
//
- // The kernel also enforces that the path is always null terminated.
+ // We also rely on the path being null-terminated.
if (path[0] == 0) {
- // Abstract socket paths are null padded to the end of the struct
- // sockaddr. However, these null bytes may or may not show up in
- // /proc/net/unix depending on the kernel version. Truncate after the first
- // null byte (by treating path as a c-string).
- return StrCat("@", &path[1]);
+ std::string abstract_path = StrCat("@", &path[1]);
+ StripAbstractPathPadding(&abstract_path);
+ return abstract_path;
}
return std::string(path);
}
@@ -96,14 +121,6 @@ PosixErrorOr<std::vector<UnixEntry>> ProcNetUnixEntries() {
continue;
}
- // Abstract socket paths can have trailing null bytes in them depending on
- // the linux version. Strip off everything after a null byte, including the
- // null byte.
- std::size_t null_pos = line.find('\0');
- if (null_pos != std::string::npos) {
- line.erase(null_pos);
- }
-
// Parse a single entry from /proc/net/unix.
//
// Sample file:
@@ -151,6 +168,7 @@ PosixErrorOr<std::vector<UnixEntry>> ProcNetUnixEntries() {
entry.path = "";
if (fields.size() > 1) {
entry.path = fields[1];
+ StripAbstractPathPadding(&entry.path);
}
entries.push_back(entry);
@@ -200,8 +218,8 @@ TEST(ProcNetUnix, FilesystemBindAcceptConnect) {
std::string path1 = ExtractPath(sockets->first_addr());
std::string path2 = ExtractPath(sockets->second_addr());
- std::cout << StreamFormat("Server socket address: %s\n", path1);
- std::cout << StreamFormat("Client socket address: %s\n", path2);
+ std::cerr << StreamFormat("Server socket address (path1): %s\n", path1);
+ std::cerr << StreamFormat("Client socket address (path2): %s\n", path2);
std::vector<UnixEntry> entries =
ASSERT_NO_ERRNO_AND_VALUE(ProcNetUnixEntries());
@@ -224,8 +242,8 @@ TEST(ProcNetUnix, AbstractBindAcceptConnect) {
std::string path1 = ExtractPath(sockets->first_addr());
std::string path2 = ExtractPath(sockets->second_addr());
- std::cout << StreamFormat("Server socket address: '%s'\n", path1);
- std::cout << StreamFormat("Client socket address: '%s'\n", path2);
+ std::cerr << StreamFormat("Server socket address (path1): '%s'\n", path1);
+ std::cerr << StreamFormat("Client socket address (path2): '%s'\n", path2);
std::vector<UnixEntry> entries =
ASSERT_NO_ERRNO_AND_VALUE(ProcNetUnixEntries());
diff --git a/test/syscalls/linux/ptrace.cc b/test/syscalls/linux/ptrace.cc
index abf2b1a04..8f3800380 100644
--- a/test/syscalls/linux/ptrace.cc
+++ b/test/syscalls/linux/ptrace.cc
@@ -27,6 +27,7 @@
#include "gmock/gmock.h"
#include "gtest/gtest.h"
+#include "absl/flags/flag.h"
#include "absl/time/clock.h"
#include "absl/time/time.h"
#include "test/util/logging.h"
@@ -36,10 +37,10 @@
#include "test/util/thread_util.h"
#include "test/util/time_util.h"
-DEFINE_bool(ptrace_test_execve_child, false,
- "If true, run the "
- "PtraceExecveTest_Execve_GetRegs_PeekUser_SIGKILL_TraceClone_"
- "TraceExit child workload.");
+ABSL_FLAG(bool, ptrace_test_execve_child, false,
+ "If true, run the "
+ "PtraceExecveTest_Execve_GetRegs_PeekUser_SIGKILL_TraceClone_"
+ "TraceExit child workload.");
namespace gvisor {
namespace testing {
@@ -1206,7 +1207,7 @@ TEST(PtraceTest, SeizeSetOptions) {
int main(int argc, char** argv) {
gvisor::testing::TestInit(&argc, &argv);
- if (FLAGS_ptrace_test_execve_child) {
+ if (absl::GetFlag(FLAGS_ptrace_test_execve_child)) {
gvisor::testing::RunExecveChild();
}
diff --git a/test/syscalls/linux/pty.cc b/test/syscalls/linux/pty.cc
index bd6907876..99a0df235 100644
--- a/test/syscalls/linux/pty.cc
+++ b/test/syscalls/linux/pty.cc
@@ -1292,10 +1292,9 @@ TEST_F(JobControlTest, ReleaseTTY) {
// Make sure we're ignoring SIGHUP, which will be sent to this process once we
// disconnect they TTY.
- struct sigaction sa = {
- .sa_handler = SIG_IGN,
- .sa_flags = 0,
- };
+ struct sigaction sa = {};
+ sa.sa_handler = SIG_IGN;
+ sa.sa_flags = 0;
sigemptyset(&sa.sa_mask);
struct sigaction old_sa;
EXPECT_THAT(sigaction(SIGHUP, &sa, &old_sa), SyscallSucceeds());
@@ -1362,10 +1361,9 @@ TEST_F(JobControlTest, ReleaseTTYSignals) {
ASSERT_THAT(ioctl(slave_.get(), TIOCSCTTY, 0), SyscallSucceeds());
received = 0;
- struct sigaction sa = {
- .sa_handler = sig_handler,
- .sa_flags = 0,
- };
+ struct sigaction sa = {};
+ sa.sa_handler = sig_handler;
+ sa.sa_flags = 0;
sigemptyset(&sa.sa_mask);
sigaddset(&sa.sa_mask, SIGHUP);
sigaddset(&sa.sa_mask, SIGCONT);
@@ -1403,10 +1401,9 @@ TEST_F(JobControlTest, ReleaseTTYSignals) {
// Make sure we're ignoring SIGHUP, which will be sent to this process once we
// disconnect they TTY.
- struct sigaction sighup_sa = {
- .sa_handler = SIG_IGN,
- .sa_flags = 0,
- };
+ struct sigaction sighup_sa = {};
+ sighup_sa.sa_handler = SIG_IGN;
+ sighup_sa.sa_flags = 0;
sigemptyset(&sighup_sa.sa_mask);
struct sigaction old_sa;
EXPECT_THAT(sigaction(SIGHUP, &sighup_sa, &old_sa), SyscallSucceeds());
@@ -1456,10 +1453,9 @@ TEST_F(JobControlTest, SetForegroundProcessGroup) {
ASSERT_THAT(ioctl(slave_.get(), TIOCSCTTY, 0), SyscallSucceeds());
// Ignore SIGTTOU so that we don't stop ourself when calling tcsetpgrp.
- struct sigaction sa = {
- .sa_handler = SIG_IGN,
- .sa_flags = 0,
- };
+ struct sigaction sa = {};
+ sa.sa_handler = SIG_IGN;
+ sa.sa_flags = 0;
sigemptyset(&sa.sa_mask);
sigaction(SIGTTOU, &sa, NULL);
@@ -1531,27 +1527,70 @@ TEST_F(JobControlTest, SetForegroundProcessGroupEmptyProcessGroup) {
TEST_F(JobControlTest, SetForegroundProcessGroupDifferentSession) {
ASSERT_THAT(ioctl(slave_.get(), TIOCSCTTY, 0), SyscallSucceeds());
+ int sync_setsid[2];
+ int sync_exit[2];
+ ASSERT_THAT(pipe(sync_setsid), SyscallSucceeds());
+ ASSERT_THAT(pipe(sync_exit), SyscallSucceeds());
+
// Create a new process and put it in a new session.
pid_t child = fork();
if (!child) {
TEST_PCHECK(setsid() >= 0);
// Tell the parent we're in a new session.
- TEST_PCHECK(!raise(SIGSTOP));
- TEST_PCHECK(!pause());
- _exit(1);
+ char c = 'c';
+ TEST_PCHECK(WriteFd(sync_setsid[1], &c, 1) == 1);
+ TEST_PCHECK(ReadFd(sync_exit[0], &c, 1) == 1);
+ _exit(0);
}
// Wait for the child to tell us it's in a new session.
- int wstatus;
- EXPECT_THAT(waitpid(child, &wstatus, WUNTRACED),
- SyscallSucceedsWithValue(child));
- EXPECT_TRUE(WSTOPSIG(wstatus));
+ char c = 'c';
+ ASSERT_THAT(ReadFd(sync_setsid[0], &c, 1), SyscallSucceedsWithValue(1));
// Child is in a new session, so we can't make it the foregroup process group.
EXPECT_THAT(ioctl(slave_.get(), TIOCSPGRP, &child),
SyscallFailsWithErrno(EPERM));
- EXPECT_THAT(kill(child, SIGKILL), SyscallSucceeds());
+ EXPECT_THAT(WriteFd(sync_exit[1], &c, 1), SyscallSucceedsWithValue(1));
+
+ int wstatus;
+ EXPECT_THAT(waitpid(child, &wstatus, 0), SyscallSucceedsWithValue(child));
+ EXPECT_TRUE(WIFEXITED(wstatus));
+ EXPECT_EQ(WEXITSTATUS(wstatus), 0);
+}
+
+// Verify that we don't hang when creating a new session from an orphaned
+// process group (b/139968068). Calling setsid() creates an orphaned process
+// group, as process groups that contain the session's leading process are
+// orphans.
+//
+// We create 2 sessions in this test. The init process in gVisor is considered
+// not to be an orphan (see sessions.go), so we have to create a session from
+// which to create a session. The latter session is being created from an
+// orphaned process group.
+TEST_F(JobControlTest, OrphanRegression) {
+ pid_t session_2_leader = fork();
+ if (!session_2_leader) {
+ TEST_PCHECK(setsid() >= 0);
+
+ pid_t session_3_leader = fork();
+ if (!session_3_leader) {
+ TEST_PCHECK(setsid() >= 0);
+
+ _exit(0);
+ }
+
+ int wstatus;
+ TEST_PCHECK(waitpid(session_3_leader, &wstatus, 0) == session_3_leader);
+ TEST_PCHECK(wstatus == 0);
+
+ _exit(0);
+ }
+
+ int wstatus;
+ ASSERT_THAT(waitpid(session_2_leader, &wstatus, 0),
+ SyscallSucceedsWithValue(session_2_leader));
+ ASSERT_EQ(wstatus, 0);
}
} // namespace
diff --git a/test/syscalls/linux/pty_root.cc b/test/syscalls/linux/pty_root.cc
index d2a321a6e..14a4af980 100644
--- a/test/syscalls/linux/pty_root.cc
+++ b/test/syscalls/linux/pty_root.cc
@@ -50,7 +50,7 @@ TEST(JobControlRootTest, StealTTY) {
// of 1.
pid_t child = fork();
if (!child) {
- TEST_PCHECK(setsid() >= 0);
+ ASSERT_THAT(setsid(), SyscallSucceeds());
// We shouldn't be able to steal the terminal with the wrong arg value.
TEST_PCHECK(ioctl(slave.get(), TIOCSCTTY, 0));
// We should be able to steal it here.
diff --git a/test/syscalls/linux/pwritev2.cc b/test/syscalls/linux/pwritev2.cc
index db519f4e0..f6a0fc96c 100644
--- a/test/syscalls/linux/pwritev2.cc
+++ b/test/syscalls/linux/pwritev2.cc
@@ -244,8 +244,10 @@ TEST(Pwritev2Test, TestInvalidOffset) {
const FileDescriptor fd =
ASSERT_NO_ERRNO_AND_VALUE(Open(file.path(), O_RDWR));
+ char buf[16];
struct iovec iov;
- iov.iov_base = nullptr;
+ iov.iov_base = buf;
+ iov.iov_len = sizeof(buf);
EXPECT_THAT(pwritev2(fd.get(), &iov, /*iovcnt=*/1,
/*offset=*/static_cast<off_t>(-8), /*flags=*/0),
@@ -286,8 +288,10 @@ TEST(Pwritev2Test, TestUnseekableFileInValid) {
SKIP_IF(pwritev2(-1, nullptr, 0, 0, 0) < 0 && errno == ENOSYS);
int pipe_fds[2];
+ char buf[16];
struct iovec iov;
- iov.iov_base = nullptr;
+ iov.iov_base = buf;
+ iov.iov_len = sizeof(buf);
ASSERT_THAT(pipe(pipe_fds), SyscallSucceeds());
@@ -307,8 +311,10 @@ TEST(Pwritev2Test, TestReadOnlyFile) {
const FileDescriptor fd =
ASSERT_NO_ERRNO_AND_VALUE(Open(file.path(), O_RDONLY));
+ char buf[16];
struct iovec iov;
- iov.iov_base = nullptr;
+ iov.iov_base = buf;
+ iov.iov_len = sizeof(buf);
EXPECT_THAT(pwritev2(fd.get(), &iov, /*iovcnt=*/1,
/*offset=*/0, /*flags=*/0),
@@ -324,8 +330,10 @@ TEST(Pwritev2Test, TestInvalidFlag) {
const FileDescriptor fd =
ASSERT_NO_ERRNO_AND_VALUE(Open(file.path(), O_RDWR | O_DIRECT));
+ char buf[16];
struct iovec iov;
- iov.iov_base = nullptr;
+ iov.iov_base = buf;
+ iov.iov_len = sizeof(buf);
EXPECT_THAT(pwritev2(fd.get(), &iov, /*iovcnt=*/1,
/*offset=*/0, /*flags=*/0xF0),
diff --git a/test/syscalls/linux/raw_socket_hdrincl.cc b/test/syscalls/linux/raw_socket_hdrincl.cc
index a070817eb..0a27506aa 100644
--- a/test/syscalls/linux/raw_socket_hdrincl.cc
+++ b/test/syscalls/linux/raw_socket_hdrincl.cc
@@ -63,7 +63,11 @@ class RawHDRINCL : public ::testing::Test {
};
void RawHDRINCL::SetUp() {
- SKIP_IF(!ASSERT_NO_ERRNO_AND_VALUE(HaveCapability(CAP_NET_RAW)));
+ if (!ASSERT_NO_ERRNO_AND_VALUE(HaveCapability(CAP_NET_RAW))) {
+ ASSERT_THAT(socket(AF_INET, SOCK_RAW, IPPROTO_RAW),
+ SyscallFailsWithErrno(EPERM));
+ GTEST_SKIP();
+ }
ASSERT_THAT(socket_ = socket(AF_INET, SOCK_RAW, IPPROTO_RAW),
SyscallSucceeds());
@@ -76,9 +80,10 @@ void RawHDRINCL::SetUp() {
}
void RawHDRINCL::TearDown() {
- SKIP_IF(!ASSERT_NO_ERRNO_AND_VALUE(HaveCapability(CAP_NET_RAW)));
-
- EXPECT_THAT(close(socket_), SyscallSucceeds());
+ // TearDown will be run even if we skip the test.
+ if (ASSERT_NO_ERRNO_AND_VALUE(HaveCapability(CAP_NET_RAW))) {
+ EXPECT_THAT(close(socket_), SyscallSucceeds());
+ }
}
struct iphdr RawHDRINCL::LoopbackHeader() {
@@ -123,8 +128,6 @@ bool RawHDRINCL::FillPacket(char* buf, size_t buf_size, int port,
// We should be able to create multiple IPPROTO_RAW sockets. RawHDRINCL::Setup
// creates the first one, so we only have to create one more here.
TEST_F(RawHDRINCL, MultipleCreation) {
- SKIP_IF(!ASSERT_NO_ERRNO_AND_VALUE(HaveCapability(CAP_NET_RAW)));
-
int s2;
ASSERT_THAT(s2 = socket(AF_INET, SOCK_RAW, IPPROTO_RAW), SyscallSucceeds());
@@ -133,23 +136,17 @@ TEST_F(RawHDRINCL, MultipleCreation) {
// Test that shutting down an unconnected socket fails.
TEST_F(RawHDRINCL, FailShutdownWithoutConnect) {
- SKIP_IF(!ASSERT_NO_ERRNO_AND_VALUE(HaveCapability(CAP_NET_RAW)));
-
ASSERT_THAT(shutdown(socket_, SHUT_WR), SyscallFailsWithErrno(ENOTCONN));
ASSERT_THAT(shutdown(socket_, SHUT_RD), SyscallFailsWithErrno(ENOTCONN));
}
// Test that listen() fails.
TEST_F(RawHDRINCL, FailListen) {
- SKIP_IF(!ASSERT_NO_ERRNO_AND_VALUE(HaveCapability(CAP_NET_RAW)));
-
ASSERT_THAT(listen(socket_, 1), SyscallFailsWithErrno(ENOTSUP));
}
// Test that accept() fails.
TEST_F(RawHDRINCL, FailAccept) {
- SKIP_IF(!ASSERT_NO_ERRNO_AND_VALUE(HaveCapability(CAP_NET_RAW)));
-
struct sockaddr saddr;
socklen_t addrlen;
ASSERT_THAT(accept(socket_, &saddr, &addrlen),
@@ -158,8 +155,6 @@ TEST_F(RawHDRINCL, FailAccept) {
// Test that the socket is writable immediately.
TEST_F(RawHDRINCL, PollWritableImmediately) {
- SKIP_IF(!ASSERT_NO_ERRNO_AND_VALUE(HaveCapability(CAP_NET_RAW)));
-
struct pollfd pfd = {};
pfd.fd = socket_;
pfd.events = POLLOUT;
@@ -168,8 +163,6 @@ TEST_F(RawHDRINCL, PollWritableImmediately) {
// Test that the socket isn't readable.
TEST_F(RawHDRINCL, NotReadable) {
- SKIP_IF(!ASSERT_NO_ERRNO_AND_VALUE(HaveCapability(CAP_NET_RAW)));
-
// Try to receive data with MSG_DONTWAIT, which returns immediately if there's
// nothing to be read.
char buf[117];
@@ -179,16 +172,12 @@ TEST_F(RawHDRINCL, NotReadable) {
// Test that we can connect() to a valid IP (loopback).
TEST_F(RawHDRINCL, ConnectToLoopback) {
- SKIP_IF(!ASSERT_NO_ERRNO_AND_VALUE(HaveCapability(CAP_NET_RAW)));
-
ASSERT_THAT(connect(socket_, reinterpret_cast<struct sockaddr*>(&addr_),
sizeof(addr_)),
SyscallSucceeds());
}
TEST_F(RawHDRINCL, SendWithoutConnectSucceeds) {
- SKIP_IF(!ASSERT_NO_ERRNO_AND_VALUE(HaveCapability(CAP_NET_RAW)));
-
struct iphdr hdr = LoopbackHeader();
ASSERT_THAT(send(socket_, &hdr, sizeof(hdr), 0),
SyscallSucceedsWithValue(sizeof(hdr)));
@@ -197,8 +186,6 @@ TEST_F(RawHDRINCL, SendWithoutConnectSucceeds) {
// HDRINCL implies write-only. Verify that we can't read a packet sent to
// loopback.
TEST_F(RawHDRINCL, NotReadableAfterWrite) {
- SKIP_IF(!ASSERT_NO_ERRNO_AND_VALUE(HaveCapability(CAP_NET_RAW)));
-
ASSERT_THAT(connect(socket_, reinterpret_cast<struct sockaddr*>(&addr_),
sizeof(addr_)),
SyscallSucceeds());
@@ -221,8 +208,6 @@ TEST_F(RawHDRINCL, NotReadableAfterWrite) {
}
TEST_F(RawHDRINCL, WriteTooSmall) {
- SKIP_IF(!ASSERT_NO_ERRNO_AND_VALUE(HaveCapability(CAP_NET_RAW)));
-
ASSERT_THAT(connect(socket_, reinterpret_cast<struct sockaddr*>(&addr_),
sizeof(addr_)),
SyscallSucceeds());
@@ -235,8 +220,6 @@ TEST_F(RawHDRINCL, WriteTooSmall) {
// Bind to localhost.
TEST_F(RawHDRINCL, BindToLocalhost) {
- SKIP_IF(!ASSERT_NO_ERRNO_AND_VALUE(HaveCapability(CAP_NET_RAW)));
-
ASSERT_THAT(
bind(socket_, reinterpret_cast<struct sockaddr*>(&addr_), sizeof(addr_)),
SyscallSucceeds());
@@ -244,8 +227,6 @@ TEST_F(RawHDRINCL, BindToLocalhost) {
// Bind to a different address.
TEST_F(RawHDRINCL, BindToInvalid) {
- SKIP_IF(!ASSERT_NO_ERRNO_AND_VALUE(HaveCapability(CAP_NET_RAW)));
-
struct sockaddr_in bind_addr = {};
bind_addr.sin_family = AF_INET;
bind_addr.sin_addr = {1}; // 1.0.0.0 - An address that we can't bind to.
@@ -256,8 +237,6 @@ TEST_F(RawHDRINCL, BindToInvalid) {
// Send and receive a packet.
TEST_F(RawHDRINCL, SendAndReceive) {
- SKIP_IF(!ASSERT_NO_ERRNO_AND_VALUE(HaveCapability(CAP_NET_RAW)));
-
int port = 40000;
if (!IsRunningOnGvisor()) {
port = static_cast<short>(ASSERT_NO_ERRNO_AND_VALUE(
@@ -302,8 +281,6 @@ TEST_F(RawHDRINCL, SendAndReceive) {
// Send and receive a packet with nonzero IP ID.
TEST_F(RawHDRINCL, SendAndReceiveNonzeroID) {
- SKIP_IF(!ASSERT_NO_ERRNO_AND_VALUE(HaveCapability(CAP_NET_RAW)));
-
int port = 40000;
if (!IsRunningOnGvisor()) {
port = static_cast<short>(ASSERT_NO_ERRNO_AND_VALUE(
@@ -349,8 +326,6 @@ TEST_F(RawHDRINCL, SendAndReceiveNonzeroID) {
// Send and receive a packet where the sendto address is not the same as the
// provided destination.
TEST_F(RawHDRINCL, SendAndReceiveDifferentAddress) {
- SKIP_IF(!ASSERT_NO_ERRNO_AND_VALUE(HaveCapability(CAP_NET_RAW)));
-
int port = 40000;
if (!IsRunningOnGvisor()) {
port = static_cast<short>(ASSERT_NO_ERRNO_AND_VALUE(
diff --git a/test/syscalls/linux/raw_socket_icmp.cc b/test/syscalls/linux/raw_socket_icmp.cc
index 971592d7d..8bcaba6f1 100644
--- a/test/syscalls/linux/raw_socket_icmp.cc
+++ b/test/syscalls/linux/raw_socket_icmp.cc
@@ -77,7 +77,11 @@ class RawSocketICMPTest : public ::testing::Test {
};
void RawSocketICMPTest::SetUp() {
- SKIP_IF(!ASSERT_NO_ERRNO_AND_VALUE(HaveCapability(CAP_NET_RAW)));
+ if (!ASSERT_NO_ERRNO_AND_VALUE(HaveCapability(CAP_NET_RAW))) {
+ ASSERT_THAT(socket(AF_INET, SOCK_RAW, IPPROTO_ICMP),
+ SyscallFailsWithErrno(EPERM));
+ GTEST_SKIP();
+ }
ASSERT_THAT(s_ = socket(AF_INET, SOCK_RAW, IPPROTO_ICMP), SyscallSucceeds());
@@ -90,9 +94,10 @@ void RawSocketICMPTest::SetUp() {
}
void RawSocketICMPTest::TearDown() {
- SKIP_IF(!ASSERT_NO_ERRNO_AND_VALUE(HaveCapability(CAP_NET_RAW)));
-
- EXPECT_THAT(close(s_), SyscallSucceeds());
+ // TearDown will be run even if we skip the test.
+ if (ASSERT_NO_ERRNO_AND_VALUE(HaveCapability(CAP_NET_RAW))) {
+ EXPECT_THAT(close(s_), SyscallSucceeds());
+ }
}
// We'll only read an echo in this case, as the kernel won't respond to the
diff --git a/test/syscalls/linux/raw_socket_ipv4.cc b/test/syscalls/linux/raw_socket_ipv4.cc
index 352037c88..cde2f07c9 100644
--- a/test/syscalls/linux/raw_socket_ipv4.cc
+++ b/test/syscalls/linux/raw_socket_ipv4.cc
@@ -67,7 +67,11 @@ class RawSocketTest : public ::testing::TestWithParam<int> {
};
void RawSocketTest::SetUp() {
- SKIP_IF(!ASSERT_NO_ERRNO_AND_VALUE(HaveCapability(CAP_NET_RAW)));
+ if (!ASSERT_NO_ERRNO_AND_VALUE(HaveCapability(CAP_NET_RAW))) {
+ ASSERT_THAT(socket(AF_INET, SOCK_RAW, Protocol()),
+ SyscallFailsWithErrno(EPERM));
+ GTEST_SKIP();
+ }
ASSERT_THAT(s_ = socket(AF_INET, SOCK_RAW, Protocol()), SyscallSucceeds());
@@ -79,9 +83,10 @@ void RawSocketTest::SetUp() {
}
void RawSocketTest::TearDown() {
- SKIP_IF(!ASSERT_NO_ERRNO_AND_VALUE(HaveCapability(CAP_NET_RAW)));
-
- EXPECT_THAT(close(s_), SyscallSucceeds());
+ // TearDown will be run even if we skip the test.
+ if (ASSERT_NO_ERRNO_AND_VALUE(HaveCapability(CAP_NET_RAW))) {
+ EXPECT_THAT(close(s_), SyscallSucceeds());
+ }
}
// We should be able to create multiple raw sockets for the same protocol.
diff --git a/test/syscalls/linux/readahead.cc b/test/syscalls/linux/readahead.cc
new file mode 100644
index 000000000..09703b5c1
--- /dev/null
+++ b/test/syscalls/linux/readahead.cc
@@ -0,0 +1,91 @@
+// Copyright 2019 The gVisor Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+#include <errno.h>
+#include <fcntl.h>
+
+#include "gtest/gtest.h"
+#include "test/util/file_descriptor.h"
+#include "test/util/temp_path.h"
+#include "test/util/test_util.h"
+
+namespace gvisor {
+namespace testing {
+
+namespace {
+
+TEST(ReadaheadTest, InvalidFD) {
+ EXPECT_THAT(readahead(-1, 1, 1), SyscallFailsWithErrno(EBADF));
+}
+
+TEST(ReadaheadTest, InvalidOffset) {
+ const TempPath in_file = ASSERT_NO_ERRNO_AND_VALUE(TempPath::CreateFile());
+ const FileDescriptor fd =
+ ASSERT_NO_ERRNO_AND_VALUE(Open(in_file.path(), O_RDWR));
+ EXPECT_THAT(readahead(fd.get(), -1, 1), SyscallFailsWithErrno(EINVAL));
+}
+
+TEST(ReadaheadTest, ValidOffset) {
+ constexpr char kData[] = "123";
+ const TempPath in_file = ASSERT_NO_ERRNO_AND_VALUE(TempPath::CreateFileWith(
+ GetAbsoluteTestTmpdir(), kData, TempPath::kDefaultFileMode));
+ const FileDescriptor fd =
+ ASSERT_NO_ERRNO_AND_VALUE(Open(in_file.path(), O_RDWR));
+
+ // N.B. The implementation of readahead is filesystem-specific, and a file
+ // backed by ram may return EINVAL because there is nothing to be read.
+ EXPECT_THAT(readahead(fd.get(), 1, 1), AnyOf(SyscallSucceedsWithValue(0),
+ SyscallFailsWithErrno(EINVAL)));
+}
+
+TEST(ReadaheadTest, PastEnd) {
+ constexpr char kData[] = "123";
+ const TempPath in_file = ASSERT_NO_ERRNO_AND_VALUE(TempPath::CreateFileWith(
+ GetAbsoluteTestTmpdir(), kData, TempPath::kDefaultFileMode));
+ const FileDescriptor fd =
+ ASSERT_NO_ERRNO_AND_VALUE(Open(in_file.path(), O_RDWR));
+ // See above.
+ EXPECT_THAT(readahead(fd.get(), 2, 2), AnyOf(SyscallSucceedsWithValue(0),
+ SyscallFailsWithErrno(EINVAL)));
+}
+
+TEST(ReadaheadTest, CrossesEnd) {
+ constexpr char kData[] = "123";
+ const TempPath in_file = ASSERT_NO_ERRNO_AND_VALUE(TempPath::CreateFileWith(
+ GetAbsoluteTestTmpdir(), kData, TempPath::kDefaultFileMode));
+ const FileDescriptor fd =
+ ASSERT_NO_ERRNO_AND_VALUE(Open(in_file.path(), O_RDWR));
+ // See above.
+ EXPECT_THAT(readahead(fd.get(), 4, 2), AnyOf(SyscallSucceedsWithValue(0),
+ SyscallFailsWithErrno(EINVAL)));
+}
+
+TEST(ReadaheadTest, WriteOnly) {
+ const TempPath in_file = ASSERT_NO_ERRNO_AND_VALUE(TempPath::CreateFile());
+ const FileDescriptor fd =
+ ASSERT_NO_ERRNO_AND_VALUE(Open(in_file.path(), O_WRONLY));
+ EXPECT_THAT(readahead(fd.get(), 0, 1), SyscallFailsWithErrno(EBADF));
+}
+
+TEST(ReadaheadTest, InvalidSize) {
+ const TempPath in_file = ASSERT_NO_ERRNO_AND_VALUE(TempPath::CreateFile());
+ const FileDescriptor fd =
+ ASSERT_NO_ERRNO_AND_VALUE(Open(in_file.path(), O_RDWR));
+ EXPECT_THAT(readahead(fd.get(), 0, -1), SyscallFailsWithErrno(EINVAL));
+}
+
+} // namespace
+
+} // namespace testing
+} // namespace gvisor
diff --git a/test/syscalls/linux/semaphore.cc b/test/syscalls/linux/semaphore.cc
index 421318fcb..40c57f543 100644
--- a/test/syscalls/linux/semaphore.cc
+++ b/test/syscalls/linux/semaphore.cc
@@ -15,6 +15,7 @@
#include <sys/ipc.h>
#include <sys/sem.h>
#include <sys/types.h>
+
#include <atomic>
#include <cerrno>
#include <ctime>
@@ -22,6 +23,7 @@
#include "gmock/gmock.h"
#include "gtest/gtest.h"
#include "absl/base/macros.h"
+#include "absl/memory/memory.h"
#include "absl/synchronization/mutex.h"
#include "absl/time/clock.h"
#include "test/util/capability_util.h"
diff --git a/test/syscalls/linux/sendfile.cc b/test/syscalls/linux/sendfile.cc
index e5d72e28a..4502e7fb4 100644
--- a/test/syscalls/linux/sendfile.cc
+++ b/test/syscalls/linux/sendfile.cc
@@ -19,9 +19,12 @@
#include "gmock/gmock.h"
#include "gtest/gtest.h"
#include "absl/strings/string_view.h"
+#include "absl/time/clock.h"
+#include "absl/time/time.h"
#include "test/util/file_descriptor.h"
#include "test/util/temp_path.h"
#include "test/util/test_util.h"
+#include "test/util/thread_util.h"
namespace gvisor {
namespace testing {
@@ -299,10 +302,30 @@ TEST(SendFileTest, DoNotSendfileIfOutfileIsAppendOnly) {
// Open the output file as append only.
const FileDescriptor outf =
- ASSERT_NO_ERRNO_AND_VALUE(Open(out_file.path(), O_APPEND));
+ ASSERT_NO_ERRNO_AND_VALUE(Open(out_file.path(), O_WRONLY | O_APPEND));
// Send data and verify that sendfile returns the correct errno.
EXPECT_THAT(sendfile(outf.get(), inf.get(), nullptr, kDataSize),
+ SyscallFailsWithErrno(EINVAL));
+}
+
+TEST(SendFileTest, AppendCheckOrdering) {
+ constexpr char kData[] = "And by opposing end them: to die, to sleep";
+ constexpr int kDataSize = sizeof(kData) - 1;
+ const TempPath file = ASSERT_NO_ERRNO_AND_VALUE(TempPath::CreateFileWith(
+ GetAbsoluteTestTmpdir(), kData, TempPath::kDefaultFileMode));
+
+ const FileDescriptor read =
+ ASSERT_NO_ERRNO_AND_VALUE(Open(file.path(), O_RDONLY));
+ const FileDescriptor write =
+ ASSERT_NO_ERRNO_AND_VALUE(Open(file.path(), O_WRONLY));
+ const FileDescriptor append =
+ ASSERT_NO_ERRNO_AND_VALUE(Open(file.path(), O_APPEND));
+
+ // Check that read/write file mode is verified before append.
+ EXPECT_THAT(sendfile(append.get(), read.get(), nullptr, kDataSize),
+ SyscallFailsWithErrno(EBADF));
+ EXPECT_THAT(sendfile(write.get(), write.get(), nullptr, kDataSize),
SyscallFailsWithErrno(EBADF));
}
@@ -422,6 +445,72 @@ TEST(SendFileTest, SendToNotARegularFile) {
EXPECT_THAT(sendfile(outf.get(), inf.get(), nullptr, 0),
SyscallFailsWithErrno(EINVAL));
}
+
+TEST(SendFileTest, SendPipeWouldBlock) {
+ // Create temp file.
+ constexpr char kData[] =
+ "The fool doth think he is wise, but the wise man knows himself to be a "
+ "fool.";
+ constexpr int kDataSize = sizeof(kData) - 1;
+ const TempPath in_file = ASSERT_NO_ERRNO_AND_VALUE(TempPath::CreateFileWith(
+ GetAbsoluteTestTmpdir(), kData, TempPath::kDefaultFileMode));
+
+ // Open the input file as read only.
+ const FileDescriptor inf =
+ ASSERT_NO_ERRNO_AND_VALUE(Open(in_file.path(), O_RDONLY));
+
+ // Setup the output named pipe.
+ int fds[2];
+ ASSERT_THAT(pipe2(fds, O_NONBLOCK), SyscallSucceeds());
+ const FileDescriptor rfd(fds[0]);
+ const FileDescriptor wfd(fds[1]);
+
+ // Fill up the pipe's buffer.
+ int pipe_size = -1;
+ ASSERT_THAT(pipe_size = fcntl(wfd.get(), F_GETPIPE_SZ), SyscallSucceeds());
+ std::vector<char> buf(2 * pipe_size);
+ ASSERT_THAT(write(wfd.get(), buf.data(), buf.size()),
+ SyscallSucceedsWithValue(pipe_size));
+
+ EXPECT_THAT(sendfile(wfd.get(), inf.get(), nullptr, kDataSize),
+ SyscallFailsWithErrno(EWOULDBLOCK));
+}
+
+TEST(SendFileTest, SendPipeBlocks) {
+ // Create temp file.
+ constexpr char kData[] =
+ "The fault, dear Brutus, is not in our stars, but in ourselves.";
+ constexpr int kDataSize = sizeof(kData) - 1;
+ const TempPath in_file = ASSERT_NO_ERRNO_AND_VALUE(TempPath::CreateFileWith(
+ GetAbsoluteTestTmpdir(), kData, TempPath::kDefaultFileMode));
+
+ // Open the input file as read only.
+ const FileDescriptor inf =
+ ASSERT_NO_ERRNO_AND_VALUE(Open(in_file.path(), O_RDONLY));
+
+ // Setup the output named pipe.
+ int fds[2];
+ ASSERT_THAT(pipe(fds), SyscallSucceeds());
+ const FileDescriptor rfd(fds[0]);
+ const FileDescriptor wfd(fds[1]);
+
+ // Fill up the pipe's buffer.
+ int pipe_size = -1;
+ ASSERT_THAT(pipe_size = fcntl(wfd.get(), F_GETPIPE_SZ), SyscallSucceeds());
+ std::vector<char> buf(pipe_size);
+ ASSERT_THAT(write(wfd.get(), buf.data(), buf.size()),
+ SyscallSucceedsWithValue(pipe_size));
+
+ ScopedThread t([&]() {
+ absl::SleepFor(absl::Milliseconds(100));
+ ASSERT_THAT(read(rfd.get(), buf.data(), buf.size()),
+ SyscallSucceedsWithValue(pipe_size));
+ });
+
+ EXPECT_THAT(sendfile(wfd.get(), inf.get(), nullptr, kDataSize),
+ SyscallSucceedsWithValue(kDataSize));
+}
+
} // namespace
} // namespace testing
diff --git a/test/syscalls/linux/signalfd.cc b/test/syscalls/linux/signalfd.cc
new file mode 100644
index 000000000..54c598627
--- /dev/null
+++ b/test/syscalls/linux/signalfd.cc
@@ -0,0 +1,333 @@
+// Copyright 2019 The gVisor Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+#include <errno.h>
+#include <poll.h>
+#include <signal.h>
+#include <stdio.h>
+#include <string.h>
+#include <sys/signalfd.h>
+#include <unistd.h>
+
+#include <functional>
+#include <vector>
+
+#include "gtest/gtest.h"
+#include "gtest/gtest.h"
+#include "absl/synchronization/mutex.h"
+#include "test/util/file_descriptor.h"
+#include "test/util/posix_error.h"
+#include "test/util/signal_util.h"
+#include "test/util/test_util.h"
+#include "test/util/thread_util.h"
+
+using ::testing::KilledBySignal;
+
+namespace gvisor {
+namespace testing {
+
+namespace {
+
+constexpr int kSigno = SIGUSR1;
+constexpr int kSignoAlt = SIGUSR2;
+
+// Returns a new signalfd.
+inline PosixErrorOr<FileDescriptor> NewSignalFD(sigset_t* mask, int flags = 0) {
+ int fd = signalfd(-1, mask, flags);
+ MaybeSave();
+ if (fd < 0) {
+ return PosixError(errno, "signalfd");
+ }
+ return FileDescriptor(fd);
+}
+
+TEST(Signalfd, Basic) {
+ // Create the signalfd.
+ sigset_t mask;
+ sigemptyset(&mask);
+ sigaddset(&mask, kSigno);
+ FileDescriptor fd = ASSERT_NO_ERRNO_AND_VALUE(NewSignalFD(&mask, 0));
+
+ // Deliver the blocked signal.
+ const auto scoped_sigmask =
+ ASSERT_NO_ERRNO_AND_VALUE(ScopedSignalMask(SIG_BLOCK, kSigno));
+ ASSERT_THAT(tgkill(getpid(), gettid(), kSigno), SyscallSucceeds());
+
+ // We should now read the signal.
+ struct signalfd_siginfo rbuf;
+ ASSERT_THAT(read(fd.get(), &rbuf, sizeof(rbuf)),
+ SyscallSucceedsWithValue(sizeof(rbuf)));
+ EXPECT_EQ(rbuf.ssi_signo, kSigno);
+}
+
+TEST(Signalfd, MaskWorks) {
+ // Create two signalfds with different masks.
+ sigset_t mask1, mask2;
+ sigemptyset(&mask1);
+ sigemptyset(&mask2);
+ sigaddset(&mask1, kSigno);
+ sigaddset(&mask2, kSignoAlt);
+ FileDescriptor fd1 = ASSERT_NO_ERRNO_AND_VALUE(NewSignalFD(&mask1, 0));
+ FileDescriptor fd2 = ASSERT_NO_ERRNO_AND_VALUE(NewSignalFD(&mask2, 0));
+
+ // Deliver the two signals.
+ const auto scoped_sigmask1 =
+ ASSERT_NO_ERRNO_AND_VALUE(ScopedSignalMask(SIG_BLOCK, kSigno));
+ const auto scoped_sigmask2 =
+ ASSERT_NO_ERRNO_AND_VALUE(ScopedSignalMask(SIG_BLOCK, kSignoAlt));
+ ASSERT_THAT(tgkill(getpid(), gettid(), kSigno), SyscallSucceeds());
+ ASSERT_THAT(tgkill(getpid(), gettid(), kSignoAlt), SyscallSucceeds());
+
+ // We should see the signals on the appropriate signalfds.
+ //
+ // We read in the opposite order as the signals deliver above, to ensure that
+ // we don't happen to read the correct signal from the correct signalfd.
+ struct signalfd_siginfo rbuf1, rbuf2;
+ ASSERT_THAT(read(fd2.get(), &rbuf2, sizeof(rbuf2)),
+ SyscallSucceedsWithValue(sizeof(rbuf2)));
+ EXPECT_EQ(rbuf2.ssi_signo, kSignoAlt);
+ ASSERT_THAT(read(fd1.get(), &rbuf1, sizeof(rbuf1)),
+ SyscallSucceedsWithValue(sizeof(rbuf1)));
+ EXPECT_EQ(rbuf1.ssi_signo, kSigno);
+}
+
+TEST(Signalfd, Cloexec) {
+ // Exec tests confirm that O_CLOEXEC has the intended effect. We just create a
+ // signalfd with the appropriate flag here and assert that the FD has it set.
+ sigset_t mask;
+ sigemptyset(&mask);
+ FileDescriptor fd =
+ ASSERT_NO_ERRNO_AND_VALUE(NewSignalFD(&mask, SFD_CLOEXEC));
+ EXPECT_THAT(fcntl(fd.get(), F_GETFD), SyscallSucceedsWithValue(FD_CLOEXEC));
+}
+
+TEST(Signalfd, Blocking) {
+ // Create the signalfd in blocking mode.
+ sigset_t mask;
+ sigemptyset(&mask);
+ sigaddset(&mask, kSigno);
+ FileDescriptor fd = ASSERT_NO_ERRNO_AND_VALUE(NewSignalFD(&mask, 0));
+
+ // Shared tid variable.
+ absl::Mutex mu;
+ bool has_tid;
+ pid_t tid;
+
+ // Start a thread reading.
+ ScopedThread t([&] {
+ // Copy the tid and notify the caller.
+ {
+ absl::MutexLock ml(&mu);
+ tid = gettid();
+ has_tid = true;
+ }
+
+ // Read the signal from the signalfd.
+ struct signalfd_siginfo rbuf;
+ ASSERT_THAT(read(fd.get(), &rbuf, sizeof(rbuf)),
+ SyscallSucceedsWithValue(sizeof(rbuf)));
+ EXPECT_EQ(rbuf.ssi_signo, kSigno);
+ });
+
+ // Wait until blocked.
+ absl::MutexLock ml(&mu);
+ mu.Await(absl::Condition(&has_tid));
+
+ // Deliver the signal to either the waiting thread, or
+ // to this thread. N.B. this is a bug in the core gVisor
+ // behavior for signalfd, and needs to be fixed.
+ //
+ // See gvisor.dev/issue/139.
+ if (IsRunningOnGvisor()) {
+ ASSERT_THAT(tgkill(getpid(), gettid(), kSigno), SyscallSucceeds());
+ } else {
+ ASSERT_THAT(tgkill(getpid(), tid, kSigno), SyscallSucceeds());
+ }
+
+ // Ensure that it was received.
+ t.Join();
+}
+
+TEST(Signalfd, ThreadGroup) {
+ // Create the signalfd in blocking mode.
+ sigset_t mask;
+ sigemptyset(&mask);
+ sigaddset(&mask, kSigno);
+ FileDescriptor fd = ASSERT_NO_ERRNO_AND_VALUE(NewSignalFD(&mask, 0));
+
+ // Shared variable.
+ absl::Mutex mu;
+ bool first = false;
+ bool second = false;
+
+ // Start a thread reading.
+ ScopedThread t([&] {
+ // Read the signal from the signalfd.
+ struct signalfd_siginfo rbuf;
+ ASSERT_THAT(read(fd.get(), &rbuf, sizeof(rbuf)),
+ SyscallSucceedsWithValue(sizeof(rbuf)));
+ EXPECT_EQ(rbuf.ssi_signo, kSigno);
+
+ // Wait for the other thread.
+ absl::MutexLock ml(&mu);
+ first = true;
+ mu.Await(absl::Condition(&second));
+ });
+
+ // Deliver the signal to the threadgroup.
+ ASSERT_THAT(kill(getpid(), kSigno), SyscallSucceeds());
+
+ // Wait for the first thread to process.
+ {
+ absl::MutexLock ml(&mu);
+ mu.Await(absl::Condition(&first));
+ }
+
+ // Deliver to the thread group again (other thread still exists).
+ ASSERT_THAT(kill(getpid(), kSigno), SyscallSucceeds());
+
+ // Ensure that we can also receive it.
+ struct signalfd_siginfo rbuf;
+ ASSERT_THAT(read(fd.get(), &rbuf, sizeof(rbuf)),
+ SyscallSucceedsWithValue(sizeof(rbuf)));
+ EXPECT_EQ(rbuf.ssi_signo, kSigno);
+
+ // Mark the test as done.
+ {
+ absl::MutexLock ml(&mu);
+ second = true;
+ }
+
+ // The other thread should be joinable.
+ t.Join();
+}
+
+TEST(Signalfd, Nonblock) {
+ // Create the signalfd in non-blocking mode.
+ sigset_t mask;
+ sigemptyset(&mask);
+ sigaddset(&mask, kSigno);
+ FileDescriptor fd =
+ ASSERT_NO_ERRNO_AND_VALUE(NewSignalFD(&mask, SFD_NONBLOCK));
+
+ // We should return if we attempt to read.
+ struct signalfd_siginfo rbuf;
+ ASSERT_THAT(read(fd.get(), &rbuf, sizeof(rbuf)),
+ SyscallFailsWithErrno(EWOULDBLOCK));
+
+ // Block and deliver the signal.
+ const auto scoped_sigmask =
+ ASSERT_NO_ERRNO_AND_VALUE(ScopedSignalMask(SIG_BLOCK, kSigno));
+ ASSERT_THAT(tgkill(getpid(), gettid(), kSigno), SyscallSucceeds());
+
+ // Ensure that a read actually works.
+ ASSERT_THAT(read(fd.get(), &rbuf, sizeof(rbuf)),
+ SyscallSucceedsWithValue(sizeof(rbuf)));
+ EXPECT_EQ(rbuf.ssi_signo, kSigno);
+
+ // Should block again.
+ EXPECT_THAT(read(fd.get(), &rbuf, sizeof(rbuf)),
+ SyscallFailsWithErrno(EWOULDBLOCK));
+}
+
+TEST(Signalfd, SetMask) {
+ // Create the signalfd matching nothing.
+ sigset_t mask;
+ sigemptyset(&mask);
+ FileDescriptor fd =
+ ASSERT_NO_ERRNO_AND_VALUE(NewSignalFD(&mask, SFD_NONBLOCK));
+
+ // Block and deliver a signal.
+ const auto scoped_sigmask =
+ ASSERT_NO_ERRNO_AND_VALUE(ScopedSignalMask(SIG_BLOCK, kSigno));
+ ASSERT_THAT(tgkill(getpid(), gettid(), kSigno), SyscallSucceeds());
+
+ // We should have nothing.
+ struct signalfd_siginfo rbuf;
+ ASSERT_THAT(read(fd.get(), &rbuf, sizeof(rbuf)),
+ SyscallFailsWithErrno(EWOULDBLOCK));
+
+ // Change the signal mask.
+ sigaddset(&mask, kSigno);
+ ASSERT_THAT(signalfd(fd.get(), &mask, 0), SyscallSucceeds());
+
+ // We should now have the signal.
+ ASSERT_THAT(read(fd.get(), &rbuf, sizeof(rbuf)),
+ SyscallSucceedsWithValue(sizeof(rbuf)));
+ EXPECT_EQ(rbuf.ssi_signo, kSigno);
+}
+
+TEST(Signalfd, Poll) {
+ // Create the signalfd.
+ sigset_t mask;
+ sigemptyset(&mask);
+ sigaddset(&mask, kSigno);
+ FileDescriptor fd = ASSERT_NO_ERRNO_AND_VALUE(NewSignalFD(&mask, 0));
+
+ // Block the signal, and start a thread to deliver it.
+ const auto scoped_sigmask =
+ ASSERT_NO_ERRNO_AND_VALUE(ScopedSignalMask(SIG_BLOCK, kSigno));
+ pid_t orig_tid = gettid();
+ ScopedThread t([&] {
+ absl::SleepFor(absl::Seconds(5));
+ ASSERT_THAT(tgkill(getpid(), orig_tid, kSigno), SyscallSucceeds());
+ });
+
+ // Start polling for the signal. We expect that it is not available at the
+ // outset, but then becomes available when the signal is sent. We give a
+ // timeout of 10000ms (or the delay above + 5 seconds of additional grace
+ // time).
+ struct pollfd poll_fd = {fd.get(), POLLIN, 0};
+ EXPECT_THAT(RetryEINTR(poll)(&poll_fd, 1, 10000),
+ SyscallSucceedsWithValue(1));
+
+ // Actually read the signal to prevent delivery.
+ struct signalfd_siginfo rbuf;
+ EXPECT_THAT(read(fd.get(), &rbuf, sizeof(rbuf)),
+ SyscallSucceedsWithValue(sizeof(rbuf)));
+}
+
+TEST(Signalfd, KillStillKills) {
+ sigset_t mask;
+ sigemptyset(&mask);
+ sigaddset(&mask, SIGKILL);
+ FileDescriptor fd =
+ ASSERT_NO_ERRNO_AND_VALUE(NewSignalFD(&mask, SFD_CLOEXEC));
+
+ // Just because there is a signalfd, we shouldn't see any change in behavior
+ // for unblockable signals. It's easier to test this with SIGKILL.
+ const auto scoped_sigmask =
+ ASSERT_NO_ERRNO_AND_VALUE(ScopedSignalMask(SIG_BLOCK, SIGKILL));
+ EXPECT_EXIT(tgkill(getpid(), gettid(), SIGKILL), KilledBySignal(SIGKILL), "");
+}
+
+} // namespace
+
+} // namespace testing
+} // namespace gvisor
+
+int main(int argc, char** argv) {
+ // These tests depend on delivering signals. Block them up front so that all
+ // other threads created by TestInit will also have them blocked, and they
+ // will not interface with the rest of the test.
+ sigset_t set;
+ sigemptyset(&set);
+ sigaddset(&set, gvisor::testing::kSigno);
+ sigaddset(&set, gvisor::testing::kSignoAlt);
+ TEST_PCHECK(sigprocmask(SIG_BLOCK, &set, nullptr) == 0);
+
+ gvisor::testing::TestInit(&argc, &argv);
+
+ return RUN_ALL_TESTS();
+}
diff --git a/test/syscalls/linux/sigstop.cc b/test/syscalls/linux/sigstop.cc
index 9c7210e17..7db57d968 100644
--- a/test/syscalls/linux/sigstop.cc
+++ b/test/syscalls/linux/sigstop.cc
@@ -17,6 +17,7 @@
#include <sys/select.h>
#include "gtest/gtest.h"
+#include "absl/flags/flag.h"
#include "absl/time/clock.h"
#include "absl/time/time.h"
#include "test/util/multiprocess_util.h"
@@ -24,8 +25,8 @@
#include "test/util/test_util.h"
#include "test/util/thread_util.h"
-DEFINE_bool(sigstop_test_child, false,
- "If true, run the SigstopTest child workload.");
+ABSL_FLAG(bool, sigstop_test_child, false,
+ "If true, run the SigstopTest child workload.");
namespace gvisor {
namespace testing {
@@ -141,7 +142,7 @@ void RunChild() {
int main(int argc, char** argv) {
gvisor::testing::TestInit(&argc, &argv);
- if (FLAGS_sigstop_test_child) {
+ if (absl::GetFlag(FLAGS_sigstop_test_child)) {
gvisor::testing::RunChild();
return 1;
}
diff --git a/test/syscalls/linux/socket.cc b/test/syscalls/linux/socket.cc
index 0404190a0..caae215b8 100644
--- a/test/syscalls/linux/socket.cc
+++ b/test/syscalls/linux/socket.cc
@@ -30,12 +30,25 @@ TEST(SocketTest, UnixSocketPairProtocol) {
close(socks[1]);
}
-TEST(SocketTest, Protocol) {
+TEST(SocketTest, ProtocolUnix) {
struct {
int domain, type, protocol;
} tests[] = {
- {AF_UNIX, SOCK_STREAM, PF_UNIX}, {AF_UNIX, SOCK_SEQPACKET, PF_UNIX},
- {AF_UNIX, SOCK_DGRAM, PF_UNIX}, {AF_INET, SOCK_DGRAM, IPPROTO_UDP},
+ {AF_UNIX, SOCK_STREAM, PF_UNIX},
+ {AF_UNIX, SOCK_SEQPACKET, PF_UNIX},
+ {AF_UNIX, SOCK_DGRAM, PF_UNIX},
+ };
+ for (int i = 0; i < ABSL_ARRAYSIZE(tests); i++) {
+ ASSERT_NO_ERRNO_AND_VALUE(
+ Socket(tests[i].domain, tests[i].type, tests[i].protocol));
+ }
+}
+
+TEST(SocketTest, ProtocolInet) {
+ struct {
+ int domain, type, protocol;
+ } tests[] = {
+ {AF_INET, SOCK_DGRAM, IPPROTO_UDP},
{AF_INET, SOCK_STREAM, IPPROTO_TCP},
};
for (int i = 0; i < ABSL_ARRAYSIZE(tests); i++) {
diff --git a/test/syscalls/linux/socket_bind_to_device.cc b/test/syscalls/linux/socket_bind_to_device.cc
new file mode 100644
index 000000000..d20821cac
--- /dev/null
+++ b/test/syscalls/linux/socket_bind_to_device.cc
@@ -0,0 +1,314 @@
+// Copyright 2019 The gVisor Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+#include <arpa/inet.h>
+#include <linux/if_tun.h>
+#include <net/if.h>
+#include <netinet/in.h>
+#include <sys/ioctl.h>
+#include <sys/socket.h>
+#include <sys/types.h>
+#include <sys/un.h>
+
+#include <cstdio>
+#include <cstring>
+#include <map>
+#include <memory>
+#include <unordered_map>
+#include <unordered_set>
+#include <utility>
+#include <vector>
+
+#include "gmock/gmock.h"
+#include "gtest/gtest.h"
+#include "gtest/gtest.h"
+#include "test/syscalls/linux/ip_socket_test_util.h"
+#include "test/syscalls/linux/socket_bind_to_device_util.h"
+#include "test/syscalls/linux/socket_test_util.h"
+#include "test/util/capability_util.h"
+#include "test/util/test_util.h"
+#include "test/util/thread_util.h"
+
+namespace gvisor {
+namespace testing {
+
+using std::string;
+
+// Test fixture for SO_BINDTODEVICE tests.
+class BindToDeviceTest : public ::testing::TestWithParam<SocketKind> {
+ protected:
+ void SetUp() override {
+ printf("Testing case: %s\n", GetParam().description.c_str());
+ ASSERT_TRUE(ASSERT_NO_ERRNO_AND_VALUE(HaveCapability(CAP_NET_RAW)))
+ << "CAP_NET_RAW is required to use SO_BINDTODEVICE";
+
+ interface_name_ = "eth1";
+ auto interface_names = GetInterfaceNames();
+ if (interface_names.find(interface_name_) == interface_names.end()) {
+ // Need a tunnel.
+ tunnel_ = ASSERT_NO_ERRNO_AND_VALUE(Tunnel::New());
+ interface_name_ = tunnel_->GetName();
+ ASSERT_FALSE(interface_name_.empty());
+ }
+ socket_ = ASSERT_NO_ERRNO_AND_VALUE(GetParam().Create());
+ }
+
+ string interface_name() const { return interface_name_; }
+
+ int socket_fd() const { return socket_->get(); }
+
+ private:
+ std::unique_ptr<Tunnel> tunnel_;
+ string interface_name_;
+ std::unique_ptr<FileDescriptor> socket_;
+};
+
+constexpr char kIllegalIfnameChar = '/';
+
+// Tests getsockopt of the default value.
+TEST_P(BindToDeviceTest, GetsockoptDefault) {
+ char name_buffer[IFNAMSIZ * 2];
+ char original_name_buffer[IFNAMSIZ * 2];
+ socklen_t name_buffer_size;
+
+ // Read the default SO_BINDTODEVICE.
+ memset(original_name_buffer, kIllegalIfnameChar, sizeof(name_buffer));
+ for (size_t i = 0; i <= sizeof(name_buffer); i++) {
+ memset(name_buffer, kIllegalIfnameChar, sizeof(name_buffer));
+ name_buffer_size = i;
+ EXPECT_THAT(getsockopt(socket_fd(), SOL_SOCKET, SO_BINDTODEVICE,
+ name_buffer, &name_buffer_size),
+ SyscallSucceedsWithValue(0));
+ EXPECT_EQ(name_buffer_size, 0);
+ EXPECT_EQ(memcmp(name_buffer, original_name_buffer, sizeof(name_buffer)),
+ 0);
+ }
+}
+
+// Tests setsockopt of invalid device name.
+TEST_P(BindToDeviceTest, SetsockoptInvalidDeviceName) {
+ char name_buffer[IFNAMSIZ * 2];
+ socklen_t name_buffer_size;
+
+ // Set an invalid device name.
+ memset(name_buffer, kIllegalIfnameChar, 5);
+ name_buffer_size = 5;
+ EXPECT_THAT(setsockopt(socket_fd(), SOL_SOCKET, SO_BINDTODEVICE, name_buffer,
+ name_buffer_size),
+ SyscallFailsWithErrno(ENODEV));
+}
+
+// Tests setsockopt of a buffer with a valid device name but not
+// null-terminated, with different sizes of buffer.
+TEST_P(BindToDeviceTest, SetsockoptValidDeviceNameWithoutNullTermination) {
+ char name_buffer[IFNAMSIZ * 2];
+ socklen_t name_buffer_size;
+
+ strncpy(name_buffer, interface_name().c_str(), interface_name().size() + 1);
+ // Intentionally overwrite the null at the end.
+ memset(name_buffer + interface_name().size(), kIllegalIfnameChar,
+ sizeof(name_buffer) - interface_name().size());
+ for (size_t i = 1; i <= sizeof(name_buffer); i++) {
+ name_buffer_size = i;
+ SCOPED_TRACE(absl::StrCat("Buffer size: ", i));
+ // It should only work if the size provided is exactly right.
+ if (name_buffer_size == interface_name().size()) {
+ EXPECT_THAT(setsockopt(socket_fd(), SOL_SOCKET, SO_BINDTODEVICE,
+ name_buffer, name_buffer_size),
+ SyscallSucceeds());
+ } else {
+ EXPECT_THAT(setsockopt(socket_fd(), SOL_SOCKET, SO_BINDTODEVICE,
+ name_buffer, name_buffer_size),
+ SyscallFailsWithErrno(ENODEV));
+ }
+ }
+}
+
+// Tests setsockopt of a buffer with a valid device name and null-terminated,
+// with different sizes of buffer.
+TEST_P(BindToDeviceTest, SetsockoptValidDeviceNameWithNullTermination) {
+ char name_buffer[IFNAMSIZ * 2];
+ socklen_t name_buffer_size;
+
+ strncpy(name_buffer, interface_name().c_str(), interface_name().size() + 1);
+ // Don't overwrite the null at the end.
+ memset(name_buffer + interface_name().size() + 1, kIllegalIfnameChar,
+ sizeof(name_buffer) - interface_name().size() - 1);
+ for (size_t i = 1; i <= sizeof(name_buffer); i++) {
+ name_buffer_size = i;
+ SCOPED_TRACE(absl::StrCat("Buffer size: ", i));
+ // It should only work if the size provided is at least the right size.
+ if (name_buffer_size >= interface_name().size()) {
+ EXPECT_THAT(setsockopt(socket_fd(), SOL_SOCKET, SO_BINDTODEVICE,
+ name_buffer, name_buffer_size),
+ SyscallSucceeds());
+ } else {
+ EXPECT_THAT(setsockopt(socket_fd(), SOL_SOCKET, SO_BINDTODEVICE,
+ name_buffer, name_buffer_size),
+ SyscallFailsWithErrno(ENODEV));
+ }
+ }
+}
+
+// Tests that setsockopt of an invalid device name doesn't unset the previous
+// valid setsockopt.
+TEST_P(BindToDeviceTest, SetsockoptValidThenInvalid) {
+ char name_buffer[IFNAMSIZ * 2];
+ socklen_t name_buffer_size;
+
+ // Write successfully.
+ strncpy(name_buffer, interface_name().c_str(), sizeof(name_buffer));
+ ASSERT_THAT(setsockopt(socket_fd(), SOL_SOCKET, SO_BINDTODEVICE, name_buffer,
+ sizeof(name_buffer)),
+ SyscallSucceeds());
+
+ // Read it back successfully.
+ memset(name_buffer, kIllegalIfnameChar, sizeof(name_buffer));
+ name_buffer_size = sizeof(name_buffer);
+ EXPECT_THAT(getsockopt(socket_fd(), SOL_SOCKET, SO_BINDTODEVICE, name_buffer,
+ &name_buffer_size),
+ SyscallSucceeds());
+ EXPECT_EQ(name_buffer_size, interface_name().size() + 1);
+ EXPECT_STREQ(name_buffer, interface_name().c_str());
+
+ // Write unsuccessfully.
+ memset(name_buffer, kIllegalIfnameChar, sizeof(name_buffer));
+ name_buffer_size = 5;
+ EXPECT_THAT(setsockopt(socket_fd(), SOL_SOCKET, SO_BINDTODEVICE, name_buffer,
+ sizeof(name_buffer)),
+ SyscallFailsWithErrno(ENODEV));
+
+ // Read it back successfully, it's unchanged.
+ memset(name_buffer, kIllegalIfnameChar, sizeof(name_buffer));
+ name_buffer_size = sizeof(name_buffer);
+ EXPECT_THAT(getsockopt(socket_fd(), SOL_SOCKET, SO_BINDTODEVICE, name_buffer,
+ &name_buffer_size),
+ SyscallSucceeds());
+ EXPECT_EQ(name_buffer_size, interface_name().size() + 1);
+ EXPECT_STREQ(name_buffer, interface_name().c_str());
+}
+
+// Tests that setsockopt of zero-length string correctly unsets the previous
+// value.
+TEST_P(BindToDeviceTest, SetsockoptValidThenClear) {
+ char name_buffer[IFNAMSIZ * 2];
+ socklen_t name_buffer_size;
+
+ // Write successfully.
+ strncpy(name_buffer, interface_name().c_str(), sizeof(name_buffer));
+ EXPECT_THAT(setsockopt(socket_fd(), SOL_SOCKET, SO_BINDTODEVICE, name_buffer,
+ sizeof(name_buffer)),
+ SyscallSucceeds());
+
+ // Read it back successfully.
+ memset(name_buffer, kIllegalIfnameChar, sizeof(name_buffer));
+ name_buffer_size = sizeof(name_buffer);
+ EXPECT_THAT(getsockopt(socket_fd(), SOL_SOCKET, SO_BINDTODEVICE, name_buffer,
+ &name_buffer_size),
+ SyscallSucceeds());
+ EXPECT_EQ(name_buffer_size, interface_name().size() + 1);
+ EXPECT_STREQ(name_buffer, interface_name().c_str());
+
+ // Clear it successfully.
+ name_buffer_size = 0;
+ EXPECT_THAT(setsockopt(socket_fd(), SOL_SOCKET, SO_BINDTODEVICE, name_buffer,
+ name_buffer_size),
+ SyscallSucceeds());
+
+ // Read it back successfully, it's cleared.
+ memset(name_buffer, kIllegalIfnameChar, sizeof(name_buffer));
+ name_buffer_size = sizeof(name_buffer);
+ EXPECT_THAT(getsockopt(socket_fd(), SOL_SOCKET, SO_BINDTODEVICE, name_buffer,
+ &name_buffer_size),
+ SyscallSucceeds());
+ EXPECT_EQ(name_buffer_size, 0);
+}
+
+// Tests that setsockopt of empty string correctly unsets the previous
+// value.
+TEST_P(BindToDeviceTest, SetsockoptValidThenClearWithNull) {
+ char name_buffer[IFNAMSIZ * 2];
+ socklen_t name_buffer_size;
+
+ // Write successfully.
+ strncpy(name_buffer, interface_name().c_str(), sizeof(name_buffer));
+ EXPECT_THAT(setsockopt(socket_fd(), SOL_SOCKET, SO_BINDTODEVICE, name_buffer,
+ sizeof(name_buffer)),
+ SyscallSucceeds());
+
+ // Read it back successfully.
+ memset(name_buffer, kIllegalIfnameChar, sizeof(name_buffer));
+ name_buffer_size = sizeof(name_buffer);
+ EXPECT_THAT(getsockopt(socket_fd(), SOL_SOCKET, SO_BINDTODEVICE, name_buffer,
+ &name_buffer_size),
+ SyscallSucceeds());
+ EXPECT_EQ(name_buffer_size, interface_name().size() + 1);
+ EXPECT_STREQ(name_buffer, interface_name().c_str());
+
+ // Clear it successfully.
+ memset(name_buffer, kIllegalIfnameChar, sizeof(name_buffer));
+ name_buffer[0] = 0;
+ name_buffer_size = sizeof(name_buffer);
+ EXPECT_THAT(setsockopt(socket_fd(), SOL_SOCKET, SO_BINDTODEVICE, name_buffer,
+ name_buffer_size),
+ SyscallSucceeds());
+
+ // Read it back successfully, it's cleared.
+ memset(name_buffer, kIllegalIfnameChar, sizeof(name_buffer));
+ name_buffer_size = sizeof(name_buffer);
+ EXPECT_THAT(getsockopt(socket_fd(), SOL_SOCKET, SO_BINDTODEVICE, name_buffer,
+ &name_buffer_size),
+ SyscallSucceeds());
+ EXPECT_EQ(name_buffer_size, 0);
+}
+
+// Tests getsockopt with different buffer sizes.
+TEST_P(BindToDeviceTest, GetsockoptDevice) {
+ char name_buffer[IFNAMSIZ * 2];
+ socklen_t name_buffer_size;
+
+ // Write successfully.
+ strncpy(name_buffer, interface_name().c_str(), sizeof(name_buffer));
+ ASSERT_THAT(setsockopt(socket_fd(), SOL_SOCKET, SO_BINDTODEVICE, name_buffer,
+ sizeof(name_buffer)),
+ SyscallSucceeds());
+
+ // Read it back at various buffer sizes.
+ for (size_t i = 0; i <= sizeof(name_buffer); i++) {
+ memset(name_buffer, kIllegalIfnameChar, sizeof(name_buffer));
+ name_buffer_size = i;
+ SCOPED_TRACE(absl::StrCat("Buffer size: ", i));
+ // Linux only allows a buffer at least IFNAMSIZ, even if less would suffice
+ // for this interface name.
+ if (name_buffer_size >= IFNAMSIZ) {
+ EXPECT_THAT(getsockopt(socket_fd(), SOL_SOCKET, SO_BINDTODEVICE,
+ name_buffer, &name_buffer_size),
+ SyscallSucceeds());
+ EXPECT_EQ(name_buffer_size, interface_name().size() + 1);
+ EXPECT_STREQ(name_buffer, interface_name().c_str());
+ } else {
+ EXPECT_THAT(getsockopt(socket_fd(), SOL_SOCKET, SO_BINDTODEVICE,
+ name_buffer, &name_buffer_size),
+ SyscallFailsWithErrno(EINVAL));
+ EXPECT_EQ(name_buffer_size, i);
+ }
+ }
+}
+
+INSTANTIATE_TEST_SUITE_P(BindToDeviceTest, BindToDeviceTest,
+ ::testing::Values(IPv4UDPUnboundSocket(0),
+ IPv4TCPUnboundSocket(0)));
+
+} // namespace testing
+} // namespace gvisor
diff --git a/test/syscalls/linux/socket_bind_to_device_distribution.cc b/test/syscalls/linux/socket_bind_to_device_distribution.cc
new file mode 100644
index 000000000..4d2400328
--- /dev/null
+++ b/test/syscalls/linux/socket_bind_to_device_distribution.cc
@@ -0,0 +1,381 @@
+// Copyright 2019 The gVisor Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+#include <arpa/inet.h>
+#include <linux/if_tun.h>
+#include <net/if.h>
+#include <netinet/in.h>
+#include <sys/ioctl.h>
+#include <sys/socket.h>
+#include <sys/types.h>
+#include <sys/un.h>
+
+#include <atomic>
+#include <cstdio>
+#include <cstring>
+#include <map>
+#include <memory>
+#include <unordered_map>
+#include <unordered_set>
+#include <utility>
+#include <vector>
+
+#include "gmock/gmock.h"
+#include "gtest/gtest.h"
+#include "gtest/gtest.h"
+#include "test/syscalls/linux/ip_socket_test_util.h"
+#include "test/syscalls/linux/socket_bind_to_device_util.h"
+#include "test/syscalls/linux/socket_test_util.h"
+#include "test/util/capability_util.h"
+#include "test/util/test_util.h"
+#include "test/util/thread_util.h"
+
+namespace gvisor {
+namespace testing {
+
+using std::string;
+using std::vector;
+
+struct EndpointConfig {
+ std::string bind_to_device;
+ double expected_ratio;
+};
+
+struct DistributionTestCase {
+ std::string name;
+ std::vector<EndpointConfig> endpoints;
+};
+
+struct ListenerConnector {
+ TestAddress listener;
+ TestAddress connector;
+};
+
+// Test fixture for SO_BINDTODEVICE tests the distribution of packets received
+// with varying SO_BINDTODEVICE settings.
+class BindToDeviceDistributionTest
+ : public ::testing::TestWithParam<
+ ::testing::tuple<ListenerConnector, DistributionTestCase>> {
+ protected:
+ void SetUp() override {
+ printf("Testing case: %s, listener=%s, connector=%s\n",
+ ::testing::get<1>(GetParam()).name.c_str(),
+ ::testing::get<0>(GetParam()).listener.description.c_str(),
+ ::testing::get<0>(GetParam()).connector.description.c_str());
+ ASSERT_TRUE(ASSERT_NO_ERRNO_AND_VALUE(HaveCapability(CAP_NET_RAW)))
+ << "CAP_NET_RAW is required to use SO_BINDTODEVICE";
+ }
+};
+
+PosixErrorOr<uint16_t> AddrPort(int family, sockaddr_storage const& addr) {
+ switch (family) {
+ case AF_INET:
+ return static_cast<uint16_t>(
+ reinterpret_cast<sockaddr_in const*>(&addr)->sin_port);
+ case AF_INET6:
+ return static_cast<uint16_t>(
+ reinterpret_cast<sockaddr_in6 const*>(&addr)->sin6_port);
+ default:
+ return PosixError(EINVAL,
+ absl::StrCat("unknown socket family: ", family));
+ }
+}
+
+PosixError SetAddrPort(int family, sockaddr_storage* addr, uint16_t port) {
+ switch (family) {
+ case AF_INET:
+ reinterpret_cast<sockaddr_in*>(addr)->sin_port = port;
+ return NoError();
+ case AF_INET6:
+ reinterpret_cast<sockaddr_in6*>(addr)->sin6_port = port;
+ return NoError();
+ default:
+ return PosixError(EINVAL,
+ absl::StrCat("unknown socket family: ", family));
+ }
+}
+
+// Binds sockets to different devices and then creates many TCP connections.
+// Checks that the distribution of connections received on the sockets matches
+// the expectation.
+TEST_P(BindToDeviceDistributionTest, Tcp) {
+ auto const& [listener_connector, test] = GetParam();
+
+ TestAddress const& listener = listener_connector.listener;
+ TestAddress const& connector = listener_connector.connector;
+ sockaddr_storage listen_addr = listener.addr;
+ sockaddr_storage conn_addr = connector.addr;
+
+ auto interface_names = GetInterfaceNames();
+
+ // Create the listening sockets.
+ std::vector<FileDescriptor> listener_fds;
+ std::vector<std::unique_ptr<Tunnel>> all_tunnels;
+ for (auto const& endpoint : test.endpoints) {
+ if (!endpoint.bind_to_device.empty() &&
+ interface_names.find(endpoint.bind_to_device) ==
+ interface_names.end()) {
+ all_tunnels.push_back(
+ ASSERT_NO_ERRNO_AND_VALUE(Tunnel::New(endpoint.bind_to_device)));
+ interface_names.insert(endpoint.bind_to_device);
+ }
+
+ listener_fds.push_back(ASSERT_NO_ERRNO_AND_VALUE(
+ Socket(listener.family(), SOCK_STREAM, IPPROTO_TCP)));
+ int fd = listener_fds.back().get();
+
+ ASSERT_THAT(setsockopt(fd, SOL_SOCKET, SO_REUSEPORT, &kSockOptOn,
+ sizeof(kSockOptOn)),
+ SyscallSucceeds());
+ ASSERT_THAT(setsockopt(fd, SOL_SOCKET, SO_BINDTODEVICE,
+ endpoint.bind_to_device.c_str(),
+ endpoint.bind_to_device.size() + 1),
+ SyscallSucceeds());
+ ASSERT_THAT(
+ bind(fd, reinterpret_cast<sockaddr*>(&listen_addr), listener.addr_len),
+ SyscallSucceeds());
+ ASSERT_THAT(listen(fd, 40), SyscallSucceeds());
+
+ // On the first bind we need to determine which port was bound.
+ if (listener_fds.size() > 1) {
+ continue;
+ }
+
+ // Get the port bound by the listening socket.
+ socklen_t addrlen = listener.addr_len;
+ ASSERT_THAT(
+ getsockname(listener_fds[0].get(),
+ reinterpret_cast<sockaddr*>(&listen_addr), &addrlen),
+ SyscallSucceeds());
+ uint16_t const port =
+ ASSERT_NO_ERRNO_AND_VALUE(AddrPort(listener.family(), listen_addr));
+ ASSERT_NO_ERRNO(SetAddrPort(connector.family(), &conn_addr, port));
+ }
+
+ constexpr int kConnectAttempts = 10000;
+ std::atomic<int> connects_received = ATOMIC_VAR_INIT(0);
+ std::vector<int> accept_counts(listener_fds.size(), 0);
+ std::vector<std::unique_ptr<ScopedThread>> listen_threads(
+ listener_fds.size());
+
+ for (int i = 0; i < listener_fds.size(); i++) {
+ listen_threads[i] = absl::make_unique<ScopedThread>(
+ [&listener_fds, &accept_counts, &connects_received, i,
+ kConnectAttempts]() {
+ do {
+ auto fd = Accept(listener_fds[i].get(), nullptr, nullptr);
+ if (!fd.ok()) {
+ // Another thread has shutdown our read side causing the accept to
+ // fail.
+ ASSERT_GE(connects_received, kConnectAttempts)
+ << "errno = " << fd.error();
+ return;
+ }
+ // Receive some data from a socket to be sure that the connect()
+ // system call has been completed on another side.
+ int data;
+ EXPECT_THAT(
+ RetryEINTR(recv)(fd.ValueOrDie().get(), &data, sizeof(data), 0),
+ SyscallSucceedsWithValue(sizeof(data)));
+ accept_counts[i]++;
+ } while (++connects_received < kConnectAttempts);
+
+ // Shutdown all sockets to wake up other threads.
+ for (auto const& listener_fd : listener_fds) {
+ shutdown(listener_fd.get(), SHUT_RDWR);
+ }
+ });
+ }
+
+ for (int i = 0; i < kConnectAttempts; i++) {
+ FileDescriptor const fd = ASSERT_NO_ERRNO_AND_VALUE(
+ Socket(connector.family(), SOCK_STREAM, IPPROTO_TCP));
+ ASSERT_THAT(
+ RetryEINTR(connect)(fd.get(), reinterpret_cast<sockaddr*>(&conn_addr),
+ connector.addr_len),
+ SyscallSucceeds());
+
+ EXPECT_THAT(RetryEINTR(send)(fd.get(), &i, sizeof(i), 0),
+ SyscallSucceedsWithValue(sizeof(i)));
+ }
+
+ // Join threads to be sure that all connections have been counted.
+ for (auto const& listen_thread : listen_threads) {
+ listen_thread->Join();
+ }
+ // Check that connections are distributed correctly among listening sockets.
+ for (int i = 0; i < accept_counts.size(); i++) {
+ EXPECT_THAT(
+ accept_counts[i],
+ EquivalentWithin(static_cast<int>(kConnectAttempts *
+ test.endpoints[i].expected_ratio),
+ 0.10))
+ << "endpoint " << i << " got the wrong number of packets";
+ }
+}
+
+// Binds sockets to different devices and then sends many UDP packets. Checks
+// that the distribution of packets received on the sockets matches the
+// expectation.
+TEST_P(BindToDeviceDistributionTest, Udp) {
+ auto const& [listener_connector, test] = GetParam();
+
+ TestAddress const& listener = listener_connector.listener;
+ TestAddress const& connector = listener_connector.connector;
+ sockaddr_storage listen_addr = listener.addr;
+ sockaddr_storage conn_addr = connector.addr;
+
+ auto interface_names = GetInterfaceNames();
+
+ // Create the listening socket.
+ std::vector<FileDescriptor> listener_fds;
+ std::vector<std::unique_ptr<Tunnel>> all_tunnels;
+ for (auto const& endpoint : test.endpoints) {
+ if (!endpoint.bind_to_device.empty() &&
+ interface_names.find(endpoint.bind_to_device) ==
+ interface_names.end()) {
+ all_tunnels.push_back(
+ ASSERT_NO_ERRNO_AND_VALUE(Tunnel::New(endpoint.bind_to_device)));
+ interface_names.insert(endpoint.bind_to_device);
+ }
+
+ listener_fds.push_back(
+ ASSERT_NO_ERRNO_AND_VALUE(Socket(listener.family(), SOCK_DGRAM, 0)));
+ int fd = listener_fds.back().get();
+
+ ASSERT_THAT(setsockopt(fd, SOL_SOCKET, SO_REUSEPORT, &kSockOptOn,
+ sizeof(kSockOptOn)),
+ SyscallSucceeds());
+ ASSERT_THAT(setsockopt(fd, SOL_SOCKET, SO_BINDTODEVICE,
+ endpoint.bind_to_device.c_str(),
+ endpoint.bind_to_device.size() + 1),
+ SyscallSucceeds());
+ ASSERT_THAT(
+ bind(fd, reinterpret_cast<sockaddr*>(&listen_addr), listener.addr_len),
+ SyscallSucceeds());
+
+ // On the first bind we need to determine which port was bound.
+ if (listener_fds.size() > 1) {
+ continue;
+ }
+
+ // Get the port bound by the listening socket.
+ socklen_t addrlen = listener.addr_len;
+ ASSERT_THAT(
+ getsockname(listener_fds[0].get(),
+ reinterpret_cast<sockaddr*>(&listen_addr), &addrlen),
+ SyscallSucceeds());
+ uint16_t const port =
+ ASSERT_NO_ERRNO_AND_VALUE(AddrPort(listener.family(), listen_addr));
+ ASSERT_NO_ERRNO(SetAddrPort(listener.family(), &listen_addr, port));
+ ASSERT_NO_ERRNO(SetAddrPort(connector.family(), &conn_addr, port));
+ }
+
+ constexpr int kConnectAttempts = 10000;
+ std::atomic<int> packets_received = ATOMIC_VAR_INIT(0);
+ std::vector<int> packets_per_socket(listener_fds.size(), 0);
+ std::vector<std::unique_ptr<ScopedThread>> receiver_threads(
+ listener_fds.size());
+
+ for (int i = 0; i < listener_fds.size(); i++) {
+ receiver_threads[i] = absl::make_unique<ScopedThread>(
+ [&listener_fds, &packets_per_socket, &packets_received, i]() {
+ do {
+ struct sockaddr_storage addr = {};
+ socklen_t addrlen = sizeof(addr);
+ int data;
+
+ auto ret = RetryEINTR(recvfrom)(
+ listener_fds[i].get(), &data, sizeof(data), 0,
+ reinterpret_cast<struct sockaddr*>(&addr), &addrlen);
+
+ if (packets_received < kConnectAttempts) {
+ ASSERT_THAT(ret, SyscallSucceedsWithValue(sizeof(data)));
+ }
+
+ if (ret != sizeof(data)) {
+ // Another thread may have shutdown our read side causing the
+ // recvfrom to fail.
+ break;
+ }
+
+ packets_received++;
+ packets_per_socket[i]++;
+
+ // A response is required to synchronize with the main thread,
+ // otherwise the main thread can send more than can fit into receive
+ // queues.
+ EXPECT_THAT(RetryEINTR(sendto)(
+ listener_fds[i].get(), &data, sizeof(data), 0,
+ reinterpret_cast<sockaddr*>(&addr), addrlen),
+ SyscallSucceedsWithValue(sizeof(data)));
+ } while (packets_received < kConnectAttempts);
+
+ // Shutdown all sockets to wake up other threads.
+ for (auto const& listener_fd : listener_fds) {
+ shutdown(listener_fd.get(), SHUT_RDWR);
+ }
+ });
+ }
+
+ for (int i = 0; i < kConnectAttempts; i++) {
+ FileDescriptor const fd =
+ ASSERT_NO_ERRNO_AND_VALUE(Socket(connector.family(), SOCK_DGRAM, 0));
+ EXPECT_THAT(RetryEINTR(sendto)(fd.get(), &i, sizeof(i), 0,
+ reinterpret_cast<sockaddr*>(&conn_addr),
+ connector.addr_len),
+ SyscallSucceedsWithValue(sizeof(i)));
+ int data;
+ EXPECT_THAT(RetryEINTR(recv)(fd.get(), &data, sizeof(data), 0),
+ SyscallSucceedsWithValue(sizeof(data)));
+ }
+
+ // Join threads to be sure that all connections have been counted.
+ for (auto const& receiver_thread : receiver_threads) {
+ receiver_thread->Join();
+ }
+ // Check that packets are distributed correctly among listening sockets.
+ for (int i = 0; i < packets_per_socket.size(); i++) {
+ EXPECT_THAT(
+ packets_per_socket[i],
+ EquivalentWithin(static_cast<int>(kConnectAttempts *
+ test.endpoints[i].expected_ratio),
+ 0.10))
+ << "endpoint " << i << " got the wrong number of packets";
+ }
+}
+
+std::vector<DistributionTestCase> GetDistributionTestCases() {
+ return std::vector<DistributionTestCase>{
+ {"Even distribution among sockets not bound to device",
+ {{"", 1. / 3}, {"", 1. / 3}, {"", 1. / 3}}},
+ {"Sockets bound to other interfaces get no packets",
+ {{"eth1", 0}, {"", 1. / 2}, {"", 1. / 2}}},
+ {"Bound has priority over unbound", {{"eth1", 0}, {"", 0}, {"lo", 1}}},
+ {"Even distribution among sockets bound to device",
+ {{"eth1", 0}, {"lo", 1. / 2}, {"lo", 1. / 2}}},
+ };
+}
+
+INSTANTIATE_TEST_SUITE_P(
+ BindToDeviceTest, BindToDeviceDistributionTest,
+ ::testing::Combine(::testing::Values(
+ // Listeners bound to IPv4 addresses refuse
+ // connections using IPv6 addresses.
+ ListenerConnector{V4Any(), V4Loopback()},
+ ListenerConnector{V4Loopback(), V4MappedLoopback()}),
+ ::testing::ValuesIn(GetDistributionTestCases())));
+
+} // namespace testing
+} // namespace gvisor
diff --git a/test/syscalls/linux/socket_bind_to_device_sequence.cc b/test/syscalls/linux/socket_bind_to_device_sequence.cc
new file mode 100644
index 000000000..a7365d139
--- /dev/null
+++ b/test/syscalls/linux/socket_bind_to_device_sequence.cc
@@ -0,0 +1,316 @@
+// Copyright 2019 The gVisor Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+#include <arpa/inet.h>
+#include <linux/capability.h>
+#include <linux/if_tun.h>
+#include <net/if.h>
+#include <netinet/in.h>
+#include <sys/ioctl.h>
+#include <sys/socket.h>
+#include <sys/types.h>
+#include <sys/un.h>
+
+#include <cstdio>
+#include <cstring>
+#include <map>
+#include <memory>
+#include <unordered_map>
+#include <unordered_set>
+#include <utility>
+#include <vector>
+
+#include "gmock/gmock.h"
+#include "gtest/gtest.h"
+#include "gtest/gtest.h"
+#include "test/syscalls/linux/ip_socket_test_util.h"
+#include "test/syscalls/linux/socket_bind_to_device_util.h"
+#include "test/syscalls/linux/socket_test_util.h"
+#include "test/util/capability_util.h"
+#include "test/util/test_util.h"
+#include "test/util/thread_util.h"
+
+namespace gvisor {
+namespace testing {
+
+using std::string;
+using std::vector;
+
+// Test fixture for SO_BINDTODEVICE tests the results of sequences of socket
+// binding.
+class BindToDeviceSequenceTest : public ::testing::TestWithParam<SocketKind> {
+ protected:
+ void SetUp() override {
+ printf("Testing case: %s\n", GetParam().description.c_str());
+ ASSERT_TRUE(ASSERT_NO_ERRNO_AND_VALUE(HaveCapability(CAP_NET_RAW)))
+ << "CAP_NET_RAW is required to use SO_BINDTODEVICE";
+ socket_factory_ = GetParam();
+
+ interface_names_ = GetInterfaceNames();
+ }
+
+ PosixErrorOr<std::unique_ptr<FileDescriptor>> NewSocket() const {
+ return socket_factory_.Create();
+ }
+
+ // Gets a device by device_id. If the device_id has been seen before, returns
+ // the previously returned device. If not, finds or creates a new device.
+ // Returns an empty string on failure.
+ void GetDevice(int device_id, string *device_name) {
+ auto device = devices_.find(device_id);
+ if (device != devices_.end()) {
+ *device_name = device->second;
+ return;
+ }
+
+ // Need to pick a new device. Try ethernet first.
+ *device_name = absl::StrCat("eth", next_unused_eth_);
+ if (interface_names_.find(*device_name) != interface_names_.end()) {
+ devices_[device_id] = *device_name;
+ next_unused_eth_++;
+ return;
+ }
+
+ // Need to make a new tunnel device. gVisor tests should have enough
+ // ethernet devices to never reach here.
+ ASSERT_FALSE(IsRunningOnGvisor());
+ // Need a tunnel.
+ tunnels_.push_back(ASSERT_NO_ERRNO_AND_VALUE(Tunnel::New()));
+ devices_[device_id] = tunnels_.back()->GetName();
+ *device_name = devices_[device_id];
+ }
+
+ // Release the socket
+ void ReleaseSocket(int socket_id) {
+ // Close the socket that was made in a previous action. The socket_id
+ // indicates which socket to close based on index into the list of actions.
+ sockets_to_close_.erase(socket_id);
+ }
+
+ // Bind a socket with the reuse option and bind_to_device options. Checks
+ // that all steps succeed and that the bind command's error matches want.
+ // Sets the socket_id to uniquely identify the socket bound if it is not
+ // nullptr.
+ void BindSocket(bool reuse, int device_id = 0, int want = 0,
+ int *socket_id = nullptr) {
+ next_socket_id_++;
+ sockets_to_close_[next_socket_id_] = ASSERT_NO_ERRNO_AND_VALUE(NewSocket());
+ auto socket_fd = sockets_to_close_[next_socket_id_]->get();
+ if (socket_id != nullptr) {
+ *socket_id = next_socket_id_;
+ }
+
+ // If reuse is indicated, do that.
+ if (reuse) {
+ EXPECT_THAT(setsockopt(socket_fd, SOL_SOCKET, SO_REUSEPORT, &kSockOptOn,
+ sizeof(kSockOptOn)),
+ SyscallSucceedsWithValue(0));
+ }
+
+ // If the device is non-zero, bind to that device.
+ if (device_id != 0) {
+ string device_name;
+ ASSERT_NO_FATAL_FAILURE(GetDevice(device_id, &device_name));
+ EXPECT_THAT(setsockopt(socket_fd, SOL_SOCKET, SO_BINDTODEVICE,
+ device_name.c_str(), device_name.size() + 1),
+ SyscallSucceedsWithValue(0));
+ char get_device[100];
+ socklen_t get_device_size = 100;
+ EXPECT_THAT(getsockopt(socket_fd, SOL_SOCKET, SO_BINDTODEVICE, get_device,
+ &get_device_size),
+ SyscallSucceedsWithValue(0));
+ }
+
+ struct sockaddr_in addr = {};
+ addr.sin_family = AF_INET;
+ addr.sin_addr.s_addr = htonl(INADDR_LOOPBACK);
+ addr.sin_port = port_;
+ if (want == 0) {
+ ASSERT_THAT(
+ bind(socket_fd, reinterpret_cast<const struct sockaddr *>(&addr),
+ sizeof(addr)),
+ SyscallSucceeds());
+ } else {
+ ASSERT_THAT(
+ bind(socket_fd, reinterpret_cast<const struct sockaddr *>(&addr),
+ sizeof(addr)),
+ SyscallFailsWithErrno(want));
+ }
+
+ if (port_ == 0) {
+ // We don't yet know what port we'll be using so we need to fetch it and
+ // remember it for future commands.
+ socklen_t addr_size = sizeof(addr);
+ ASSERT_THAT(
+ getsockname(socket_fd, reinterpret_cast<struct sockaddr *>(&addr),
+ &addr_size),
+ SyscallSucceeds());
+ port_ = addr.sin_port;
+ }
+ }
+
+ private:
+ SocketKind socket_factory_;
+ // devices maps from the device id in the test case to the name of the device.
+ std::unordered_map<int, string> devices_;
+ // These are the tunnels that were created for the test and will be destroyed
+ // by the destructor.
+ vector<std::unique_ptr<Tunnel>> tunnels_;
+ // A list of all interface names before the test started.
+ std::unordered_set<string> interface_names_;
+ // The next ethernet device to use when requested a device.
+ int next_unused_eth_ = 1;
+ // The port for all tests. Originally 0 (any) and later set to the port that
+ // all further commands will use.
+ in_port_t port_ = 0;
+ // sockets_to_close_ is a map from action index to the socket that was
+ // created.
+ std::unordered_map<int,
+ std::unique_ptr<gvisor::testing::FileDescriptor>>
+ sockets_to_close_;
+ int next_socket_id_ = 0;
+};
+
+TEST_P(BindToDeviceSequenceTest, BindTwiceWithDeviceFails) {
+ ASSERT_NO_FATAL_FAILURE(
+ BindSocket(/* reuse */ false, /* bind_to_device */ 3));
+ ASSERT_NO_FATAL_FAILURE(
+ BindSocket(/* reuse */ false, /* bind_to_device */ 3, EADDRINUSE));
+}
+
+TEST_P(BindToDeviceSequenceTest, BindToDevice) {
+ ASSERT_NO_FATAL_FAILURE(
+ BindSocket(/* reuse */ false, /* bind_to_device */ 1));
+ ASSERT_NO_FATAL_FAILURE(
+ BindSocket(/* reuse */ false, /* bind_to_device */ 2));
+}
+
+TEST_P(BindToDeviceSequenceTest, BindToDeviceAndThenWithoutDevice) {
+ ASSERT_NO_FATAL_FAILURE(
+ BindSocket(/* reuse */ false, /* bind_to_device */ 123));
+ ASSERT_NO_FATAL_FAILURE(
+ BindSocket(/* reuse */ false, /* bind_to_device */ 0, EADDRINUSE));
+}
+
+TEST_P(BindToDeviceSequenceTest, BindWithoutDevice) {
+ ASSERT_NO_FATAL_FAILURE(BindSocket(/* reuse */ false));
+ ASSERT_NO_FATAL_FAILURE(
+ BindSocket(/* reuse */ false, /* bind_to_device */ 123, EADDRINUSE));
+ ASSERT_NO_FATAL_FAILURE(
+ BindSocket(/* reuse */ true, /* bind_to_device */ 123, EADDRINUSE));
+ ASSERT_NO_FATAL_FAILURE(
+ BindSocket(/* reuse */ false, /* bind_to_device */ 0, EADDRINUSE));
+ ASSERT_NO_FATAL_FAILURE(
+ BindSocket(/* reuse */ true, /* bind_to_device */ 0, EADDRINUSE));
+}
+
+TEST_P(BindToDeviceSequenceTest, BindWithDevice) {
+ ASSERT_NO_FATAL_FAILURE(
+ BindSocket(/* reuse */ false, /* bind_to_device */ 123, 0));
+ ASSERT_NO_FATAL_FAILURE(
+ BindSocket(/* reuse */ false, /* bind_to_device */ 123, EADDRINUSE));
+ ASSERT_NO_FATAL_FAILURE(
+ BindSocket(/* reuse */ true, /* bind_to_device */ 123, EADDRINUSE));
+ ASSERT_NO_FATAL_FAILURE(
+ BindSocket(/* reuse */ false, /* bind_to_device */ 0, EADDRINUSE));
+ ASSERT_NO_FATAL_FAILURE(
+ BindSocket(/* reuse */ true, /* bind_to_device */ 0, EADDRINUSE));
+ ASSERT_NO_FATAL_FAILURE(
+ BindSocket(/* reuse */ true, /* bind_to_device */ 456, 0));
+ ASSERT_NO_FATAL_FAILURE(
+ BindSocket(/* reuse */ false, /* bind_to_device */ 789, 0));
+ ASSERT_NO_FATAL_FAILURE(
+ BindSocket(/* reuse */ false, /* bind_to_device */ 0, EADDRINUSE));
+ ASSERT_NO_FATAL_FAILURE(
+ BindSocket(/* reuse */ true, /* bind_to_device */ 0, EADDRINUSE));
+}
+
+TEST_P(BindToDeviceSequenceTest, BindWithReuse) {
+ ASSERT_NO_FATAL_FAILURE(BindSocket(/* reuse */ true));
+ ASSERT_NO_FATAL_FAILURE(
+ BindSocket(/* reuse */ false, /* bind_to_device */ 123, EADDRINUSE));
+ ASSERT_NO_FATAL_FAILURE(
+ BindSocket(/* reuse */ true, /* bind_to_device */ 123));
+ ASSERT_NO_FATAL_FAILURE(
+ BindSocket(/* reuse */ false, /* bind_to_device */ 0, EADDRINUSE));
+ ASSERT_NO_FATAL_FAILURE(BindSocket(/* reuse */ true, /* bind_to_device */ 0));
+}
+
+TEST_P(BindToDeviceSequenceTest, BindingWithReuseAndDevice) {
+ ASSERT_NO_FATAL_FAILURE(
+ BindSocket(/* reuse */ true, /* bind_to_device */ 123));
+ ASSERT_NO_FATAL_FAILURE(
+ BindSocket(/* reuse */ false, /* bind_to_device */ 123, EADDRINUSE));
+ ASSERT_NO_FATAL_FAILURE(
+ BindSocket(/* reuse */ true, /* bind_to_device */ 123));
+ ASSERT_NO_FATAL_FAILURE(
+ BindSocket(/* reuse */ false, /* bind_to_device */ 0, EADDRINUSE));
+ ASSERT_NO_FATAL_FAILURE(
+ BindSocket(/* reuse */ true, /* bind_to_device */ 456));
+ ASSERT_NO_FATAL_FAILURE(BindSocket(/* reuse */ true));
+ ASSERT_NO_FATAL_FAILURE(
+ BindSocket(/* reuse */ true, /* bind_to_device */ 789));
+ ASSERT_NO_FATAL_FAILURE(
+ BindSocket(/* reuse */ false, /* bind_to_device */ 999, EADDRINUSE));
+}
+
+TEST_P(BindToDeviceSequenceTest, MixingReuseAndNotReuseByBindingToDevice) {
+ ASSERT_NO_FATAL_FAILURE(
+ BindSocket(/* reuse */ true, /* bind_to_device */ 123, 0));
+ ASSERT_NO_FATAL_FAILURE(
+ BindSocket(/* reuse */ false, /* bind_to_device */ 456, 0));
+ ASSERT_NO_FATAL_FAILURE(
+ BindSocket(/* reuse */ true, /* bind_to_device */ 789, 0));
+ ASSERT_NO_FATAL_FAILURE(
+ BindSocket(/* reuse */ false, /* bind_to_device */ 999, 0));
+}
+
+TEST_P(BindToDeviceSequenceTest, CannotBindTo0AfterMixingReuseAndNotReuse) {
+ ASSERT_NO_FATAL_FAILURE(
+ BindSocket(/* reuse */ true, /* bind_to_device */ 123));
+ ASSERT_NO_FATAL_FAILURE(
+ BindSocket(/* reuse */ false, /* bind_to_device */ 456));
+ ASSERT_NO_FATAL_FAILURE(
+ BindSocket(/* reuse */ true, /* bind_to_device */ 0, EADDRINUSE));
+}
+
+TEST_P(BindToDeviceSequenceTest, BindAndRelease) {
+ ASSERT_NO_FATAL_FAILURE(
+ BindSocket(/* reuse */ true, /* bind_to_device */ 123));
+ int to_release;
+ ASSERT_NO_FATAL_FAILURE(
+ BindSocket(/* reuse */ true, /* bind_to_device */ 0, 0, &to_release));
+ ASSERT_NO_FATAL_FAILURE(
+ BindSocket(/* reuse */ false, /* bind_to_device */ 345, EADDRINUSE));
+ ASSERT_NO_FATAL_FAILURE(
+ BindSocket(/* reuse */ true, /* bind_to_device */ 789));
+ // Release the bind to device 0 and try again.
+ ASSERT_NO_FATAL_FAILURE(ReleaseSocket(to_release));
+ ASSERT_NO_FATAL_FAILURE(
+ BindSocket(/* reuse */ false, /* bind_to_device */ 345));
+}
+
+TEST_P(BindToDeviceSequenceTest, BindTwiceWithReuseOnce) {
+ ASSERT_NO_FATAL_FAILURE(
+ BindSocket(/* reuse */ false, /* bind_to_device */ 123));
+ ASSERT_NO_FATAL_FAILURE(
+ BindSocket(/* reuse */ true, /* bind_to_device */ 0, EADDRINUSE));
+}
+
+INSTANTIATE_TEST_SUITE_P(BindToDeviceTest, BindToDeviceSequenceTest,
+ ::testing::Values(IPv4UDPUnboundSocket(0),
+ IPv4TCPUnboundSocket(0)));
+
+} // namespace testing
+} // namespace gvisor
diff --git a/test/syscalls/linux/socket_bind_to_device_util.cc b/test/syscalls/linux/socket_bind_to_device_util.cc
new file mode 100644
index 000000000..f4ee775bd
--- /dev/null
+++ b/test/syscalls/linux/socket_bind_to_device_util.cc
@@ -0,0 +1,75 @@
+// Copyright 2019 The gVisor Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+#include "test/syscalls/linux/socket_bind_to_device_util.h"
+
+#include <arpa/inet.h>
+#include <fcntl.h>
+#include <linux/if_tun.h>
+#include <net/if.h>
+#include <netinet/in.h>
+#include <sys/ioctl.h>
+#include <sys/socket.h>
+#include <sys/types.h>
+#include <sys/un.h>
+#include <unistd.h>
+
+#include <cstdio>
+#include <cstring>
+#include <map>
+#include <memory>
+#include <unordered_map>
+#include <unordered_set>
+#include <utility>
+#include <vector>
+
+#include "test/util/test_util.h"
+
+namespace gvisor {
+namespace testing {
+
+using std::string;
+
+PosixErrorOr<std::unique_ptr<Tunnel>> Tunnel::New(string tunnel_name) {
+ int fd;
+ RETURN_ERROR_IF_SYSCALL_FAIL(fd = open("/dev/net/tun", O_RDWR));
+
+ // Using `new` to access a non-public constructor.
+ auto new_tunnel = absl::WrapUnique(new Tunnel(fd));
+
+ ifreq ifr = {};
+ ifr.ifr_flags = IFF_TUN;
+ strncpy(ifr.ifr_name, tunnel_name.c_str(), sizeof(ifr.ifr_name));
+
+ RETURN_ERROR_IF_SYSCALL_FAIL(ioctl(fd, TUNSETIFF, &ifr));
+ new_tunnel->name_ = ifr.ifr_name;
+ return new_tunnel;
+}
+
+std::unordered_set<string> GetInterfaceNames() {
+ struct if_nameindex* interfaces = if_nameindex();
+ std::unordered_set<string> names;
+ if (interfaces == nullptr) {
+ return names;
+ }
+ for (auto interface = interfaces;
+ interface->if_index != 0 || interface->if_name != nullptr; interface++) {
+ names.insert(interface->if_name);
+ }
+ if_freenameindex(interfaces);
+ return names;
+}
+
+} // namespace testing
+} // namespace gvisor
diff --git a/test/syscalls/linux/socket_bind_to_device_util.h b/test/syscalls/linux/socket_bind_to_device_util.h
new file mode 100644
index 000000000..f941ccc86
--- /dev/null
+++ b/test/syscalls/linux/socket_bind_to_device_util.h
@@ -0,0 +1,67 @@
+// Copyright 2019 The gVisor Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+#ifndef GVISOR_TEST_SYSCALLS_SOCKET_BIND_TO_DEVICE_UTILS_H_
+#define GVISOR_TEST_SYSCALLS_SOCKET_BIND_TO_DEVICE_UTILS_H_
+
+#include <arpa/inet.h>
+#include <linux/if_tun.h>
+#include <net/if.h>
+#include <netinet/in.h>
+#include <sys/ioctl.h>
+#include <sys/socket.h>
+#include <sys/types.h>
+#include <sys/un.h>
+#include <unistd.h>
+
+#include <cstdio>
+#include <cstring>
+#include <map>
+#include <memory>
+#include <string>
+#include <unordered_map>
+#include <unordered_set>
+#include <utility>
+#include <vector>
+
+#include "absl/memory/memory.h"
+#include "test/util/test_util.h"
+
+namespace gvisor {
+namespace testing {
+
+class Tunnel {
+ public:
+ static PosixErrorOr<std::unique_ptr<Tunnel>> New(
+ std::string tunnel_name = "");
+ const std::string& GetName() const { return name_; }
+
+ ~Tunnel() {
+ if (fd_ != -1) {
+ close(fd_);
+ }
+ }
+
+ private:
+ Tunnel(int fd) : fd_(fd) {}
+ int fd_ = -1;
+ std::string name_;
+};
+
+std::unordered_set<std::string> GetInterfaceNames();
+
+} // namespace testing
+} // namespace gvisor
+
+#endif // GVISOR_TEST_SYSCALLS_SOCKET_BIND_TO_DEVICE_UTILS_H_
diff --git a/test/syscalls/linux/socket_ip_tcp_generic.cc b/test/syscalls/linux/socket_ip_tcp_generic.cc
index a43cf9bce..bfa7943b1 100644
--- a/test/syscalls/linux/socket_ip_tcp_generic.cc
+++ b/test/syscalls/linux/socket_ip_tcp_generic.cc
@@ -117,7 +117,7 @@ TEST_P(TCPSocketPairTest, RSTCausesPollHUP) {
struct pollfd poll_fd3 = {sockets->first_fd(), POLLHUP, 0};
ASSERT_THAT(RetryEINTR(poll)(&poll_fd3, 1, kPollTimeoutMs),
SyscallSucceedsWithValue(1));
- ASSERT_NE(poll_fd.revents & (POLLHUP | POLLIN), 0);
+ ASSERT_NE(poll_fd3.revents & POLLHUP, 0);
}
// This test validates that even if a RST is sent the other end will not
diff --git a/test/syscalls/linux/socket_test_util.cc b/test/syscalls/linux/socket_test_util.cc
index 3c716235b..eff7d577e 100644
--- a/test/syscalls/linux/socket_test_util.cc
+++ b/test/syscalls/linux/socket_test_util.cc
@@ -588,8 +588,9 @@ ssize_t SendLargeSendMsg(const std::unique_ptr<SocketPair>& sockets,
return RetryEINTR(sendmsg)(sockets->first_fd(), &msg, 0);
}
-PosixErrorOr<int> PortAvailable(int port, AddressFamily family, SocketType type,
- bool reuse_addr) {
+namespace internal {
+PosixErrorOr<int> TryPortAvailable(int port, AddressFamily family,
+ SocketType type, bool reuse_addr) {
if (port < 0) {
return PosixError(EINVAL, "Invalid port");
}
@@ -664,10 +665,7 @@ PosixErrorOr<int> PortAvailable(int port, AddressFamily family, SocketType type,
return available_port;
}
-
-PosixError FreeAvailablePort(int port) {
- return NoError();
-}
+} // namespace internal
PosixErrorOr<int> SendMsg(int sock, msghdr* msg, char buf[], int buf_size) {
struct iovec iov;
diff --git a/test/syscalls/linux/socket_test_util.h b/test/syscalls/linux/socket_test_util.h
index ae0da2679..6efa8055f 100644
--- a/test/syscalls/linux/socket_test_util.h
+++ b/test/syscalls/linux/socket_test_util.h
@@ -492,6 +492,11 @@ uint16_t UDPChecksum(struct iphdr iphdr, struct udphdr udphdr,
uint16_t ICMPChecksum(struct icmphdr icmphdr, const char* payload,
ssize_t payload_len);
+namespace internal {
+PosixErrorOr<int> TryPortAvailable(int port, AddressFamily family,
+ SocketType type, bool reuse_addr);
+} // namespace internal
+
} // namespace testing
} // namespace gvisor
diff --git a/test/syscalls/linux/socket_test_util_impl.cc b/test/syscalls/linux/socket_test_util_impl.cc
new file mode 100644
index 000000000..ef661a0e3
--- /dev/null
+++ b/test/syscalls/linux/socket_test_util_impl.cc
@@ -0,0 +1,28 @@
+// Copyright 2019 The gVisor Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+#include "test/syscalls/linux/socket_test_util.h"
+
+namespace gvisor {
+namespace testing {
+
+PosixErrorOr<int> PortAvailable(int port, AddressFamily family, SocketType type,
+ bool reuse_addr) {
+ return internal::TryPortAvailable(port, family, type, reuse_addr);
+}
+
+PosixError FreeAvailablePort(int port) { return NoError(); }
+
+} // namespace testing
+} // namespace gvisor
diff --git a/test/syscalls/linux/splice.cc b/test/syscalls/linux/splice.cc
index e25f264f6..85232cb1f 100644
--- a/test/syscalls/linux/splice.cc
+++ b/test/syscalls/linux/splice.cc
@@ -14,12 +14,16 @@
#include <fcntl.h>
#include <sys/eventfd.h>
+#include <sys/resource.h>
#include <sys/sendfile.h>
+#include <sys/time.h>
#include <unistd.h>
#include "gmock/gmock.h"
#include "gtest/gtest.h"
#include "absl/strings/string_view.h"
+#include "absl/time/clock.h"
+#include "absl/time/time.h"
#include "test/util/file_descriptor.h"
#include "test/util/temp_path.h"
#include "test/util/test_util.h"
@@ -36,23 +40,23 @@ TEST(SpliceTest, TwoRegularFiles) {
const TempPath out_file = ASSERT_NO_ERRNO_AND_VALUE(TempPath::CreateFile());
// Open the input file as read only.
- const FileDescriptor inf =
+ const FileDescriptor in_fd =
ASSERT_NO_ERRNO_AND_VALUE(Open(in_file.path(), O_RDONLY));
// Open the output file as write only.
- const FileDescriptor outf =
+ const FileDescriptor out_fd =
ASSERT_NO_ERRNO_AND_VALUE(Open(out_file.path(), O_WRONLY));
// Verify that it is rejected as expected; regardless of offsets.
loff_t in_offset = 0;
loff_t out_offset = 0;
- EXPECT_THAT(splice(inf.get(), &in_offset, outf.get(), &out_offset, 1, 0),
+ EXPECT_THAT(splice(in_fd.get(), &in_offset, out_fd.get(), &out_offset, 1, 0),
SyscallFailsWithErrno(EINVAL));
- EXPECT_THAT(splice(inf.get(), nullptr, outf.get(), &out_offset, 1, 0),
+ EXPECT_THAT(splice(in_fd.get(), nullptr, out_fd.get(), &out_offset, 1, 0),
SyscallFailsWithErrno(EINVAL));
- EXPECT_THAT(splice(inf.get(), &in_offset, outf.get(), nullptr, 1, 0),
+ EXPECT_THAT(splice(in_fd.get(), &in_offset, out_fd.get(), nullptr, 1, 0),
SyscallFailsWithErrno(EINVAL));
- EXPECT_THAT(splice(inf.get(), nullptr, outf.get(), nullptr, 1, 0),
+ EXPECT_THAT(splice(in_fd.get(), nullptr, out_fd.get(), nullptr, 1, 0),
SyscallFailsWithErrno(EINVAL));
}
@@ -75,8 +79,6 @@ TEST(SpliceTest, SamePipe) {
}
TEST(TeeTest, SamePipe) {
- SKIP_IF(IsRunningOnGvisor());
-
// Create a new pipe.
int fds[2];
ASSERT_THAT(pipe(fds), SyscallSucceeds());
@@ -95,11 +97,9 @@ TEST(TeeTest, SamePipe) {
}
TEST(TeeTest, RegularFile) {
- SKIP_IF(IsRunningOnGvisor());
-
// Open some file.
const TempPath in_file = ASSERT_NO_ERRNO_AND_VALUE(TempPath::CreateFile());
- const FileDescriptor inf =
+ const FileDescriptor in_fd =
ASSERT_NO_ERRNO_AND_VALUE(Open(in_file.path(), O_RDWR));
// Create a new pipe.
@@ -109,9 +109,9 @@ TEST(TeeTest, RegularFile) {
const FileDescriptor wfd(fds[1]);
// Attempt to tee from the file.
- EXPECT_THAT(tee(inf.get(), wfd.get(), kPageSize, 0),
+ EXPECT_THAT(tee(in_fd.get(), wfd.get(), kPageSize, 0),
SyscallFailsWithErrno(EINVAL));
- EXPECT_THAT(tee(rfd.get(), inf.get(), kPageSize, 0),
+ EXPECT_THAT(tee(rfd.get(), in_fd.get(), kPageSize, 0),
SyscallFailsWithErrno(EINVAL));
}
@@ -142,7 +142,7 @@ TEST(SpliceTest, FromEventFD) {
constexpr uint64_t kEventFDValue = 1;
int efd;
ASSERT_THAT(efd = eventfd(kEventFDValue, 0), SyscallSucceeds());
- const FileDescriptor inf(efd);
+ const FileDescriptor in_fd(efd);
// Create a new pipe.
int fds[2];
@@ -152,7 +152,7 @@ TEST(SpliceTest, FromEventFD) {
// Splice 8-byte eventfd value to pipe.
constexpr int kEventFDSize = 8;
- EXPECT_THAT(splice(inf.get(), nullptr, wfd.get(), nullptr, kEventFDSize, 0),
+ EXPECT_THAT(splice(in_fd.get(), nullptr, wfd.get(), nullptr, kEventFDSize, 0),
SyscallSucceedsWithValue(kEventFDSize));
// Contents should be equal.
@@ -166,7 +166,7 @@ TEST(SpliceTest, FromEventFD) {
TEST(SpliceTest, FromEventFDOffset) {
int efd;
ASSERT_THAT(efd = eventfd(0, 0), SyscallSucceeds());
- const FileDescriptor inf(efd);
+ const FileDescriptor in_fd(efd);
// Create a new pipe.
int fds[2];
@@ -179,7 +179,7 @@ TEST(SpliceTest, FromEventFDOffset) {
// This is not allowed because eventfd doesn't support pread.
constexpr int kEventFDSize = 8;
loff_t in_off = 0;
- EXPECT_THAT(splice(inf.get(), &in_off, wfd.get(), nullptr, kEventFDSize, 0),
+ EXPECT_THAT(splice(in_fd.get(), &in_off, wfd.get(), nullptr, kEventFDSize, 0),
SyscallFailsWithErrno(EINVAL));
}
@@ -200,28 +200,29 @@ TEST(SpliceTest, ToEventFDOffset) {
int efd;
ASSERT_THAT(efd = eventfd(0, 0), SyscallSucceeds());
- const FileDescriptor outf(efd);
+ const FileDescriptor out_fd(efd);
// Attempt to splice 8-byte eventfd value to pipe with offset.
//
// This is not allowed because eventfd doesn't support pwrite.
loff_t out_off = 0;
- EXPECT_THAT(splice(rfd.get(), nullptr, outf.get(), &out_off, kEventFDSize, 0),
- SyscallFailsWithErrno(EINVAL));
+ EXPECT_THAT(
+ splice(rfd.get(), nullptr, out_fd.get(), &out_off, kEventFDSize, 0),
+ SyscallFailsWithErrno(EINVAL));
}
TEST(SpliceTest, ToPipe) {
// Open the input file.
const TempPath in_file = ASSERT_NO_ERRNO_AND_VALUE(TempPath::CreateFile());
- const FileDescriptor inf =
+ const FileDescriptor in_fd =
ASSERT_NO_ERRNO_AND_VALUE(Open(in_file.path(), O_RDWR));
// Fill with some random data.
std::vector<char> buf(kPageSize);
RandomizeBuffer(buf.data(), buf.size());
- ASSERT_THAT(write(inf.get(), buf.data(), buf.size()),
+ ASSERT_THAT(write(in_fd.get(), buf.data(), buf.size()),
SyscallSucceedsWithValue(kPageSize));
- ASSERT_THAT(lseek(inf.get(), 0, SEEK_SET), SyscallSucceedsWithValue(0));
+ ASSERT_THAT(lseek(in_fd.get(), 0, SEEK_SET), SyscallSucceedsWithValue(0));
// Create a new pipe.
int fds[2];
@@ -230,7 +231,7 @@ TEST(SpliceTest, ToPipe) {
const FileDescriptor wfd(fds[1]);
// Splice to the pipe.
- EXPECT_THAT(splice(inf.get(), nullptr, wfd.get(), nullptr, kPageSize, 0),
+ EXPECT_THAT(splice(in_fd.get(), nullptr, wfd.get(), nullptr, kPageSize, 0),
SyscallSucceedsWithValue(kPageSize));
// Contents should be equal.
@@ -243,13 +244,13 @@ TEST(SpliceTest, ToPipe) {
TEST(SpliceTest, ToPipeOffset) {
// Open the input file.
const TempPath in_file = ASSERT_NO_ERRNO_AND_VALUE(TempPath::CreateFile());
- const FileDescriptor inf =
+ const FileDescriptor in_fd =
ASSERT_NO_ERRNO_AND_VALUE(Open(in_file.path(), O_RDWR));
// Fill with some random data.
std::vector<char> buf(kPageSize);
RandomizeBuffer(buf.data(), buf.size());
- ASSERT_THAT(write(inf.get(), buf.data(), buf.size()),
+ ASSERT_THAT(write(in_fd.get(), buf.data(), buf.size()),
SyscallSucceedsWithValue(kPageSize));
// Create a new pipe.
@@ -261,7 +262,7 @@ TEST(SpliceTest, ToPipeOffset) {
// Splice to the pipe.
loff_t in_offset = kPageSize / 2;
EXPECT_THAT(
- splice(inf.get(), &in_offset, wfd.get(), nullptr, kPageSize / 2, 0),
+ splice(in_fd.get(), &in_offset, wfd.get(), nullptr, kPageSize / 2, 0),
SyscallSucceedsWithValue(kPageSize / 2));
// Contents should be equal to only the second part.
@@ -286,22 +287,22 @@ TEST(SpliceTest, FromPipe) {
// Open the input file.
const TempPath out_file = ASSERT_NO_ERRNO_AND_VALUE(TempPath::CreateFile());
- const FileDescriptor outf =
+ const FileDescriptor out_fd =
ASSERT_NO_ERRNO_AND_VALUE(Open(out_file.path(), O_RDWR));
// Splice to the output file.
- EXPECT_THAT(splice(rfd.get(), nullptr, outf.get(), nullptr, kPageSize, 0),
+ EXPECT_THAT(splice(rfd.get(), nullptr, out_fd.get(), nullptr, kPageSize, 0),
SyscallSucceedsWithValue(kPageSize));
// The offset of the output should be equal to kPageSize. We assert that and
// reset to zero so that we can read the contents and ensure they match.
- EXPECT_THAT(lseek(outf.get(), 0, SEEK_CUR),
+ EXPECT_THAT(lseek(out_fd.get(), 0, SEEK_CUR),
SyscallSucceedsWithValue(kPageSize));
- ASSERT_THAT(lseek(outf.get(), 0, SEEK_SET), SyscallSucceedsWithValue(0));
+ ASSERT_THAT(lseek(out_fd.get(), 0, SEEK_SET), SyscallSucceedsWithValue(0));
// Contents should be equal.
std::vector<char> rbuf(kPageSize);
- ASSERT_THAT(read(outf.get(), rbuf.data(), rbuf.size()),
+ ASSERT_THAT(read(out_fd.get(), rbuf.data(), rbuf.size()),
SyscallSucceedsWithValue(kPageSize));
EXPECT_EQ(memcmp(rbuf.data(), buf.data(), buf.size()), 0);
}
@@ -321,18 +322,19 @@ TEST(SpliceTest, FromPipeOffset) {
// Open the input file.
const TempPath out_file = ASSERT_NO_ERRNO_AND_VALUE(TempPath::CreateFile());
- const FileDescriptor outf =
+ const FileDescriptor out_fd =
ASSERT_NO_ERRNO_AND_VALUE(Open(out_file.path(), O_RDWR));
// Splice to the output file.
loff_t out_offset = kPageSize / 2;
- EXPECT_THAT(splice(rfd.get(), nullptr, outf.get(), &out_offset, kPageSize, 0),
- SyscallSucceedsWithValue(kPageSize));
+ EXPECT_THAT(
+ splice(rfd.get(), nullptr, out_fd.get(), &out_offset, kPageSize, 0),
+ SyscallSucceedsWithValue(kPageSize));
// Content should reflect the splice. We write to a specific offset in the
// file, so the internals should now be allocated sparsely.
std::vector<char> rbuf(kPageSize);
- ASSERT_THAT(read(outf.get(), rbuf.data(), rbuf.size()),
+ ASSERT_THAT(read(out_fd.get(), rbuf.data(), rbuf.size()),
SyscallSucceedsWithValue(kPageSize));
std::vector<char> zbuf(kPageSize / 2);
memset(zbuf.data(), 0, zbuf.size());
@@ -404,8 +406,6 @@ TEST(SpliceTest, Blocking) {
}
TEST(TeeTest, Blocking) {
- SKIP_IF(IsRunningOnGvisor());
-
// Create two new pipes.
int first[2], second[2];
ASSERT_THAT(pipe(first), SyscallSucceeds());
@@ -440,6 +440,49 @@ TEST(TeeTest, Blocking) {
EXPECT_EQ(memcmp(rbuf.data(), buf.data(), kPageSize), 0);
}
+TEST(TeeTest, BlockingWrite) {
+ // Create two new pipes.
+ int first[2], second[2];
+ ASSERT_THAT(pipe(first), SyscallSucceeds());
+ const FileDescriptor rfd1(first[0]);
+ const FileDescriptor wfd1(first[1]);
+ ASSERT_THAT(pipe(second), SyscallSucceeds());
+ const FileDescriptor rfd2(second[0]);
+ const FileDescriptor wfd2(second[1]);
+
+ // Make some data available to be read.
+ std::vector<char> buf1(kPageSize);
+ RandomizeBuffer(buf1.data(), buf1.size());
+ ASSERT_THAT(write(wfd1.get(), buf1.data(), buf1.size()),
+ SyscallSucceedsWithValue(kPageSize));
+
+ // Fill up the write pipe's buffer.
+ int pipe_size = -1;
+ ASSERT_THAT(pipe_size = fcntl(wfd2.get(), F_GETPIPE_SZ), SyscallSucceeds());
+ std::vector<char> buf2(pipe_size);
+ ASSERT_THAT(write(wfd2.get(), buf2.data(), buf2.size()),
+ SyscallSucceedsWithValue(pipe_size));
+
+ ScopedThread t([&]() {
+ absl::SleepFor(absl::Milliseconds(100));
+ ASSERT_THAT(read(rfd2.get(), buf2.data(), buf2.size()),
+ SyscallSucceedsWithValue(pipe_size));
+ });
+
+ // Attempt a tee immediately; it should block.
+ EXPECT_THAT(tee(rfd1.get(), wfd2.get(), kPageSize, 0),
+ SyscallSucceedsWithValue(kPageSize));
+
+ // Thread should be joinable.
+ t.Join();
+
+ // Content should reflect the tee.
+ std::vector<char> rbuf(kPageSize);
+ ASSERT_THAT(read(rfd2.get(), rbuf.data(), rbuf.size()),
+ SyscallSucceedsWithValue(kPageSize));
+ EXPECT_EQ(memcmp(rbuf.data(), buf1.data(), kPageSize), 0);
+}
+
TEST(SpliceTest, NonBlocking) {
// Create two new pipes.
int first[2], second[2];
@@ -457,8 +500,6 @@ TEST(SpliceTest, NonBlocking) {
}
TEST(TeeTest, NonBlocking) {
- SKIP_IF(IsRunningOnGvisor());
-
// Create two new pipes.
int first[2], second[2];
ASSERT_THAT(pipe(first), SyscallSucceeds());
@@ -473,6 +514,79 @@ TEST(TeeTest, NonBlocking) {
SyscallFailsWithErrno(EAGAIN));
}
+TEST(TeeTest, MultiPage) {
+ // Create two new pipes.
+ int first[2], second[2];
+ ASSERT_THAT(pipe(first), SyscallSucceeds());
+ const FileDescriptor rfd1(first[0]);
+ const FileDescriptor wfd1(first[1]);
+ ASSERT_THAT(pipe(second), SyscallSucceeds());
+ const FileDescriptor rfd2(second[0]);
+ const FileDescriptor wfd2(second[1]);
+
+ // Make some data available to be read.
+ std::vector<char> wbuf(8 * kPageSize);
+ RandomizeBuffer(wbuf.data(), wbuf.size());
+ ASSERT_THAT(write(wfd1.get(), wbuf.data(), wbuf.size()),
+ SyscallSucceedsWithValue(wbuf.size()));
+
+ // Attempt a tee immediately; it should complete.
+ EXPECT_THAT(tee(rfd1.get(), wfd2.get(), wbuf.size(), 0),
+ SyscallSucceedsWithValue(wbuf.size()));
+
+ // Content should reflect the tee.
+ std::vector<char> rbuf(wbuf.size());
+ ASSERT_THAT(read(rfd2.get(), rbuf.data(), rbuf.size()),
+ SyscallSucceedsWithValue(rbuf.size()));
+ EXPECT_EQ(memcmp(rbuf.data(), wbuf.data(), rbuf.size()), 0);
+ ASSERT_THAT(read(rfd1.get(), rbuf.data(), rbuf.size()),
+ SyscallSucceedsWithValue(rbuf.size()));
+ EXPECT_EQ(memcmp(rbuf.data(), wbuf.data(), rbuf.size()), 0);
+}
+
+TEST(SpliceTest, FromPipeMaxFileSize) {
+ // Create a new pipe.
+ int fds[2];
+ ASSERT_THAT(pipe(fds), SyscallSucceeds());
+ const FileDescriptor rfd(fds[0]);
+ const FileDescriptor wfd(fds[1]);
+
+ // Fill with some random data.
+ std::vector<char> buf(kPageSize);
+ RandomizeBuffer(buf.data(), buf.size());
+ ASSERT_THAT(write(wfd.get(), buf.data(), buf.size()),
+ SyscallSucceedsWithValue(kPageSize));
+
+ // Open the input file.
+ const TempPath out_file = ASSERT_NO_ERRNO_AND_VALUE(TempPath::CreateFile());
+ const FileDescriptor out_fd =
+ ASSERT_NO_ERRNO_AND_VALUE(Open(out_file.path(), O_RDWR));
+
+ EXPECT_THAT(ftruncate(out_fd.get(), 13 << 20), SyscallSucceeds());
+ EXPECT_THAT(lseek(out_fd.get(), 0, SEEK_END),
+ SyscallSucceedsWithValue(13 << 20));
+
+ // Set our file size limit.
+ sigset_t set;
+ sigemptyset(&set);
+ sigaddset(&set, SIGXFSZ);
+ TEST_PCHECK(sigprocmask(SIG_BLOCK, &set, nullptr) == 0);
+ rlimit rlim = {};
+ rlim.rlim_cur = rlim.rlim_max = (13 << 20);
+ EXPECT_THAT(setrlimit(RLIMIT_FSIZE, &rlim), SyscallSucceeds());
+
+ // Splice to the output file.
+ EXPECT_THAT(
+ splice(rfd.get(), nullptr, out_fd.get(), nullptr, 3 * kPageSize, 0),
+ SyscallFailsWithErrno(EFBIG));
+
+ // Contents should be equal.
+ std::vector<char> rbuf(kPageSize);
+ ASSERT_THAT(read(rfd.get(), rbuf.data(), rbuf.size()),
+ SyscallSucceedsWithValue(kPageSize));
+ EXPECT_EQ(memcmp(rbuf.data(), buf.data(), buf.size()), 0);
+}
+
} // namespace
} // namespace testing
diff --git a/test/syscalls/linux/sticky.cc b/test/syscalls/linux/sticky.cc
index 59fb5dfe6..7e73325bf 100644
--- a/test/syscalls/linux/sticky.cc
+++ b/test/syscalls/linux/sticky.cc
@@ -17,9 +17,11 @@
#include <sys/prctl.h>
#include <sys/types.h>
#include <unistd.h>
+
#include <vector>
#include "gtest/gtest.h"
+#include "absl/flags/flag.h"
#include "test/util/capability_util.h"
#include "test/util/file_descriptor.h"
#include "test/util/fs_util.h"
@@ -27,8 +29,8 @@
#include "test/util/test_util.h"
#include "test/util/thread_util.h"
-DEFINE_int32(scratch_uid, 65534, "first scratch UID");
-DEFINE_int32(scratch_gid, 65534, "first scratch GID");
+ABSL_FLAG(int32_t, scratch_uid, 65534, "first scratch UID");
+ABSL_FLAG(int32_t, scratch_gid, 65534, "first scratch GID");
namespace gvisor {
namespace testing {
@@ -52,10 +54,12 @@ TEST(StickyTest, StickyBitPermDenied) {
}
// Change EUID and EGID.
- EXPECT_THAT(syscall(SYS_setresgid, -1, FLAGS_scratch_gid, -1),
- SyscallSucceeds());
- EXPECT_THAT(syscall(SYS_setresuid, -1, FLAGS_scratch_uid, -1),
- SyscallSucceeds());
+ EXPECT_THAT(
+ syscall(SYS_setresgid, -1, absl::GetFlag(FLAGS_scratch_gid), -1),
+ SyscallSucceeds());
+ EXPECT_THAT(
+ syscall(SYS_setresuid, -1, absl::GetFlag(FLAGS_scratch_uid), -1),
+ SyscallSucceeds());
EXPECT_THAT(rmdir(path.c_str()), SyscallFailsWithErrno(EPERM));
});
@@ -78,8 +82,9 @@ TEST(StickyTest, StickyBitSameUID) {
}
// Change EGID.
- EXPECT_THAT(syscall(SYS_setresgid, -1, FLAGS_scratch_gid, -1),
- SyscallSucceeds());
+ EXPECT_THAT(
+ syscall(SYS_setresgid, -1, absl::GetFlag(FLAGS_scratch_gid), -1),
+ SyscallSucceeds());
// We still have the same EUID.
EXPECT_THAT(rmdir(path.c_str()), SyscallSucceeds());
@@ -101,10 +106,12 @@ TEST(StickyTest, StickyBitCapFOWNER) {
EXPECT_THAT(prctl(PR_SET_KEEPCAPS, 1, 0, 0, 0), SyscallSucceeds());
// Change EUID and EGID.
- EXPECT_THAT(syscall(SYS_setresgid, -1, FLAGS_scratch_gid, -1),
- SyscallSucceeds());
- EXPECT_THAT(syscall(SYS_setresuid, -1, FLAGS_scratch_uid, -1),
- SyscallSucceeds());
+ EXPECT_THAT(
+ syscall(SYS_setresgid, -1, absl::GetFlag(FLAGS_scratch_gid), -1),
+ SyscallSucceeds());
+ EXPECT_THAT(
+ syscall(SYS_setresuid, -1, absl::GetFlag(FLAGS_scratch_uid), -1),
+ SyscallSucceeds());
EXPECT_NO_ERRNO(SetCapability(CAP_FOWNER, true));
EXPECT_THAT(rmdir(path.c_str()), SyscallSucceeds());
diff --git a/test/syscalls/linux/tcp_socket.cc b/test/syscalls/linux/tcp_socket.cc
index 8f4d3f386..bfa031bce 100644
--- a/test/syscalls/linux/tcp_socket.cc
+++ b/test/syscalls/linux/tcp_socket.cc
@@ -579,7 +579,7 @@ TEST_P(TcpSocketTest, TcpInq) {
if (size == sizeof(buf)) {
break;
}
- usleep(10000);
+ absl::SleepFor(absl::Milliseconds(10));
}
struct msghdr msg = {};
@@ -610,6 +610,25 @@ TEST_P(TcpSocketTest, TcpInq) {
}
}
+TEST_P(TcpSocketTest, Tiocinq) {
+ char buf[1024];
+ size_t size = sizeof(buf);
+ ASSERT_THAT(RetryEINTR(write)(s_, buf, size), SyscallSucceedsWithValue(size));
+
+ uint32_t seed = time(nullptr);
+ const size_t max_chunk = size / 10;
+ while (size > 0) {
+ size_t chunk = (rand_r(&seed) % max_chunk) + 1;
+ ssize_t read = RetryEINTR(recvfrom)(t_, buf, chunk, 0, nullptr, nullptr);
+ ASSERT_THAT(read, SyscallSucceeds());
+ size -= read;
+
+ int inq = 0;
+ ASSERT_THAT(ioctl(t_, TIOCINQ, &inq), SyscallSucceeds());
+ ASSERT_EQ(inq, size);
+ }
+}
+
TEST_P(TcpSocketTest, TcpSCMPriority) {
char buf[1024];
ASSERT_THAT(RetryEINTR(write)(s_, buf, sizeof(buf)),
diff --git a/test/syscalls/linux/timers.cc b/test/syscalls/linux/timers.cc
index fd42e81e1..3db18d7ac 100644
--- a/test/syscalls/linux/timers.cc
+++ b/test/syscalls/linux/timers.cc
@@ -23,6 +23,7 @@
#include <atomic>
#include "gtest/gtest.h"
+#include "absl/flags/flag.h"
#include "absl/time/clock.h"
#include "absl/time/time.h"
#include "test/util/cleanup.h"
@@ -33,8 +34,8 @@
#include "test/util/test_util.h"
#include "test/util/thread_util.h"
-DEFINE_bool(timers_test_sleep, false,
- "If true, sleep forever instead of running tests.");
+ABSL_FLAG(bool, timers_test_sleep, false,
+ "If true, sleep forever instead of running tests.");
using ::testing::_;
using ::testing::AnyOf;
@@ -635,7 +636,7 @@ TEST(IntervalTimerTest, IgnoredSignalCountsAsOverrun) {
int main(int argc, char** argv) {
gvisor::testing::TestInit(&argc, &argv);
- if (FLAGS_timers_test_sleep) {
+ if (absl::GetFlag(FLAGS_timers_test_sleep)) {
while (true) {
absl::SleepFor(absl::Seconds(10));
}
diff --git a/test/syscalls/linux/uidgid.cc b/test/syscalls/linux/uidgid.cc
index bf1ca8679..6218fbce1 100644
--- a/test/syscalls/linux/uidgid.cc
+++ b/test/syscalls/linux/uidgid.cc
@@ -18,17 +18,19 @@
#include <unistd.h>
#include "gtest/gtest.h"
+#include "absl/flags/flag.h"
#include "absl/strings/str_cat.h"
#include "absl/strings/str_join.h"
#include "test/util/capability_util.h"
#include "test/util/posix_error.h"
#include "test/util/test_util.h"
#include "test/util/thread_util.h"
+#include "test/util/uid_util.h"
-DEFINE_int32(scratch_uid1, 65534, "first scratch UID");
-DEFINE_int32(scratch_uid2, 65533, "second scratch UID");
-DEFINE_int32(scratch_gid1, 65534, "first scratch GID");
-DEFINE_int32(scratch_gid2, 65533, "second scratch GID");
+ABSL_FLAG(int32_t, scratch_uid1, 65534, "first scratch UID");
+ABSL_FLAG(int32_t, scratch_uid2, 65533, "second scratch UID");
+ABSL_FLAG(int32_t, scratch_gid1, 65534, "first scratch GID");
+ABSL_FLAG(int32_t, scratch_gid2, 65533, "second scratch GID");
using ::testing::UnorderedElementsAreArray;
@@ -67,30 +69,6 @@ TEST(UidGidTest, Getgroups) {
// here; see the setgroups test below.
}
-// If the caller's real/effective/saved user/group IDs are all 0, IsRoot returns
-// true. Otherwise IsRoot logs an explanatory message and returns false.
-PosixErrorOr<bool> IsRoot() {
- uid_t ruid, euid, suid;
- int rc = getresuid(&ruid, &euid, &suid);
- MaybeSave();
- if (rc < 0) {
- return PosixError(errno, "getresuid");
- }
- if (ruid != 0 || euid != 0 || suid != 0) {
- return false;
- }
- gid_t rgid, egid, sgid;
- rc = getresgid(&rgid, &egid, &sgid);
- MaybeSave();
- if (rc < 0) {
- return PosixError(errno, "getresgid");
- }
- if (rgid != 0 || egid != 0 || sgid != 0) {
- return false;
- }
- return true;
-}
-
// Checks that the calling process' real/effective/saved user IDs are
// ruid/euid/suid respectively.
PosixError CheckUIDs(uid_t ruid, uid_t euid, uid_t suid) {
@@ -146,7 +124,7 @@ TEST(UidGidRootTest, Setuid) {
// real UID.
EXPECT_THAT(syscall(SYS_setuid, -1), SyscallFailsWithErrno(EINVAL));
- const uid_t uid = FLAGS_scratch_uid1;
+ const uid_t uid = absl::GetFlag(FLAGS_scratch_uid1);
EXPECT_THAT(syscall(SYS_setuid, uid), SyscallSucceeds());
// "If the effective UID of the caller is root (more precisely: if the
// caller has the CAP_SETUID capability), the real UID and saved set-user-ID
@@ -160,7 +138,7 @@ TEST(UidGidRootTest, Setgid) {
EXPECT_THAT(setgid(-1), SyscallFailsWithErrno(EINVAL));
- const gid_t gid = FLAGS_scratch_gid1;
+ const gid_t gid = absl::GetFlag(FLAGS_scratch_gid1);
ASSERT_THAT(setgid(gid), SyscallSucceeds());
EXPECT_NO_ERRNO(CheckGIDs(gid, gid, gid));
}
@@ -168,7 +146,7 @@ TEST(UidGidRootTest, Setgid) {
TEST(UidGidRootTest, SetgidNotFromThreadGroupLeader) {
SKIP_IF(!ASSERT_NO_ERRNO_AND_VALUE(IsRoot()));
- const gid_t gid = FLAGS_scratch_gid1;
+ const gid_t gid = absl::GetFlag(FLAGS_scratch_gid1);
// NOTE(b/64676707): Do setgid in a separate thread so that we can test if
// info.si_pid is set correctly.
ScopedThread([gid] { ASSERT_THAT(setgid(gid), SyscallSucceeds()); });
@@ -189,8 +167,8 @@ TEST(UidGidRootTest, Setreuid) {
// cannot be opened by the `uid` set below after the test. After calling
// setuid(non-zero-UID), there is no way to get root privileges back.
ScopedThread([&] {
- const uid_t ruid = FLAGS_scratch_uid1;
- const uid_t euid = FLAGS_scratch_uid2;
+ const uid_t ruid = absl::GetFlag(FLAGS_scratch_uid1);
+ const uid_t euid = absl::GetFlag(FLAGS_scratch_uid2);
// Use syscall instead of glibc setuid wrapper because we want this setuid
// call to only apply to this task. posix threads, however, require that all
@@ -211,8 +189,8 @@ TEST(UidGidRootTest, Setregid) {
EXPECT_THAT(setregid(-1, -1), SyscallSucceeds());
EXPECT_NO_ERRNO(CheckGIDs(0, 0, 0));
- const gid_t rgid = FLAGS_scratch_gid1;
- const gid_t egid = FLAGS_scratch_gid2;
+ const gid_t rgid = absl::GetFlag(FLAGS_scratch_gid1);
+ const gid_t egid = absl::GetFlag(FLAGS_scratch_gid2);
ASSERT_THAT(setregid(rgid, egid), SyscallSucceeds());
EXPECT_NO_ERRNO(CheckGIDs(rgid, egid, egid));
}
diff --git a/test/syscalls/linux/uname.cc b/test/syscalls/linux/uname.cc
index 0a5d91017..d8824b171 100644
--- a/test/syscalls/linux/uname.cc
+++ b/test/syscalls/linux/uname.cc
@@ -41,6 +41,19 @@ TEST(UnameTest, Sanity) {
TEST(UnameTest, SetNames) {
SKIP_IF(!ASSERT_NO_ERRNO_AND_VALUE(HaveCapability(CAP_SYS_ADMIN)));
+ char hostname[65];
+ ASSERT_THAT(sethostname("0123456789", 3), SyscallSucceeds());
+ EXPECT_THAT(gethostname(hostname, sizeof(hostname)), SyscallSucceeds());
+ EXPECT_EQ(absl::string_view(hostname), "012");
+
+ ASSERT_THAT(sethostname("0123456789\0xxx", 11), SyscallSucceeds());
+ EXPECT_THAT(gethostname(hostname, sizeof(hostname)), SyscallSucceeds());
+ EXPECT_EQ(absl::string_view(hostname), "0123456789");
+
+ ASSERT_THAT(sethostname("0123456789\0xxx", 12), SyscallSucceeds());
+ EXPECT_THAT(gethostname(hostname, sizeof(hostname)), SyscallSucceeds());
+ EXPECT_EQ(absl::string_view(hostname), "0123456789");
+
constexpr char kHostname[] = "wubbalubba";
ASSERT_THAT(sethostname(kHostname, sizeof(kHostname)), SyscallSucceeds());
@@ -54,7 +67,6 @@ TEST(UnameTest, SetNames) {
EXPECT_EQ(absl::string_view(buf.domainname), kDomainname);
// These should just be glibc wrappers that also call uname(2).
- char hostname[65];
EXPECT_THAT(gethostname(hostname, sizeof(hostname)), SyscallSucceeds());
EXPECT_EQ(absl::string_view(hostname), kHostname);
diff --git a/test/syscalls/linux/unlink.cc b/test/syscalls/linux/unlink.cc
index b6f65e027..2040375c9 100644
--- a/test/syscalls/linux/unlink.cc
+++ b/test/syscalls/linux/unlink.cc
@@ -123,6 +123,8 @@ TEST(UnlinkTest, AtBad) {
SyscallSucceeds());
EXPECT_THAT(unlinkat(dirfd, "UnlinkAtFile", AT_REMOVEDIR),
SyscallFailsWithErrno(ENOTDIR));
+ EXPECT_THAT(unlinkat(dirfd, "UnlinkAtFile/", 0),
+ SyscallFailsWithErrno(ENOTDIR));
ASSERT_THAT(close(fd), SyscallSucceeds());
EXPECT_THAT(unlinkat(dirfd, "UnlinkAtFile", 0), SyscallSucceeds());
diff --git a/test/syscalls/linux/vfork.cc b/test/syscalls/linux/vfork.cc
index f67b06f37..0aaba482d 100644
--- a/test/syscalls/linux/vfork.cc
+++ b/test/syscalls/linux/vfork.cc
@@ -22,14 +22,15 @@
#include "gmock/gmock.h"
#include "gtest/gtest.h"
+#include "absl/flags/flag.h"
#include "absl/time/time.h"
#include "test/util/logging.h"
#include "test/util/multiprocess_util.h"
#include "test/util/test_util.h"
#include "test/util/time_util.h"
-DEFINE_bool(vfork_test_child, false,
- "If true, run the VforkTest child workload.");
+ABSL_FLAG(bool, vfork_test_child, false,
+ "If true, run the VforkTest child workload.");
namespace gvisor {
namespace testing {
@@ -186,7 +187,7 @@ int RunChild() {
int main(int argc, char** argv) {
gvisor::testing::TestInit(&argc, &argv);
- if (FLAGS_vfork_test_child) {
+ if (absl::GetFlag(FLAGS_vfork_test_child)) {
return gvisor::testing::RunChild();
}
diff --git a/test/syscalls/syscall_test_runner.go b/test/syscalls/syscall_test_runner.go
index 32408f021..c1e9ce22c 100644
--- a/test/syscalls/syscall_test_runner.go
+++ b/test/syscalls/syscall_test_runner.go
@@ -20,12 +20,10 @@ import (
"flag"
"fmt"
"io/ioutil"
- "math"
"os"
"os/exec"
"os/signal"
"path/filepath"
- "strconv"
"strings"
"syscall"
"testing"
@@ -35,7 +33,7 @@ import (
"golang.org/x/sys/unix"
"gvisor.dev/gvisor/pkg/log"
"gvisor.dev/gvisor/runsc/specutils"
- "gvisor.dev/gvisor/runsc/test/testutil"
+ "gvisor.dev/gvisor/runsc/testutil"
"gvisor.dev/gvisor/test/syscalls/gtest"
)
@@ -358,32 +356,14 @@ func main() {
fatalf("ParseTestCases(%q) failed: %v", testBin, err)
}
- // If sharding, then get the subset of tests to run based on the shard index.
- if indexStr, totalStr := os.Getenv("TEST_SHARD_INDEX"), os.Getenv("TEST_TOTAL_SHARDS"); indexStr != "" && totalStr != "" {
- // Parse index and total to ints.
- index, err := strconv.Atoi(indexStr)
- if err != nil {
- fatalf("invalid TEST_SHARD_INDEX %q: %v", indexStr, err)
- }
- total, err := strconv.Atoi(totalStr)
- if err != nil {
- fatalf("invalid TEST_TOTAL_SHARDS %q: %v", totalStr, err)
- }
- // Calculate subslice of tests to run.
- shardSize := int(math.Ceil(float64(len(testCases)) / float64(total)))
- begin := index * shardSize
- // Set end as begin of next subslice.
- end := ((index + 1) * shardSize)
- if begin > len(testCases) {
- // Nothing to run.
- return
- }
- if end > len(testCases) {
- end = len(testCases)
- }
- testCases = testCases[begin:end]
+ // Get subset of tests corresponding to shard.
+ begin, end, err := testutil.TestBoundsForShard(len(testCases))
+ if err != nil {
+ fatalf("TestsForShard() failed: %v", err)
}
+ testCases = testCases[begin:end]
+ // Run the tests.
var tests []testing.InternalTest
for _, tc := range testCases {
// Capture tc.
diff --git a/test/util/BUILD b/test/util/BUILD
index c124cef34..5d2a9cc2c 100644
--- a/test/util/BUILD
+++ b/test/util/BUILD
@@ -1,3 +1,6 @@
+load("@rules_cc//cc:defs.bzl", "cc_library", "cc_test")
+load("//test/syscalls:build_defs.bzl", "select_for_linux")
+
package(
default_visibility = ["//:sandbox"],
licenses = ["notice"],
@@ -139,7 +142,11 @@ cc_library(
cc_library(
name = "save_util",
testonly = 1,
- srcs = ["save_util.cc"],
+ srcs = ["save_util.cc"] +
+ select_for_linux(
+ ["save_util_linux.cc"],
+ ["save_util_other.cc"],
+ ),
hdrs = ["save_util.h"],
)
@@ -232,8 +239,9 @@ cc_library(
":logging",
":posix_error",
":save_util",
- "@com_github_gflags_gflags//:gflags",
"@com_google_absl//absl/base:core_headers",
+ "@com_google_absl//absl/flags:flag",
+ "@com_google_absl//absl/flags:parse",
"@com_google_absl//absl/strings",
"@com_google_absl//absl/strings:str_format",
"@com_google_absl//absl/time",
@@ -316,3 +324,14 @@ cc_library(
":test_util",
],
)
+
+cc_library(
+ name = "uid_util",
+ testonly = 1,
+ srcs = ["uid_util.cc"],
+ hdrs = ["uid_util.h"],
+ deps = [
+ ":posix_error",
+ ":save_util",
+ ],
+)
diff --git a/test/util/fs_util.cc b/test/util/fs_util.cc
index ae49725a0..f7d231b14 100644
--- a/test/util/fs_util.cc
+++ b/test/util/fs_util.cc
@@ -105,6 +105,15 @@ PosixErrorOr<struct stat> Stat(absl::string_view path) {
return stat_buf;
}
+PosixErrorOr<struct stat> Fstat(int fd) {
+ struct stat stat_buf;
+ int res = fstat(fd, &stat_buf);
+ if (res < 0) {
+ return PosixError(errno, absl::StrCat("fstat ", fd));
+ }
+ return stat_buf;
+}
+
PosixErrorOr<bool> Exists(absl::string_view path) {
struct stat stat_buf;
int res = stat(std::string(path).c_str(), &stat_buf);
diff --git a/test/util/fs_util.h b/test/util/fs_util.h
index 3969f8309..e5b555891 100644
--- a/test/util/fs_util.h
+++ b/test/util/fs_util.h
@@ -35,6 +35,9 @@ PosixErrorOr<bool> Exists(absl::string_view path);
// Returns a stat structure for the given path or an error.
PosixErrorOr<struct stat> Stat(absl::string_view path);
+// Returns a stat struct for the given fd.
+PosixErrorOr<struct stat> Fstat(int fd);
+
// Deletes the file or directory at path or returns an error.
PosixError Delete(absl::string_view path);
diff --git a/test/util/memory_util.h b/test/util/memory_util.h
index 190c469b5..e189b73e8 100644
--- a/test/util/memory_util.h
+++ b/test/util/memory_util.h
@@ -118,6 +118,18 @@ inline PosixErrorOr<Mapping> MmapAnon(size_t length, int prot, int flags) {
return Mmap(nullptr, length, prot, flags | MAP_ANONYMOUS, -1, 0);
}
+// Wrapper for mremap that returns a PosixErrorOr<>, since the return type of
+// void* isn't directly compatible with SyscallSucceeds.
+inline PosixErrorOr<void*> Mremap(void* old_address, size_t old_size,
+ size_t new_size, int flags,
+ void* new_address) {
+ void* rv = mremap(old_address, old_size, new_size, flags, new_address);
+ if (rv == MAP_FAILED) {
+ return PosixError(errno, "mremap failed");
+ }
+ return rv;
+}
+
// Returns true if the page containing addr is mapped.
inline bool IsMapped(uintptr_t addr) {
int const rv = msync(reinterpret_cast<void*>(addr & ~(kPageSize - 1)),
diff --git a/test/util/proc_util.cc b/test/util/proc_util.cc
index 75b24da37..34d636ba9 100644
--- a/test/util/proc_util.cc
+++ b/test/util/proc_util.cc
@@ -88,7 +88,7 @@ PosixErrorOr<std::vector<ProcMapsEntry>> ParseProcMaps(
std::vector<ProcMapsEntry> entries;
auto lines = absl::StrSplit(contents, '\n', absl::SkipEmpty());
for (const auto& l : lines) {
- std::cout << "line: " << l;
+ std::cout << "line: " << l << std::endl;
ASSIGN_OR_RETURN_ERRNO(auto entry, ParseProcMapsLine(l));
entries.push_back(entry);
}
diff --git a/test/util/save_util.cc b/test/util/save_util.cc
index 05f52b80d..384d626f0 100644
--- a/test/util/save_util.cc
+++ b/test/util/save_util.cc
@@ -16,8 +16,8 @@
#include <stddef.h>
#include <stdlib.h>
-#include <sys/syscall.h>
#include <unistd.h>
+
#include <atomic>
#include <cerrno>
@@ -61,13 +61,11 @@ void DisableSave::reset() {
}
}
-void MaybeSave() {
- if (CooperativeSaveEnabled() && !save_disable.load()) {
- int orig_errno = errno;
- syscall(SYS_create_module, nullptr, 0);
- errno = orig_errno;
- }
+namespace internal {
+bool ShouldSave() {
+ return CooperativeSaveEnabled() && (save_disable.load() == 0);
}
+} // namespace internal
} // namespace testing
} // namespace gvisor
diff --git a/test/util/save_util.h b/test/util/save_util.h
index 90460701e..bddad6120 100644
--- a/test/util/save_util.h
+++ b/test/util/save_util.h
@@ -41,6 +41,11 @@ class DisableSave {
//
// errno is guaranteed to be preserved.
void MaybeSave();
+
+namespace internal {
+bool ShouldSave();
+} // namespace internal
+
} // namespace testing
} // namespace gvisor
diff --git a/test/util/save_util_linux.cc b/test/util/save_util_linux.cc
new file mode 100644
index 000000000..7a0f14342
--- /dev/null
+++ b/test/util/save_util_linux.cc
@@ -0,0 +1,33 @@
+// Copyright 2019 The gVisor Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+#include <errno.h>
+#include <sys/syscall.h>
+#include <unistd.h>
+
+#include "test/util/save_util.h"
+
+namespace gvisor {
+namespace testing {
+
+void MaybeSave() {
+ if (internal::ShouldSave()) {
+ int orig_errno = errno;
+ syscall(SYS_create_module, nullptr, 0);
+ errno = orig_errno;
+ }
+}
+
+} // namespace testing
+} // namespace gvisor
diff --git a/runsc/test/testutil/testutil_race.go b/test/util/save_util_other.cc
index 86db6ffa1..1aca663b7 100644
--- a/runsc/test/testutil/testutil_race.go
+++ b/test/util/save_util_other.cc
@@ -1,4 +1,4 @@
-// Copyright 2018 The gVisor Authors.
+// Copyright 2019 The gVisor Authors.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
@@ -12,10 +12,12 @@
// See the License for the specific language governing permissions and
// limitations under the License.
-// +build race
+namespace gvisor {
+namespace testing {
-package testutil
-
-func init() {
- RaceEnabled = true
+void MaybeSave() {
+ // Saving is never available in a non-linux environment.
}
+
+} // namespace testing
+} // namespace gvisor
diff --git a/test/util/test_util.cc b/test/util/test_util.cc
index e42bba04a..ba0dcf7d0 100644
--- a/test/util/test_util.cc
+++ b/test/util/test_util.cc
@@ -28,6 +28,8 @@
#include <vector>
#include "absl/base/attributes.h"
+#include "absl/flags/flag.h" // IWYU pragma: keep
+#include "absl/flags/parse.h" // IWYU pragma: keep
#include "absl/strings/numbers.h"
#include "absl/strings/str_cat.h"
#include "absl/strings/str_split.h"
@@ -224,7 +226,7 @@ bool Equivalent(uint64_t current, uint64_t target, double tolerance) {
void TestInit(int* argc, char*** argv) {
::testing::InitGoogleTest(argc, *argv);
- ::gflags::ParseCommandLineFlags(argc, argv, true);
+ ::absl::ParseCommandLine(*argc, *argv);
// Always mask SIGPIPE as it's common and tests aren't expected to handle it.
struct sigaction sa = {};
diff --git a/test/util/test_util.h b/test/util/test_util.h
index cdbe8bfd1..b9d2dc2ba 100644
--- a/test/util/test_util.h
+++ b/test/util/test_util.h
@@ -185,7 +185,6 @@
#include <utility>
#include <vector>
-#include <gflags/gflags.h>
#include "gmock/gmock.h"
#include "absl/strings/str_cat.h"
#include "absl/strings/str_format.h"
diff --git a/test/util/thread_util.h b/test/util/thread_util.h
index 860e77531..923c4fe10 100644
--- a/test/util/thread_util.h
+++ b/test/util/thread_util.h
@@ -16,7 +16,9 @@
#define GVISOR_TEST_UTIL_THREAD_UTIL_H_
#include <pthread.h>
+#ifdef __linux__
#include <sys/syscall.h>
+#endif
#include <unistd.h>
#include <functional>
@@ -66,13 +68,13 @@ class ScopedThread {
private:
void CreateThread() {
- TEST_PCHECK_MSG(
- pthread_create(&pt_, /* attr = */ nullptr,
- +[](void* arg) -> void* {
- return static_cast<ScopedThread*>(arg)->f_();
- },
- this) == 0,
- "thread creation failed");
+ TEST_PCHECK_MSG(pthread_create(
+ &pt_, /* attr = */ nullptr,
+ +[](void* arg) -> void* {
+ return static_cast<ScopedThread*>(arg)->f_();
+ },
+ this) == 0,
+ "thread creation failed");
}
std::function<void*()> f_;
@@ -81,7 +83,9 @@ class ScopedThread {
void* retval_ = nullptr;
};
+#ifdef __linux__
inline pid_t gettid() { return syscall(SYS_gettid); }
+#endif
} // namespace testing
} // namespace gvisor
diff --git a/test/util/uid_util.cc b/test/util/uid_util.cc
new file mode 100644
index 000000000..b131b4b99
--- /dev/null
+++ b/test/util/uid_util.cc
@@ -0,0 +1,44 @@
+// Copyright 2018 The gVisor Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+#include "test/util/posix_error.h"
+#include "test/util/save_util.h"
+
+namespace gvisor {
+namespace testing {
+
+PosixErrorOr<bool> IsRoot() {
+ uid_t ruid, euid, suid;
+ int rc = getresuid(&ruid, &euid, &suid);
+ MaybeSave();
+ if (rc < 0) {
+ return PosixError(errno, "getresuid");
+ }
+ if (ruid != 0 || euid != 0 || suid != 0) {
+ return false;
+ }
+ gid_t rgid, egid, sgid;
+ rc = getresgid(&rgid, &egid, &sgid);
+ MaybeSave();
+ if (rc < 0) {
+ return PosixError(errno, "getresgid");
+ }
+ if (rgid != 0 || egid != 0 || sgid != 0) {
+ return false;
+ }
+ return true;
+}
+
+} // namespace testing
+} // namespace gvisor
diff --git a/test/util/uid_util.h b/test/util/uid_util.h
new file mode 100644
index 000000000..2cd387fb0
--- /dev/null
+++ b/test/util/uid_util.h
@@ -0,0 +1,29 @@
+// Copyright 2018 The gVisor Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+#ifndef GVISOR_TEST_SYSCALLS_UID_UTIL_H_
+#define GVISOR_TEST_SYSCALLS_UID_UTIL_H_
+
+#include "test/util/posix_error.h"
+
+namespace gvisor {
+namespace testing {
+
+// Returns true if the caller's real/effective/saved user/group IDs are all 0.
+PosixErrorOr<bool> IsRoot();
+
+} // namespace testing
+} // namespace gvisor
+
+#endif // GVISOR_TEST_SYSCALLS_UID_UTIL_H_
diff --git a/third_party/gvsync/downgradable_rwmutex_unsafe.go b/third_party/gvsync/downgradable_rwmutex_unsafe.go
index 069939033..1f6007aa1 100644
--- a/third_party/gvsync/downgradable_rwmutex_unsafe.go
+++ b/third_party/gvsync/downgradable_rwmutex_unsafe.go
@@ -57,9 +57,6 @@ func (rw *DowngradableRWMutex) RLock() {
// RUnlock undoes a single RLock call.
func (rw *DowngradableRWMutex) RUnlock() {
if RaceEnabled {
- // TODO(jamieliu): Why does this need to be ReleaseMerge instead of
- // Release? IIUC this establishes Unlock happens-before RUnlock, which
- // seems unnecessary.
RaceReleaseMerge(unsafe.Pointer(&rw.writerSem))
RaceDisable()
}
diff --git a/tools/go_branch.sh b/tools/go_branch.sh
index d9e79401d..ddb9b6e7b 100755
--- a/tools/go_branch.sh
+++ b/tools/go_branch.sh
@@ -59,7 +59,11 @@ git checkout -b go "${go_branch}"
# Start working on a merge commit that combines the previous history with the
# current history. Note that we don't actually want any changes yet.
-git merge --allow-unrelated-histories --no-commit --strategy ours ${head}
+#
+# N.B. The git behavior changed at some point and the relevant flag was added
+# to allow for override, so try the only behavior first then pass the flag.
+git merge --no-commit --strategy ours ${head} || \
+ git merge --allow-unrelated-histories --no-commit --strategy ours ${head}
# Sync the entire gopath_dir and go.mod.
rsync --recursive --verbose --delete --exclude .git --exclude README.md -L "${gopath_dir}/" .
diff --git a/tools/go_marshal/BUILD b/tools/go_marshal/BUILD
new file mode 100644
index 000000000..c862b277c
--- /dev/null
+++ b/tools/go_marshal/BUILD
@@ -0,0 +1,14 @@
+load("@io_bazel_rules_go//go:def.bzl", "go_binary")
+
+package(licenses = ["notice"])
+
+go_binary(
+ name = "go_marshal",
+ srcs = ["main.go"],
+ visibility = [
+ "//:sandbox",
+ ],
+ deps = [
+ "//tools/go_marshal/gomarshal",
+ ],
+)
diff --git a/tools/go_marshal/README.md b/tools/go_marshal/README.md
new file mode 100644
index 000000000..481575bd3
--- /dev/null
+++ b/tools/go_marshal/README.md
@@ -0,0 +1,164 @@
+This package implements the go_marshal utility.
+
+# Overview
+
+`go_marshal` is a code generation utility similar to `go_stateify` for
+automatically generating code to marshal go data structures to memory.
+
+`go_marshal` attempts to improve on `binary.Write` and the sentry's
+`binary.Marshal` by moving the go runtime reflection necessary to marshal a
+struct to compile-time.
+
+`go_marshal` automatically generates implementations for `abi.Marshallable` and
+`safemem.{Reader,Writer}`. Call-sites for serialization (typically syscall
+implementations) can directly invoke `safemem.Reader.ReadToBlocks` and
+`safemem.Writer.WriteFromBlocks`. Data structures that require custom
+serialization will have manual implementations for these interfaces.
+
+Data structures can be flagged for code generation by adding a struct-level
+comment `// +marshal`.
+
+# Usage
+
+See `defs.bzl`: two new rules are provided, `go_marshal` and `go_library`.
+
+The recommended way to generate a go library with marshalling is to use the
+`go_library` with mostly identical configuration as the native go_library rule.
+
+```
+load("<PKGPATH>/gvisor/tools/go_marshal:defs.bzl", "go_library")
+
+go_library(
+ name = "foo",
+ srcs = ["foo.go"],
+)
+```
+
+Under the hood, the `go_marshal` rule is used to generate a file that will
+appear in a Go target; the output file should appear explicitly in a srcs list.
+For example (note that the above is the preferred method):
+
+```
+load("<PKGPATH>/gvisor/tools/go_marshal:defs.bzl", "go_marshal")
+
+go_marshal(
+ name = "foo_abi",
+ srcs = ["foo.go"],
+ out = "foo_abi.go",
+ package = "foo",
+)
+
+go_library(
+ name = "foo",
+ srcs = [
+ "foo.go",
+ "foo_abi.go",
+ ],
+ deps = [
+ "<PKGPATH>/gvisor/pkg/abi",
+ "<PKGPATH>/gvisor/pkg/sentry/safemem/safemem",
+ "<PKGPATH>/gvisor/pkg/sentry/usermem/usermem",
+ ],
+)
+```
+
+As part of the interface generation, `go_marshal` also generates some tests for
+sanity checking the struct definitions for potential alignment issues, and a
+simple round-trip test through Marshal/Unmarshal to verify the implementation.
+These tests use reflection to verify properties of the ABI struct, and should be
+considered part of the generated interfaces (but are too expensive to execute at
+runtime). Ensure these tests run at some point.
+
+```
+$ cat BUILD
+load("<PKGPATH>/gvisor/tools/go_marshal:defs.bzl", "go_library")
+
+go_library(
+ name = "foo",
+ srcs = ["foo.go"],
+)
+$ blaze build :foo
+$ blaze query ...
+<path-to-dir>:foo_abi_autogen
+<path-to-dir>:foo_abi_autogen_test
+$ blaze test :foo_abi_autogen_test
+<test-output>
+```
+
+# Restrictions
+
+Not all valid go type definitions can be used with `go_marshal`. `go_marshal` is
+intended for ABI structs, which have these additional restrictions:
+
+- At the moment, `go_marshal` only supports struct declarations.
+
+- Structs are marshalled as packed types. This means no implicit padding is
+ inserted between fields shorter than the platform register size. For
+ alignment, manually insert padding fields.
+
+- Structs used with `go_marshal` must have a compile-time static size. This
+ means no dynamically sizes fields like slices or strings. Use statically
+ sized array (byte arrays for strings) instead.
+
+- No pointers, channel, map or function pointer fields, and no fields that are
+ arrays of these types. These don't make sense in an ABI data structure.
+
+- We could support opaque pointers as `uintptr`, but this is currently not
+ implemented. Implementing this would require handling the architecture
+ dependent native pointer size.
+
+- Fields must either be a primitive integer type (`byte`,
+ `[u]int{8,16,32,64}`), or of a type that implements abi.Marshallable.
+
+- `int` and `uint` fields are not allowed. Use an explicitly-sized numeric
+ type.
+
+- `float*` fields are currently not supported, but could be if necessary.
+
+# Appendix
+
+## Working with Non-Packed Structs
+
+ABI structs must generally be packed types, meaning they should have no implicit
+padding between short fields. However, if a field is tagged
+`marshal:"unaligned"`, `go_marshal` will fall back to a safer but slower
+mechanism to deal with potentially unaligned fields.
+
+Note that the non-packed property is inheritted by any other struct that embeds
+this struct, since the `go_marshal` tool currently can't reason about alignments
+for embedded structs that are not aligned.
+
+Because of this, it's generally best to avoid using `marshal:"unaligned"` and
+insert explicit padding fields instead.
+
+## Debugging go_marshal
+
+To enable debugging output from the go marshal tool, pass the `-debug` flag to
+the tool. When using the build rules from above, add a `debug = True` field to
+the build rule like this:
+
+```
+load("<PKGPATH>/gvisor/tools/go_marshal:defs.bzl", "go_library")
+
+go_library(
+ name = "foo",
+ srcs = ["foo.go"],
+ debug = True,
+)
+```
+
+## Modifying the `go_marshal` Tool
+
+The following are some guidelines for modifying the `go_marshal` tool:
+
+- The `go_marshal` tool currently does a single pass over all types requesting
+ code generation, in arbitrary order. This means the generated code can't
+ directly obtain information about embedded marshallable types at
+ compile-time. One way to work around this restriction is to add a new
+ Marshallable interface method providing this piece of information, and
+ calling it from the generated code. Use this sparingly, as we want to rely
+ on compile-time information as much as possible for performance.
+
+- No runtime reflection in the code generated for the marshallable interface.
+ The entire point of the tool is to avoid runtime reflection. The generated
+ tests may use reflection.
diff --git a/tools/go_marshal/analysis/BUILD b/tools/go_marshal/analysis/BUILD
new file mode 100644
index 000000000..c859ced77
--- /dev/null
+++ b/tools/go_marshal/analysis/BUILD
@@ -0,0 +1,13 @@
+load("@io_bazel_rules_go//go:def.bzl", "go_library")
+
+package(licenses = ["notice"])
+
+go_library(
+ name = "analysis",
+ testonly = 1,
+ srcs = ["analysis_unsafe.go"],
+ importpath = "gvisor.dev/gvisor/tools/go_marshal/analysis",
+ visibility = [
+ "//:sandbox",
+ ],
+)
diff --git a/tools/go_marshal/analysis/analysis_unsafe.go b/tools/go_marshal/analysis/analysis_unsafe.go
new file mode 100644
index 000000000..9a9a4f298
--- /dev/null
+++ b/tools/go_marshal/analysis/analysis_unsafe.go
@@ -0,0 +1,175 @@
+// Copyright 2019 The gVisor Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+// Package analysis implements common functionality used by generated
+// go_marshal tests.
+package analysis
+
+// All functions in this package are unsafe and are not intended for general
+// consumption. They contain sharp edge cases and the caller is responsible for
+// ensuring none of them are hit. Callers must be carefully to pass in only sane
+// arguments. Failure to do so may cause panics at best and arbitrary memory
+// corruption at worst.
+//
+// Never use outside of tests.
+
+import (
+ "fmt"
+ "math/rand"
+ "reflect"
+ "testing"
+ "unsafe"
+)
+
+// RandomizeValue assigns random value(s) to an abitrary type. This is intended
+// for used with ABI structs from go_marshal, meaning the typical restrictions
+// apply (fixed-size types, no pointers, maps, channels, etc), and should only
+// be used on zeroed values to avoid overwriting pointers to active go objects.
+//
+// Internally, we populate the type with random data by doing an unsafe cast to
+// access the underlying memory of the type and filling it as if it were a byte
+// slice. This almost gets us what we want, but padding fields named "_" are
+// normally not accessible, so we walk the type and recursively zero all "_"
+// fields.
+//
+// Precondition: x must be a pointer. x must not contain any valid
+// pointers to active go objects (pointer fields aren't allowed in ABI
+// structs anyways), or we'd be violating the go runtime contract and
+// the GC may malfunction.
+func RandomizeValue(x interface{}) {
+ v := reflect.Indirect(reflect.ValueOf(x))
+ if !v.CanSet() {
+ panic("RandomizeType() called with an unaddressable value. You probably need to pass a pointer to the argument")
+ }
+
+ // Cast the underlying memory for the type into a byte slice.
+ var b []byte
+ hdr := (*reflect.SliceHeader)(unsafe.Pointer(&b))
+ // Note: v.UnsafeAddr panics if x is passed by value. x should be a pointer.
+ hdr.Data = v.UnsafeAddr()
+ hdr.Len = int(v.Type().Size())
+ hdr.Cap = hdr.Len
+
+ // Fill the byte slice with random data, which in effect fills the type with
+ // random values.
+ n, err := rand.Read(b)
+ if err != nil || n != len(b) {
+ panic("unreachable")
+ }
+
+ // Normally, padding fields are not accessible, so zero them out.
+ reflectZeroPaddingFields(v.Type(), b, false)
+}
+
+// reflectZeroPaddingFields assigns zero values to padding fields for the value
+// of type r, represented by the memory in data. Padding fields are defined as
+// fields with the name "_". If zero is true, the immediate value itself is
+// zeroed. In addition, the type is recursively scanned for padding fields in
+// inner types.
+//
+// This is used for zeroing padding fields after calling RandomizeValue.
+func reflectZeroPaddingFields(r reflect.Type, data []byte, zero bool) {
+ if zero {
+ for i, _ := range data {
+ data[i] = 0
+ }
+ }
+ switch r.Kind() {
+ case reflect.Int8, reflect.Uint8, reflect.Int16, reflect.Uint16, reflect.Int32, reflect.Uint32, reflect.Int64, reflect.Uint64:
+ // These types are explicitly allowed in an ABI type, but we don't need
+ // to recurse further as they're scalar types.
+ case reflect.Struct:
+ for i, numFields := 0, r.NumField(); i < numFields; i++ {
+ f := r.Field(i)
+ off := f.Offset
+ len := f.Type.Size()
+ window := data[off : off+len]
+ reflectZeroPaddingFields(f.Type, window, f.Name == "_")
+ }
+ case reflect.Array:
+ eLen := int(r.Elem().Size())
+ if int(r.Size()) != eLen*r.Len() {
+ panic("Array has unexpected size?")
+ }
+ for i, n := 0, r.Len(); i < n; i++ {
+ reflectZeroPaddingFields(r.Elem(), data[i*eLen:(i+1)*eLen], false)
+ }
+ default:
+ panic(fmt.Sprintf("Type %v not allowed in ABI struct", r.Kind()))
+
+ }
+}
+
+// AlignmentCheck ensures the definition of the type represented by typ doesn't
+// cause the go compiler to emit implicit padding between elements of the type
+// (i.e. fields in a struct).
+//
+// AlignmentCheck doesn't explicitly recurse for embedded structs because any
+// struct present in an ABI struct must also be Marshallable, and therefore
+// they're aligned by definition (or their alignment check would have failed).
+func AlignmentCheck(t *testing.T, typ reflect.Type) (ok bool, delta uint64) {
+ switch typ.Kind() {
+ case reflect.Int8, reflect.Uint8, reflect.Int16, reflect.Uint16, reflect.Int32, reflect.Uint32, reflect.Int64, reflect.Uint64:
+ // Primitive types are always considered well aligned. Primitive types
+ // that are fields in structs are checked independently, this branch
+ // exists to handle recursive calls to alignmentCheck.
+ case reflect.Struct:
+ xOff := 0
+ nextXOff := 0
+ skipNext := false
+ for i, numFields := 0, typ.NumField(); i < numFields; i++ {
+ xOff = nextXOff
+ f := typ.Field(i)
+ fmt.Printf("Checking alignment of %s.%s @ %d [+%d]...\n", typ.Name(), f.Name, f.Offset, f.Type.Size())
+ nextXOff = int(f.Offset + f.Type.Size())
+
+ if f.Name == "_" {
+ // Padding fields need not be aligned.
+ fmt.Printf("Padding field of type %v\n", f.Type)
+ continue
+ }
+
+ if tag, ok := f.Tag.Lookup("marshal"); ok && tag == "unaligned" {
+ skipNext = true
+ continue
+ }
+
+ if skipNext {
+ skipNext = false
+ fmt.Printf("Skipping alignment check for field %s.%s explicitly marked as unaligned.\n", typ.Name(), f.Name)
+ continue
+ }
+
+ if xOff != int(f.Offset) {
+ implicitPad := int(f.Offset) - xOff
+ t.Fatalf("Suspect offset for field %s.%s, detected an implicit %d byte padding from offset %d to %d; either add %d bytes of explicit padding before this field or tag it as `marshal:\"unaligned\"`.", typ.Name(), f.Name, implicitPad, xOff, f.Offset, implicitPad)
+ }
+ }
+
+ // Ensure structs end on a byte explicitly defined by the type.
+ if typ.NumField() > 0 && nextXOff != int(typ.Size()) {
+ implicitPad := int(typ.Size()) - nextXOff
+ f := typ.Field(typ.NumField() - 1) // Final field
+ t.Fatalf("Suspect offset for field %s.%s at the end of %s, detected an implicit %d byte padding from offset %d to %d at the end of the struct; either add %d bytes of explict padding at end of the struct or tag the final field %s as `marshal:\"unaligned\"`.",
+ typ.Name(), f.Name, typ.Name(), implicitPad, nextXOff, typ.Size(), implicitPad, f.Name)
+ }
+ case reflect.Array:
+ // Independent arrays are also always considered well aligned. We only
+ // need to worry about their alignment when they're embedded in structs,
+ // which we handle above.
+ default:
+ t.Fatalf("Unsupported type in ABI struct while checking for field alignment for type: %v", typ.Kind())
+ }
+ return true, uint64(typ.Size())
+}
diff --git a/tools/go_marshal/defs.bzl b/tools/go_marshal/defs.bzl
new file mode 100644
index 000000000..c32eb559f
--- /dev/null
+++ b/tools/go_marshal/defs.bzl
@@ -0,0 +1,152 @@
+"""Marshal is a tool for generating marshalling interfaces for Go types.
+
+The recommended way is to use the go_library rule defined below with mostly
+identical configuration as the native go_library rule.
+
+load("//tools/go_marshal:defs.bzl", "go_library")
+
+go_library(
+ name = "foo",
+ srcs = ["foo.go"],
+)
+
+Under the hood, the go_marshal rule is used to generate a file that will
+appear in a Go target; the output file should appear explicitly in a srcs list.
+For example (the above is still the preferred way):
+
+load("//tools/go_marshal:defs.bzl", "go_marshal")
+
+go_marshal(
+ name = "foo_abi",
+ srcs = ["foo.go"],
+ out = "foo_abi.go",
+ package = "foo",
+)
+
+go_library(
+ name = "foo",
+ srcs = [
+ "foo.go",
+ "foo_abi.go",
+ ],
+ deps = [
+ "//tools/go_marshal:marshal",
+ "//pkg/sentry/platform/safecopy",
+ "//pkg/sentry/usermem",
+ ],
+)
+"""
+
+load("@io_bazel_rules_go//go:def.bzl", _go_library = "go_library", _go_test = "go_test")
+
+def _go_marshal_impl(ctx):
+ """Execute the go_marshal tool."""
+ output = ctx.outputs.lib
+ output_test = ctx.outputs.test
+ (build_dir, _, _) = ctx.build_file_path.rpartition("/BUILD")
+
+ decl = "/".join(["gvisor.dev/gvisor", build_dir])
+
+ # Run the marshal command.
+ args = ["-output=%s" % output.path]
+ args += ["-pkg=%s" % ctx.attr.package]
+ args += ["-output_test=%s" % output_test.path]
+ args += ["-declarationPkg=%s" % decl]
+
+ if ctx.attr.debug:
+ args += ["-debug"]
+
+ args += ["--"]
+ for src in ctx.attr.srcs:
+ args += [f.path for f in src.files.to_list()]
+ ctx.actions.run(
+ inputs = ctx.files.srcs,
+ outputs = [output, output_test],
+ mnemonic = "GoMarshal",
+ progress_message = "go_marshal: %s" % ctx.label,
+ arguments = args,
+ executable = ctx.executable._tool,
+ )
+
+# Generates save and restore logic from a set of Go files.
+#
+# Args:
+# name: the name of the rule.
+# srcs: the input source files. These files should include all structs in the
+# package that need to be saved.
+# imports: an optional list of extra, non-aliased, Go-style absolute import
+# paths.
+# out: the name of the generated file output. This must not conflict with any
+# other files and must be added to the srcs of the relevant go_library.
+# package: the package name for the input sources.
+go_marshal = rule(
+ implementation = _go_marshal_impl,
+ attrs = {
+ "srcs": attr.label_list(mandatory = True, allow_files = True),
+ "libname": attr.string(mandatory = True),
+ "imports": attr.string_list(mandatory = False),
+ "package": attr.string(mandatory = True),
+ "debug": attr.bool(doc = "enable debugging output from the go_marshal tool"),
+ "_tool": attr.label(executable = True, cfg = "host", default = Label("//tools/go_marshal:go_marshal")),
+ },
+ outputs = {
+ "lib": "%{name}_unsafe.go",
+ "test": "%{name}_test.go",
+ },
+)
+
+def go_library(name, srcs, deps = [], imports = [], debug = False, **kwargs):
+ """wraps the standard go_library and does mashalling interface generation.
+
+ Args:
+ name: Same as native go_library.
+ srcs: Same as native go_library.
+ deps: Same as native go_library.
+ imports: Extra import paths to pass to the go_marshal tool.
+ debug: Enables debugging output from the go_marshal tool.
+ **kwargs: Remaining args to pass to the native go_library rule unmodified.
+ """
+ go_marshal(
+ name = name + "_abi_autogen",
+ libname = name,
+ srcs = [src for src in srcs if src.endswith(".go")],
+ debug = debug,
+ imports = imports,
+ package = name,
+ )
+
+ extra_deps = [
+ "//tools/go_marshal/marshal",
+ "//pkg/sentry/platform/safecopy",
+ "//pkg/sentry/usermem",
+ ]
+
+ all_srcs = srcs + [name + "_abi_autogen_unsafe.go"]
+ all_deps = deps + [] # + extra_deps
+
+ for extra in extra_deps:
+ if extra not in deps:
+ all_deps.append(extra)
+
+ _go_library(
+ name = name,
+ srcs = all_srcs,
+ deps = all_deps,
+ **kwargs
+ )
+
+ # Don't pass importpath arg to go_test.
+ kwargs.pop("importpath", "")
+
+ _go_test(
+ name = name + "_abi_autogen_test",
+ srcs = [name + "_abi_autogen_test.go"],
+ # Generated test has a fixed set of dependencies since we generate these
+ # tests. They should only depend on the library generated above, and the
+ # Marshallable interface.
+ deps = [
+ ":" + name,
+ "//tools/go_marshal/analysis",
+ ],
+ **kwargs
+ )
diff --git a/tools/go_marshal/gomarshal/BUILD b/tools/go_marshal/gomarshal/BUILD
new file mode 100644
index 000000000..a0eae6492
--- /dev/null
+++ b/tools/go_marshal/gomarshal/BUILD
@@ -0,0 +1,17 @@
+load("@io_bazel_rules_go//go:def.bzl", "go_library")
+
+package(licenses = ["notice"])
+
+go_library(
+ name = "gomarshal",
+ srcs = [
+ "generator.go",
+ "generator_interfaces.go",
+ "generator_tests.go",
+ "util.go",
+ ],
+ importpath = "gvisor.dev/gvisor/tools/go_marshal/gomarshal",
+ visibility = [
+ "//:sandbox",
+ ],
+)
diff --git a/tools/go_marshal/gomarshal/generator.go b/tools/go_marshal/gomarshal/generator.go
new file mode 100644
index 000000000..641ccd938
--- /dev/null
+++ b/tools/go_marshal/gomarshal/generator.go
@@ -0,0 +1,382 @@
+// Copyright 2019 The gVisor Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+// Package gomarshal implements the go_marshal code generator. See README.md.
+package gomarshal
+
+import (
+ "bytes"
+ "fmt"
+ "go/ast"
+ "go/parser"
+ "go/token"
+ "os"
+ "sort"
+)
+
+const (
+ marshalImport = "gvisor.dev/gvisor/tools/go_marshal/marshal"
+ usermemImport = "gvisor.dev/gvisor/pkg/sentry/usermem"
+ safecopyImport = "gvisor.dev/gvisor/pkg/sentry/platform/safecopy"
+)
+
+// List of identifiers we use in generated code, that may conflict a
+// similarly-named source identifier. Avoid problems by refusing the generate
+// code when we see these.
+//
+// This only applies to import aliases at the moment. All other identifiers
+// are qualified by a receiver argument, since they're struct fields.
+//
+// All recievers are single letters, so we don't allow import aliases to be a
+// single letter.
+var badIdents = []string{
+ "src", "srcs", "dst", "dsts", "blk", "buf", "err",
+ // All single-letter identifiers.
+}
+
+// Generator drives code generation for a single invocation of the go_marshal
+// utility.
+//
+// The Generator holds arguments passed to the tool, and drives parsing,
+// processing and code Generator for all types marked with +marshal declared in
+// the input files.
+//
+// See Generator.run() as the entry point.
+type Generator struct {
+ // Paths to input go source files.
+ inputs []string
+ // Output file to write generated go source.
+ output *os.File
+ // Output file to write generated tests.
+ outputTest *os.File
+ // Package name for the generated file.
+ pkg string
+ // Go import path for package we're processing. This package should directly
+ // declare the type we're generating code for.
+ declaration string
+ // Set of extra packages to import in the generated file.
+ imports *importTable
+}
+
+// NewGenerator creates a new code Generator.
+func NewGenerator(srcs []string, out, outTest, pkg, declaration string, imports []string) (*Generator, error) {
+ f, err := os.OpenFile(out, os.O_CREATE|os.O_WRONLY|os.O_TRUNC, 0644)
+ if err != nil {
+ return nil, fmt.Errorf("Couldn't open output file %q: %v", out, err)
+ }
+ fTest, err := os.OpenFile(outTest, os.O_CREATE|os.O_WRONLY|os.O_TRUNC, 0644)
+ if err != nil {
+ return nil, fmt.Errorf("Couldn't open test output file %q: %v", out, err)
+ }
+ g := Generator{
+ inputs: srcs,
+ output: f,
+ outputTest: fTest,
+ pkg: pkg,
+ declaration: declaration,
+ imports: newImportTable(),
+ }
+ for _, i := range imports {
+ // All imports on the extra imports list are unconditionally marked as
+ // used, so they're always added to the generated code.
+ g.imports.add(i).markUsed()
+ }
+ g.imports.add(marshalImport).markUsed()
+ // The follow imports may or may not be used by the generated
+ // code, depending what's required for the target types. Don't
+ // mark these imports as used by default.
+ g.imports.add(usermemImport)
+ g.imports.add(safecopyImport)
+ g.imports.add("unsafe")
+
+ return &g, nil
+}
+
+// writeHeader writes the header for the generated source file. The header
+// includes the package name, package level comments and import statements.
+func (g *Generator) writeHeader() error {
+ var b sourceBuffer
+ b.emit("// Automatically generated marshal implementation. See tools/go_marshal.\n\n")
+ b.emit("package %s\n\n", g.pkg)
+ if err := b.write(g.output); err != nil {
+ return err
+ }
+
+ return g.imports.write(g.output)
+}
+
+// writeTypeChecks writes a statement to force the compiler to perform a type
+// check for all Marshallable types referenced by the generated code.
+func (g *Generator) writeTypeChecks(ms map[string]struct{}) error {
+ if len(ms) == 0 {
+ return nil
+ }
+
+ msl := make([]string, 0, len(ms))
+ for m, _ := range ms {
+ msl = append(msl, m)
+ }
+ sort.Strings(msl)
+
+ var buf bytes.Buffer
+ fmt.Fprint(&buf, "// Marshallable types used by this file.\n")
+
+ for _, m := range msl {
+ fmt.Fprintf(&buf, "var _ marshal.Marshallable = (*%s)(nil)\n", m)
+ }
+ fmt.Fprint(&buf, "\n")
+
+ _, err := fmt.Fprint(g.output, buf.String())
+ return err
+}
+
+// parse processes all input files passed this generator and produces a set of
+// parsed go ASTs.
+func (g *Generator) parse() ([]*ast.File, []*token.FileSet, error) {
+ debugf("go_marshal invoked with %d input files:\n", len(g.inputs))
+ for _, path := range g.inputs {
+ debugf(" %s\n", path)
+ }
+
+ files := make([]*ast.File, 0, len(g.inputs))
+ fsets := make([]*token.FileSet, 0, len(g.inputs))
+
+ for _, path := range g.inputs {
+ fset := token.NewFileSet()
+ f, err := parser.ParseFile(fset, path, nil, parser.ParseComments)
+ if err != nil {
+ // Not a valid input file?
+ return nil, nil, fmt.Errorf("Input %q can't be parsed: %v", path, err)
+ }
+
+ if debugEnabled() {
+ debugf("AST for %q:\n", path)
+ ast.Print(fset, f)
+ }
+
+ files = append(files, f)
+ fsets = append(fsets, fset)
+ }
+
+ return files, fsets, nil
+}
+
+// collectMarshallabeTypes walks the parsed AST and collects a list of type
+// declarations for which we need to generate the Marshallable interface.
+func (g *Generator) collectMarshallabeTypes(a *ast.File, f *token.FileSet) []*ast.TypeSpec {
+ var types []*ast.TypeSpec
+ for _, decl := range a.Decls {
+ gdecl, ok := decl.(*ast.GenDecl)
+ // Type declaration?
+ if !ok || gdecl.Tok != token.TYPE {
+ debugfAt(f.Position(decl.Pos()), "Skipping declaration since it's not a type declaration.\n")
+ continue
+ }
+ // Does it have a comment?
+ if gdecl.Doc == nil {
+ debugfAt(f.Position(gdecl.Pos()), "Skipping declaration since it doesn't have a comment.\n")
+ continue
+ }
+ // Does the comment contain a "+marshal" line?
+ marked := false
+ for _, c := range gdecl.Doc.List {
+ if c.Text == "// +marshal" {
+ marked = true
+ break
+ }
+ }
+ if !marked {
+ debugfAt(f.Position(gdecl.Pos()), "Skipping declaration since it doesn't have a comment containing +marshal line.\n")
+ continue
+ }
+ for _, spec := range gdecl.Specs {
+ // We already confirmed we're in a type declaration earlier.
+ t := spec.(*ast.TypeSpec)
+ if _, ok := t.Type.(*ast.StructType); ok {
+ debugfAt(f.Position(t.Pos()), "Collected marshallable type %s.\n", t.Name.Name)
+ types = append(types, t)
+ continue
+ }
+ debugf("Skipping declaration %v since it's not a struct declaration.\n", gdecl)
+ }
+ }
+ return types
+}
+
+// collectImports collects all imports from all input source files. Some of
+// these imports are copied to the generated output, if they're referenced by
+// the generated code.
+//
+// collectImports de-duplicates imports while building the list, and ensures
+// identifiers in the generated code don't conflict with any imported package
+// names.
+func (g *Generator) collectImports(a *ast.File, f *token.FileSet) map[string]importStmt {
+ badImportNames := make(map[string]bool)
+ for _, i := range badIdents {
+ badImportNames[i] = true
+ }
+
+ is := make(map[string]importStmt)
+ for _, decl := range a.Decls {
+ gdecl, ok := decl.(*ast.GenDecl)
+ // Import statement?
+ if !ok || gdecl.Tok != token.IMPORT {
+ continue
+ }
+ for _, spec := range gdecl.Specs {
+ i := g.imports.addFromSpec(spec.(*ast.ImportSpec), f)
+ debugf("Collected import '%s' as '%s'\n", i.path, i.name)
+
+ // Make sure we have an import that doesn't use any local names that
+ // would conflict with identifiers in the generated code.
+ if len(i.name) == 1 {
+ abortAt(f.Position(spec.Pos()), fmt.Sprintf("Import has a single character local name '%s'; this may conflict with code generated by go_marshal, use a multi-character import alias", i.name))
+ }
+ if badImportNames[i.name] {
+ abortAt(f.Position(spec.Pos()), fmt.Sprintf("Import name '%s' is likely to conflict with code generated by go_marshal, use a different import alias", i.name))
+ }
+ }
+ }
+ return is
+
+}
+
+func (g *Generator) generateOne(t *ast.TypeSpec, fset *token.FileSet) *interfaceGenerator {
+ // We're guaranteed to have only struct type specs by now. See
+ // Generator.collectMarshallabeTypes.
+ i := newInterfaceGenerator(t, fset)
+ i.validate()
+ i.emitMarshallable()
+ return i
+}
+
+// generateOneTestSuite generates a test suite for the automatically generated
+// implementations type t.
+func (g *Generator) generateOneTestSuite(t *ast.TypeSpec) *testGenerator {
+ i := newTestGenerator(t, g.declaration)
+ i.emitTests()
+ return i
+}
+
+// Run is the entry point to code generation using g.
+//
+// Run parses all input source files specified in g and emits generated code.
+func (g *Generator) Run() error {
+ // Parse our input source files into ASTs and token sets.
+ asts, fsets, err := g.parse()
+ if err != nil {
+ return err
+ }
+
+ if len(asts) != len(fsets) {
+ panic("ASTs and FileSets don't match")
+ }
+
+ // Map of imports in source files; key = local package name, value = import
+ // path.
+ is := make(map[string]importStmt)
+ for i, a := range asts {
+ // Collect all imports from the source files. We may need to copy some
+ // of these to the generated code if they're referenced. This has to be
+ // done before the loop below because we need to process all ASTs before
+ // we start requesting imports to be copied one by one as we encounter
+ // them in each generated source.
+ for name, i := range g.collectImports(a, fsets[i]) {
+ is[name] = i
+ }
+ }
+
+ var impls []*interfaceGenerator
+ var ts []*testGenerator
+ // Set of Marshallable types referenced by generated code.
+ ms := make(map[string]struct{})
+ for i, a := range asts {
+ // Collect type declarations marked for code generation and generate
+ // Marshallable interfaces.
+ for _, t := range g.collectMarshallabeTypes(a, fsets[i]) {
+ impl := g.generateOne(t, fsets[i])
+ // Collect Marshallable types referenced by the generated code.
+ for ref, _ := range impl.ms {
+ ms[ref] = struct{}{}
+ }
+ impls = append(impls, impl)
+ // Collect imports referenced by the generated code and add them to
+ // the list of imports we need to copy to the generated code.
+ for name, _ := range impl.is {
+ if !g.imports.markUsed(name) {
+ panic(fmt.Sprintf("Generated code for '%s' referenced a non-existent import with local name '%s'", impl.typeName(), name))
+ }
+ }
+ ts = append(ts, g.generateOneTestSuite(t))
+ }
+ }
+
+ // Tool was invoked with input files with no data structures marked for code
+ // generation. This is probably not what the user intended.
+ if len(impls) == 0 {
+ var buf bytes.Buffer
+ fmt.Fprintf(&buf, "go_marshal invoked on these files, but they don't contain any types requiring code generation. Perhaps mark some with \"// +marshal\"?:\n")
+ for _, i := range g.inputs {
+ fmt.Fprintf(&buf, " %s\n", i)
+ }
+ abort(buf.String())
+ }
+
+ // Write output file header. These include things like package name and
+ // import statements.
+ if err := g.writeHeader(); err != nil {
+ return err
+ }
+
+ // Write type checks for referenced marshallable types to output file.
+ if err := g.writeTypeChecks(ms); err != nil {
+ return err
+ }
+
+ // Write generated interfaces to output file.
+ for _, i := range impls {
+ if err := i.write(g.output); err != nil {
+ return err
+ }
+ }
+
+ // Write generated tests to test file.
+ return g.writeTests(ts)
+}
+
+// writeTests outputs tests for the generated interface implementations to a go
+// source file.
+func (g *Generator) writeTests(ts []*testGenerator) error {
+ var b sourceBuffer
+ b.emit("package %s_test\n\n", g.pkg)
+ if err := b.write(g.outputTest); err != nil {
+ return err
+ }
+
+ imports := newImportTable()
+ for _, t := range ts {
+ imports.merge(t.imports)
+ }
+
+ if err := imports.write(g.outputTest); err != nil {
+ return err
+ }
+
+ for _, t := range ts {
+ if err := t.write(g.outputTest); err != nil {
+ return err
+ }
+ }
+ return nil
+}
diff --git a/tools/go_marshal/gomarshal/generator_interfaces.go b/tools/go_marshal/gomarshal/generator_interfaces.go
new file mode 100644
index 000000000..a712c14dc
--- /dev/null
+++ b/tools/go_marshal/gomarshal/generator_interfaces.go
@@ -0,0 +1,507 @@
+// Copyright 2019 The gVisor Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+package gomarshal
+
+import (
+ "fmt"
+ "go/ast"
+ "go/token"
+ "strings"
+)
+
+// interfaceGenerator generates marshalling interfaces for a single type.
+//
+// getState is not thread-safe.
+type interfaceGenerator struct {
+ sourceBuffer
+
+ // The type we're serializing.
+ t *ast.TypeSpec
+
+ // Receiver argument for generated methods.
+ r string
+
+ // FileSet containing the tokens for the type we're processing.
+ f *token.FileSet
+
+ // is records external packages referenced by the generated implementation.
+ is map[string]struct{}
+
+ // ms records Marshallable types referenced by the generated implementation
+ // of t's interfaces.
+ ms map[string]struct{}
+
+ // as records embedded fields in t that are potentially not packed. The key
+ // is the accessor for the field.
+ as map[string]struct{}
+}
+
+// typeName returns the name of the type this g represents.
+func (g *interfaceGenerator) typeName() string {
+ return g.t.Name.Name
+}
+
+// newinterfaceGenerator creates a new interface generator.
+func newInterfaceGenerator(t *ast.TypeSpec, fset *token.FileSet) *interfaceGenerator {
+ if _, ok := t.Type.(*ast.StructType); !ok {
+ panic(fmt.Sprintf("Attempting to generate code for a not struct type %v", t))
+ }
+ g := &interfaceGenerator{
+ t: t,
+ r: receiverName(t),
+ f: fset,
+ is: make(map[string]struct{}),
+ ms: make(map[string]struct{}),
+ as: make(map[string]struct{}),
+ }
+ g.recordUsedMarshallable(g.typeName())
+ return g
+}
+
+func (g *interfaceGenerator) recordUsedMarshallable(m string) {
+ g.ms[m] = struct{}{}
+
+}
+
+func (g *interfaceGenerator) recordUsedImport(i string) {
+ g.is[i] = struct{}{}
+
+}
+
+func (g *interfaceGenerator) recordPotentiallyNonPackedField(fieldName string) {
+ g.as[fieldName] = struct{}{}
+}
+
+func (g *interfaceGenerator) forEachField(fn func(f *ast.Field)) {
+ // This is guaranteed to succeed because g.t is always a struct.
+ st := g.t.Type.(*ast.StructType)
+ for _, field := range st.Fields.List {
+ fn(field)
+ }
+}
+
+func (g *interfaceGenerator) fieldAccessor(n *ast.Ident) string {
+ return fmt.Sprintf("%s.%s", g.r, n.Name)
+}
+
+// abortAt aborts the go_marshal tool with the given error message, with a
+// reference position to the input source. Same as abortAt, but uses g to
+// resolve p to position.
+func (g *interfaceGenerator) abortAt(p token.Pos, msg string) {
+ abortAt(g.f.Position(p), msg)
+}
+
+// validate ensures the type we're working with can be marshalled. These checks
+// are done ahead of time and in one place so we can make assumptions later.
+func (g *interfaceGenerator) validate() {
+ g.forEachField(func(f *ast.Field) {
+ if len(f.Names) == 0 {
+ g.abortAt(f.Pos(), "Cannot marshal structs with embedded fields, give the field a name; use '_' for anonymous fields such as padding fields")
+ }
+ })
+
+ g.forEachField(func(f *ast.Field) {
+ fieldDispatcher{
+ primitive: func(_, t *ast.Ident) {
+ switch t.Name {
+ case "int8", "uint8", "byte", "int16", "uint16", "int32", "uint32", "int64", "uint64":
+ // These are the only primitive types we're allow. Below, we
+ // provide suggestions for some disallowed types and reject
+ // them, then attempt to marshal any remaining types by
+ // invoking the marshal.Marshallable interface on them. If
+ // these types don't actually implement
+ // marshal.Marshallable, compilation of the generated code
+ // will fail with an appropriate error message.
+ return
+ case "int":
+ g.abortAt(f.Pos(), "Type 'int' has ambiguous width, use int32 or int64")
+ case "uint":
+ g.abortAt(f.Pos(), "Type 'uint' has ambiguous width, use uint32 or uint64")
+ case "string":
+ g.abortAt(f.Pos(), "Type 'string' is dynamically-sized and cannot be marshalled, use a fixed size byte array '[...]byte' instead")
+ default:
+ debugfAt(g.f.Position(f.Pos()), fmt.Sprintf("Found derived type '%s', will attempt dispatch via marshal.Marshallable.\n", t.Name))
+ }
+ },
+ selector: func(_, _, _ *ast.Ident) {
+ // No validation to perform on selector fields. However this
+ // callback must still be provided.
+ },
+ array: func(n, _ *ast.Ident, len int) {
+ a := f.Type.(*ast.ArrayType)
+ if a.Len == nil {
+ g.abortAt(f.Pos(), fmt.Sprintf("Dynamically sized slice '%s' cannot be marshalled, arrays must be statically sized", n.Name))
+ }
+
+ if _, ok := a.Len.(*ast.BasicLit); !ok {
+ g.abortAt(a.Len.Pos(), fmt.Sprintf("Array size must be a literal, don's use consts or expressions"))
+ }
+
+ if _, ok := a.Elt.(*ast.Ident); !ok {
+ g.abortAt(a.Elt.Pos(), fmt.Sprintf("Marshalling not supported for arrays with %s elements, array elements must be primitive types", kindString(a.Elt)))
+ }
+
+ if len <= 0 {
+ g.abortAt(a.Len.Pos(), fmt.Sprintf("Marshalling not supported for zero length arrays, why does an ABI struct have one?"))
+ }
+ },
+ unhandled: func(_ *ast.Ident) {
+ g.abortAt(f.Pos(), fmt.Sprintf("Marshalling not supported for %s fields", kindString(f.Type)))
+ },
+ }.dispatch(f)
+ })
+}
+
+// scalarSize returns the size of type identified by t. If t isn't a primitive
+// type, the size isn't known at code generation time, and must be resolved via
+// the marshal.Marshallable interface.
+func (g *interfaceGenerator) scalarSize(t *ast.Ident) (size int, unknownSize bool) {
+ switch t.Name {
+ case "int8", "uint8", "byte":
+ return 1, false
+ case "int16", "uint16":
+ return 2, false
+ case "int32", "uint32":
+ return 4, false
+ case "int64", "uint64":
+ return 8, false
+ default:
+ return 0, true
+ }
+}
+
+func (g *interfaceGenerator) shift(bufVar string, n int) {
+ g.emit("%s = %s[%d:]\n", bufVar, bufVar, n)
+}
+
+func (g *interfaceGenerator) shiftDynamic(bufVar, name string) {
+ g.emit("%s = %s[%s.SizeBytes():]\n", bufVar, bufVar, name)
+}
+
+func (g *interfaceGenerator) marshalScalar(accessor, typ string, bufVar string) {
+ switch typ {
+ case "int8", "uint8", "byte":
+ g.emit("%s[0] = byte(%s)\n", bufVar, accessor)
+ g.shift(bufVar, 1)
+ case "int16", "uint16":
+ g.recordUsedImport("usermem")
+ g.emit("usermem.ByteOrder.PutUint16(%s[:2], uint16(%s))\n", bufVar, accessor)
+ g.shift(bufVar, 2)
+ case "int32", "uint32":
+ g.recordUsedImport("usermem")
+ g.emit("usermem.ByteOrder.PutUint32(%s[:4], uint32(%s))\n", bufVar, accessor)
+ g.shift(bufVar, 4)
+ case "int64", "uint64":
+ g.recordUsedImport("usermem")
+ g.emit("usermem.ByteOrder.PutUint64(%s[:8], uint64(%s))\n", bufVar, accessor)
+ g.shift(bufVar, 8)
+ default:
+ g.emit("%s.MarshalBytes(%s[:%s.SizeBytes()])\n", accessor, bufVar, accessor)
+ g.shiftDynamic(bufVar, accessor)
+ }
+}
+
+func (g *interfaceGenerator) unmarshalScalar(accessor, typ string, bufVar string) {
+ switch typ {
+ case "int8":
+ g.emit("%s = int8(%s[0])\n", accessor, bufVar)
+ g.shift(bufVar, 1)
+ case "uint8":
+ g.emit("%s = uint8(%s[0])\n", accessor, bufVar)
+ g.shift(bufVar, 1)
+ case "byte":
+ g.emit("%s = %s[0]\n", accessor, bufVar)
+ g.shift(bufVar, 1)
+
+ case "int16":
+ g.recordUsedImport("usermem")
+ g.emit("%s = int16(usermem.ByteOrder.Uint16(%s[:2]))\n", accessor, bufVar)
+ g.shift(bufVar, 2)
+ case "uint16":
+ g.recordUsedImport("usermem")
+ g.emit("%s = usermem.ByteOrder.Uint16(%s[:2])\n", accessor, bufVar)
+ g.shift(bufVar, 2)
+
+ case "int32":
+ g.recordUsedImport("usermem")
+ g.emit("%s = int32(usermem.ByteOrder.Uint32(%s[:4]))\n", accessor, bufVar)
+ g.shift(bufVar, 4)
+ case "uint32":
+ g.recordUsedImport("usermem")
+ g.emit("%s = usermem.ByteOrder.Uint32(%s[:4])\n", accessor, bufVar)
+ g.shift(bufVar, 4)
+
+ case "int64":
+ g.recordUsedImport("usermem")
+ g.emit("%s = int64(usermem.ByteOrder.Uint64(%s[:8]))\n", accessor, bufVar)
+ g.shift(bufVar, 8)
+ case "uint64":
+ g.recordUsedImport("usermem")
+ g.emit("%s = usermem.ByteOrder.Uint64(%s[:8])\n", accessor, bufVar)
+ g.shift(bufVar, 8)
+ default:
+ g.emit("%s.UnmarshalBytes(%s[:%s.SizeBytes()])\n", accessor, bufVar, accessor)
+ g.shiftDynamic(bufVar, accessor)
+ g.recordPotentiallyNonPackedField(accessor)
+ }
+}
+
+// areFieldsPackedExpression returns a go expression checking whether g.t's fields are
+// packed. Returns "", false if g.t has no fields that may be potentially
+// packed, otherwise returns <clause>, true, where <clause> is an expression
+// like "t.a.Packed() && t.b.Packed() && t.c.Packed()".
+func (g *interfaceGenerator) areFieldsPackedExpression() (string, bool) {
+ if len(g.as) == 0 {
+ return "", false
+ }
+
+ cs := make([]string, 0, len(g.as))
+ for accessor, _ := range g.as {
+ cs = append(cs, fmt.Sprintf("%s.Packed()", accessor))
+ }
+ return strings.Join(cs, " && "), true
+}
+
+func (g *interfaceGenerator) emitMarshallable() {
+ // Is g.t a packed struct without consideing field types?
+ thisPacked := true
+ g.forEachField(func(f *ast.Field) {
+ if f.Tag != nil {
+ if f.Tag.Value == "`marshal:\"unaligned\"`" {
+ if thisPacked {
+ debugfAt(g.f.Position(g.t.Pos()),
+ fmt.Sprintf("Marking type '%s' as not packed due to tag `marshal:\"unaligned\"`.\n", g.t.Name))
+ thisPacked = false
+ }
+ }
+ }
+ })
+
+ g.emit("// SizeBytes implements marshal.Marshallable.SizeBytes.\n")
+ g.emit("func (%s *%s) SizeBytes() int {\n", g.r, g.typeName())
+ g.inIndent(func() {
+ primitiveSize := 0
+ var dynamicSizeTerms []string
+
+ g.forEachField(fieldDispatcher{
+ primitive: func(n, t *ast.Ident) {
+ if size, dynamic := g.scalarSize(t); !dynamic {
+ primitiveSize += size
+ } else {
+ g.recordUsedMarshallable(t.Name)
+ dynamicSizeTerms = append(dynamicSizeTerms, fmt.Sprintf("%s.SizeBytes()", g.fieldAccessor(n)))
+ }
+ },
+ selector: func(n, tX, tSel *ast.Ident) {
+ tName := fmt.Sprintf("%s.%s", tX.Name, tSel.Name)
+ g.recordUsedImport(tX.Name)
+ g.recordUsedMarshallable(tName)
+ dynamicSizeTerms = append(dynamicSizeTerms, fmt.Sprintf("(*%s)(nil).SizeBytes()", tName))
+ },
+ array: func(n, t *ast.Ident, len int) {
+ if len < 1 {
+ // Zero-length arrays should've been rejected by validate().
+ panic("unreachable")
+ }
+ if size, dynamic := g.scalarSize(t); !dynamic {
+ primitiveSize += size * len
+ } else {
+ g.recordUsedMarshallable(t.Name)
+ dynamicSizeTerms = append(dynamicSizeTerms, fmt.Sprintf("(*%s)(nil).SizeBytes()*%d", t.Name, len))
+ }
+ },
+ }.dispatch)
+ g.emit("return %d", primitiveSize)
+ if len(dynamicSizeTerms) > 0 {
+ g.incIndent()
+ }
+ {
+ for _, d := range dynamicSizeTerms {
+ g.emitNoIndent(" +\n")
+ g.emit(d)
+ }
+ }
+ if len(dynamicSizeTerms) > 0 {
+ g.decIndent()
+ }
+ })
+ g.emit("\n}\n\n")
+
+ g.emit("// MarshalBytes implements marshal.Marshallable.MarshalBytes.\n")
+ g.emit("func (%s *%s) MarshalBytes(dst []byte) {\n", g.r, g.typeName())
+ g.inIndent(func() {
+ g.forEachField(fieldDispatcher{
+ primitive: func(n, t *ast.Ident) {
+ if n.Name == "_" {
+ g.emit("// Padding: dst[:sizeof(%s)] ~= %s(0)\n", t.Name, t.Name)
+ if len, dynamic := g.scalarSize(t); !dynamic {
+ g.shift("dst", len)
+ } else {
+ // We can't use shiftDynamic here because we don't have
+ // an instance of the dynamic type we can referece here
+ // (since the version in this struct is anonymous). Use
+ // a typed nil pointer to call SizeBytes() instead.
+ g.emit("dst = dst[(*%s)(nil).SizeBytes():]\n", t.Name)
+ }
+ return
+ }
+ g.marshalScalar(g.fieldAccessor(n), t.Name, "dst")
+ },
+ selector: func(n, tX, tSel *ast.Ident) {
+ g.marshalScalar(g.fieldAccessor(n), fmt.Sprintf("%s.%s", tX.Name, tSel.Name), "dst")
+ },
+ array: func(n, t *ast.Ident, size int) {
+ if n.Name == "_" {
+ g.emit("// Padding: dst[:sizeof(%s)*%d] ~= [%d]%s{0}\n", t.Name, size, size, t.Name)
+ if len, dynamic := g.scalarSize(t); !dynamic {
+ g.shift("dst", len*size)
+ } else {
+ // We can't use shiftDynamic here because we don't have
+ // an instance of the dynamic type we can reference here
+ // (since the version in this struct is anonymous). Use
+ // a typed nil pointer to call SizeBytes() instead.
+ g.emit("dst = dst[(*%s)(nil).SizeBytes()*%d:]\n", t.Name, size)
+ }
+ return
+ }
+
+ g.emit("for i := 0; i < %d; i++ {\n", size)
+ g.inIndent(func() {
+ g.marshalScalar(fmt.Sprintf("%s[i]", g.fieldAccessor(n)), t.Name, "dst")
+ })
+ g.emit("}\n")
+ },
+ }.dispatch)
+ })
+ g.emit("}\n\n")
+
+ g.emit("// UnmarshalBytes implements marshal.Marshallable.UnmarshalBytes.\n")
+ g.emit("func (%s *%s) UnmarshalBytes(src []byte) {\n", g.r, g.typeName())
+ g.inIndent(func() {
+ g.forEachField(fieldDispatcher{
+ primitive: func(n, t *ast.Ident) {
+ if n.Name == "_" {
+ g.emit("// Padding: var _ %s ~= src[:sizeof(%s)]\n", t.Name, t.Name)
+ if len, dynamic := g.scalarSize(t); !dynamic {
+ g.shift("src", len)
+ } else {
+ // We can't use shiftDynamic here because we don't have
+ // an instance of the dynamic type we can reference here
+ // (since the version in this struct is anonymous). Use
+ // a typed nil pointer to call SizeBytes() instead.
+ g.emit("src = src[(*%s)(nil).SizeBytes():]\n", t.Name)
+ g.recordPotentiallyNonPackedField(fmt.Sprintf("(*%s)(nil)", t.Name))
+ }
+ return
+ }
+ g.unmarshalScalar(g.fieldAccessor(n), t.Name, "src")
+ },
+ selector: func(n, tX, tSel *ast.Ident) {
+ g.unmarshalScalar(g.fieldAccessor(n), fmt.Sprintf("%s.%s", tX.Name, tSel.Name), "src")
+ },
+ array: func(n, t *ast.Ident, size int) {
+ if n.Name == "_" {
+ g.emit("// Padding: ~ copy([%d]%s(%s), src[:sizeof(%s)*%d])\n", size, t.Name, g.fieldAccessor(n), t.Name, size)
+ if len, dynamic := g.scalarSize(t); !dynamic {
+ g.shift("src", len*size)
+ } else {
+ // We can't use shiftDynamic here because we don't have
+ // an instance of the dynamic type we can referece here
+ // (since the version in this struct is anonymous). Use
+ // a typed nil pointer to call SizeBytes() instead.
+ g.emit("src = src[(*%s)(nil).SizeBytes()*%d:]\n", t.Name, size)
+ }
+ return
+ }
+
+ g.emit("for i := 0; i < %d; i++ {\n", size)
+ g.inIndent(func() {
+ g.unmarshalScalar(fmt.Sprintf("%s[i]", g.fieldAccessor(n)), t.Name, "src")
+ })
+ g.emit("}\n")
+ },
+ }.dispatch)
+ })
+ g.emit("}\n\n")
+
+ g.emit("// Packed implements marshal.Marshallable.Packed.\n")
+ g.emit("func (%s *%s) Packed() bool {\n", g.r, g.typeName())
+ g.inIndent(func() {
+ expr, fieldsMaybePacked := g.areFieldsPackedExpression()
+ switch {
+ case !thisPacked:
+ g.emit("return false\n")
+ case fieldsMaybePacked:
+ g.emit("return %s\n", expr)
+ default:
+ g.emit("return true\n")
+
+ }
+ })
+ g.emit("}\n\n")
+
+ g.emit("// MarshalUnsafe implements marshal.Marshallable.MarshalUnsafe.\n")
+ g.emit("func (%s *%s) MarshalUnsafe(dst []byte) {\n", g.r, g.typeName())
+ g.inIndent(func() {
+ if thisPacked {
+ g.recordUsedImport("safecopy")
+ g.recordUsedImport("unsafe")
+ if cond, ok := g.areFieldsPackedExpression(); ok {
+ g.emit("if %s {\n", cond)
+ g.inIndent(func() {
+ g.emit("safecopy.CopyIn(dst, unsafe.Pointer(%s))\n", g.r)
+ })
+ g.emit("} else {\n")
+ g.inIndent(func() {
+ g.emit("%s.MarshalBytes(dst)\n", g.r)
+ })
+ g.emit("}\n")
+ } else {
+ g.emit("safecopy.CopyIn(dst, unsafe.Pointer(%s))\n", g.r)
+ }
+ } else {
+ g.emit("// Type %s doesn't have a packed layout in memory, fallback to MarshalBytes.\n", g.typeName())
+ g.emit("%s.MarshalBytes(dst)\n", g.r)
+ }
+ })
+ g.emit("}\n\n")
+
+ g.emit("// UnmarshalUnsafe implements marshal.Marshallable.UnmarshalUnsafe.\n")
+ g.emit("func (%s *%s) UnmarshalUnsafe(src []byte) {\n", g.r, g.typeName())
+ g.inIndent(func() {
+ if thisPacked {
+ g.recordUsedImport("safecopy")
+ g.recordUsedImport("unsafe")
+ if cond, ok := g.areFieldsPackedExpression(); ok {
+ g.emit("if %s {\n", cond)
+ g.inIndent(func() {
+ g.emit("safecopy.CopyOut(unsafe.Pointer(%s), src)\n", g.r)
+ })
+ g.emit("} else {\n")
+ g.inIndent(func() {
+ g.emit("%s.UnmarshalBytes(src)\n", g.r)
+ })
+ g.emit("}\n")
+ } else {
+ g.emit("safecopy.CopyOut(unsafe.Pointer(%s), src)\n", g.r)
+ }
+ } else {
+ g.emit("// Type %s doesn't have a packed layout in memory, fall back to UnmarshalBytes.\n", g.typeName())
+ g.emit("%s.UnmarshalBytes(src)\n", g.r)
+ }
+ })
+ g.emit("}\n\n")
+
+}
diff --git a/tools/go_marshal/gomarshal/generator_tests.go b/tools/go_marshal/gomarshal/generator_tests.go
new file mode 100644
index 000000000..df25cb5b2
--- /dev/null
+++ b/tools/go_marshal/gomarshal/generator_tests.go
@@ -0,0 +1,154 @@
+// Copyright 2019 The gVisor Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+package gomarshal
+
+import (
+ "fmt"
+ "go/ast"
+ "io"
+ "strings"
+)
+
+var standardImports = []string{
+ "fmt",
+ "reflect",
+ "testing",
+ "gvisor.dev/gvisor/tools/go_marshal/analysis",
+}
+
+type testGenerator struct {
+ sourceBuffer
+
+ // The type we're serializing.
+ t *ast.TypeSpec
+
+ // Receiver argument for generated methods.
+ r string
+
+ // Imports used by generated code.
+ imports *importTable
+
+ // Import statement for the package declaring the type we generated code
+ // for. We need this to construct test instances for the type, since the
+ // tests aren't written in the same package.
+ decl *importStmt
+}
+
+func newTestGenerator(t *ast.TypeSpec, declaration string) *testGenerator {
+ if _, ok := t.Type.(*ast.StructType); !ok {
+ panic(fmt.Sprintf("Attempting to generate code for a not struct type %v", t))
+ }
+ g := &testGenerator{
+ t: t,
+ r: receiverName(t),
+ imports: newImportTable(),
+ }
+
+ for _, i := range standardImports {
+ g.imports.add(i).markUsed()
+ }
+ g.decl = g.imports.add(declaration)
+ g.decl.markUsed()
+
+ return g
+}
+
+func (g *testGenerator) typeName() string {
+ return fmt.Sprintf("%s.%s", g.decl.name, g.t.Name.Name)
+}
+
+func (g *testGenerator) forEachField(fn func(f *ast.Field)) {
+ // This is guaranteed to succeed because g.t is always a struct.
+ st := g.t.Type.(*ast.StructType)
+ for _, field := range st.Fields.List {
+ fn(field)
+ }
+}
+
+func (g *testGenerator) testFuncName(base string) string {
+ return fmt.Sprintf("%s%s", base, strings.Title(g.t.Name.Name))
+}
+
+func (g *testGenerator) inTestFunction(name string, body func()) {
+ g.emit("func %s(t *testing.T) {\n", g.testFuncName(name))
+ g.inIndent(body)
+ g.emit("}\n\n")
+}
+
+func (g *testGenerator) emitTestNonZeroSize() {
+ g.inTestFunction("TestSizeNonZero", func() {
+ g.emit("x := &%s{}\n", g.typeName())
+ g.emit("if x.SizeBytes() == 0 {\n")
+ g.inIndent(func() {
+ g.emit("t.Fatal(\"Marshallable.Size() should not return zero\")\n")
+ })
+ g.emit("}\n")
+ })
+}
+
+func (g *testGenerator) emitTestSuspectAlignment() {
+ g.inTestFunction("TestSuspectAlignment", func() {
+ g.emit("x := %s{}\n", g.typeName())
+ g.emit("analysis.AlignmentCheck(t, reflect.TypeOf(x))\n")
+ })
+}
+
+func (g *testGenerator) emitTestMarshalUnmarshalPreservesData() {
+ g.inTestFunction("TestSafeMarshalUnmarshalPreservesData", func() {
+ g.emit("var x, y, z, yUnsafe, zUnsafe %s\n", g.typeName())
+ g.emit("analysis.RandomizeValue(&x)\n\n")
+
+ g.emit("buf := make([]byte, x.SizeBytes())\n")
+ g.emit("x.MarshalBytes(buf)\n")
+ g.emit("bufUnsafe := make([]byte, x.SizeBytes())\n")
+ g.emit("x.MarshalUnsafe(bufUnsafe)\n\n")
+
+ g.emit("y.UnmarshalBytes(buf)\n")
+ g.emit("if !reflect.DeepEqual(x, y) {\n")
+ g.inIndent(func() {
+ g.emit("t.Fatal(fmt.Sprintf(\"Data corrupted across Marshal/Unmarshal cycle:\\nBefore: %%+v\\nAfter: %%+v\\n\", x, y))\n")
+ })
+ g.emit("}\n")
+ g.emit("yUnsafe.UnmarshalBytes(bufUnsafe)\n")
+ g.emit("if !reflect.DeepEqual(x, yUnsafe) {\n")
+ g.inIndent(func() {
+ g.emit("t.Fatal(fmt.Sprintf(\"Data corrupted across MarshalUnsafe/Unmarshal cycle:\\nBefore: %%+v\\nAfter: %%+v\\n\", x, yUnsafe))\n")
+ })
+ g.emit("}\n\n")
+
+ g.emit("z.UnmarshalUnsafe(buf)\n")
+ g.emit("if !reflect.DeepEqual(x, z) {\n")
+ g.inIndent(func() {
+ g.emit("t.Fatal(fmt.Sprintf(\"Data corrupted across Marshal/UnmarshalUnsafe cycle:\\nBefore: %%+v\\nAfter: %%+v\\n\", x, z))\n")
+ })
+ g.emit("}\n")
+ g.emit("zUnsafe.UnmarshalUnsafe(bufUnsafe)\n")
+ g.emit("if !reflect.DeepEqual(x, zUnsafe) {\n")
+ g.inIndent(func() {
+ g.emit("t.Fatal(fmt.Sprintf(\"Data corrupted across MarshalUnsafe/UnmarshalUnsafe cycle:\\nBefore: %%+v\\nAfter: %%+v\\n\", x, zUnsafe))\n")
+ })
+ g.emit("}\n")
+ })
+}
+
+func (g *testGenerator) emitTests() {
+ g.emitTestNonZeroSize()
+ g.emitTestSuspectAlignment()
+ g.emitTestMarshalUnmarshalPreservesData()
+}
+
+func (g *testGenerator) write(out io.Writer) error {
+ return g.sourceBuffer.write(out)
+}
diff --git a/tools/go_marshal/gomarshal/util.go b/tools/go_marshal/gomarshal/util.go
new file mode 100644
index 000000000..967537abf
--- /dev/null
+++ b/tools/go_marshal/gomarshal/util.go
@@ -0,0 +1,387 @@
+// Copyright 2019 The gVisor Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+package gomarshal
+
+import (
+ "bytes"
+ "flag"
+ "fmt"
+ "go/ast"
+ "go/token"
+ "io"
+ "os"
+ "path"
+ "reflect"
+ "sort"
+ "strconv"
+ "strings"
+)
+
+var debug = flag.Bool("debug", false, "enables debugging output")
+
+// receiverName returns an appropriate receiver name given a type spec.
+func receiverName(t *ast.TypeSpec) string {
+ if len(t.Name.Name) < 1 {
+ // Zero length type name?
+ panic("unreachable")
+ }
+ return strings.ToLower(t.Name.Name[:1])
+}
+
+// kindString returns a user-friendly representation of an AST expr type.
+func kindString(e ast.Expr) string {
+ switch e.(type) {
+ case *ast.Ident:
+ return "scalar"
+ case *ast.ArrayType:
+ return "array"
+ case *ast.StructType:
+ return "struct"
+ case *ast.StarExpr:
+ return "pointer"
+ case *ast.FuncType:
+ return "function"
+ case *ast.InterfaceType:
+ return "interface"
+ case *ast.MapType:
+ return "map"
+ case *ast.ChanType:
+ return "channel"
+ default:
+ return reflect.TypeOf(e).String()
+ }
+}
+
+// fieldDispatcher is a collection of callbacks for handling different types of
+// fields in a struct declaration.
+type fieldDispatcher struct {
+ primitive func(n, t *ast.Ident)
+ selector func(n, tX, tSel *ast.Ident)
+ array func(n, t *ast.Ident, size int)
+ unhandled func(n *ast.Ident)
+}
+
+// Precondition: All dispatch callbacks that will be invoked must be
+// provided. Embedded fields are not allowed, len(f.Names) >= 1.
+func (fd fieldDispatcher) dispatch(f *ast.Field) {
+ // Each field declaration may actually be multiple declarations of the same
+ // type. For example, consider:
+ //
+ // type Point struct {
+ // x, y, z int
+ // }
+ //
+ // We invoke the call-backs once per such instance. Embedded fields are not
+ // allowed, and results in a panic.
+ if len(f.Names) < 1 {
+ panic("Precondition not met: attempted to dispatch on embedded field")
+ }
+
+ for _, name := range f.Names {
+ switch v := f.Type.(type) {
+ case *ast.Ident:
+ fd.primitive(name, v)
+ case *ast.SelectorExpr:
+ fd.selector(name, v.X.(*ast.Ident), v.Sel)
+ case *ast.ArrayType:
+ len := 0
+ if v.Len != nil {
+ // Non-literal array length is handled by generatorInterfaces.validate().
+ if lenLit, ok := v.Len.(*ast.BasicLit); ok {
+ var err error
+ len, err = strconv.Atoi(lenLit.Value)
+ if err != nil {
+ panic(err)
+ }
+ }
+ }
+ switch t := v.Elt.(type) {
+ case *ast.Ident:
+ fd.array(name, t, len)
+ default:
+ fd.array(name, nil, len)
+ }
+ default:
+ fd.unhandled(name)
+ }
+ }
+}
+
+// debugEnabled indicates whether debugging is enabled for gomarshal.
+func debugEnabled() bool {
+ return *debug
+}
+
+// abort aborts the go_marshal tool with the given error message.
+func abort(msg string) {
+ if !strings.HasSuffix(msg, "\n") {
+ msg += "\n"
+ }
+ fmt.Print(msg)
+ os.Exit(1)
+}
+
+// abortAt aborts the go_marshal tool with the given error message, with
+// a reference position to the input source.
+func abortAt(p token.Position, msg string) {
+ abort(fmt.Sprintf("%v:\n %s\n", p, msg))
+}
+
+// debugf conditionally prints a debug message.
+func debugf(f string, a ...interface{}) {
+ if debugEnabled() {
+ fmt.Printf(f, a...)
+ }
+}
+
+// debugfAt conditionally prints a debug message with a reference to a position
+// in the input source.
+func debugfAt(p token.Position, f string, a ...interface{}) {
+ if debugEnabled() {
+ fmt.Printf("%s:\n %s", p, fmt.Sprintf(f, a...))
+ }
+}
+
+// emit generates a line of code in the output file.
+//
+// emit is a wrapper around writing a formatted string to the output
+// buffer. emit can be invoked in one of two ways:
+//
+// (1) emit("some string")
+// When emit is called with a single string argument, it is simply copied to
+// the output buffer without any further formatting.
+// (2) emit(fmtString, args...)
+// emit can also be invoked in a similar fashion to *Printf() functions,
+// where the first argument is a format string.
+//
+// Calling emit with a single argument that is not a string will result in a
+// panic, as the caller's intent is ambiguous.
+func emit(out io.Writer, indent int, a ...interface{}) {
+ const spacesPerIndentLevel = 4
+
+ if len(a) < 1 {
+ panic("emit() called with no arguments")
+ }
+
+ if indent > 0 {
+ if _, err := fmt.Fprint(out, strings.Repeat(" ", indent*spacesPerIndentLevel)); err != nil {
+ // Writing to the emit output should not fail. Typically the output
+ // is a byte.Buffer; writes to these never fail.
+ panic(err)
+ }
+ }
+
+ first, ok := a[0].(string)
+ if !ok {
+ // First argument must be either the string to emit (case 1 from
+ // function-level comment), or a format string (case 2).
+ panic(fmt.Sprintf("First argument to emit() is not a string: %+v", a[0]))
+ }
+
+ if len(a) == 1 {
+ // Single string argument. Assume no formatting requested.
+ if _, err := fmt.Fprint(out, first); err != nil {
+ // Writing to out should not fail.
+ panic(err)
+ }
+ return
+
+ }
+
+ // Formatting requested.
+ if _, err := fmt.Fprintf(out, first, a[1:]...); err != nil {
+ // Writing to out should not fail.
+ panic(err)
+ }
+}
+
+// sourceBuffer represents fragments of generated go source code.
+//
+// sourceBuffer provides a convenient way to build up go souce fragments in
+// memory. May be safely zero-value initialized. Not thread-safe.
+type sourceBuffer struct {
+ // Current indentation level.
+ indent int
+
+ // Memory buffer containing contents while they're being generated.
+ b bytes.Buffer
+}
+
+func (b *sourceBuffer) incIndent() {
+ b.indent++
+}
+
+func (b *sourceBuffer) decIndent() {
+ if b.indent <= 0 {
+ panic("decIndent() without matching incIndent()")
+ }
+ b.indent--
+}
+
+func (b *sourceBuffer) emit(a ...interface{}) {
+ emit(&b.b, b.indent, a...)
+}
+
+func (b *sourceBuffer) emitNoIndent(a ...interface{}) {
+ emit(&b.b, 0 /*indent*/, a...)
+}
+
+func (b *sourceBuffer) inIndent(body func()) {
+ b.incIndent()
+ body()
+ b.decIndent()
+}
+
+func (b *sourceBuffer) write(out io.Writer) error {
+ _, err := fmt.Fprint(out, b.b.String())
+ return err
+}
+
+// Write implements io.Writer.Write.
+func (b *sourceBuffer) Write(buf []byte) (int, error) {
+ return (b.b.Write(buf))
+}
+
+// importStmt represents a single import statement.
+type importStmt struct {
+ // Local name of the imported package.
+ name string
+ // Import path.
+ path string
+ // Indicates whether the local name is an alias, or simply the final
+ // component of the path.
+ aliased bool
+ // Indicates whether this import was referenced by generated code.
+ used bool
+}
+
+func newImport(p string) *importStmt {
+ name := path.Base(p)
+ return &importStmt{
+ name: name,
+ path: p,
+ aliased: false,
+ }
+}
+
+func newImportFromSpec(spec *ast.ImportSpec, f *token.FileSet) *importStmt {
+ p := spec.Path.Value[1 : len(spec.Path.Value)-1] // Strip the " quotes around path.
+ name := path.Base(p)
+ if name == "" || name == "/" || name == "." {
+ panic(fmt.Sprintf("Couldn't process local package name for import at %s, (processed as %s)",
+ f.Position(spec.Path.Pos()), name))
+ }
+ if spec.Name != nil {
+ name = spec.Name.Name
+ }
+ return &importStmt{
+ name: name,
+ path: p,
+ aliased: spec.Name != nil,
+ }
+}
+
+func (i *importStmt) String() string {
+ if i.aliased {
+ return fmt.Sprintf("%s \"%s\"", i.name, i.path)
+ }
+ return fmt.Sprintf("\"%s\"", i.path)
+}
+
+func (i *importStmt) markUsed() {
+ i.used = true
+}
+
+func (i *importStmt) equivalent(other *importStmt) bool {
+ return i == other
+}
+
+// importTable represents a collection of importStmts.
+type importTable struct {
+ // Map of imports and whether they should be copied to the output.
+ is map[string]*importStmt
+}
+
+func newImportTable() *importTable {
+ return &importTable{
+ is: make(map[string]*importStmt),
+ }
+}
+
+// Merges import statements from other into i. Collisions in import statements
+// result in a panic.
+func (i *importTable) merge(other *importTable) {
+ for name, im := range other.is {
+ if dup, ok := i.is[name]; ok && dup.equivalent(im) {
+ panic(fmt.Sprintf("Found colliding import statements: ours: %+v, other's: %+v", dup, im))
+ }
+
+ i.is[name] = im
+ }
+}
+
+func (i *importTable) add(s string) *importStmt {
+ n := newImport(s)
+ i.is[n.name] = n
+ return n
+}
+
+func (i *importTable) addFromSpec(spec *ast.ImportSpec, f *token.FileSet) *importStmt {
+ n := newImportFromSpec(spec, f)
+ i.is[n.name] = n
+ return n
+}
+
+// Marks the import named n as used. If no such import is in the table, returns
+// false.
+func (i *importTable) markUsed(n string) bool {
+ if n, ok := i.is[n]; ok {
+ n.markUsed()
+ return true
+ }
+ return false
+}
+
+func (i *importTable) clear() {
+ for _, i := range i.is {
+ i.used = false
+ }
+}
+
+func (i *importTable) write(out io.Writer) error {
+ if len(i.is) == 0 {
+ // Nothing to import, we're done.
+ return nil
+ }
+
+ imports := make([]string, 0, len(i.is))
+ for _, i := range i.is {
+ if i.used {
+ imports = append(imports, i.String())
+ }
+ }
+ sort.Strings(imports)
+
+ var b sourceBuffer
+ b.emit("import (\n")
+ b.incIndent()
+ for _, i := range imports {
+ b.emit("%s\n", i)
+ }
+ b.decIndent()
+ b.emit(")\n\n")
+
+ return b.write(out)
+}
diff --git a/tools/go_marshal/main.go b/tools/go_marshal/main.go
new file mode 100644
index 000000000..3d12eb93c
--- /dev/null
+++ b/tools/go_marshal/main.go
@@ -0,0 +1,73 @@
+// Copyright 2019 Google LLC
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+// go_marshal is a code generation utility for automatically generating code to
+// marshal go data structures to memory.
+//
+// This binary is typically run as part of the build process, and is invoked by
+// the go_marshal bazel rule defined in defs.bzl.
+//
+// See README.md.
+package main
+
+import (
+ "flag"
+ "fmt"
+ "os"
+ "strings"
+
+ "gvisor.dev/gvisor/tools/go_marshal/gomarshal"
+)
+
+var (
+ pkg = flag.String("pkg", "", "output package")
+ output = flag.String("output", "", "output file")
+ outputTest = flag.String("output_test", "", "output file for tests")
+ imports = flag.String("imports", "", "comma-separated list of extra packages to import in generated code")
+ declarationPkg = flag.String("declarationPkg", "", "import path of target declaring the types we're generating on")
+)
+
+func main() {
+ flag.Usage = func() {
+ fmt.Fprintf(os.Stderr, "Usage: %s <input go src files>\n", os.Args[0])
+ flag.PrintDefaults()
+ }
+ flag.Parse()
+ if len(flag.Args()) == 0 {
+ flag.Usage()
+ os.Exit(1)
+ }
+
+ if *pkg == "" {
+ flag.Usage()
+ fmt.Fprint(os.Stderr, "Flag -pkg must be provided.\n")
+ os.Exit(1)
+ }
+
+ var extraImports []string
+ if len(*imports) > 0 {
+ // Note: strings.Split(s, sep) returns s if sep doesn't exist in s. Thus
+ // we check for an empty imports list to avoid emitting an empty string
+ // as an import.
+ extraImports = strings.Split(*imports, ",")
+ }
+ g, err := gomarshal.NewGenerator(flag.Args(), *output, *outputTest, *pkg, *declarationPkg, extraImports)
+ if err != nil {
+ panic(err)
+ }
+
+ if err := g.Run(); err != nil {
+ panic(err)
+ }
+}
diff --git a/tools/go_marshal/marshal/BUILD b/tools/go_marshal/marshal/BUILD
new file mode 100644
index 000000000..47dda97a1
--- /dev/null
+++ b/tools/go_marshal/marshal/BUILD
@@ -0,0 +1,14 @@
+load("@io_bazel_rules_go//go:def.bzl", "go_library")
+
+package(licenses = ["notice"])
+
+go_library(
+ name = "marshal",
+ srcs = [
+ "marshal.go",
+ ],
+ importpath = "gvisor.dev/gvisor/tools/go_marshal/marshal",
+ visibility = [
+ "//:sandbox",
+ ],
+)
diff --git a/tools/go_marshal/marshal/marshal.go b/tools/go_marshal/marshal/marshal.go
new file mode 100644
index 000000000..a313a27ed
--- /dev/null
+++ b/tools/go_marshal/marshal/marshal.go
@@ -0,0 +1,60 @@
+// Copyright 2019 The gVisor Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+// Package marshal defines the Marshallable interface for
+// serialize/deserializing go data structures to/from memory, according to the
+// Linux ABI.
+//
+// Implementations of this interface are typically automatically generated by
+// tools/go_marshal. See the go_marshal README for details.
+package marshal
+
+// Marshallable represents a type that can be marshalled to and from memory.
+type Marshallable interface {
+ // SizeBytes is the size of the memory representation of a type in
+ // marshalled form.
+ SizeBytes() int
+
+ // MarshalBytes serializes a copy of a type to dst. dst must be at least
+ // SizeBytes() long.
+ MarshalBytes(dst []byte)
+
+ // UnmarshalBytes deserializes a type from src. src must be at least
+ // SizeBytes() long.
+ UnmarshalBytes(src []byte)
+
+ // Packed returns true if the marshalled size of the type is the same as the
+ // size it occupies in memory. This happens when the type has no fields
+ // starting at unaligned addresses (should always be true by default for ABI
+ // structs, verified by automatically generated tests when using
+ // go_marshal), and has no fields marked `marshal:"unaligned"`.
+ Packed() bool
+
+ // MarshalUnsafe serializes a type by bulk copying its in-memory
+ // representation to the dst buffer. This is only safe to do when the type
+ // has no implicit padding, see Marshallable.Packed. When Packed would
+ // return false, MarshalUnsafe should fall back to the safer but slower
+ // MarshalBytes.
+ MarshalUnsafe(dst []byte)
+
+ // UnmarshalUnsafe deserializes a type directly to the underlying memory
+ // allocated for the object by the runtime.
+ //
+ // This allows much faster unmarshalling of types which have no implicit
+ // padding, see Marshallable.Packed. When Packed would return false,
+ // UnmarshalUnsafe should fall back to the safer but slower unmarshal
+ // mechanism implemented in UnmarshalBytes (usually by calling
+ // UnmarshalBytes directly).
+ UnmarshalUnsafe(src []byte)
+}
diff --git a/tools/go_marshal/test/BUILD b/tools/go_marshal/test/BUILD
new file mode 100644
index 000000000..fa82f8e9b
--- /dev/null
+++ b/tools/go_marshal/test/BUILD
@@ -0,0 +1,31 @@
+load("@io_bazel_rules_go//go:def.bzl", "go_test")
+
+package(licenses = ["notice"])
+
+load("//tools/go_marshal:defs.bzl", "go_library")
+
+package_group(
+ name = "gomarshal_test",
+ packages = [
+ "//tools/go_marshal/test/...",
+ ],
+)
+
+go_test(
+ name = "benchmark_test",
+ srcs = ["benchmark_test.go"],
+ deps = [
+ ":test",
+ "//pkg/binary",
+ "//pkg/sentry/usermem",
+ "//tools/go_marshal/analysis",
+ ],
+)
+
+go_library(
+ name = "test",
+ testonly = 1,
+ srcs = ["test.go"],
+ importpath = "gvisor.dev/gvisor/tools/go_marshal/test",
+ deps = ["//tools/go_marshal/test/external"],
+)
diff --git a/tools/go_marshal/test/benchmark_test.go b/tools/go_marshal/test/benchmark_test.go
new file mode 100644
index 000000000..e70db06d8
--- /dev/null
+++ b/tools/go_marshal/test/benchmark_test.go
@@ -0,0 +1,178 @@
+// Copyright 2019 The gVisor Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+package benchmark_test
+
+import (
+ "bytes"
+ encbin "encoding/binary"
+ "fmt"
+ "reflect"
+ "testing"
+
+ "gvisor.dev/gvisor/pkg/binary"
+ "gvisor.dev/gvisor/pkg/sentry/usermem"
+ "gvisor.dev/gvisor/tools/go_marshal/analysis"
+ test "gvisor.dev/gvisor/tools/go_marshal/test"
+)
+
+// Marshalling using the standard encoding/binary package.
+func BenchmarkEncodingBinary(b *testing.B) {
+ var s1, s2 test.Stat
+ analysis.RandomizeValue(&s1)
+
+ size := encbin.Size(&s1)
+
+ b.ResetTimer()
+
+ for n := 0; n < b.N; n++ {
+ buf := bytes.NewBuffer(make([]byte, size))
+ buf.Reset()
+ if err := encbin.Write(buf, usermem.ByteOrder, &s1); err != nil {
+ b.Error("Write:", err)
+ }
+ if err := encbin.Read(buf, usermem.ByteOrder, &s2); err != nil {
+ b.Error("Read:", err)
+ }
+ }
+
+ b.StopTimer()
+
+ // Sanity check, make sure the values were preserved.
+ if !reflect.DeepEqual(s1, s2) {
+ panic(fmt.Sprintf("Data corruption across marshal/unmarshal cycle:\nBefore: %+v\nAfter: %+v\n", s1, s2))
+ }
+}
+
+// Marshalling using the sentry's binary.Marshal.
+func BenchmarkBinary(b *testing.B) {
+ var s1, s2 test.Stat
+ analysis.RandomizeValue(&s1)
+
+ size := binary.Size(s1)
+
+ b.ResetTimer()
+
+ for n := 0; n < b.N; n++ {
+ buf := make([]byte, 0, size)
+ buf = binary.Marshal(buf, usermem.ByteOrder, &s1)
+ binary.Unmarshal(buf, usermem.ByteOrder, &s2)
+ }
+
+ b.StopTimer()
+
+ // Sanity check, make sure the values were preserved.
+ if !reflect.DeepEqual(s1, s2) {
+ panic(fmt.Sprintf("Data corruption across marshal/unmarshal cycle:\nBefore: %+v\nAfter: %+v\n", s1, s2))
+ }
+}
+
+// Marshalling field-by-field with manually-written code.
+func BenchmarkMarshalManual(b *testing.B) {
+ var s1, s2 test.Stat
+ analysis.RandomizeValue(&s1)
+
+ b.ResetTimer()
+
+ for n := 0; n < b.N; n++ {
+ buf := make([]byte, 0, s1.SizeBytes())
+
+ // Marshal
+ buf = binary.AppendUint64(buf, usermem.ByteOrder, s1.Dev)
+ buf = binary.AppendUint64(buf, usermem.ByteOrder, s1.Ino)
+ buf = binary.AppendUint64(buf, usermem.ByteOrder, s1.Nlink)
+ buf = binary.AppendUint32(buf, usermem.ByteOrder, s1.Mode)
+ buf = binary.AppendUint32(buf, usermem.ByteOrder, s1.UID)
+ buf = binary.AppendUint32(buf, usermem.ByteOrder, s1.GID)
+ buf = binary.AppendUint32(buf, usermem.ByteOrder, 0)
+ buf = binary.AppendUint64(buf, usermem.ByteOrder, s1.Rdev)
+ buf = binary.AppendUint64(buf, usermem.ByteOrder, uint64(s1.Size))
+ buf = binary.AppendUint64(buf, usermem.ByteOrder, uint64(s1.Blksize))
+ buf = binary.AppendUint64(buf, usermem.ByteOrder, uint64(s1.Blocks))
+ buf = binary.AppendUint64(buf, usermem.ByteOrder, uint64(s1.ATime.Sec))
+ buf = binary.AppendUint64(buf, usermem.ByteOrder, uint64(s1.ATime.Nsec))
+ buf = binary.AppendUint64(buf, usermem.ByteOrder, uint64(s1.MTime.Sec))
+ buf = binary.AppendUint64(buf, usermem.ByteOrder, uint64(s1.MTime.Nsec))
+ buf = binary.AppendUint64(buf, usermem.ByteOrder, uint64(s1.CTime.Sec))
+ buf = binary.AppendUint64(buf, usermem.ByteOrder, uint64(s1.CTime.Nsec))
+
+ // Unmarshal
+ s2.Dev = usermem.ByteOrder.Uint64(buf[0:8])
+ s2.Ino = usermem.ByteOrder.Uint64(buf[8:16])
+ s2.Nlink = usermem.ByteOrder.Uint64(buf[16:24])
+ s2.Mode = usermem.ByteOrder.Uint32(buf[24:28])
+ s2.UID = usermem.ByteOrder.Uint32(buf[28:32])
+ s2.GID = usermem.ByteOrder.Uint32(buf[32:36])
+ // Padding: buf[36:40]
+ s2.Rdev = usermem.ByteOrder.Uint64(buf[40:48])
+ s2.Size = int64(usermem.ByteOrder.Uint64(buf[48:56]))
+ s2.Blksize = int64(usermem.ByteOrder.Uint64(buf[56:64]))
+ s2.Blocks = int64(usermem.ByteOrder.Uint64(buf[64:72]))
+ s2.ATime.Sec = int64(usermem.ByteOrder.Uint64(buf[72:80]))
+ s2.ATime.Nsec = int64(usermem.ByteOrder.Uint64(buf[80:88]))
+ s2.MTime.Sec = int64(usermem.ByteOrder.Uint64(buf[88:96]))
+ s2.MTime.Nsec = int64(usermem.ByteOrder.Uint64(buf[96:104]))
+ s2.CTime.Sec = int64(usermem.ByteOrder.Uint64(buf[104:112]))
+ s2.CTime.Nsec = int64(usermem.ByteOrder.Uint64(buf[112:120]))
+ }
+
+ b.StopTimer()
+
+ // Sanity check, make sure the values were preserved.
+ if !reflect.DeepEqual(s1, s2) {
+ panic(fmt.Sprintf("Data corruption across marshal/unmarshal cycle:\nBefore: %+v\nAfter: %+v\n", s1, s2))
+ }
+}
+
+// Marshalling with the go_marshal safe API.
+func BenchmarkGoMarshalSafe(b *testing.B) {
+ var s1, s2 test.Stat
+ analysis.RandomizeValue(&s1)
+
+ b.ResetTimer()
+
+ for n := 0; n < b.N; n++ {
+ buf := make([]byte, s1.SizeBytes())
+ s1.MarshalBytes(buf)
+ s2.UnmarshalBytes(buf)
+ }
+
+ b.StopTimer()
+
+ // Sanity check, make sure the values were preserved.
+ if !reflect.DeepEqual(s1, s2) {
+ panic(fmt.Sprintf("Data corruption across marshal/unmarshal cycle:\nBefore: %+v\nAfter: %+v\n", s1, s2))
+ }
+}
+
+// Marshalling with the go_marshal unsafe API.
+func BenchmarkGoMarshalUnsafe(b *testing.B) {
+ var s1, s2 test.Stat
+ analysis.RandomizeValue(&s1)
+
+ b.ResetTimer()
+
+ for n := 0; n < b.N; n++ {
+ buf := make([]byte, s1.SizeBytes())
+ s1.MarshalUnsafe(buf)
+ s2.UnmarshalUnsafe(buf)
+ }
+
+ b.StopTimer()
+
+ // Sanity check, make sure the values were preserved.
+ if !reflect.DeepEqual(s1, s2) {
+ panic(fmt.Sprintf("Data corruption across marshal/unmarshal cycle:\nBefore: %+v\nAfter: %+v\n", s1, s2))
+ }
+}
diff --git a/tools/go_marshal/test/external/BUILD b/tools/go_marshal/test/external/BUILD
new file mode 100644
index 000000000..8fb43179b
--- /dev/null
+++ b/tools/go_marshal/test/external/BUILD
@@ -0,0 +1,11 @@
+package(licenses = ["notice"])
+
+load("//tools/go_marshal:defs.bzl", "go_library")
+
+go_library(
+ name = "external",
+ testonly = 1,
+ srcs = ["external.go"],
+ importpath = "gvisor.dev/gvisor/tools/go_marshal/test/external",
+ visibility = ["//tools/go_marshal/test:gomarshal_test"],
+)
diff --git a/runsc/test/root/root.go b/tools/go_marshal/test/external/external.go
index 349c752cc..4be3722f3 100644
--- a/runsc/test/root/root.go
+++ b/tools/go_marshal/test/external/external.go
@@ -1,4 +1,4 @@
-// Copyright 2018 The gVisor Authors.
+// Copyright 2019 The gVisor Authors.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
@@ -12,5 +12,12 @@
// See the License for the specific language governing permissions and
// limitations under the License.
-// Package root is empty. See chroot_test.go for description.
-package root
+// Package external defines types we can import for testing.
+package external
+
+// External is a public Marshallable type for use in testing.
+//
+// +marshal
+type External struct {
+ j int64
+}
diff --git a/tools/go_marshal/test/test.go b/tools/go_marshal/test/test.go
new file mode 100644
index 000000000..8de02d707
--- /dev/null
+++ b/tools/go_marshal/test/test.go
@@ -0,0 +1,105 @@
+// Copyright 2019 The gVisor Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+// Package test contains data structures for testing the go_marshal tool.
+package test
+
+import (
+ // We're intentionally using a package name alias here even though it's not
+ // necessary to test the code generator's ability to handle package aliases.
+ ex "gvisor.dev/gvisor/tools/go_marshal/test/external"
+)
+
+// Type1 is a test data type.
+//
+// +marshal
+type Type1 struct {
+ a Type2
+ x, y int64 // Multiple field names.
+ b byte `marshal:"unaligned"` // Short field.
+ c uint64
+ _ uint32 // Unnamed scalar field.
+ _ [6]byte // Unnamed vector field, typical padding.
+ _ [2]byte
+ xs [8]int32
+ as [10]Type2 `marshal:"unaligned"` // Array of Marshallable objects.
+ ss Type3
+}
+
+// Type2 is a test data type.
+//
+// +marshal
+type Type2 struct {
+ n int64
+ c byte
+ _ [7]byte
+ m int64
+ a int64
+}
+
+// Type3 is a test data type.
+//
+// +marshal
+type Type3 struct {
+ s int64
+ x ex.External // Type defined in another package.
+}
+
+// Type4 is a test data type.
+//
+// +marshal
+type Type4 struct {
+ c byte
+ x int64 `marshal:"unaligned"`
+ d byte
+ _ [7]byte
+}
+
+// Type5 is a test data type.
+//
+// +marshal
+type Type5 struct {
+ n int64
+ t Type4
+ m int64
+}
+
+// Timespec represents struct timespec in <time.h>.
+//
+// +marshal
+type Timespec struct {
+ Sec int64
+ Nsec int64
+}
+
+// Stat represents struct stat.
+//
+// +marshal
+type Stat struct {
+ Dev uint64
+ Ino uint64
+ Nlink uint64
+ Mode uint32
+ UID uint32
+ GID uint32
+ _ int32
+ Rdev uint64
+ Size int64
+ Blksize int64
+ Blocks int64
+ ATime Timespec
+ MTime Timespec
+ CTime Timespec
+ _ [3]int64
+}
diff --git a/tools/go_stateify/defs.bzl b/tools/go_stateify/defs.bzl
index aeba197e2..3ce36c1c8 100644
--- a/tools/go_stateify/defs.bzl
+++ b/tools/go_stateify/defs.bzl
@@ -35,7 +35,7 @@ go_library(
)
"""
-load("@io_bazel_rules_go//go:def.bzl", _go_library = "go_library", _go_test = "go_test")
+load("@io_bazel_rules_go//go:def.bzl", _go_library = "go_library")
def _go_stateify_impl(ctx):
"""Implementation for the stateify tool."""
@@ -60,28 +60,57 @@ def _go_stateify_impl(ctx):
executable = ctx.executable._tool,
)
-# Generates save and restore logic from a set of Go files.
-#
-# Args:
-# name: the name of the rule.
-# srcs: the input source files. These files should include all structs in the package that need to be saved.
-# imports: an optional list of extra non-aliased, Go-style absolute import paths.
-# out: the name of the generated file output. This must not conflict with any other files and must be added to the srcs of the relevant go_library.
-# package: the package name for the input sources.
go_stateify = rule(
implementation = _go_stateify_impl,
+ doc = "Generates save and restore logic from a set of Go files.",
attrs = {
- "srcs": attr.label_list(mandatory = True, allow_files = True),
- "imports": attr.string_list(mandatory = False),
- "package": attr.string(mandatory = True),
- "out": attr.output(mandatory = True),
- "_tool": attr.label(executable = True, cfg = "host", default = Label("//tools/go_stateify:stateify")),
+ "srcs": attr.label_list(
+ doc = """
+The input source files. These files should include all structs in the package
+that need to be saved.
+""",
+ mandatory = True,
+ allow_files = True,
+ ),
+ "imports": attr.string_list(
+ doc = """
+An optional list of extra non-aliased, Go-style absolute import paths required
+for statified types.
+""",
+ mandatory = False,
+ ),
+ "package": attr.string(
+ doc = "The package name for the input sources.",
+ mandatory = True,
+ ),
+ "out": attr.output(
+ doc = """
+The name of the generated file output. This must not conflict with any other
+files and must be added to the srcs of the relevant go_library.
+""",
+ mandatory = True,
+ ),
+ "_tool": attr.label(
+ executable = True,
+ cfg = "host",
+ default = Label("//tools/go_stateify:stateify"),
+ ),
"_statepkg": attr.string(default = "gvisor.dev/gvisor/pkg/state"),
},
)
def go_library(name, srcs, deps = [], imports = [], **kwargs):
- """wraps the standard go_library and does stateification."""
+ """Standard go_library wrapped which generates state source files.
+
+ Args:
+ name: the name of the go_library rule.
+ srcs: sources of the go_library. Each will be processed for stateify
+ annotations.
+ deps: dependencies for the go_library.
+ imports: an optional list of extra non-aliased, Go-style absolute import
+ paths required for stateified types.
+ **kwargs: passed to go_library.
+ """
if "encode_unsafe.go" not in srcs and (name + "_state_autogen.go") not in srcs:
# Only do stateification for non-state packages without manual autogen.
go_stateify(
@@ -105,9 +134,3 @@ def go_library(name, srcs, deps = [], imports = [], **kwargs):
deps = all_deps,
**kwargs
)
-
-def go_test(**kwargs):
- """Wraps the standard go_test."""
- _go_test(
- **kwargs
- )
diff --git a/tools/image_build.sh b/tools/image_build.sh
new file mode 100755
index 000000000..9b20a740d
--- /dev/null
+++ b/tools/image_build.sh
@@ -0,0 +1,98 @@
+#!/bin/bash
+
+# Copyright 2019 The gVisor Authors.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+# This script is responsible for building a new GCP image that: 1) has nested
+# virtualization enabled, and 2) has been completely set up with the
+# image_setup.sh script. This script should be idempotent, as we memoize the
+# setup script with a hash and check for that name.
+#
+# The GCP project name should be defined via a gcloud config.
+
+set -xeo pipefail
+
+# Parameters.
+declare -r ZONE=${ZONE:-us-central1-f}
+declare -r USERNAME=${USERNAME:-test}
+declare -r IMAGE_PROJECT=${IMAGE_PROJECT:-ubuntu-os-cloud}
+declare -r IMAGE_FAMILY=${IMAGE_FAMILY:-ubuntu-1604-lts}
+
+# Random names.
+declare -r DISK_NAME=$(mktemp -u disk-XXXXXX | tr A-Z a-z)
+declare -r SNAPSHOT_NAME=$(mktemp -u snapshot-XXXXXX | tr A-Z a-z)
+declare -r INSTANCE_NAME=$(mktemp -u build-XXXXXX | tr A-Z a-z)
+
+# Hashes inputs.
+declare -r SETUP_BLOB=$(echo ${ZONE} ${USERNAME} ${IMAGE_PROJECT} ${IMAGE_FAMILY} && sha256sum "$@")
+declare -r SETUP_HASH=$(echo ${SETUP_BLOB} | sha256sum - | cut -d' ' -f1 | cut -c 1-16)
+declare -r IMAGE_NAME=${IMAGE_NAME:-image-}${SETUP_HASH}
+
+# Does the image already exist? Skip the build.
+declare -r existing=$(gcloud compute images list --filter="name=(${IMAGE_NAME})" --format="value(name)")
+if ! [[ -z "${existing}" ]]; then
+ echo "${existing}"
+ exit 0
+fi
+
+# Set the zone for all actions.
+gcloud config set compute/zone "${ZONE}"
+
+# Start a unique instance. Note that this instance will have a unique persistent
+# disk as it's boot disk with the same name as the instance.
+gcloud compute instances create \
+ --quiet \
+ --image-project "${IMAGE_PROJECT}" \
+ --image-family "${IMAGE_FAMILY}" \
+ --boot-disk-size "200GB" \
+ "${INSTANCE_NAME}"
+function cleanup {
+ gcloud compute instances delete --quiet "${INSTANCE_NAME}"
+}
+trap cleanup EXIT
+
+# Wait for the instance to become available.
+declare attempts=0
+while [[ "${attempts}" -lt 30 ]]; do
+ attempts=$((${attempts}+1))
+ if gcloud compute ssh "${USERNAME}"@"${INSTANCE_NAME}" -- true; then
+ break
+ fi
+done
+if [[ "${attempts}" -ge 30 ]]; then
+ echo "too many attempts: failed"
+ exit 1
+fi
+
+# Run the install scripts provided.
+for arg; do
+ gcloud compute ssh "${USERNAME}"@"${INSTANCE_NAME}" -- sudo bash - <"${arg}"
+done
+
+# Stop the instance; required before creating an image.
+gcloud compute instances stop --quiet "${INSTANCE_NAME}"
+
+# Create a snapshot of the instance disk.
+gcloud compute disks snapshot \
+ --quiet \
+ --zone="${ZONE}" \
+ --snapshot-names="${SNAPSHOT_NAME}" \
+ "${INSTANCE_NAME}"
+
+# Create the disk image.
+gcloud compute images create \
+ --quiet \
+ --source-snapshot="${SNAPSHOT_NAME}" \
+ --licenses="https://www.googleapis.com/compute/v1/projects/vm-options/global/licenses/enable-vmx" \
+ "${IMAGE_NAME}"
diff --git a/tools/make_repository.sh b/tools/make_repository.sh
new file mode 100755
index 000000000..071f72b74
--- /dev/null
+++ b/tools/make_repository.sh
@@ -0,0 +1,79 @@
+#!/bin/bash
+
+# Copyright 2018 The gVisor Authors.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+# Parse arguments. We require more than two arguments, which are the private
+# keyring, the e-mail associated with the signer, and the list of packages.
+if [ "$#" -le 3 ]; then
+ echo "usage: $0 <private-key> <signer-email> <component> <packages...>"
+ exit 1
+fi
+declare -r private_key=$(readlink -e "$1")
+declare -r signer="$2"
+declare -r component="$3"
+shift; shift; shift
+
+# Verbose from this point.
+set -xeo pipefail
+
+# Create a temporary working directory. We don't remove this, as we ultimately
+# print this result and allow the caller to copy wherever they would like.
+declare -r tmpdir=$(mktemp -d /tmp/repoXXXXXX)
+
+# Create a temporary keyring, and ensure it is cleaned up.
+declare -r keyring=$(mktemp /tmp/keyringXXXXXX.gpg)
+cleanup() {
+ rm -f "${keyring}"
+}
+trap cleanup EXIT
+gpg --no-default-keyring --keyring "${keyring}" --import "${private_key}" >&2
+
+# Copy the packages, and ensure permissions are correct.
+for pkg in "$@"; do
+ name=$(basename "${pkg}" .deb)
+ name=$(basename "${name}" .changes)
+ arch=${name##*_}
+ if [[ "${name}" == "${arch}" ]]; then
+ continue # Not a regular package.
+ fi
+ mkdir -p "${tmpdir}"/"${component}"/binary-"${arch}"
+ cp -a "${pkg}" "${tmpdir}"/"${component}"/binary-"${arch}"
+done
+find "${tmpdir}" -type f -exec chmod 0644 {} \;
+
+# Ensure there are no symlinks hanging around; these may be remnants of the
+# build process. They may be useful for other things, but we are going to build
+# an index of the actual packages here.
+find "${tmpdir}" -type l -exec rm -f {} \;
+
+# Sign all packages.
+for file in "${tmpdir}"/"${component}"/binary-*/*.deb; do
+ dpkg-sig -g "--no-default-keyring --keyring ${keyring}" --sign builder "${file}" >&2
+done
+
+# Build the package list.
+for dir in "${tmpdir}"/"${component}"/binary-*; do
+ (cd "${dir}" && apt-ftparchive packages . | gzip > Packages.gz)
+done
+
+# Build the release list.
+(cd "${tmpdir}" && apt-ftparchive release . > Release)
+
+# Sign the release.
+(cd "${tmpdir}" && gpg --no-default-keyring --keyring "${keyring}" --clearsign -o InRelease Release >&2)
+(cd "${tmpdir}" && gpg --no-default-keyring --keyring "${keyring}" -abs -o Release.gpg Release >&2)
+
+# Show the results.
+echo "${tmpdir}"
diff --git a/tools/run_build.sh b/tools/run_build.sh
deleted file mode 100755
index 7f6ada480..000000000
--- a/tools/run_build.sh
+++ /dev/null
@@ -1,49 +0,0 @@
-#!/bin/bash
-
-# Copyright 2018 The gVisor Authors.
-#
-# Licensed under the Apache License, Version 2.0 (the "License");
-# you may not use this file except in compliance with the License.
-# You may obtain a copy of the License at
-#
-# http://www.apache.org/licenses/LICENSE-2.0
-#
-# Unless required by applicable law or agreed to in writing, software
-# distributed under the License is distributed on an "AS IS" BASIS,
-# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-# See the License for the specific language governing permissions and
-# limitations under the License.
-
-# Fail on any error.
-set -e
-# Display commands to stderr.
-set -x
-
-# Install the latest version of Bazel and log the version.
-(which use_bazel.sh && use_bazel.sh latest) || which bazel
-bazel version
-
-# Switch into the workspace.
-if [[ -v KOKORO_GIT_COMMIT ]] && [[ -d git/repo ]]; then
- cd git/repo
-elif [[ -v KOKORO_GIT_COMMIT ]] && [[ -d github/repo ]]; then
- cd github/repo
-fi
-
-# Build runsc.
-bazel build -c opt --strip=never //runsc
-
-# Move the runsc binary into "latest" directory, and also a directory with the
-# current date.
-if [[ -v KOKORO_ARTIFACTS_DIR ]]; then
- latest_dir="${KOKORO_ARTIFACTS_DIR}"/latest
- today_dir="${KOKORO_ARTIFACTS_DIR}"/"$(date -Idate)"
- runsc="bazel-bin/runsc/linux_amd64_pure/runsc"
-
- mkdir -p "${latest_dir}" "${today_dir}"
- cp "${runsc}" "${latest_dir}"
- cp "${runsc}" "${today_dir}"
-
- sha512sum "${latest_dir}"/runsc | awk '{print $1 " runsc"}' > "${latest_dir}"/runsc.sha512
- cp "${latest_dir}"/runsc.sha512 "${today_dir}"/runsc.sha512
-fi
diff --git a/tools/run_tests.sh b/tools/run_tests.sh
deleted file mode 100755
index 3e7c4394d..000000000
--- a/tools/run_tests.sh
+++ /dev/null
@@ -1,302 +0,0 @@
-#!/bin/bash
-
-# Copyright 2018 The gVisor Authors.
-#
-# Licensed under the Apache License, Version 2.0 (the "License");
-# you may not use this file except in compliance with the License.
-# You may obtain a copy of the License at
-#
-# http://www.apache.org/licenses/LICENSE-2.0
-#
-# Unless required by applicable law or agreed to in writing, software
-# distributed under the License is distributed on an "AS IS" BASIS,
-# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-# See the License for the specific language governing permissions and
-# limitations under the License.
-
-# Fail on any error. Treat unset variables as error. Print commands as executed.
-set -eux
-
-###################
-# GLOBAL ENV VARS #
-###################
-
-if [[ -v KOKORO_GIT_COMMIT ]] && [[ -d git/repo ]]; then
- readonly WORKSPACE_DIR="${PWD}/git/repo"
-elif [[ -v KOKORO_GIT_COMMIT ]] && [[ -d github/repo ]]; then
- readonly WORKSPACE_DIR="${PWD}/github/repo"
-else
- readonly WORKSPACE_DIR="${PWD}"
-fi
-
-# Used to configure RBE.
-readonly CLOUD_PROJECT_ID="gvisor-rbe"
-readonly RBE_PROJECT_ID="projects/${CLOUD_PROJECT_ID}/instances/default_instance"
-
-# Random runtime name to avoid collisions.
-readonly RUNTIME="runsc_test_$((RANDOM))"
-
-# Packages that will be built and tested.
-readonly BUILD_PACKAGES=("//...")
-readonly TEST_PACKAGES=("//pkg/..." "//runsc/..." "//tools/...")
-
-#######################
-# BAZEL CONFIGURATION #
-#######################
-
-# Install the latest version of Bazel and log the version.
-(which use_bazel.sh && use_bazel.sh 0.28.0) || which bazel
-bazel version
-
-# Load the kvm module.
-sudo -n -E modprobe kvm
-
-# General Bazel build/test flags.
-BAZEL_BUILD_FLAGS=(
- "--show_timestamps"
- "--test_output=errors"
- "--keep_going"
- "--verbose_failures=true"
-)
-
-# Bazel build/test for RBE, a super-set of BAZEL_BUILD_FLAGS.
-BAZEL_BUILD_RBE_FLAGS=(
- "${BAZEL_BUILD_FLAGS[@]}"
- "--config=remote"
- "--project_id=${CLOUD_PROJECT_ID}"
- "--remote_instance_name=${RBE_PROJECT_ID}"
-)
-if [[ -v KOKORO_BAZEL_AUTH_CREDENTIAL ]]; then
- BAZEL_BUILD_RBE_FLAGS=(
- "${BAZEL_BUILD_RBE_FLAGS[@]}"
- "--auth_credentials=${KOKORO_BAZEL_AUTH_CREDENTIAL}"
- )
-fi
-
-####################
-# Helper Functions #
-####################
-
-sanity_checks() {
- cd ${WORKSPACE_DIR}
- bazel run //:gazelle -- update-repos -from_file=go.mod
- git diff --exit-code WORKSPACE
-}
-
-build_everything() {
- FLAVOR="${1}"
-
- cd ${WORKSPACE_DIR}
- bazel build \
- -c "${FLAVOR}" "${BAZEL_BUILD_RBE_FLAGS[@]}" \
- "${BUILD_PACKAGES[@]}"
-}
-
-build_runsc_debian() {
- cd ${WORKSPACE_DIR}
-
- # TODO(b/135475885): pkg_deb is incompatible with Python3.
- # https://github.com/bazelbuild/bazel/issues/8443
- bazel build --host_force_python=py2 runsc:runsc-debian
-}
-
-# Run simple tests runs the tests that require no special setup or
-# configuration.
-run_simple_tests() {
- cd ${WORKSPACE_DIR}
- bazel test \
- "${BAZEL_BUILD_FLAGS[@]}" \
- "${TEST_PACKAGES[@]}"
-}
-
-install_runtime() {
- cd ${WORKSPACE_DIR}
- sudo -n ${WORKSPACE_DIR}/runsc/test/install.sh --runtime ${RUNTIME}
-}
-
-install_helper() {
- PACKAGE="${1}"
- TAG="${2}"
- GOPATH="${3}"
-
- # Clone the repository.
- mkdir -p "${GOPATH}"/src/$(dirname "${PACKAGE}") && \
- git clone https://"${PACKAGE}" "${GOPATH}"/src/"${PACKAGE}"
-
- # Checkout and build the repository.
- (cd "${GOPATH}"/src/"${PACKAGE}" && \
- git checkout "${TAG}" && \
- GOPATH="${GOPATH}" make && \
- sudo -n -E env GOPATH="${GOPATH}" make install)
-}
-
-# Install dependencies for the crictl tests.
-install_crictl_test_deps() {
- sudo -n -E apt-get update
- sudo -n -E apt-get install -y btrfs-tools libseccomp-dev
-
- # Install containerd & cri-tools.
- GOPATH=$(mktemp -d --tmpdir gopathXXXXX)
- install_helper github.com/containerd/containerd v1.2.2 "${GOPATH}"
- install_helper github.com/kubernetes-sigs/cri-tools v1.11.0 "${GOPATH}"
-
- # Install gvisor-containerd-shim.
- local latest=/tmp/gvisor-containerd-shim-latest
- local shim_path=/tmp/gvisor-containerd-shim
- wget --no-verbose https://storage.googleapis.com/cri-containerd-staging/gvisor-containerd-shim/latest -O ${latest}
- wget --no-verbose https://storage.googleapis.com/cri-containerd-staging/gvisor-containerd-shim/gvisor-containerd-shim-$(cat ${latest}) -O ${shim_path}
- chmod +x ${shim_path}
- sudo -n -E mv ${shim_path} /usr/local/bin
-
- # Configure containerd-shim.
- local shim_config_path=/etc/containerd
- local shim_config_tmp_path=/tmp/gvisor-containerd-shim.toml
- sudo -n -E mkdir -p ${shim_config_path}
- cat > ${shim_config_tmp_path} <<-EOF
- runc_shim = "/usr/local/bin/containerd-shim"
-
- [runsc_config]
- debug = "true"
- debug-log = "/tmp/runsc-logs/"
- strace = "true"
- file-access = "shared"
-EOF
- sudo mv ${shim_config_tmp_path} ${shim_config_path}
-
- # Configure CNI.
- (cd "${GOPATH}" && sudo -n -E env PATH="${PATH}" GOPATH="${GOPATH}" \
- src/github.com/containerd/containerd/script/setup/install-cni)
-}
-
-# Run the tests that require docker.
-run_docker_tests() {
- cd ${WORKSPACE_DIR}
-
- # Run tests with a default runtime (runc).
- bazel test \
- "${BAZEL_BUILD_FLAGS[@]}" \
- --test_env=RUNSC_RUNTIME="" \
- --test_output=all \
- //runsc/test/image:image_test
-
- # These names are used to exclude tests not supported in certain
- # configuration, e.g. save/restore not supported with hostnet.
- # Run runsc tests with docker that are tagged manual.
- #
- # The --nocache_test_results option is used here to eliminate cached results
- # from the previous run for the runc runtime.
- bazel test \
- "${BAZEL_BUILD_FLAGS[@]}" \
- --test_env=RUNSC_RUNTIME="${RUNTIME}" \
- --test_output=all \
- --nocache_test_results \
- --test_output=streamed \
- //runsc/test/integration:integration_test \
- //runsc/test/integration:integration_test_hostnet \
- //runsc/test/integration:integration_test_overlay \
- //runsc/test/integration:integration_test_kvm \
- //runsc/test/image:image_test \
- //runsc/test/image:image_test_overlay \
- //runsc/test/image:image_test_hostnet \
- //runsc/test/image:image_test_kvm
-}
-
-# Run the tests that require root.
-run_root_tests() {
- cd ${WORKSPACE_DIR}
- bazel build //runsc/test/root:root_test
- local root_test=$(find -L ./bazel-bin/ -executable -type f -name root_test | grep __main__)
- if [[ ! -f "${root_test}" ]]; then
- echo "root_test executable not found"
- exit 1
- fi
- sudo -n -E RUNSC_RUNTIME="${RUNTIME}" RUNSC_EXEC=/tmp/"${RUNTIME}"/runsc ${root_test}
-}
-
-# Run syscall unit tests.
-run_syscall_tests() {
- cd ${WORKSPACE_DIR}
- bazel test "${BAZEL_BUILD_RBE_FLAGS[@]}" \
- --test_tag_filters=runsc_ptrace //test/syscalls/...
-}
-
-run_runsc_do_tests() {
- local runsc=$(find bazel-bin/runsc -type f -executable -name "runsc" | head -n1)
-
- # run runsc do without root privileges.
- ${runsc} --rootless do true
- ${runsc} --rootless --network=none do true
-
- # run runsc do with root privileges.
- sudo -n -E ${runsc} do true
-}
-
-# Find and rename all test xml and log files so that Sponge can pick them up.
-# XML files must be named sponge_log.xml, and log files must be named
-# sponge_log.log. We move all such files into KOKORO_ARTIFACTS_DIR, in a
-# subdirectory named with the test name.
-upload_test_artifacts() {
- # Skip if no kokoro directory.
- [[ -v KOKORO_ARTIFACTS_DIR ]] || return
-
- cd ${WORKSPACE_DIR}
- find -L "bazel-testlogs" -name "test.xml" -o -name "test.log" -o -name "outputs.zip" |
- tar --create --files-from - --transform 's/test\./sponge_log./' |
- tar --extract --directory ${KOKORO_ARTIFACTS_DIR}
- if [[ -d "/tmp/${RUNTIME}/logs" ]]; then
- tar --create --gzip "--file=${KOKORO_ARTIFACTS_DIR}/runsc-logs.tar.gz" -C /tmp/ ${RUNTIME}/logs
- fi
-}
-
-# Finish runs at exit, even in the event of an error, and uploads all test
-# artifacts.
-finish() {
- # Grab the last exit code, we will return it.
- local exit_code=${?}
- upload_test_artifacts
- exit ${exit_code}
-}
-
-# Run bazel in a docker container
-build_in_docker() {
- cd ${WORKSPACE_DIR}
- bazel clean
- bazel shutdown
- make
- make runsc
- make bazel-shutdown
-}
-
-########
-# MAIN #
-########
-
-main() {
- # Register finish to run at exit.
- trap finish EXIT
-
- # Build and run the simple tests.
- sanity_checks
- build_everything opt
- run_simple_tests
-
- # So far so good. Install more deps and run the integration tests.
- install_runtime
- install_crictl_test_deps
- run_docker_tests
- run_root_tests
-
- run_syscall_tests
- run_runsc_do_tests
-
- build_runsc_debian
-
- # Build other flavors too.
- build_everything dbg
-
- build_in_docker
- # No need to call "finish" here, it will happen at exit.
-}
-
-# Kick it off.
-main
diff --git a/tools/workspace_status.sh b/tools/workspace_status.sh
index 64a905fc9..fb09ff331 100755
--- a/tools/workspace_status.sh
+++ b/tools/workspace_status.sh
@@ -14,4 +14,5 @@
# See the License for the specific language governing permissions and
# limitations under the License.
-echo VERSION $(git describe --always --tags --abbrev=12 --dirty)
+# The STABLE_ prefix will trigger a re-link if it changes.
+echo STABLE_VERSION $(git describe --always --tags --abbrev=12 --dirty)