summaryrefslogtreecommitdiffhomepage
diff options
context:
space:
mode:
-rw-r--r--.bazelrc2
-rw-r--r--CONTRIBUTING.md32
-rw-r--r--README.md2
-rw-r--r--WORKSPACE55
-rw-r--r--cloudbuild/go.Dockerfile2
-rw-r--r--cloudbuild/go.yaml22
-rw-r--r--kokoro/build.cfg7
-rw-r--r--kokoro/continuous.cfg11
-rw-r--r--kokoro/go.cfg14
-rw-r--r--kokoro/go_tests.cfg (renamed from kokoro/go_test.cfg)0
-rw-r--r--kokoro/presubmit.cfg11
-rw-r--r--kokoro/release-nightly.cfg10
-rwxr-xr-xkokoro/run_build.sh19
-rwxr-xr-xkokoro/ubuntu1604/20_bazel.sh4
-rw-r--r--pkg/abi/linux/BUILD5
-rw-r--r--pkg/abi/linux/signalfd.go45
-rw-r--r--pkg/amutex/BUILD3
-rw-r--r--pkg/atomicbitops/BUILD3
-rw-r--r--pkg/binary/BUILD3
-rw-r--r--pkg/bits/BUILD3
-rw-r--r--pkg/bpf/BUILD4
-rw-r--r--pkg/compressio/BUILD3
-rw-r--r--pkg/cpuid/BUILD4
-rw-r--r--pkg/eventchannel/BUILD3
-rw-r--r--pkg/fd/BUILD3
-rw-r--r--pkg/fdchannel/BUILD3
-rw-r--r--pkg/flipcall/BUILD4
-rw-r--r--pkg/flipcall/ctrl_futex.go1
-rw-r--r--pkg/flipcall/flipcall.go10
-rw-r--r--pkg/flipcall/flipcall_test.go19
-rw-r--r--pkg/flipcall/flipcall_unsafe.go18
-rw-r--r--pkg/fspath/BUILD3
-rw-r--r--pkg/gate/BUILD3
-rw-r--r--pkg/ilist/BUILD3
-rw-r--r--pkg/linewriter/BUILD3
-rw-r--r--pkg/log/BUILD3
-rw-r--r--pkg/metric/BUILD10
-rw-r--r--pkg/p9/BUILD6
-rw-r--r--pkg/p9/client.go273
-rw-r--r--pkg/p9/client_test.go50
-rw-r--r--pkg/p9/handlers.go53
-rw-r--r--pkg/p9/messages.go103
-rw-r--r--pkg/p9/p9.go2
-rw-r--r--pkg/p9/p9test/BUILD6
-rw-r--r--pkg/p9/p9test/client_test.go95
-rw-r--r--pkg/p9/p9test/p9test.go4
-rw-r--r--pkg/p9/server.go178
-rw-r--r--pkg/p9/transport.go5
-rw-r--r--pkg/p9/transport_flipcall.go263
-rw-r--r--pkg/p9/transport_test.go4
-rw-r--r--pkg/p9/version.go9
-rw-r--r--pkg/procid/BUILD3
-rw-r--r--pkg/refs/BUILD4
-rw-r--r--pkg/seccomp/BUILD4
-rw-r--r--pkg/secio/BUILD3
-rw-r--r--pkg/segment/test/BUILD3
-rw-r--r--pkg/sentry/BUILD2
-rw-r--r--pkg/sentry/arch/BUILD7
-rw-r--r--pkg/sentry/control/BUILD3
-rw-r--r--pkg/sentry/device/BUILD4
-rw-r--r--pkg/sentry/fs/BUILD4
-rw-r--r--pkg/sentry/fs/dirent.go4
-rw-r--r--pkg/sentry/fs/dirent_refs_test.go2
-rw-r--r--pkg/sentry/fs/fdpipe/BUILD4
-rw-r--r--pkg/sentry/fs/file.go23
-rw-r--r--pkg/sentry/fs/file_operations.go9
-rw-r--r--pkg/sentry/fs/file_overlay.go9
-rw-r--r--pkg/sentry/fs/fsutil/BUILD4
-rw-r--r--pkg/sentry/fs/fsutil/file.go6
-rw-r--r--pkg/sentry/fs/gofer/BUILD4
-rw-r--r--pkg/sentry/fs/host/BUILD4
-rw-r--r--pkg/sentry/fs/inotify.go5
-rw-r--r--pkg/sentry/fs/lock/BUILD4
-rw-r--r--pkg/sentry/fs/proc/BUILD4
-rw-r--r--pkg/sentry/fs/proc/seqfile/BUILD4
-rw-r--r--pkg/sentry/fs/ramfs/BUILD4
-rw-r--r--pkg/sentry/fs/splice.go162
-rw-r--r--pkg/sentry/fs/tmpfs/BUILD4
-rw-r--r--pkg/sentry/fs/tty/BUILD5
-rw-r--r--pkg/sentry/fs/tty/dir.go3
-rw-r--r--pkg/sentry/fs/tty/master.go17
-rw-r--r--pkg/sentry/fs/tty/slave.go13
-rw-r--r--pkg/sentry/fs/tty/terminal.go92
-rw-r--r--pkg/sentry/fsimpl/ext/BUILD4
-rw-r--r--pkg/sentry/fsimpl/ext/benchmark/BUILD2
-rw-r--r--pkg/sentry/fsimpl/ext/directory.go8
-rw-r--r--pkg/sentry/fsimpl/ext/disklayout/BUILD4
-rw-r--r--pkg/sentry/fsimpl/ext/ext_test.go2
-rw-r--r--pkg/sentry/fsimpl/memfs/BUILD3
-rw-r--r--pkg/sentry/fsimpl/memfs/directory.go24
-rw-r--r--pkg/sentry/fsimpl/proc/BUILD3
-rw-r--r--pkg/sentry/hostcpu/BUILD4
-rw-r--r--pkg/sentry/hostcpu/getcpu_arm64.s28
-rw-r--r--pkg/sentry/kernel/BUILD11
-rw-r--r--pkg/sentry/kernel/epoll/BUILD4
-rw-r--r--pkg/sentry/kernel/eventfd/BUILD4
-rw-r--r--pkg/sentry/kernel/futex/BUILD4
-rw-r--r--pkg/sentry/kernel/memevent/BUILD7
-rw-r--r--pkg/sentry/kernel/pipe/BUILD4
-rw-r--r--pkg/sentry/kernel/pipe/buffer.go25
-rw-r--r--pkg/sentry/kernel/pipe/pipe.go82
-rw-r--r--pkg/sentry/kernel/pipe/reader_writer.go76
-rw-r--r--pkg/sentry/kernel/sched/BUILD3
-rw-r--r--pkg/sentry/kernel/semaphore/BUILD4
-rw-r--r--pkg/sentry/kernel/sessions.go20
-rw-r--r--pkg/sentry/kernel/signalfd/BUILD22
-rw-r--r--pkg/sentry/kernel/signalfd/signalfd.go137
-rw-r--r--pkg/sentry/kernel/task.go8
-rw-r--r--pkg/sentry/kernel/task_signals.go18
-rw-r--r--pkg/sentry/kernel/task_start.go3
-rw-r--r--pkg/sentry/kernel/thread_group.go179
-rw-r--r--pkg/sentry/kernel/tty.go28
-rw-r--r--pkg/sentry/limits/BUILD4
-rw-r--r--pkg/sentry/memmap/BUILD4
-rw-r--r--pkg/sentry/mm/BUILD4
-rw-r--r--pkg/sentry/pgalloc/BUILD4
-rw-r--r--pkg/sentry/platform/interrupt/BUILD3
-rw-r--r--pkg/sentry/platform/kvm/BUILD3
-rw-r--r--pkg/sentry/platform/ptrace/ptrace_unsafe.go16
-rw-r--r--pkg/sentry/platform/ptrace/subprocess.go6
-rw-r--r--pkg/sentry/platform/ring0/pagetables/BUILD3
-rw-r--r--pkg/sentry/platform/safecopy/BUILD3
-rw-r--r--pkg/sentry/safemem/BUILD3
-rw-r--r--pkg/sentry/socket/epsocket/epsocket.go160
-rw-r--r--pkg/sentry/socket/netlink/port/BUILD4
-rw-r--r--pkg/sentry/socket/rpcinet/BUILD9
-rw-r--r--pkg/sentry/socket/unix/transport/unix.go82
-rw-r--r--pkg/sentry/strace/BUILD7
-rw-r--r--pkg/sentry/strace/linux64.go1
-rw-r--r--pkg/sentry/syscalls/linux/BUILD1
-rw-r--r--pkg/sentry/syscalls/linux/linux64.go8
-rw-r--r--pkg/sentry/syscalls/linux/sys_file.go5
-rw-r--r--pkg/sentry/syscalls/linux/sys_signal.go77
-rw-r--r--pkg/sentry/syscalls/linux/sys_splice.go86
-rw-r--r--pkg/sentry/time/BUILD3
-rw-r--r--pkg/sentry/unimpl/BUILD7
-rw-r--r--pkg/sentry/usermem/BUILD4
-rw-r--r--pkg/sentry/vfs/BUILD3
-rw-r--r--pkg/sentry/vfs/file_description.go7
-rw-r--r--pkg/sleep/BUILD3
-rw-r--r--pkg/state/BUILD3
-rw-r--r--pkg/state/statefile/BUILD3
-rw-r--r--pkg/syserror/BUILD3
-rw-r--r--pkg/tcpip/BUILD4
-rw-r--r--pkg/tcpip/adapters/gonet/BUILD3
-rw-r--r--pkg/tcpip/buffer/BUILD4
-rw-r--r--pkg/tcpip/hash/jenkins/BUILD3
-rw-r--r--pkg/tcpip/header/BUILD4
-rw-r--r--pkg/tcpip/header/ipv6.go25
-rw-r--r--pkg/tcpip/header/udp.go5
-rw-r--r--pkg/tcpip/link/channel/channel.go9
-rw-r--r--pkg/tcpip/link/fdbased/BUILD3
-rw-r--r--pkg/tcpip/link/fdbased/endpoint.go38
-rw-r--r--pkg/tcpip/link/fdbased/endpoint_test.go3
-rw-r--r--pkg/tcpip/link/loopback/loopback.go7
-rw-r--r--pkg/tcpip/link/muxed/BUILD3
-rw-r--r--pkg/tcpip/link/muxed/injectable.go12
-rw-r--r--pkg/tcpip/link/muxed/injectable_test.go4
-rw-r--r--pkg/tcpip/link/sharedmem/BUILD3
-rw-r--r--pkg/tcpip/link/sharedmem/pipe/BUILD3
-rw-r--r--pkg/tcpip/link/sharedmem/queue/BUILD3
-rw-r--r--pkg/tcpip/link/sharedmem/sharedmem.go11
-rw-r--r--pkg/tcpip/link/sharedmem/sharedmem_test.go4
-rw-r--r--pkg/tcpip/link/sniffer/sniffer.go21
-rw-r--r--pkg/tcpip/link/waitable/BUILD3
-rw-r--r--pkg/tcpip/link/waitable/waitable.go10
-rw-r--r--pkg/tcpip/link/waitable/waitable_test.go9
-rw-r--r--pkg/tcpip/network/BUILD2
-rw-r--r--pkg/tcpip/network/arp/BUILD3
-rw-r--r--pkg/tcpip/network/arp/arp_test.go10
-rw-r--r--pkg/tcpip/network/fragmentation/BUILD4
-rw-r--r--pkg/tcpip/network/ip_test.go3
-rw-r--r--pkg/tcpip/network/ipv4/BUILD3
-rw-r--r--pkg/tcpip/network/ipv4/ipv4_test.go22
-rw-r--r--pkg/tcpip/network/ipv6/BUILD5
-rw-r--r--pkg/tcpip/network/ipv6/icmp_test.go28
-rw-r--r--pkg/tcpip/network/ipv6/ipv6_test.go258
-rw-r--r--pkg/tcpip/network/ipv6/ndp_test.go15
-rw-r--r--pkg/tcpip/ports/BUILD3
-rw-r--r--pkg/tcpip/sample/tun_tcp_connect/main.go4
-rw-r--r--pkg/tcpip/sample/tun_tcp_echo/main.go4
-rw-r--r--pkg/tcpip/stack/BUILD4
-rw-r--r--pkg/tcpip/stack/icmp_rate_limit.go49
-rw-r--r--pkg/tcpip/stack/nic.go71
-rw-r--r--pkg/tcpip/stack/registration.go37
-rw-r--r--pkg/tcpip/stack/stack.go33
-rw-r--r--pkg/tcpip/stack/stack_test.go278
-rw-r--r--pkg/tcpip/stack/transport_test.go33
-rw-r--r--pkg/tcpip/tcpip.go83
-rw-r--r--pkg/tcpip/transport/icmp/endpoint.go33
-rw-r--r--pkg/tcpip/transport/raw/endpoint.go37
-rw-r--r--pkg/tcpip/transport/tcp/BUILD4
-rw-r--r--pkg/tcpip/transport/tcp/endpoint.go212
-rw-r--r--pkg/tcpip/transport/tcp/tcp_noracedetector_test.go10
-rw-r--r--pkg/tcpip/transport/tcp/tcp_test.go121
-rw-r--r--pkg/tcpip/transport/tcp/testing/context/context.go17
-rw-r--r--pkg/tcpip/transport/tcpconntrack/BUILD3
-rw-r--r--pkg/tcpip/transport/udp/BUILD4
-rw-r--r--pkg/tcpip/transport/udp/endpoint.go50
-rw-r--r--pkg/tcpip/transport/udp/udp_test.go9
-rw-r--r--pkg/tmutex/BUILD3
-rw-r--r--pkg/unet/BUILD3
-rw-r--r--pkg/urpc/BUILD3
-rw-r--r--pkg/waiter/BUILD4
-rw-r--r--runsc/BUILD11
-rw-r--r--runsc/boot/BUILD1
-rw-r--r--runsc/boot/config.go26
-rw-r--r--runsc/boot/filter/config.go14
-rw-r--r--runsc/boot/loader.go32
-rw-r--r--runsc/boot/network.go18
-rw-r--r--runsc/boot/user.go28
-rw-r--r--runsc/boot/user_test.go3
-rw-r--r--runsc/cmd/exec.go1
-rw-r--r--runsc/container/container.go113
-rw-r--r--runsc/dockerutil/dockerutil.go23
-rw-r--r--runsc/fsgofer/filter/BUILD1
-rw-r--r--runsc/fsgofer/filter/config.go36
-rw-r--r--runsc/main.go4
-rw-r--r--runsc/sandbox/sandbox.go10
-rw-r--r--runsc/specutils/specutils.go55
-rw-r--r--runsc/testutil/testutil.go43
-rw-r--r--runsc/version.go2
-rwxr-xr-xrunsc/version_test.sh36
-rwxr-xr-xscripts/build.sh49
-rwxr-xr-xscripts/common.sh59
-rwxr-xr-xscripts/common_bazel.sh37
-rwxr-xr-xscripts/dev.sh73
-rwxr-xr-xscripts/docker_tests.sh6
-rwxr-xr-xscripts/go.sh11
-rwxr-xr-xscripts/hostnet_tests.sh5
-rwxr-xr-xscripts/kvm_tests.sh10
-rwxr-xr-xscripts/overlay_tests.sh5
-rwxr-xr-xscripts/release.sh6
-rwxr-xr-xscripts/root_tests.sh6
-rw-r--r--test/e2e/exec_test.go69
-rw-r--r--test/root/BUILD8
-rw-r--r--test/root/cgroup_test.go18
-rw-r--r--test/root/chroot_test.go16
-rw-r--r--test/root/main_test.go49
-rw-r--r--test/root/oom_score_adj_test.go376
-rw-r--r--test/root/root.go7
-rw-r--r--test/runtimes/BUILD46
-rw-r--r--test/runtimes/README.md5
-rw-r--r--test/runtimes/build_defs.bzl46
-rw-r--r--test/runtimes/common/BUILD20
-rw-r--r--test/runtimes/common/common.go114
-rw-r--r--test/runtimes/go/BUILD9
-rw-r--r--test/runtimes/go/Dockerfile35
-rw-r--r--test/runtimes/images/Dockerfile_go1.1210
-rw-r--r--test/runtimes/images/Dockerfile_java1130
-rw-r--r--test/runtimes/images/Dockerfile_nodejs12.4.028
-rw-r--r--test/runtimes/images/Dockerfile_php7.3.627
-rw-r--r--test/runtimes/images/Dockerfile_python3.7.330
-rw-r--r--test/runtimes/images/proctor/BUILD26
-rw-r--r--test/runtimes/images/proctor/go.go (renamed from test/runtimes/go/proctor-go.go)67
-rw-r--r--test/runtimes/images/proctor/java.go (renamed from test/runtimes/java/proctor-java.go)49
-rw-r--r--test/runtimes/images/proctor/nodejs.go46
-rw-r--r--test/runtimes/images/proctor/php.go (renamed from test/runtimes/php/proctor-php.go)36
-rw-r--r--test/runtimes/images/proctor/proctor.go154
-rw-r--r--test/runtimes/images/proctor/proctor_test.go (renamed from test/runtimes/common/common_test.go)11
-rw-r--r--test/runtimes/images/proctor/python.go (renamed from test/runtimes/python/proctor-python.go)34
-rw-r--r--test/runtimes/java/BUILD9
-rw-r--r--test/runtimes/java/Dockerfile36
-rw-r--r--test/runtimes/nodejs/BUILD9
-rw-r--r--test/runtimes/nodejs/Dockerfile31
-rw-r--r--test/runtimes/nodejs/proctor-nodejs.go60
-rw-r--r--test/runtimes/php/BUILD9
-rw-r--r--test/runtimes/php/Dockerfile31
-rw-r--r--test/runtimes/python/BUILD9
-rw-r--r--test/runtimes/python/Dockerfile33
-rw-r--r--test/runtimes/runner.go147
-rwxr-xr-x[-rw-r--r--]test/runtimes/runner.sh (renamed from kokoro/run_tests.sh)30
-rw-r--r--test/runtimes/runtimes_test.go93
-rw-r--r--test/syscalls/BUILD4
-rw-r--r--test/syscalls/linux/BUILD58
-rw-r--r--test/syscalls/linux/aio.cc155
-rw-r--r--test/syscalls/linux/chown.cc30
-rw-r--r--test/syscalls/linux/fcntl.cc47
-rw-r--r--test/syscalls/linux/kill.cc13
-rw-r--r--test/syscalls/linux/link.cc6
-rw-r--r--test/syscalls/linux/mremap.cc11
-rw-r--r--test/syscalls/linux/pipe.cc14
-rw-r--r--test/syscalls/linux/prctl.cc9
-rw-r--r--test/syscalls/linux/prctl_setuid.cc24
-rw-r--r--test/syscalls/linux/proc.cc4
-rw-r--r--test/syscalls/linux/proc_net.cc2
-rw-r--r--test/syscalls/linux/ptrace.cc11
-rw-r--r--test/syscalls/linux/pty.cc393
-rw-r--r--test/syscalls/linux/pty_root.cc68
-rw-r--r--test/syscalls/linux/sendfile.cc69
-rw-r--r--test/syscalls/linux/signalfd.cc333
-rw-r--r--test/syscalls/linux/sigstop.cc7
-rw-r--r--test/syscalls/linux/socket_ip_tcp_generic.cc2
-rw-r--r--test/syscalls/linux/splice.cc194
-rw-r--r--test/syscalls/linux/sticky.cc31
-rw-r--r--test/syscalls/linux/timers.cc7
-rw-r--r--test/syscalls/linux/uidgid.cc23
-rw-r--r--test/syscalls/linux/unlink.cc2
-rw-r--r--test/syscalls/linux/vfork.cc7
-rw-r--r--test/syscalls/syscall_test_runner.go32
-rw-r--r--test/util/BUILD15
-rw-r--r--test/util/memory_util.h12
-rw-r--r--test/util/pty_util.cc45
-rw-r--r--test/util/pty_util.h30
-rw-r--r--test/util/test_util.cc4
-rw-r--r--test/util/test_util.h1
-rw-r--r--third_party/gvsync/downgradable_rwmutex_unsafe.go3
-rwxr-xr-xtools/go_branch.sh6
-rw-r--r--tools/go_marshal/BUILD14
-rw-r--r--tools/go_marshal/README.md164
-rw-r--r--tools/go_marshal/analysis/BUILD13
-rw-r--r--tools/go_marshal/analysis/analysis_unsafe.go175
-rw-r--r--tools/go_marshal/defs.bzl152
-rw-r--r--tools/go_marshal/gomarshal/BUILD17
-rw-r--r--tools/go_marshal/gomarshal/generator.go382
-rw-r--r--tools/go_marshal/gomarshal/generator_interfaces.go507
-rw-r--r--tools/go_marshal/gomarshal/generator_tests.go154
-rw-r--r--tools/go_marshal/gomarshal/util.go387
-rw-r--r--tools/go_marshal/main.go73
-rw-r--r--tools/go_marshal/marshal/BUILD14
-rw-r--r--tools/go_marshal/marshal/marshal.go60
-rw-r--r--tools/go_marshal/test/BUILD31
-rw-r--r--tools/go_marshal/test/benchmark_test.go178
-rw-r--r--tools/go_marshal/test/external/BUILD11
-rw-r--r--tools/go_marshal/test/external/external.go (renamed from test/runtimes/runtimes.go)15
-rw-r--r--tools/go_marshal/test/test.go105
-rw-r--r--tools/go_stateify/defs.bzl65
-rwxr-xr-xtools/make_repository.sh36
-rwxr-xr-xtools/workspace_status.sh3
329 files changed, 9221 insertions, 2327 deletions
diff --git a/.bazelrc b/.bazelrc
index eda884473..379fc8328 100644
--- a/.bazelrc
+++ b/.bazelrc
@@ -13,7 +13,7 @@
# limitations under the License.
# Display the current git revision in the info block.
-build --workspace_status_command tools/workspace_status.sh
+build --stamp --workspace_status_command tools/workspace_status.sh
# Enable remote execution so actions are performed on the remote systems.
build:remote --remote_executor=grpcs://remotebuildexecution.googleapis.com
diff --git a/CONTRIBUTING.md b/CONTRIBUTING.md
index 638942a42..5d46168bc 100644
--- a/CONTRIBUTING.md
+++ b/CONTRIBUTING.md
@@ -83,6 +83,8 @@ Rules:
### Code reviews
+Before sending code reviews, run `bazel test ...` to ensure tests are passing.
+
Code changes are accepted via [pull request][github].
When approved, the change will be submitted by a team member and automatically
@@ -100,6 +102,36 @@ form `b/1234`. These correspond to bugs in our internal bug tracker. Eventually
these bugs will be moved to the GitHub Issues, but until then they can simply be
ignored.
+### Build and test with Docker
+
+`scripts/dev.sh` is a convenient script that builds and installs `runsc` as a
+new Docker runtime for you. The scripts tries to extract the runtime name from
+your local environment and will print it at the end. You can also customize it.
+The script creates one regular runtime and another with debug flags enabled.
+Here are a few examples:
+
+```bash
+# Default case (inside branch my-branch)
+$ scripts/dev.sh
+...
+Runtimes my-branch and my-branch-d (debug enabled) setup.
+Use --runtime=my-branch with your Docker command.
+ docker run --rm --runtime=my-branch --rm hello-world
+
+If you rebuild, use scripts/dev.sh --refresh.
+Logs are in: /tmp/my-branch/logs
+
+# --refresh just updates the runtime binary and doesn't restart docker.
+$ git/my_branch> scripts/dev.sh --refresh
+
+# Using a custom runtime name
+$ git/my_branch> scripts/dev.sh my-runtime
+...
+Runtimes my-runtime and my-runtime-d (debug enabled) setup.
+Use --runtime=my-runtime with your Docker command.
+ docker run --rm --runtime=my-runtime --rm hello-world
+```
+
### The small print
Contributions made by corporations are covered by a different agreement than the
diff --git a/README.md b/README.md
index d102845ac..7ab76d305 100644
--- a/README.md
+++ b/README.md
@@ -48,7 +48,7 @@ Make sure the following dependencies are installed:
* Linux 4.14.77+ ([older linux][old-linux])
* [git][git]
-* [Bazel][bazel] 0.23.0+
+* [Bazel][bazel] 0.28.0+
* [Python][python]
* [Docker version 17.09.0 or greater][docker]
* Gold linker (e.g. `binutils-gold` package on Ubuntu)
diff --git a/WORKSPACE b/WORKSPACE
index c403e774d..77750f9e6 100644
--- a/WORKSPACE
+++ b/WORKSPACE
@@ -3,19 +3,19 @@ load("@bazel_tools//tools/build_defs/repo:http.bzl", "http_archive")
http_archive(
name = "io_bazel_rules_go",
- sha256 = "313f2c7a23fecc33023563f082f381a32b9b7254f727a7dd2d6380ccc6dfe09b",
+ sha256 = "ae8c36ff6e565f674c7a3692d6a9ea1096e4c1ade497272c2108a810fb39acd2",
urls = [
- "https://storage.googleapis.com/bazel-mirror/github.com/bazelbuild/rules_go/releases/download/0.19.3/rules_go-0.19.3.tar.gz",
- "https://github.com/bazelbuild/rules_go/releases/download/0.19.3/rules_go-0.19.3.tar.gz",
+ "https://storage.googleapis.com/bazel-mirror/github.com/bazelbuild/rules_go/releases/download/0.19.4/rules_go-0.19.4.tar.gz",
+ "https://github.com/bazelbuild/rules_go/releases/download/0.19.4/rules_go-0.19.4.tar.gz",
],
)
http_archive(
name = "bazel_gazelle",
- sha256 = "be9296bfd64882e3c08e3283c58fcb461fa6dd3c171764fcc4cf322f60615a9b",
+ sha256 = "7fc87f4170011201b1690326e8c16c5d802836e3a0d617d8f75c3af2b23180c4",
urls = [
- "https://storage.googleapis.com/bazel-mirror/github.com/bazelbuild/bazel-gazelle/releases/download/0.18.1/bazel-gazelle-0.18.1.tar.gz",
- "https://github.com/bazelbuild/bazel-gazelle/releases/download/0.18.1/bazel-gazelle-0.18.1.tar.gz",
+ "https://storage.googleapis.com/bazel-mirror/github.com/bazelbuild/bazel-gazelle/releases/download/0.18.2/bazel-gazelle-0.18.2.tar.gz",
+ "https://github.com/bazelbuild/bazel-gazelle/releases/download/0.18.2/bazel-gazelle-0.18.2.tar.gz",
],
)
@@ -24,7 +24,7 @@ load("@io_bazel_rules_go//go:deps.bzl", "go_rules_dependencies", "go_register_to
go_rules_dependencies()
go_register_toolchains(
- go_version = "1.12.9",
+ go_version = "1.13",
nogo = "@//:nogo",
)
@@ -44,13 +44,14 @@ http_archive(
)
# Load protobuf dependencies.
-load("@bazel_tools//tools/build_defs/repo:git.bzl", "git_repository")
-
-git_repository(
+http_archive(
name = "com_google_protobuf",
- commit = "09745575a923640154bcf307fba8aedff47f240a",
- remote = "https://github.com/protocolbuffers/protobuf",
- shallow_since = "1558721209 -0700",
+ sha256 = "532d2575d8c0992065bb19ec5fba13aa3683499726f6055c11b474f91a00bb0c",
+ strip_prefix = "protobuf-7f520092d9050d96fb4b707ad11a51701af4ce49",
+ urls = [
+ "https://mirror.bazel.build/github.com/protocolbuffers/protobuf/archive/7f520092d9050d96fb4b707ad11a51701af4ce49.zip",
+ "https://github.com/protocolbuffers/protobuf/archive/7f520092d9050d96fb4b707ad11a51701af4ce49.zip",
+ ],
)
load("@com_google_protobuf//:protobuf_deps.bzl", "protobuf_deps")
@@ -61,11 +62,11 @@ protobuf_deps()
# See releases at https://releases.bazel.build/bazel-toolchains.html
http_archive(
name = "bazel_toolchains",
- sha256 = "e71eadcfcbdb47b4b740eb48b32ca4226e36aabc425d035a18dd40c2dda808c1",
- strip_prefix = "bazel-toolchains-0.28.4",
+ sha256 = "a019fbd579ce5aed0239de865b2d8281dbb809efd537bf42e0d366783e8dec65",
+ strip_prefix = "bazel-toolchains-0.29.2",
urls = [
- "https://mirror.bazel.build/github.com/bazelbuild/bazel-toolchains/archive/0.28.4.tar.gz",
- "https://github.com/bazelbuild/bazel-toolchains/archive/0.28.4.tar.gz",
+ "https://mirror.bazel.build/github.com/bazelbuild/bazel-toolchains/archive/0.29.2.tar.gz",
+ "https://github.com/bazelbuild/bazel-toolchains/archive/0.29.2.tar.gz",
],
)
@@ -195,7 +196,7 @@ go_repository(
go_repository(
name = "org_golang_x_time",
- commit = "9d24e82272b4f38b78bc8cff74fa936d31ccd8ef",
+ commit = "c4c64cad1fd0a1a8dab2523e04e61d35308e131e",
importpath = "golang.org/x/time",
)
@@ -221,16 +222,6 @@ go_repository(
# System Call test dependencies.
http_archive(
- name = "com_github_gflags_gflags",
- sha256 = "34af2f15cf7367513b352bdcd2493ab14ce43692d2dcd9dfc499492966c64dcf",
- strip_prefix = "gflags-2.2.2",
- urls = [
- "https://mirror.bazel.build/github.com/gflags/gflags/archive/v2.2.2.tar.gz",
- "https://github.com/gflags/gflags/archive/v2.2.2.tar.gz",
- ],
-)
-
-http_archive(
name = "com_google_absl",
sha256 = "56775f1283a59e6274c28d99981a9717ff4e0b1161e9129fdb2fcf22531d8d93",
strip_prefix = "abseil-cpp-a0d1e098c2f99694fa399b175a7ccf920762030e",
@@ -242,10 +233,10 @@ http_archive(
http_archive(
name = "com_google_googletest",
- sha256 = "db657310d3c5ca2d3f674e3a4b79718d1d39da70604568ee0568ba8e39065ef4",
- strip_prefix = "googletest-31200def0dec8a624c861f919e86e4444e6e6ee7",
+ sha256 = "0a10bea96d8670e5eef948d79d824162b1577bb7889539e49ec786bfc3e48912",
+ strip_prefix = "googletest-565f1b848215b77c3732bca345fe76a0431d8b34",
urls = [
- "https://mirror.bazel.build/github.com/google/googletest/archive/31200def0dec8a624c861f919e86e4444e6e6ee7.tar.gz",
- "https://github.com/google/googletest/archive/31200def0dec8a624c861f919e86e4444e6e6ee7.tar.gz",
+ "https://mirror.bazel.build/github.com/google/googletest/archive/565f1b848215b77c3732bca345fe76a0431d8b34.tar.gz",
+ "https://github.com/google/googletest/archive/565f1b848215b77c3732bca345fe76a0431d8b34.tar.gz",
],
)
diff --git a/cloudbuild/go.Dockerfile b/cloudbuild/go.Dockerfile
deleted file mode 100644
index 226442fd2..000000000
--- a/cloudbuild/go.Dockerfile
+++ /dev/null
@@ -1,2 +0,0 @@
-FROM ubuntu
-RUN apt-get -q update && apt-get install -qqy git rsync
diff --git a/cloudbuild/go.yaml b/cloudbuild/go.yaml
deleted file mode 100644
index a38ef71fc..000000000
--- a/cloudbuild/go.yaml
+++ /dev/null
@@ -1,22 +0,0 @@
-steps:
-- name: 'gcr.io/cloud-builders/git'
- args: ['fetch', '--all', '--unshallow']
-- name: 'gcr.io/cloud-builders/bazel'
- args: ['build', ':gopath']
-- name: 'gcr.io/cloud-builders/docker'
- args: ['build', '-t', 'gcr.io/$PROJECT_ID/go-branch', '-f', 'cloudbuild/go.Dockerfile', '.']
-- name: 'gcr.io/$PROJECT_ID/go-branch'
- args: ['tools/go_branch.sh']
-- name: 'gcr.io/cloud-builders/git'
- args: ['checkout', 'go']
-- name: 'gcr.io/cloud-builders/git'
- args: ['clean', '-f']
-- name: 'golang'
- args: ['go', 'build', './...']
-- name: 'gcr.io/cloud-builders/git'
- entrypoint: 'bash'
- args:
- - '-c'
- - 'if [[ "$BRANCH_NAME" == "master" ]]; then git push "${_ORIGIN}" go:go; fi'
-substitutions:
- _ORIGIN: origin
diff --git a/kokoro/build.cfg b/kokoro/build.cfg
index d67af4694..084347dde 100644
--- a/kokoro/build.cfg
+++ b/kokoro/build.cfg
@@ -11,13 +11,12 @@ before_action {
env_vars {
key: "KOKORO_REPO_KEY"
- value: "$KOKORO_ROOT/src/keystore/73898_kokoro-repo-key"
+ value: "73898_kokoro-repo-key"
}
action {
define_artifacts {
- regex: "**/runsc"
- regex: "**/runsc.sha256"
- regex: "**/repo/**"
+ regex: "**/runsc.*"
+ regex: "**/dists/**"
}
}
diff --git a/kokoro/continuous.cfg b/kokoro/continuous.cfg
deleted file mode 100644
index 88694220a..000000000
--- a/kokoro/continuous.cfg
+++ /dev/null
@@ -1,11 +0,0 @@
-# This is a temporary file. It will be removed when new Kokoro jobs exist for
-# all the other presubmits.
-build_file: "repo/scripts/build.sh"
-
-action {
- define_artifacts {
- regex: "**/sponge_log.xml"
- regex: "**/sponge_log.log"
- regex: "**/outputs.zip"
- }
-}
diff --git a/kokoro/go.cfg b/kokoro/go.cfg
index d1577252a..b9c1fcb12 100644
--- a/kokoro/go.cfg
+++ b/kokoro/go.cfg
@@ -1,5 +1,19 @@
build_file: "repo/scripts/go.sh"
+before_action {
+ fetch_keystore {
+ keystore_resource {
+ keystore_config_id: 73898
+ keyname: "kokoro-github-access-token"
+ }
+ }
+}
+
+env_vars {
+ key: "KOKORO_GITHUB_ACCESS_TOKEN"
+ value: "73898_kokoro-github-access-token"
+}
+
env_vars {
key: "KOKORO_GO_PUSH"
value: "true"
diff --git a/kokoro/go_test.cfg b/kokoro/go_tests.cfg
index 5eb51041a..5eb51041a 100644
--- a/kokoro/go_test.cfg
+++ b/kokoro/go_tests.cfg
diff --git a/kokoro/presubmit.cfg b/kokoro/presubmit.cfg
deleted file mode 100644
index eb0c78ea4..000000000
--- a/kokoro/presubmit.cfg
+++ /dev/null
@@ -1,11 +0,0 @@
-# This is a temporary file. It will be removed when new Kokoro jobs exist for
-# all the other presubmits.
-build_file: "repo/kokoro/run_tests.sh"
-
-action {
- define_artifacts {
- regex: "**/sponge_log.xml"
- regex: "**/sponge_log.log"
- regex: "**/outputs.zip"
- }
-}
diff --git a/kokoro/release-nightly.cfg b/kokoro/release-nightly.cfg
deleted file mode 100644
index ae134258c..000000000
--- a/kokoro/release-nightly.cfg
+++ /dev/null
@@ -1,10 +0,0 @@
-# This file is a temporary bridge. It will be removed shortly, when Kokoro jobs
-# are configured to point at the new build and release configurations.
-build_file: "repo/kokoro/run_build.sh"
-
-action {
- define_artifacts {
- regex: "**/runsc"
- regex: "**/runsc.sha512"
- }
-}
diff --git a/kokoro/run_build.sh b/kokoro/run_build.sh
deleted file mode 100755
index da6a0c85e..000000000
--- a/kokoro/run_build.sh
+++ /dev/null
@@ -1,19 +0,0 @@
-#!/bin/bash
-
-# Copyright 2018 The gVisor Authors.
-#
-# Licensed under the Apache License, Version 2.0 (the "License");
-# you may not use this file except in compliance with the License.
-# You may obtain a copy of the License at
-#
-# http://www.apache.org/licenses/LICENSE-2.0
-#
-# Unless required by applicable law or agreed to in writing, software
-# distributed under the License is distributed on an "AS IS" BASIS,
-# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-# See the License for the specific language governing permissions and
-# limitations under the License.
-
-# This file is a temporary bridge. We will create multiple independent Kokoro
-# workflows that call each of the build scripts independently.
-KOKORO_BUILD_NIGHTLY=true $(dirname $0)/../scripts/build.sh
diff --git a/kokoro/ubuntu1604/20_bazel.sh b/kokoro/ubuntu1604/20_bazel.sh
index 74b4b8be2..b9a894024 100755
--- a/kokoro/ubuntu1604/20_bazel.sh
+++ b/kokoro/ubuntu1604/20_bazel.sh
@@ -16,9 +16,7 @@
set -xeo pipefail
-# We need to install a specific version of bazel due to a bug with the RBE
-# environment not respecting the dockerPrivileged configuration.
-declare -r BAZEL_VERSION=0.28.1
+declare -r BAZEL_VERSION=0.29.1
# Install bazel dependencies.
apt-get update && apt-get install -y openjdk-8-jdk-headless unzip
diff --git a/pkg/abi/linux/BUILD b/pkg/abi/linux/BUILD
index ba233b93f..f45934466 100644
--- a/pkg/abi/linux/BUILD
+++ b/pkg/abi/linux/BUILD
@@ -2,9 +2,11 @@
# Linux kernel. It should be used instead of syscall or golang.org/x/sys/unix
# when the host OS may not be Linux.
+load("@io_bazel_rules_go//go:def.bzl", "go_test")
+
package(licenses = ["notice"])
-load("//tools/go_stateify:defs.bzl", "go_library", "go_test")
+load("//tools/go_stateify:defs.bzl", "go_library")
go_library(
name = "linux",
@@ -44,6 +46,7 @@ go_library(
"sem.go",
"shm.go",
"signal.go",
+ "signalfd.go",
"socket.go",
"splice.go",
"tcp.go",
diff --git a/pkg/abi/linux/signalfd.go b/pkg/abi/linux/signalfd.go
new file mode 100644
index 000000000..85fad9956
--- /dev/null
+++ b/pkg/abi/linux/signalfd.go
@@ -0,0 +1,45 @@
+// Copyright 2019 The gVisor Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+package linux
+
+const (
+ // SFD_NONBLOCK is a signalfd(2) flag.
+ SFD_NONBLOCK = 00004000
+
+ // SFD_CLOEXEC is a signalfd(2) flag.
+ SFD_CLOEXEC = 02000000
+)
+
+// SignalfdSiginfo is the siginfo encoding for signalfds.
+type SignalfdSiginfo struct {
+ Signo uint32
+ Errno int32
+ Code int32
+ PID uint32
+ UID uint32
+ FD int32
+ TID uint32
+ Band uint32
+ Overrun uint32
+ TrapNo uint32
+ Status int32
+ Int int32
+ Ptr uint64
+ UTime uint64
+ STime uint64
+ Addr uint64
+ AddrLSB uint16
+ _ [48]uint8
+}
diff --git a/pkg/amutex/BUILD b/pkg/amutex/BUILD
index 39d253b98..6bc486b62 100644
--- a/pkg/amutex/BUILD
+++ b/pkg/amutex/BUILD
@@ -1,4 +1,5 @@
-load("//tools/go_stateify:defs.bzl", "go_library", "go_test")
+load("//tools/go_stateify:defs.bzl", "go_library")
+load("@io_bazel_rules_go//go:def.bzl", "go_test")
package(licenses = ["notice"])
diff --git a/pkg/atomicbitops/BUILD b/pkg/atomicbitops/BUILD
index 47ab65346..5f59866fa 100644
--- a/pkg/atomicbitops/BUILD
+++ b/pkg/atomicbitops/BUILD
@@ -1,4 +1,5 @@
-load("//tools/go_stateify:defs.bzl", "go_library", "go_test")
+load("//tools/go_stateify:defs.bzl", "go_library")
+load("@io_bazel_rules_go//go:def.bzl", "go_test")
package(licenses = ["notice"])
diff --git a/pkg/binary/BUILD b/pkg/binary/BUILD
index 09d6c2c1f..543fb54bf 100644
--- a/pkg/binary/BUILD
+++ b/pkg/binary/BUILD
@@ -1,4 +1,5 @@
-load("//tools/go_stateify:defs.bzl", "go_library", "go_test")
+load("//tools/go_stateify:defs.bzl", "go_library")
+load("@io_bazel_rules_go//go:def.bzl", "go_test")
package(licenses = ["notice"])
diff --git a/pkg/bits/BUILD b/pkg/bits/BUILD
index 0c2dde4f8..51967b811 100644
--- a/pkg/bits/BUILD
+++ b/pkg/bits/BUILD
@@ -1,4 +1,5 @@
-load("//tools/go_stateify:defs.bzl", "go_library", "go_test")
+load("//tools/go_stateify:defs.bzl", "go_library")
+load("@io_bazel_rules_go//go:def.bzl", "go_test")
package(licenses = ["notice"])
diff --git a/pkg/bpf/BUILD b/pkg/bpf/BUILD
index b692aa3b1..8d31e068c 100644
--- a/pkg/bpf/BUILD
+++ b/pkg/bpf/BUILD
@@ -1,6 +1,8 @@
+load("@io_bazel_rules_go//go:def.bzl", "go_test")
+
package(licenses = ["notice"])
-load("//tools/go_stateify:defs.bzl", "go_library", "go_test")
+load("//tools/go_stateify:defs.bzl", "go_library")
go_library(
name = "bpf",
diff --git a/pkg/compressio/BUILD b/pkg/compressio/BUILD
index cdec96df1..a0b21d4bd 100644
--- a/pkg/compressio/BUILD
+++ b/pkg/compressio/BUILD
@@ -1,4 +1,5 @@
-load("//tools/go_stateify:defs.bzl", "go_library", "go_test")
+load("//tools/go_stateify:defs.bzl", "go_library")
+load("@io_bazel_rules_go//go:def.bzl", "go_test")
package(licenses = ["notice"])
diff --git a/pkg/cpuid/BUILD b/pkg/cpuid/BUILD
index 830e19e07..32422f9e2 100644
--- a/pkg/cpuid/BUILD
+++ b/pkg/cpuid/BUILD
@@ -1,6 +1,8 @@
+load("@io_bazel_rules_go//go:def.bzl", "go_test")
+
package(licenses = ["notice"])
-load("//tools/go_stateify:defs.bzl", "go_library", "go_test")
+load("//tools/go_stateify:defs.bzl", "go_library")
go_library(
name = "cpuid",
diff --git a/pkg/eventchannel/BUILD b/pkg/eventchannel/BUILD
index 9961baaa9..71f2abc83 100644
--- a/pkg/eventchannel/BUILD
+++ b/pkg/eventchannel/BUILD
@@ -1,5 +1,6 @@
-load("//tools/go_stateify:defs.bzl", "go_library", "go_test")
+load("//tools/go_stateify:defs.bzl", "go_library")
load("@io_bazel_rules_go//proto:def.bzl", "go_proto_library")
+load("@io_bazel_rules_go//go:def.bzl", "go_test")
package(licenses = ["notice"])
diff --git a/pkg/fd/BUILD b/pkg/fd/BUILD
index 785c685a0..afa8f7659 100644
--- a/pkg/fd/BUILD
+++ b/pkg/fd/BUILD
@@ -1,4 +1,5 @@
-load("//tools/go_stateify:defs.bzl", "go_library", "go_test")
+load("//tools/go_stateify:defs.bzl", "go_library")
+load("@io_bazel_rules_go//go:def.bzl", "go_test")
package(licenses = ["notice"])
diff --git a/pkg/fdchannel/BUILD b/pkg/fdchannel/BUILD
index e54e7371c..56495cbd9 100644
--- a/pkg/fdchannel/BUILD
+++ b/pkg/fdchannel/BUILD
@@ -1,4 +1,5 @@
-load("//tools/go_stateify:defs.bzl", "go_library", "go_test")
+load("//tools/go_stateify:defs.bzl", "go_library")
+load("@io_bazel_rules_go//go:def.bzl", "go_test")
package(licenses = ["notice"])
diff --git a/pkg/flipcall/BUILD b/pkg/flipcall/BUILD
index bd1d614b6..5643d5f26 100644
--- a/pkg/flipcall/BUILD
+++ b/pkg/flipcall/BUILD
@@ -1,4 +1,5 @@
-load("//tools/go_stateify:defs.bzl", "go_library", "go_test")
+load("//tools/go_stateify:defs.bzl", "go_library")
+load("@io_bazel_rules_go//go:def.bzl", "go_test")
package(licenses = ["notice"])
@@ -18,6 +19,7 @@ go_library(
"//pkg/abi/linux",
"//pkg/log",
"//pkg/memutil",
+ "//third_party/gvsync",
],
)
diff --git a/pkg/flipcall/ctrl_futex.go b/pkg/flipcall/ctrl_futex.go
index d59159912..8390915a2 100644
--- a/pkg/flipcall/ctrl_futex.go
+++ b/pkg/flipcall/ctrl_futex.go
@@ -82,6 +82,7 @@ func (ep *Endpoint) ctrlWaitFirst() error {
*ep.dataLen() = w.Len()
// Return control to the client.
+ raceBecomeInactive()
if err := ep.futexSwitchToPeer(); err != nil {
return err
}
diff --git a/pkg/flipcall/flipcall.go b/pkg/flipcall/flipcall.go
index 991018684..386cee42c 100644
--- a/pkg/flipcall/flipcall.go
+++ b/pkg/flipcall/flipcall.go
@@ -180,7 +180,11 @@ const (
// Preconditions: ep is a client Endpoint. ep.Connect(), ep.RecvFirst(),
// ep.SendRecv(), and ep.SendLast() have never been called.
func (ep *Endpoint) Connect() error {
- return ep.ctrlConnect()
+ err := ep.ctrlConnect()
+ if err == nil {
+ raceBecomeActive()
+ }
+ return err
}
// RecvFirst blocks until the peer Endpoint calls Endpoint.SendRecv(), then
@@ -192,6 +196,7 @@ func (ep *Endpoint) RecvFirst() (uint32, error) {
if err := ep.ctrlWaitFirst(); err != nil {
return 0, err
}
+ raceBecomeActive()
recvDataLen := atomic.LoadUint32(ep.dataLen())
if recvDataLen > ep.dataCap {
return 0, fmt.Errorf("received packet with invalid datagram length %d (maximum %d)", recvDataLen, ep.dataCap)
@@ -218,9 +223,11 @@ func (ep *Endpoint) SendRecv(dataLen uint32) (uint32, error) {
// after ep.ctrlRoundTrip(), so if the peer is mutating it concurrently then
// they can only shoot themselves in the foot.
*ep.dataLen() = dataLen
+ raceBecomeInactive()
if err := ep.ctrlRoundTrip(); err != nil {
return 0, err
}
+ raceBecomeActive()
recvDataLen := atomic.LoadUint32(ep.dataLen())
if recvDataLen > ep.dataCap {
return 0, fmt.Errorf("received packet with invalid datagram length %d (maximum %d)", recvDataLen, ep.dataCap)
@@ -240,6 +247,7 @@ func (ep *Endpoint) SendLast(dataLen uint32) error {
panic(fmt.Sprintf("attempting to send packet with datagram length %d (maximum %d)", dataLen, ep.dataCap))
}
*ep.dataLen() = dataLen
+ raceBecomeInactive()
if err := ep.ctrlWakeLast(); err != nil {
return err
}
diff --git a/pkg/flipcall/flipcall_test.go b/pkg/flipcall/flipcall_test.go
index 435e4eeae..168a487ec 100644
--- a/pkg/flipcall/flipcall_test.go
+++ b/pkg/flipcall/flipcall_test.go
@@ -62,6 +62,9 @@ func (c *testConnection) destroy() {
}
func testSendRecv(t *testing.T, c *testConnection) {
+ // This shared variable is used to confirm that synchronization between
+ // flipcall endpoints is visible to the Go race detector.
+ state := 0
var serverRun sync.WaitGroup
serverRun.Add(1)
go func() {
@@ -71,11 +74,19 @@ func testSendRecv(t *testing.T, c *testConnection) {
t.Errorf("server Endpoint.RecvFirst() failed: %v", err)
return
}
+ state++
+ if state != 2 {
+ t.Errorf("shared state counter: got %d, wanted 2", state)
+ }
t.Logf("server Endpoint got packet 1, sending packet 2 and waiting for packet 3")
if _, err := c.serverEP.SendRecv(0); err != nil {
t.Errorf("server Endpoint.SendRecv() failed: %v", err)
return
}
+ state++
+ if state != 4 {
+ t.Errorf("shared state counter: got %d, wanted 4", state)
+ }
t.Logf("server Endpoint got packet 3")
}()
defer func() {
@@ -89,10 +100,18 @@ func testSendRecv(t *testing.T, c *testConnection) {
if err := c.clientEP.Connect(); err != nil {
t.Fatalf("client Endpoint.Connect() failed: %v", err)
}
+ state++
+ if state != 1 {
+ t.Errorf("shared state counter: got %d, wanted 1", state)
+ }
t.Logf("client Endpoint sending packet 1 and waiting for packet 2")
if _, err := c.clientEP.SendRecv(0); err != nil {
t.Fatalf("client Endpoint.SendRecv() failed: %v", err)
}
+ state++
+ if state != 3 {
+ t.Errorf("shared state counter: got %d, wanted 3", state)
+ }
t.Logf("client Endpoint got packet 2, sending packet 3")
if err := c.clientEP.SendLast(0); err != nil {
t.Fatalf("client Endpoint.SendLast() failed: %v", err)
diff --git a/pkg/flipcall/flipcall_unsafe.go b/pkg/flipcall/flipcall_unsafe.go
index 73e6eef29..a37952637 100644
--- a/pkg/flipcall/flipcall_unsafe.go
+++ b/pkg/flipcall/flipcall_unsafe.go
@@ -17,6 +17,8 @@ package flipcall
import (
"reflect"
"unsafe"
+
+ "gvisor.dev/gvisor/third_party/gvsync"
)
// Packets consist of a 16-byte header followed by an arbitrarily-sized
@@ -67,3 +69,19 @@ func (ep *Endpoint) Data() []byte {
bsReflect.Cap = int(ep.dataCap)
return bs
}
+
+// ioSync is a dummy variable used to indicate synchronization to the Go race
+// detector. Compare syscall.ioSync.
+var ioSync int64
+
+func raceBecomeActive() {
+ if gvsync.RaceEnabled {
+ gvsync.RaceAcquire((unsafe.Pointer)(&ioSync))
+ }
+}
+
+func raceBecomeInactive() {
+ if gvsync.RaceEnabled {
+ gvsync.RaceReleaseMerge((unsafe.Pointer)(&ioSync))
+ }
+}
diff --git a/pkg/fspath/BUILD b/pkg/fspath/BUILD
index 11716af81..0c5f50397 100644
--- a/pkg/fspath/BUILD
+++ b/pkg/fspath/BUILD
@@ -1,4 +1,5 @@
-load("//tools/go_stateify:defs.bzl", "go_library", "go_test")
+load("//tools/go_stateify:defs.bzl", "go_library")
+load("@io_bazel_rules_go//go:def.bzl", "go_test")
package(
default_visibility = ["//visibility:public"],
diff --git a/pkg/gate/BUILD b/pkg/gate/BUILD
index e6a8dbd02..4b9321711 100644
--- a/pkg/gate/BUILD
+++ b/pkg/gate/BUILD
@@ -1,4 +1,5 @@
-load("//tools/go_stateify:defs.bzl", "go_library", "go_test")
+load("//tools/go_stateify:defs.bzl", "go_library")
+load("@io_bazel_rules_go//go:def.bzl", "go_test")
package(licenses = ["notice"])
diff --git a/pkg/ilist/BUILD b/pkg/ilist/BUILD
index 8f3defa25..34d2673ef 100644
--- a/pkg/ilist/BUILD
+++ b/pkg/ilist/BUILD
@@ -1,5 +1,6 @@
+load("@io_bazel_rules_go//go:def.bzl", "go_test")
load("//tools/go_generics:defs.bzl", "go_template", "go_template_instance")
-load("//tools/go_stateify:defs.bzl", "go_library", "go_test")
+load("//tools/go_stateify:defs.bzl", "go_library")
package(licenses = ["notice"])
diff --git a/pkg/linewriter/BUILD b/pkg/linewriter/BUILD
index c8e923a74..a5d980d14 100644
--- a/pkg/linewriter/BUILD
+++ b/pkg/linewriter/BUILD
@@ -1,4 +1,5 @@
-load("//tools/go_stateify:defs.bzl", "go_library", "go_test")
+load("//tools/go_stateify:defs.bzl", "go_library")
+load("@io_bazel_rules_go//go:def.bzl", "go_test")
package(licenses = ["notice"])
diff --git a/pkg/log/BUILD b/pkg/log/BUILD
index 12615240c..fc5f5779b 100644
--- a/pkg/log/BUILD
+++ b/pkg/log/BUILD
@@ -1,4 +1,5 @@
-load("//tools/go_stateify:defs.bzl", "go_library", "go_test")
+load("//tools/go_stateify:defs.bzl", "go_library")
+load("@io_bazel_rules_go//go:def.bzl", "go_test")
package(licenses = ["notice"])
diff --git a/pkg/metric/BUILD b/pkg/metric/BUILD
index 3b8a691f4..dd6ca6d39 100644
--- a/pkg/metric/BUILD
+++ b/pkg/metric/BUILD
@@ -1,5 +1,7 @@
-load("//tools/go_stateify:defs.bzl", "go_library", "go_test")
+load("//tools/go_stateify:defs.bzl", "go_library")
load("@io_bazel_rules_go//proto:def.bzl", "go_proto_library")
+load("@io_bazel_rules_go//go:def.bzl", "go_test")
+load("@rules_cc//cc:defs.bzl", "cc_proto_library")
package(licenses = ["notice"])
@@ -21,6 +23,12 @@ proto_library(
visibility = ["//:sandbox"],
)
+cc_proto_library(
+ name = "metric_cc_proto",
+ visibility = ["//:sandbox"],
+ deps = [":metric_proto"],
+)
+
go_proto_library(
name = "metric_go_proto",
importpath = "gvisor.dev/gvisor/pkg/metric/metric_go_proto",
diff --git a/pkg/p9/BUILD b/pkg/p9/BUILD
index c6737bf97..f32244c69 100644
--- a/pkg/p9/BUILD
+++ b/pkg/p9/BUILD
@@ -1,4 +1,5 @@
-load("//tools/go_stateify:defs.bzl", "go_library", "go_test")
+load("//tools/go_stateify:defs.bzl", "go_library")
+load("@io_bazel_rules_go//go:def.bzl", "go_test")
package(
default_visibility = ["//visibility:public"],
@@ -19,11 +20,14 @@ go_library(
"pool.go",
"server.go",
"transport.go",
+ "transport_flipcall.go",
"version.go",
],
importpath = "gvisor.dev/gvisor/pkg/p9",
deps = [
"//pkg/fd",
+ "//pkg/fdchannel",
+ "//pkg/flipcall",
"//pkg/log",
"//pkg/unet",
"@org_golang_x_sys//unix:go_default_library",
diff --git a/pkg/p9/client.go b/pkg/p9/client.go
index 7dc20aeef..2412aa5e1 100644
--- a/pkg/p9/client.go
+++ b/pkg/p9/client.go
@@ -20,6 +20,8 @@ import (
"sync"
"syscall"
+ "golang.org/x/sys/unix"
+ "gvisor.dev/gvisor/pkg/flipcall"
"gvisor.dev/gvisor/pkg/log"
"gvisor.dev/gvisor/pkg/unet"
)
@@ -77,6 +79,47 @@ type Client struct {
// fidPool is the collection of available fids.
fidPool pool
+ // messageSize is the maximum total size of a message.
+ messageSize uint32
+
+ // payloadSize is the maximum payload size of a read or write.
+ //
+ // For large reads and writes this means that the read or write is
+ // broken up into buffer-size/payloadSize requests.
+ payloadSize uint32
+
+ // version is the agreed upon version X of 9P2000.L.Google.X.
+ // version 0 implies 9P2000.L.
+ version uint32
+
+ // closedWg is marked as done when the Client.watch() goroutine, which is
+ // responsible for closing channels and the socket fd, returns.
+ closedWg sync.WaitGroup
+
+ // sendRecv is the transport function.
+ //
+ // This is determined dynamically based on whether or not the server
+ // supports flipcall channels (preferred as it is faster and more
+ // efficient, and does not require tags).
+ sendRecv func(message, message) error
+
+ // -- below corresponds to sendRecvChannel --
+
+ // channelsMu protects channels.
+ channelsMu sync.Mutex
+
+ // channelsWg counts the number of channels for which channel.active ==
+ // true.
+ channelsWg sync.WaitGroup
+
+ // channels is the set of all initialized channels.
+ channels []*channel
+
+ // availableChannels is a FIFO of inactive channels.
+ availableChannels []*channel
+
+ // -- below corresponds to sendRecvLegacy --
+
// pending is the set of pending messages.
pending map[Tag]*response
pendingMu sync.Mutex
@@ -89,25 +132,12 @@ type Client struct {
// Whoever writes to this channel is permitted to call recv. When
// finished calling recv, this channel should be emptied.
recvr chan bool
-
- // messageSize is the maximum total size of a message.
- messageSize uint32
-
- // payloadSize is the maximum payload size of a read or write
- // request. For large reads and writes this means that the
- // read or write is broken up into buffer-size/payloadSize
- // requests.
- payloadSize uint32
-
- // version is the agreed upon version X of 9P2000.L.Google.X.
- // version 0 implies 9P2000.L.
- version uint32
}
// NewClient creates a new client. It performs a Tversion exchange with
// the server to assert that messageSize is ok to use.
//
-// You should not use the same socket for multiple clients.
+// If NewClient succeeds, ownership of socket is transferred to the new Client.
func NewClient(socket *unet.Socket, messageSize uint32, version string) (*Client, error) {
// Need at least one byte of payload.
if messageSize <= msgRegistry.largestFixedSize {
@@ -138,8 +168,15 @@ func NewClient(socket *unet.Socket, messageSize uint32, version string) (*Client
return nil, ErrBadVersionString
}
for {
+ // Always exchange the version using the legacy version of the
+ // protocol. If the protocol supports flipcall, then we switch
+ // our sendRecv function to use that functionality. Otherwise,
+ // we stick to sendRecvLegacy.
rversion := Rversion{}
- err := c.sendRecv(&Tversion{Version: versionString(requested), MSize: messageSize}, &rversion)
+ err := c.sendRecvLegacy(&Tversion{
+ Version: versionString(requested),
+ MSize: messageSize,
+ }, &rversion)
// The server told us to try again with a lower version.
if err == syscall.EAGAIN {
@@ -165,9 +202,155 @@ func NewClient(socket *unet.Socket, messageSize uint32, version string) (*Client
c.version = version
break
}
+
+ // Can we switch to use the more advanced channels and create
+ // independent channels for communication? Prefer it if possible.
+ if versionSupportsFlipcall(c.version) {
+ // Attempt to initialize IPC-based communication.
+ for i := 0; i < channelsPerClient; i++ {
+ if err := c.openChannel(i); err != nil {
+ log.Warningf("error opening flipcall channel: %v", err)
+ break // Stop.
+ }
+ }
+ if len(c.channels) >= 1 {
+ // At least one channel created.
+ c.sendRecv = c.sendRecvChannel
+ } else {
+ // Channel setup failed; fallback.
+ c.sendRecv = c.sendRecvLegacy
+ }
+ } else {
+ // No channels available: use the legacy mechanism.
+ c.sendRecv = c.sendRecvLegacy
+ }
+
+ // Ensure that the socket and channels are closed when the socket is shut
+ // down.
+ c.closedWg.Add(1)
+ go c.watch(socket) // S/R-SAFE: not relevant.
+
return c, nil
}
+// watch watches the given socket and releases resources on hangup events.
+//
+// This is intended to be called as a goroutine.
+func (c *Client) watch(socket *unet.Socket) {
+ defer c.closedWg.Done()
+
+ events := []unix.PollFd{
+ unix.PollFd{
+ Fd: int32(socket.FD()),
+ Events: unix.POLLHUP | unix.POLLRDHUP,
+ },
+ }
+
+ // Wait for a shutdown event.
+ for {
+ n, err := unix.Ppoll(events, nil, nil)
+ if err == syscall.EINTR || err == syscall.EAGAIN {
+ continue
+ }
+ if err != nil {
+ log.Warningf("p9.Client.watch(): %v", err)
+ break
+ }
+ if n != 1 {
+ log.Warningf("p9.Client.watch(): got %d events, wanted 1", n)
+ }
+ break
+ }
+
+ // Set availableChannels to nil so that future calls to c.sendRecvChannel()
+ // don't attempt to activate a channel, and concurrent calls to
+ // c.sendRecvChannel() don't mark released channels as available.
+ c.channelsMu.Lock()
+ c.availableChannels = nil
+
+ // Shut down all active channels.
+ for _, ch := range c.channels {
+ if ch.active {
+ log.Debugf("shutting down active channel@%p...", ch)
+ ch.Shutdown()
+ }
+ }
+ c.channelsMu.Unlock()
+
+ // Wait for active channels to become inactive.
+ c.channelsWg.Wait()
+
+ // Close all channels.
+ c.channelsMu.Lock()
+ for _, ch := range c.channels {
+ ch.Close()
+ }
+ c.channelsMu.Unlock()
+
+ // Close the main socket.
+ c.socket.Close()
+}
+
+// openChannel attempts to open a client channel.
+//
+// Note that this function returns naked errors which should not be propagated
+// directly to a caller. It is expected that the errors will be logged and a
+// fallback path will be used instead.
+func (c *Client) openChannel(id int) error {
+ var (
+ rchannel0 Rchannel
+ rchannel1 Rchannel
+ res = new(channel)
+ )
+
+ // Open the data channel.
+ if err := c.sendRecvLegacy(&Tchannel{
+ ID: uint32(id),
+ Control: 0,
+ }, &rchannel0); err != nil {
+ return fmt.Errorf("error handling Tchannel message: %v", err)
+ }
+ if rchannel0.FilePayload() == nil {
+ return fmt.Errorf("missing file descriptor on primary channel")
+ }
+
+ // We don't need to hold this.
+ defer rchannel0.FilePayload().Close()
+
+ // Open the channel for file descriptors.
+ if err := c.sendRecvLegacy(&Tchannel{
+ ID: uint32(id),
+ Control: 1,
+ }, &rchannel1); err != nil {
+ return err
+ }
+ if rchannel1.FilePayload() == nil {
+ return fmt.Errorf("missing file descriptor on file descriptor channel")
+ }
+
+ // Construct the endpoints.
+ res.desc = flipcall.PacketWindowDescriptor{
+ FD: rchannel0.FilePayload().FD(),
+ Offset: int64(rchannel0.Offset),
+ Length: int(rchannel0.Length),
+ }
+ if err := res.data.Init(flipcall.ClientSide, res.desc); err != nil {
+ rchannel1.FilePayload().Close()
+ return err
+ }
+
+ // The fds channel owns the control payload, and it will be closed when
+ // the channel object is closed.
+ res.fds.Init(rchannel1.FilePayload().Release())
+
+ // Save the channel.
+ c.channelsMu.Lock()
+ defer c.channelsMu.Unlock()
+ c.channels = append(c.channels, res)
+ c.availableChannels = append(c.availableChannels, res)
+ return nil
+}
+
// handleOne handles a single incoming message.
//
// This should only be called with the token from recvr. Note that the received
@@ -247,10 +430,10 @@ func (c *Client) waitAndRecv(done chan error) error {
}
}
-// sendRecv performs a roundtrip message exchange.
+// sendRecvLegacy performs a roundtrip message exchange.
//
// This is called by internal functions.
-func (c *Client) sendRecv(t message, r message) error {
+func (c *Client) sendRecvLegacy(t message, r message) error {
tag, ok := c.tagPool.Get()
if !ok {
return ErrOutOfTags
@@ -296,12 +479,62 @@ func (c *Client) sendRecv(t message, r message) error {
return nil
}
+// sendRecvChannel uses channels to send a message.
+func (c *Client) sendRecvChannel(t message, r message) error {
+ // Acquire an available channel.
+ c.channelsMu.Lock()
+ if len(c.availableChannels) == 0 {
+ c.channelsMu.Unlock()
+ return c.sendRecvLegacy(t, r)
+ }
+ idx := len(c.availableChannels) - 1
+ ch := c.availableChannels[idx]
+ c.availableChannels = c.availableChannels[:idx]
+ ch.active = true
+ c.channelsWg.Add(1)
+ c.channelsMu.Unlock()
+
+ // Ensure that it's connected.
+ if !ch.connected {
+ ch.connected = true
+ if err := ch.data.Connect(); err != nil {
+ // The channel is unusable, so don't return it to
+ // c.availableChannels. However, we still have to mark it as
+ // inactive so c.watch() doesn't wait for it.
+ c.channelsMu.Lock()
+ ch.active = false
+ c.channelsMu.Unlock()
+ c.channelsWg.Done()
+ return err
+ }
+ }
+
+ // Send the message.
+ err := ch.sendRecv(c, t, r)
+
+ // Release the channel.
+ c.channelsMu.Lock()
+ ch.active = false
+ // If c.availableChannels is nil, c.watch() has fired and we should not
+ // mark this channel as available.
+ if c.availableChannels != nil {
+ c.availableChannels = append(c.availableChannels, ch)
+ }
+ c.channelsMu.Unlock()
+ c.channelsWg.Done()
+
+ return err
+}
+
// Version returns the negotiated 9P2000.L.Google version number.
func (c *Client) Version() uint32 {
return c.version
}
-// Close closes the underlying socket.
-func (c *Client) Close() error {
- return c.socket.Close()
+// Close closes the underlying socket and channels.
+func (c *Client) Close() {
+ // unet.Socket.Shutdown() has no effect if unet.Socket.Close() has already
+ // been called (by c.watch()).
+ c.socket.Shutdown()
+ c.closedWg.Wait()
}
diff --git a/pkg/p9/client_test.go b/pkg/p9/client_test.go
index 87b2dd61e..29a0afadf 100644
--- a/pkg/p9/client_test.go
+++ b/pkg/p9/client_test.go
@@ -35,23 +35,23 @@ func TestVersion(t *testing.T) {
go s.Handle(serverSocket)
// NewClient does a Tversion exchange, so this is our test for success.
- c, err := NewClient(clientSocket, 1024*1024 /* 1M message size */, HighestVersionString())
+ c, err := NewClient(clientSocket, DefaultMessageSize, HighestVersionString())
if err != nil {
t.Fatalf("got %v, expected nil", err)
}
// Check a bogus version string.
- if err := c.sendRecv(&Tversion{Version: "notokay", MSize: 1024 * 1024}, &Rversion{}); err != syscall.EINVAL {
+ if err := c.sendRecv(&Tversion{Version: "notokay", MSize: DefaultMessageSize}, &Rversion{}); err != syscall.EINVAL {
t.Errorf("got %v expected %v", err, syscall.EINVAL)
}
// Check a bogus version number.
- if err := c.sendRecv(&Tversion{Version: "9P1000.L", MSize: 1024 * 1024}, &Rversion{}); err != syscall.EINVAL {
+ if err := c.sendRecv(&Tversion{Version: "9P1000.L", MSize: DefaultMessageSize}, &Rversion{}); err != syscall.EINVAL {
t.Errorf("got %v expected %v", err, syscall.EINVAL)
}
// Check a too high version number.
- if err := c.sendRecv(&Tversion{Version: versionString(highestSupportedVersion + 1), MSize: 1024 * 1024}, &Rversion{}); err != syscall.EAGAIN {
+ if err := c.sendRecv(&Tversion{Version: versionString(highestSupportedVersion + 1), MSize: DefaultMessageSize}, &Rversion{}); err != syscall.EAGAIN {
t.Errorf("got %v expected %v", err, syscall.EAGAIN)
}
@@ -60,3 +60,45 @@ func TestVersion(t *testing.T) {
t.Errorf("got %v expected %v", err, syscall.EINVAL)
}
}
+
+func benchmarkSendRecv(b *testing.B, fn func(c *Client) func(message, message) error) {
+ // See above.
+ serverSocket, clientSocket, err := unet.SocketPair(false)
+ if err != nil {
+ b.Fatalf("socketpair got err %v expected nil", err)
+ }
+ defer clientSocket.Close()
+
+ // See above.
+ s := NewServer(nil)
+ go s.Handle(serverSocket)
+
+ // See above.
+ c, err := NewClient(clientSocket, DefaultMessageSize, HighestVersionString())
+ if err != nil {
+ b.Fatalf("got %v, expected nil", err)
+ }
+
+ // Initialize messages.
+ sendRecv := fn(c)
+ tversion := &Tversion{
+ Version: versionString(highestSupportedVersion),
+ MSize: DefaultMessageSize,
+ }
+ rversion := new(Rversion)
+
+ // Run in a loop.
+ for i := 0; i < b.N; i++ {
+ if err := sendRecv(tversion, rversion); err != nil {
+ b.Fatalf("got unexpected err: %v", err)
+ }
+ }
+}
+
+func BenchmarkSendRecvLegacy(b *testing.B) {
+ benchmarkSendRecv(b, func(c *Client) func(message, message) error { return c.sendRecvLegacy })
+}
+
+func BenchmarkSendRecvChannel(b *testing.B) {
+ benchmarkSendRecv(b, func(c *Client) func(message, message) error { return c.sendRecvChannel })
+}
diff --git a/pkg/p9/handlers.go b/pkg/p9/handlers.go
index 999b4f684..ba9a55d6d 100644
--- a/pkg/p9/handlers.go
+++ b/pkg/p9/handlers.go
@@ -305,7 +305,9 @@ func (t *Tlopen) handle(cs *connState) message {
ref.opened = true
ref.openFlags = t.Flags
- return &Rlopen{QID: qid, IoUnit: ioUnit, File: osFile}
+ rlopen := &Rlopen{QID: qid, IoUnit: ioUnit}
+ rlopen.SetFilePayload(osFile)
+ return rlopen
}
func (t *Tlcreate) do(cs *connState, uid UID) (*Rlcreate, error) {
@@ -364,7 +366,9 @@ func (t *Tlcreate) do(cs *connState, uid UID) (*Rlcreate, error) {
// Replace the FID reference.
cs.InsertFID(t.FID, newRef)
- return &Rlcreate{Rlopen: Rlopen{QID: qid, IoUnit: ioUnit, File: osFile}}, nil
+ rlcreate := &Rlcreate{Rlopen: Rlopen{QID: qid, IoUnit: ioUnit}}
+ rlcreate.SetFilePayload(osFile)
+ return rlcreate, nil
}
// handle implements handler.handle.
@@ -1287,5 +1291,48 @@ func (t *Tlconnect) handle(cs *connState) message {
return newErr(err)
}
- return &Rlconnect{File: osFile}
+ rlconnect := &Rlconnect{}
+ rlconnect.SetFilePayload(osFile)
+ return rlconnect
+}
+
+// handle implements handler.handle.
+func (t *Tchannel) handle(cs *connState) message {
+ // Ensure that channels are enabled.
+ if err := cs.initializeChannels(); err != nil {
+ return newErr(err)
+ }
+
+ // Lookup the given channel.
+ ch := cs.lookupChannel(t.ID)
+ if ch == nil {
+ return newErr(syscall.ENOSYS)
+ }
+
+ // Return the payload. Note that we need to duplicate the file
+ // descriptor for the channel allocator, because sending is a
+ // destructive operation between sendRecvLegacy (and now the newer
+ // channel send operations). Same goes for the client FD.
+ rchannel := &Rchannel{
+ Offset: uint64(ch.desc.Offset),
+ Length: uint64(ch.desc.Length),
+ }
+ switch t.Control {
+ case 0:
+ // Open the main data channel.
+ mfd, err := syscall.Dup(int(cs.channelAlloc.FD()))
+ if err != nil {
+ return newErr(err)
+ }
+ rchannel.SetFilePayload(fd.New(mfd))
+ case 1:
+ cfd, err := syscall.Dup(ch.client.FD())
+ if err != nil {
+ return newErr(err)
+ }
+ rchannel.SetFilePayload(fd.New(cfd))
+ default:
+ return newErr(syscall.EINVAL)
+ }
+ return rchannel
}
diff --git a/pkg/p9/messages.go b/pkg/p9/messages.go
index fd9eb1c5d..ffdd7e8c6 100644
--- a/pkg/p9/messages.go
+++ b/pkg/p9/messages.go
@@ -64,6 +64,21 @@ type filer interface {
SetFilePayload(*fd.FD)
}
+// filePayload embeds a File object.
+type filePayload struct {
+ File *fd.FD
+}
+
+// FilePayload returns the file payload.
+func (f *filePayload) FilePayload() *fd.FD {
+ return f.File
+}
+
+// SetFilePayload sets the received file.
+func (f *filePayload) SetFilePayload(file *fd.FD) {
+ f.File = file
+}
+
// Tversion is a version request.
type Tversion struct {
// MSize is the message size to use.
@@ -524,10 +539,7 @@ type Rlopen struct {
// IoUnit is the recommended I/O unit.
IoUnit uint32
- // File may be attached via the socket.
- //
- // This is an extension specific to this package.
- File *fd.FD
+ filePayload
}
// Decode implements encoder.Decode.
@@ -547,16 +559,6 @@ func (*Rlopen) Type() MsgType {
return MsgRlopen
}
-// FilePayload returns the file payload.
-func (r *Rlopen) FilePayload() *fd.FD {
- return r.File
-}
-
-// SetFilePayload sets the received file.
-func (r *Rlopen) SetFilePayload(file *fd.FD) {
- r.File = file
-}
-
// String implements fmt.Stringer.
func (r *Rlopen) String() string {
return fmt.Sprintf("Rlopen{QID: %s, IoUnit: %d, File: %v}", r.QID, r.IoUnit, r.File)
@@ -2171,8 +2173,7 @@ func (t *Tlconnect) String() string {
// Rlconnect is a connect response.
type Rlconnect struct {
- // File is a host socket.
- File *fd.FD
+ filePayload
}
// Decode implements encoder.Decode.
@@ -2186,19 +2187,71 @@ func (*Rlconnect) Type() MsgType {
return MsgRlconnect
}
-// FilePayload returns the file payload.
-func (r *Rlconnect) FilePayload() *fd.FD {
- return r.File
+// String implements fmt.Stringer.
+func (r *Rlconnect) String() string {
+ return fmt.Sprintf("Rlconnect{File: %v}", r.File)
}
-// SetFilePayload sets the received file.
-func (r *Rlconnect) SetFilePayload(file *fd.FD) {
- r.File = file
+// Tchannel creates a new channel.
+type Tchannel struct {
+ // ID is the channel ID.
+ ID uint32
+
+ // Control is 0 if the Rchannel response should provide the flipcall
+ // component of the channel, and 1 if the Rchannel response should
+ // provide the fdchannel component of the channel.
+ Control uint32
+}
+
+// Decode implements encoder.Decode.
+func (t *Tchannel) Decode(b *buffer) {
+ t.ID = b.Read32()
+ t.Control = b.Read32()
+}
+
+// Encode implements encoder.Encode.
+func (t *Tchannel) Encode(b *buffer) {
+ b.Write32(t.ID)
+ b.Write32(t.Control)
+}
+
+// Type implements message.Type.
+func (*Tchannel) Type() MsgType {
+ return MsgTchannel
}
// String implements fmt.Stringer.
-func (r *Rlconnect) String() string {
- return fmt.Sprintf("Rlconnect{File: %v}", r.File)
+func (t *Tchannel) String() string {
+ return fmt.Sprintf("Tchannel{ID: %d, Control: %d}", t.ID, t.Control)
+}
+
+// Rchannel is the channel response.
+type Rchannel struct {
+ Offset uint64
+ Length uint64
+ filePayload
+}
+
+// Decode implements encoder.Decode.
+func (r *Rchannel) Decode(b *buffer) {
+ r.Offset = b.Read64()
+ r.Length = b.Read64()
+}
+
+// Encode implements encoder.Encode.
+func (r *Rchannel) Encode(b *buffer) {
+ b.Write64(r.Offset)
+ b.Write64(r.Length)
+}
+
+// Type implements message.Type.
+func (*Rchannel) Type() MsgType {
+ return MsgRchannel
+}
+
+// String implements fmt.Stringer.
+func (r *Rchannel) String() string {
+ return fmt.Sprintf("Rchannel{Offset: %d, Length: %d}", r.Offset, r.Length)
}
const maxCacheSize = 3
@@ -2356,4 +2409,6 @@ func init() {
msgRegistry.register(MsgRlconnect, func() message { return &Rlconnect{} })
msgRegistry.register(MsgTallocate, func() message { return &Tallocate{} })
msgRegistry.register(MsgRallocate, func() message { return &Rallocate{} })
+ msgRegistry.register(MsgTchannel, func() message { return &Tchannel{} })
+ msgRegistry.register(MsgRchannel, func() message { return &Rchannel{} })
}
diff --git a/pkg/p9/p9.go b/pkg/p9/p9.go
index e12831dbd..25530adca 100644
--- a/pkg/p9/p9.go
+++ b/pkg/p9/p9.go
@@ -378,6 +378,8 @@ const (
MsgRlconnect = 137
MsgTallocate = 138
MsgRallocate = 139
+ MsgTchannel = 250
+ MsgRchannel = 251
)
// QIDType represents the file type for QIDs.
diff --git a/pkg/p9/p9test/BUILD b/pkg/p9/p9test/BUILD
index 6e939a49a..28707c0ca 100644
--- a/pkg/p9/p9test/BUILD
+++ b/pkg/p9/p9test/BUILD
@@ -1,5 +1,5 @@
-load("//tools/go_stateify:defs.bzl", "go_library", "go_test")
-load("@io_bazel_rules_go//go:def.bzl", "go_binary")
+load("//tools/go_stateify:defs.bzl", "go_library")
+load("@io_bazel_rules_go//go:def.bzl", "go_binary", "go_test")
package(licenses = ["notice"])
@@ -77,7 +77,7 @@ go_library(
go_test(
name = "client_test",
- size = "small",
+ size = "medium",
srcs = ["client_test.go"],
embed = [":p9test"],
deps = [
diff --git a/pkg/p9/p9test/client_test.go b/pkg/p9/p9test/client_test.go
index fe649c2e8..8bbdb2488 100644
--- a/pkg/p9/p9test/client_test.go
+++ b/pkg/p9/p9test/client_test.go
@@ -2127,3 +2127,98 @@ func TestConcurrency(t *testing.T) {
}
}
}
+
+func TestReadWriteConcurrent(t *testing.T) {
+ h, c := NewHarness(t)
+ defer h.Finish()
+
+ _, root := newRoot(h, c)
+ defer root.Close()
+
+ const (
+ instances = 10
+ iterations = 10000
+ dataSize = 1024
+ )
+ var (
+ dataSets [instances][dataSize]byte
+ backends [instances]*Mock
+ files [instances]p9.File
+ )
+
+ // Walk to the file normally.
+ for i := 0; i < instances; i++ {
+ _, backends[i], files[i] = walkHelper(h, "file", root)
+ defer files[i].Close()
+ }
+
+ // Open the files.
+ for i := 0; i < instances; i++ {
+ backends[i].EXPECT().Open(p9.ReadWrite)
+ if _, _, _, err := files[i].Open(p9.ReadWrite); err != nil {
+ t.Fatalf("open got %v, wanted nil", err)
+ }
+ }
+
+ // Initialize random data for each instance.
+ for i := 0; i < instances; i++ {
+ if _, err := rand.Read(dataSets[i][:]); err != nil {
+ t.Fatalf("error initializing dataSet#%d, got %v", i, err)
+ }
+ }
+
+ // Define our random read/write mechanism.
+ randRead := func(h *Harness, backend *Mock, f p9.File, data, test []byte) {
+ // Prepare the backend.
+ backend.EXPECT().ReadAt(gomock.Any(), uint64(0)).Do(func(p []byte, offset uint64) {
+ if n := copy(p, data); n != len(data) {
+ // Note that we have to assert the result here, as the Return statement
+ // below cannot be dynamic: it will be bound before this call is made.
+ h.t.Errorf("wanted length %d, got %d", len(data), n)
+ }
+ }).Return(len(data), nil)
+
+ // Execute the read.
+ if n, err := f.ReadAt(test, 0); n != len(test) || err != nil {
+ t.Errorf("failed read: wanted (%d, nil), got (%d, %v)", len(test), n, err)
+ return // No sense doing check below.
+ }
+ if !bytes.Equal(test, data) {
+ t.Errorf("data integrity failed during read") // Not as expected.
+ }
+ }
+ randWrite := func(h *Harness, backend *Mock, f p9.File, data []byte) {
+ // Prepare the backend.
+ backend.EXPECT().WriteAt(gomock.Any(), uint64(0)).Do(func(p []byte, offset uint64) {
+ if !bytes.Equal(p, data) {
+ h.t.Errorf("data integrity failed during write") // Not as expected.
+ }
+ }).Return(len(data), nil)
+
+ // Execute the write.
+ if n, err := f.WriteAt(data, 0); n != len(data) || err != nil {
+ t.Errorf("failed read: wanted (%d, nil), got (%d, %v)", len(data), n, err)
+ }
+ }
+ randReadWrite := func(n int, h *Harness, backend *Mock, f p9.File, data []byte) {
+ test := make([]byte, len(data))
+ for i := 0; i < n; i++ {
+ if rand.Intn(2) == 0 {
+ randRead(h, backend, f, data, test)
+ } else {
+ randWrite(h, backend, f, data)
+ }
+ }
+ }
+
+ // Start reading and writing.
+ var wg sync.WaitGroup
+ for i := 0; i < instances; i++ {
+ wg.Add(1)
+ go func(i int) {
+ defer wg.Done()
+ randReadWrite(iterations, h, backends[i], files[i], dataSets[i][:])
+ }(i)
+ }
+ wg.Wait()
+}
diff --git a/pkg/p9/p9test/p9test.go b/pkg/p9/p9test/p9test.go
index 95846e5f7..4d3271b37 100644
--- a/pkg/p9/p9test/p9test.go
+++ b/pkg/p9/p9test/p9test.go
@@ -279,7 +279,7 @@ func (h *Harness) NewSocket() Generator {
// Finish completes all checks and shuts down the server.
func (h *Harness) Finish() {
- h.clientSocket.Close()
+ h.clientSocket.Shutdown()
h.wg.Wait()
h.mockCtrl.Finish()
}
@@ -315,7 +315,7 @@ func NewHarness(t *testing.T) (*Harness, *p9.Client) {
}()
// Create the client.
- client, err := p9.NewClient(clientSocket, 1024, p9.HighestVersionString())
+ client, err := p9.NewClient(clientSocket, p9.DefaultMessageSize, p9.HighestVersionString())
if err != nil {
serverSocket.Close()
clientSocket.Close()
diff --git a/pkg/p9/server.go b/pkg/p9/server.go
index b294efbb0..69c886a5d 100644
--- a/pkg/p9/server.go
+++ b/pkg/p9/server.go
@@ -21,6 +21,9 @@ import (
"sync/atomic"
"syscall"
+ "gvisor.dev/gvisor/pkg/fd"
+ "gvisor.dev/gvisor/pkg/fdchannel"
+ "gvisor.dev/gvisor/pkg/flipcall"
"gvisor.dev/gvisor/pkg/log"
"gvisor.dev/gvisor/pkg/unet"
)
@@ -45,7 +48,6 @@ type Server struct {
}
// NewServer returns a new server.
-//
func NewServer(attacher Attacher) *Server {
return &Server{
attacher: attacher,
@@ -85,6 +87,8 @@ type connState struct {
// version 0 implies 9P2000.L.
version uint32
+ // -- below relates to the legacy handler --
+
// recvOkay indicates that a receive may start.
recvOkay chan bool
@@ -93,6 +97,20 @@ type connState struct {
// sendDone is signalled when a send is finished.
sendDone chan error
+
+ // -- below relates to the flipcall handler --
+
+ // channelMu protects below.
+ channelMu sync.Mutex
+
+ // channelWg represents active workers.
+ channelWg sync.WaitGroup
+
+ // channelAlloc allocates channel memory.
+ channelAlloc *flipcall.PacketWindowAllocator
+
+ // channels are the set of initialized channels.
+ channels []*channel
}
// fidRef wraps a node and tracks references.
@@ -386,6 +404,99 @@ func (cs *connState) WaitTag(t Tag) {
<-ch
}
+// initializeChannels initializes all channels.
+//
+// This is a no-op if channels are already initialized.
+func (cs *connState) initializeChannels() (err error) {
+ cs.channelMu.Lock()
+ defer cs.channelMu.Unlock()
+
+ // Initialize our channel allocator.
+ if cs.channelAlloc == nil {
+ alloc, err := flipcall.NewPacketWindowAllocator()
+ if err != nil {
+ return err
+ }
+ cs.channelAlloc = alloc
+ }
+
+ // Create all the channels.
+ for len(cs.channels) < channelsPerClient {
+ res := &channel{
+ done: make(chan struct{}),
+ }
+
+ res.desc, err = cs.channelAlloc.Allocate(channelSize)
+ if err != nil {
+ return err
+ }
+ if err := res.data.Init(flipcall.ServerSide, res.desc); err != nil {
+ return err
+ }
+
+ socks, err := fdchannel.NewConnectedSockets()
+ if err != nil {
+ res.data.Destroy() // Cleanup.
+ return err
+ }
+ res.fds.Init(socks[0])
+ res.client = fd.New(socks[1])
+
+ cs.channels = append(cs.channels, res)
+
+ // Start servicing the channel.
+ //
+ // When we call stop, we will close all the channels and these
+ // routines should finish. We need the wait group to ensure
+ // that active handlers are actually finished before cleanup.
+ cs.channelWg.Add(1)
+ go func() { // S/R-SAFE: Server side.
+ defer cs.channelWg.Done()
+ res.service(cs)
+ }()
+ }
+
+ return nil
+}
+
+// lookupChannel looks up the channel with given id.
+//
+// The function returns nil if no such channel is available.
+func (cs *connState) lookupChannel(id uint32) *channel {
+ cs.channelMu.Lock()
+ defer cs.channelMu.Unlock()
+ if id >= uint32(len(cs.channels)) {
+ return nil
+ }
+ return cs.channels[id]
+}
+
+// handle handles a single message.
+func (cs *connState) handle(m message) (r message) {
+ defer func() {
+ if r == nil {
+ // Don't allow a panic to propagate.
+ recover()
+
+ // Include a useful log message.
+ log.Warningf("panic in handler: %s", debug.Stack())
+
+ // Wrap in an EFAULT error; we don't really have a
+ // better way to describe this kind of error. It will
+ // usually manifest as a result of the test framework.
+ r = newErr(syscall.EFAULT)
+ }
+ }()
+ if handler, ok := m.(handler); ok {
+ // Call the message handler.
+ r = handler.handle(cs)
+ } else {
+ // Produce an ENOSYS error.
+ r = newErr(syscall.ENOSYS)
+ }
+ return
+}
+
// handleRequest handles a single request.
//
// The recvDone channel is signaled when recv is done (with a error if
@@ -428,41 +539,20 @@ func (cs *connState) handleRequest() {
}
// Handle the message.
- var r message // r is the response.
- defer func() {
- if r == nil {
- // Don't allow a panic to propagate.
- recover()
+ r := cs.handle(m)
- // Include a useful log message.
- log.Warningf("panic in handler: %s", debug.Stack())
+ // Clear the tag before sending. That's because as soon as this hits
+ // the wire, the client can legally send the same tag.
+ cs.ClearTag(tag)
- // Wrap in an EFAULT error; we don't really have a
- // better way to describe this kind of error. It will
- // usually manifest as a result of the test framework.
- r = newErr(syscall.EFAULT)
- }
+ // Send back the result.
+ cs.sendMu.Lock()
+ err = send(cs.conn, tag, r)
+ cs.sendMu.Unlock()
+ cs.sendDone <- err
- // Clear the tag before sending. That's because as soon as this
- // hits the wire, the client can legally send another message
- // with the same tag.
- cs.ClearTag(tag)
-
- // Send back the result.
- cs.sendMu.Lock()
- err = send(cs.conn, tag, r)
- cs.sendMu.Unlock()
- cs.sendDone <- err
- }()
- if handler, ok := m.(handler); ok {
- // Call the message handler.
- r = handler.handle(cs)
- } else {
- // Produce an ENOSYS error.
- r = newErr(syscall.ENOSYS)
- }
+ // Return the message to the cache.
msgRegistry.put(m)
- m = nil // 'm' should not be touched after this point.
}
func (cs *connState) handleRequests() {
@@ -477,7 +567,27 @@ func (cs *connState) stop() {
close(cs.recvDone)
close(cs.sendDone)
- for _, fidRef := range cs.fids {
+ // Free the channels.
+ cs.channelMu.Lock()
+ for _, ch := range cs.channels {
+ ch.Shutdown()
+ }
+ cs.channelWg.Wait()
+ for _, ch := range cs.channels {
+ ch.Close()
+ }
+ cs.channels = nil // Clear.
+ cs.channelMu.Unlock()
+
+ // Free the channel memory.
+ if cs.channelAlloc != nil {
+ cs.channelAlloc.Destroy()
+ }
+
+ // Close all remaining fids.
+ for fid, fidRef := range cs.fids {
+ delete(cs.fids, fid)
+
// Drop final reference in the FID table. Note this should
// always close the file, since we've ensured that there are no
// handlers running via the wait for Pending => 0 below.
@@ -510,7 +620,7 @@ func (cs *connState) service() error {
for i := 0; i < pending; i++ {
<-cs.sendDone
}
- return err
+ return nil
}
// This handler is now pending.
diff --git a/pkg/p9/transport.go b/pkg/p9/transport.go
index 5648df589..6e8b4bbcd 100644
--- a/pkg/p9/transport.go
+++ b/pkg/p9/transport.go
@@ -54,7 +54,10 @@ const (
headerLength uint32 = 7
// maximumLength is the largest possible message.
- maximumLength uint32 = 4 * 1024 * 1024
+ maximumLength uint32 = 1 << 20
+
+ // DefaultMessageSize is a sensible default.
+ DefaultMessageSize uint32 = 64 << 10
// initialBufferLength is the initial data buffer we allocate.
initialBufferLength uint32 = 64
diff --git a/pkg/p9/transport_flipcall.go b/pkg/p9/transport_flipcall.go
new file mode 100644
index 000000000..7cdf4ecc3
--- /dev/null
+++ b/pkg/p9/transport_flipcall.go
@@ -0,0 +1,263 @@
+// Copyright 2019 The gVisor Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+package p9
+
+import (
+ "runtime"
+ "syscall"
+
+ "gvisor.dev/gvisor/pkg/fd"
+ "gvisor.dev/gvisor/pkg/fdchannel"
+ "gvisor.dev/gvisor/pkg/flipcall"
+ "gvisor.dev/gvisor/pkg/log"
+)
+
+// channelsPerClient is the number of channels to create per client.
+//
+// While the client and server will generally agree on this number, in reality
+// it's completely up to the server. We simply define a minimum of 2, and a
+// maximum of 4, and select the number of available processes as a tie-breaker.
+// Note that we don't want the number of channels to be too large, because each
+// will account for channelSize memory used, which can be large.
+var channelsPerClient = func() int {
+ n := runtime.NumCPU()
+ if n < 2 {
+ return 2
+ }
+ if n > 4 {
+ return 4
+ }
+ return n
+}()
+
+// channelSize is the channel size to create.
+//
+// We simply ensure that this is larger than the largest possible message size,
+// plus the flipcall packet header, plus the two bytes we write below.
+const channelSize = int(2 + flipcall.PacketHeaderBytes + 2 + maximumLength)
+
+// channel is a fast IPC channel.
+//
+// The same object is used by both the server and client implementations. In
+// general, the client will use only the send and recv methods.
+type channel struct {
+ desc flipcall.PacketWindowDescriptor
+ data flipcall.Endpoint
+ fds fdchannel.Endpoint
+ buf buffer
+
+ // -- client only --
+ connected bool
+ active bool
+
+ // -- server only --
+ client *fd.FD
+ done chan struct{}
+}
+
+// reset resets the channel buffer.
+func (ch *channel) reset(sz uint32) {
+ ch.buf.data = ch.data.Data()[:sz]
+}
+
+// service services the channel.
+func (ch *channel) service(cs *connState) error {
+ rsz, err := ch.data.RecvFirst()
+ if err != nil {
+ return err
+ }
+ for rsz > 0 {
+ m, err := ch.recv(nil, rsz)
+ if err != nil {
+ return err
+ }
+ r := cs.handle(m)
+ msgRegistry.put(m)
+ rsz, err = ch.send(r)
+ if err != nil {
+ return err
+ }
+ }
+ return nil // Done.
+}
+
+// Shutdown shuts down the channel.
+//
+// This must be called before Close.
+func (ch *channel) Shutdown() {
+ ch.data.Shutdown()
+}
+
+// Close closes the channel.
+//
+// This must only be called once, and cannot return an error. Note that
+// synchronization for this method is provided at a high-level, depending on
+// whether it is the client or server. This cannot be called while there are
+// active callers in either service or sendRecv.
+//
+// Precondition: the channel should be shutdown.
+func (ch *channel) Close() error {
+ // Close all backing transports.
+ ch.fds.Destroy()
+ ch.data.Destroy()
+ if ch.client != nil {
+ ch.client.Close()
+ }
+ return nil
+}
+
+// send sends the given message.
+//
+// The return value is the size of the received response. Not that in the
+// server case, this is the size of the next request.
+func (ch *channel) send(m message) (uint32, error) {
+ if log.IsLogging(log.Debug) {
+ log.Debugf("send [channel @%p] %s", ch, m.String())
+ }
+
+ // Send any file payload.
+ sentFD := false
+ if filer, ok := m.(filer); ok {
+ if f := filer.FilePayload(); f != nil {
+ if err := ch.fds.SendFD(f.FD()); err != nil {
+ return 0, syscall.EIO // Map everything to EIO.
+ }
+ f.Close() // Per sendRecvLegacy.
+ sentFD = true // To mark below.
+ }
+ }
+
+ // Encode the message.
+ //
+ // Note that IPC itself encodes the length of messages, so we don't
+ // need to encode a standard 9P header. We write only the message type.
+ ch.reset(0)
+
+ ch.buf.WriteMsgType(m.Type())
+ if sentFD {
+ ch.buf.Write8(1) // Incoming FD.
+ } else {
+ ch.buf.Write8(0) // No incoming FD.
+ }
+ m.Encode(&ch.buf)
+ ssz := uint32(len(ch.buf.data)) // Updated below.
+
+ // Is there a payload?
+ if payloader, ok := m.(payloader); ok {
+ p := payloader.Payload()
+ copy(ch.data.Data()[ssz:], p)
+ ssz += uint32(len(p))
+ }
+
+ // Perform the one-shot communication.
+ n, err := ch.data.SendRecv(ssz)
+ if err != nil {
+ if n > 0 {
+ return n, nil
+ }
+ return 0, syscall.EIO // See above.
+ }
+
+ return n, nil
+}
+
+// recv decodes a message that exists on the channel.
+//
+// If the passed r is non-nil, then the type must match or an error will be
+// generated. If the passed r is nil, then a new message will be created and
+// returned.
+func (ch *channel) recv(r message, rsz uint32) (message, error) {
+ // Decode the response from the inline buffer.
+ ch.reset(rsz)
+ t := ch.buf.ReadMsgType()
+ hasFD := ch.buf.Read8() != 0
+ if t == MsgRlerror {
+ // Change the message type. We check for this special case
+ // after decoding below, and transform into an error.
+ r = &Rlerror{}
+ } else if r == nil {
+ nr, err := msgRegistry.get(0, t)
+ if err != nil {
+ return nil, err
+ }
+ r = nr // New message.
+ } else if t != r.Type() {
+ // Not an error and not the expected response; propagate.
+ return nil, &ErrBadResponse{Got: t, Want: r.Type()}
+ }
+
+ // Is there a payload? Copy from the latter portion.
+ if payloader, ok := r.(payloader); ok {
+ fs := payloader.FixedSize()
+ p := payloader.Payload()
+ payloadData := ch.buf.data[fs:]
+ if len(p) < len(payloadData) {
+ p = make([]byte, len(payloadData))
+ copy(p, payloadData)
+ payloader.SetPayload(p)
+ } else if n := copy(p, payloadData); n < len(p) {
+ payloader.SetPayload(p[:n])
+ }
+ ch.buf.data = ch.buf.data[:fs]
+ }
+
+ r.Decode(&ch.buf)
+ if ch.buf.isOverrun() {
+ // Nothing valid was available.
+ log.Debugf("recv [got %d bytes, needed more]", rsz)
+ return nil, ErrNoValidMessage
+ }
+
+ // Read any FD result.
+ if hasFD {
+ if rfd, err := ch.fds.RecvFDNonblock(); err == nil {
+ f := fd.New(rfd)
+ if filer, ok := r.(filer); ok {
+ // Set the payload.
+ filer.SetFilePayload(f)
+ } else {
+ // Don't want the FD.
+ f.Close()
+ }
+ } else {
+ // The header bit was set but nothing came in.
+ log.Warningf("expected FD, got err: %v", err)
+ }
+ }
+
+ // Log a message.
+ if log.IsLogging(log.Debug) {
+ log.Debugf("recv [channel @%p] %s", ch, r.String())
+ }
+
+ // Convert errors appropriately; see above.
+ if rlerr, ok := r.(*Rlerror); ok {
+ return nil, syscall.Errno(rlerr.Error)
+ }
+
+ return r, nil
+}
+
+// sendRecv sends the given message over the channel.
+//
+// This is used by the client.
+func (ch *channel) sendRecv(c *Client, m, r message) error {
+ rsz, err := ch.send(m)
+ if err != nil {
+ return err
+ }
+ _, err = ch.recv(r, rsz)
+ return err
+}
diff --git a/pkg/p9/transport_test.go b/pkg/p9/transport_test.go
index cdb3bc841..2f50ff3ea 100644
--- a/pkg/p9/transport_test.go
+++ b/pkg/p9/transport_test.go
@@ -124,7 +124,9 @@ func TestSendRecvWithFile(t *testing.T) {
t.Fatalf("unable to create file: %v", err)
}
- if err := send(client, Tag(1), &Rlopen{File: f}); err != nil {
+ rlopen := &Rlopen{}
+ rlopen.SetFilePayload(f)
+ if err := send(client, Tag(1), rlopen); err != nil {
t.Fatalf("send got err %v expected nil", err)
}
diff --git a/pkg/p9/version.go b/pkg/p9/version.go
index c2a2885ae..f1ffdd23a 100644
--- a/pkg/p9/version.go
+++ b/pkg/p9/version.go
@@ -26,7 +26,7 @@ const (
//
// Clients are expected to start requesting this version number and
// to continuously decrement it until a Tversion request succeeds.
- highestSupportedVersion uint32 = 7
+ highestSupportedVersion uint32 = 8
// lowestSupportedVersion is the lowest supported version X in a
// version string of the format 9P2000.L.Google.X.
@@ -148,3 +148,10 @@ func VersionSupportsMultiUser(v uint32) bool {
func versionSupportsTallocate(v uint32) bool {
return v >= 7
}
+
+// versionSupportsFlipcall returns true if version v supports IPC channels from
+// the flipcall package. Note that these must be negotiated, but this version
+// string indicates that such a facility exists.
+func versionSupportsFlipcall(v uint32) bool {
+ return v >= 8
+}
diff --git a/pkg/procid/BUILD b/pkg/procid/BUILD
index 697e7a2f4..078f084b2 100644
--- a/pkg/procid/BUILD
+++ b/pkg/procid/BUILD
@@ -1,4 +1,5 @@
-load("//tools/go_stateify:defs.bzl", "go_library", "go_test")
+load("//tools/go_stateify:defs.bzl", "go_library")
+load("@io_bazel_rules_go//go:def.bzl", "go_test")
package(licenses = ["notice"])
diff --git a/pkg/refs/BUILD b/pkg/refs/BUILD
index 9c08452fc..827385139 100644
--- a/pkg/refs/BUILD
+++ b/pkg/refs/BUILD
@@ -1,7 +1,9 @@
+load("@io_bazel_rules_go//go:def.bzl", "go_test")
+
package(licenses = ["notice"])
load("//tools/go_generics:defs.bzl", "go_template_instance")
-load("//tools/go_stateify:defs.bzl", "go_library", "go_test")
+load("//tools/go_stateify:defs.bzl", "go_library")
go_template_instance(
name = "weak_ref_list",
diff --git a/pkg/seccomp/BUILD b/pkg/seccomp/BUILD
index d1024e49d..af94e944d 100644
--- a/pkg/seccomp/BUILD
+++ b/pkg/seccomp/BUILD
@@ -1,5 +1,5 @@
-load("//tools/go_stateify:defs.bzl", "go_library", "go_test")
-load("@io_bazel_rules_go//go:def.bzl", "go_binary", "go_embed_data")
+load("//tools/go_stateify:defs.bzl", "go_library")
+load("@io_bazel_rules_go//go:def.bzl", "go_binary", "go_embed_data", "go_test")
package(licenses = ["notice"])
diff --git a/pkg/secio/BUILD b/pkg/secio/BUILD
index f38fb39f3..22abdc69f 100644
--- a/pkg/secio/BUILD
+++ b/pkg/secio/BUILD
@@ -1,4 +1,5 @@
-load("//tools/go_stateify:defs.bzl", "go_library", "go_test")
+load("//tools/go_stateify:defs.bzl", "go_library")
+load("@io_bazel_rules_go//go:def.bzl", "go_test")
package(licenses = ["notice"])
diff --git a/pkg/segment/test/BUILD b/pkg/segment/test/BUILD
index 694486296..12d7c77d2 100644
--- a/pkg/segment/test/BUILD
+++ b/pkg/segment/test/BUILD
@@ -1,4 +1,5 @@
-load("//tools/go_stateify:defs.bzl", "go_library", "go_test")
+load("//tools/go_stateify:defs.bzl", "go_library")
+load("@io_bazel_rules_go//go:def.bzl", "go_test")
package(
default_visibility = ["//visibility:private"],
diff --git a/pkg/sentry/BUILD b/pkg/sentry/BUILD
index 53989301f..2d6379c86 100644
--- a/pkg/sentry/BUILD
+++ b/pkg/sentry/BUILD
@@ -8,5 +8,7 @@ package_group(
packages = [
"//pkg/sentry/...",
"//runsc/...",
+ # Code generated by go_marshal relies on go_marshal libraries.
+ "//tools/go_marshal/...",
],
)
diff --git a/pkg/sentry/arch/BUILD b/pkg/sentry/arch/BUILD
index 7aace2d7b..c71cff9f3 100644
--- a/pkg/sentry/arch/BUILD
+++ b/pkg/sentry/arch/BUILD
@@ -1,4 +1,5 @@
load("@io_bazel_rules_go//proto:def.bzl", "go_proto_library")
+load("@rules_cc//cc:defs.bzl", "cc_proto_library")
package(licenses = ["notice"])
@@ -42,6 +43,12 @@ proto_library(
visibility = ["//visibility:public"],
)
+cc_proto_library(
+ name = "registers_cc_proto",
+ visibility = ["//visibility:public"],
+ deps = [":registers_proto"],
+)
+
go_proto_library(
name = "registers_go_proto",
importpath = "gvisor.dev/gvisor/pkg/sentry/arch/registers_go_proto",
diff --git a/pkg/sentry/control/BUILD b/pkg/sentry/control/BUILD
index bf802d1b6..5522cecd0 100644
--- a/pkg/sentry/control/BUILD
+++ b/pkg/sentry/control/BUILD
@@ -1,4 +1,5 @@
-load("//tools/go_stateify:defs.bzl", "go_library", "go_test")
+load("//tools/go_stateify:defs.bzl", "go_library")
+load("@io_bazel_rules_go//go:def.bzl", "go_test")
package(licenses = ["notice"])
diff --git a/pkg/sentry/device/BUILD b/pkg/sentry/device/BUILD
index 7e8918722..0c86197f7 100644
--- a/pkg/sentry/device/BUILD
+++ b/pkg/sentry/device/BUILD
@@ -1,6 +1,8 @@
+load("@io_bazel_rules_go//go:def.bzl", "go_test")
+
package(licenses = ["notice"])
-load("//tools/go_stateify:defs.bzl", "go_library", "go_test")
+load("//tools/go_stateify:defs.bzl", "go_library")
go_library(
name = "device",
diff --git a/pkg/sentry/fs/BUILD b/pkg/sentry/fs/BUILD
index d7259b47b..3119a61b6 100644
--- a/pkg/sentry/fs/BUILD
+++ b/pkg/sentry/fs/BUILD
@@ -1,7 +1,9 @@
+load("@io_bazel_rules_go//go:def.bzl", "go_test")
+
package(licenses = ["notice"])
load("//tools/go_generics:defs.bzl", "go_template_instance")
-load("//tools/go_stateify:defs.bzl", "go_library", "go_test")
+load("//tools/go_stateify:defs.bzl", "go_library")
go_library(
name = "fs",
diff --git a/pkg/sentry/fs/dirent.go b/pkg/sentry/fs/dirent.go
index fbca06761..3cb73bd78 100644
--- a/pkg/sentry/fs/dirent.go
+++ b/pkg/sentry/fs/dirent.go
@@ -1126,7 +1126,7 @@ func (d *Dirent) unmount(ctx context.Context, replacement *Dirent) error {
// Remove removes the given file or symlink. The root dirent is used to
// resolve name, and must not be nil.
-func (d *Dirent) Remove(ctx context.Context, root *Dirent, name string) error {
+func (d *Dirent) Remove(ctx context.Context, root *Dirent, name string, dirPath bool) error {
// Check the root.
if root == nil {
panic("Dirent.Remove: root must not be nil")
@@ -1151,6 +1151,8 @@ func (d *Dirent) Remove(ctx context.Context, root *Dirent, name string) error {
// Remove cannot remove directories.
if IsDir(child.Inode.StableAttr) {
return syscall.EISDIR
+ } else if dirPath {
+ return syscall.ENOTDIR
}
// Remove cannot remove a mount point.
diff --git a/pkg/sentry/fs/dirent_refs_test.go b/pkg/sentry/fs/dirent_refs_test.go
index 884e3ff06..47bc72a88 100644
--- a/pkg/sentry/fs/dirent_refs_test.go
+++ b/pkg/sentry/fs/dirent_refs_test.go
@@ -343,7 +343,7 @@ func TestRemoveExtraRefs(t *testing.T) {
}
d := f.Dirent
- if err := test.root.Remove(contexttest.Context(t), test.root, name); err != nil {
+ if err := test.root.Remove(contexttest.Context(t), test.root, name, false /* dirPath */); err != nil {
t.Fatalf("root.Remove(root, %q) failed: %v", name, err)
}
diff --git a/pkg/sentry/fs/fdpipe/BUILD b/pkg/sentry/fs/fdpipe/BUILD
index bf00b9c09..b9bd9ed17 100644
--- a/pkg/sentry/fs/fdpipe/BUILD
+++ b/pkg/sentry/fs/fdpipe/BUILD
@@ -1,6 +1,8 @@
+load("@io_bazel_rules_go//go:def.bzl", "go_test")
+
package(licenses = ["notice"])
-load("//tools/go_stateify:defs.bzl", "go_library", "go_test")
+load("//tools/go_stateify:defs.bzl", "go_library")
go_library(
name = "fdpipe",
diff --git a/pkg/sentry/fs/file.go b/pkg/sentry/fs/file.go
index bb8117f89..c0a6e884b 100644
--- a/pkg/sentry/fs/file.go
+++ b/pkg/sentry/fs/file.go
@@ -515,6 +515,11 @@ type lockedReader struct {
// File is the file to read from.
File *File
+
+ // Offset is the offset to start at.
+ //
+ // This applies only to Read, not ReadAt.
+ Offset int64
}
// Read implements io.Reader.Read.
@@ -522,7 +527,8 @@ func (r *lockedReader) Read(buf []byte) (int, error) {
if r.Ctx.Interrupted() {
return 0, syserror.ErrInterrupted
}
- n, err := r.File.FileOperations.Read(r.Ctx, r.File, usermem.BytesIOSequence(buf), r.File.offset)
+ n, err := r.File.FileOperations.Read(r.Ctx, r.File, usermem.BytesIOSequence(buf), r.Offset)
+ r.Offset += n
return int(n), err
}
@@ -544,11 +550,21 @@ type lockedWriter struct {
// File is the file to write to.
File *File
+
+ // Offset is the offset to start at.
+ //
+ // This applies only to Write, not WriteAt.
+ Offset int64
}
// Write implements io.Writer.Write.
func (w *lockedWriter) Write(buf []byte) (int, error) {
- return w.WriteAt(buf, w.File.offset)
+ if w.Ctx.Interrupted() {
+ return 0, syserror.ErrInterrupted
+ }
+ n, err := w.WriteAt(buf, w.Offset)
+ w.Offset += int64(n)
+ return int(n), err
}
// WriteAt implements io.Writer.WriteAt.
@@ -562,6 +578,9 @@ func (w *lockedWriter) WriteAt(buf []byte, offset int64) (int, error) {
// io.Copy, since our own Write interface does not have this same
// contract. Enforce that here.
for written < len(buf) {
+ if w.Ctx.Interrupted() {
+ return written, syserror.ErrInterrupted
+ }
var n int64
n, err = w.File.FileOperations.Write(w.Ctx, w.File, usermem.BytesIOSequence(buf[written:]), offset+int64(written))
if n > 0 {
diff --git a/pkg/sentry/fs/file_operations.go b/pkg/sentry/fs/file_operations.go
index d86f5bf45..b88303f17 100644
--- a/pkg/sentry/fs/file_operations.go
+++ b/pkg/sentry/fs/file_operations.go
@@ -15,6 +15,8 @@
package fs
import (
+ "io"
+
"gvisor.dev/gvisor/pkg/sentry/arch"
"gvisor.dev/gvisor/pkg/sentry/context"
"gvisor.dev/gvisor/pkg/sentry/memmap"
@@ -105,8 +107,11 @@ type FileOperations interface {
// on the destination, following by a buffered copy with standard Read
// and Write operations.
//
+ // If dup is set, the data should be duplicated into the destination
+ // and retained.
+ //
// The same preconditions as Read apply.
- WriteTo(ctx context.Context, file *File, dst *File, opts SpliceOpts) (int64, error)
+ WriteTo(ctx context.Context, file *File, dst io.Writer, count int64, dup bool) (int64, error)
// Write writes src to file at offset and returns the number of bytes
// written which must be greater than or equal to 0. Like Read, file
@@ -126,7 +131,7 @@ type FileOperations interface {
// source. See WriteTo for details regarding how this is called.
//
// The same preconditions as Write apply; FileFlags.Write must be set.
- ReadFrom(ctx context.Context, file *File, src *File, opts SpliceOpts) (int64, error)
+ ReadFrom(ctx context.Context, file *File, src io.Reader, count int64) (int64, error)
// Fsync writes buffered modifications of file and/or flushes in-flight
// operations to backing storage based on syncType. The range to sync is
diff --git a/pkg/sentry/fs/file_overlay.go b/pkg/sentry/fs/file_overlay.go
index 9820f0b13..225e40186 100644
--- a/pkg/sentry/fs/file_overlay.go
+++ b/pkg/sentry/fs/file_overlay.go
@@ -15,6 +15,7 @@
package fs
import (
+ "io"
"sync"
"gvisor.dev/gvisor/pkg/refs"
@@ -268,9 +269,9 @@ func (f *overlayFileOperations) Read(ctx context.Context, file *File, dst userme
}
// WriteTo implements FileOperations.WriteTo.
-func (f *overlayFileOperations) WriteTo(ctx context.Context, file *File, dst *File, opts SpliceOpts) (n int64, err error) {
+func (f *overlayFileOperations) WriteTo(ctx context.Context, file *File, dst io.Writer, count int64, dup bool) (n int64, err error) {
err = f.onTop(ctx, file, func(file *File, ops FileOperations) error {
- n, err = ops.WriteTo(ctx, file, dst, opts)
+ n, err = ops.WriteTo(ctx, file, dst, count, dup)
return err // Will overwrite itself.
})
return
@@ -285,9 +286,9 @@ func (f *overlayFileOperations) Write(ctx context.Context, file *File, src userm
}
// ReadFrom implements FileOperations.ReadFrom.
-func (f *overlayFileOperations) ReadFrom(ctx context.Context, file *File, src *File, opts SpliceOpts) (n int64, err error) {
+func (f *overlayFileOperations) ReadFrom(ctx context.Context, file *File, src io.Reader, count int64) (n int64, err error) {
// See above; f.upper must be non-nil.
- return f.upper.FileOperations.ReadFrom(ctx, f.upper, src, opts)
+ return f.upper.FileOperations.ReadFrom(ctx, f.upper, src, count)
}
// Fsync implements FileOperations.Fsync.
diff --git a/pkg/sentry/fs/fsutil/BUILD b/pkg/sentry/fs/fsutil/BUILD
index 6499f87ac..b4ac83dc4 100644
--- a/pkg/sentry/fs/fsutil/BUILD
+++ b/pkg/sentry/fs/fsutil/BUILD
@@ -1,7 +1,9 @@
+load("@io_bazel_rules_go//go:def.bzl", "go_test")
+
package(licenses = ["notice"])
load("//tools/go_generics:defs.bzl", "go_template_instance")
-load("//tools/go_stateify:defs.bzl", "go_library", "go_test")
+load("//tools/go_stateify:defs.bzl", "go_library")
go_template_instance(
name = "dirty_set_impl",
diff --git a/pkg/sentry/fs/fsutil/file.go b/pkg/sentry/fs/fsutil/file.go
index 626b9126a..fc5b3b1a1 100644
--- a/pkg/sentry/fs/fsutil/file.go
+++ b/pkg/sentry/fs/fsutil/file.go
@@ -15,6 +15,8 @@
package fsutil
import (
+ "io"
+
"gvisor.dev/gvisor/pkg/sentry/arch"
"gvisor.dev/gvisor/pkg/sentry/context"
"gvisor.dev/gvisor/pkg/sentry/fs"
@@ -228,12 +230,12 @@ func (FileNoIoctl) Ioctl(context.Context, *fs.File, usermem.IO, arch.SyscallArgu
type FileNoSplice struct{}
// WriteTo implements fs.FileOperations.WriteTo.
-func (FileNoSplice) WriteTo(context.Context, *fs.File, *fs.File, fs.SpliceOpts) (int64, error) {
+func (FileNoSplice) WriteTo(context.Context, *fs.File, io.Writer, int64, bool) (int64, error) {
return 0, syserror.ENOSYS
}
// ReadFrom implements fs.FileOperations.ReadFrom.
-func (FileNoSplice) ReadFrom(context.Context, *fs.File, *fs.File, fs.SpliceOpts) (int64, error) {
+func (FileNoSplice) ReadFrom(context.Context, *fs.File, io.Reader, int64) (int64, error) {
return 0, syserror.ENOSYS
}
diff --git a/pkg/sentry/fs/gofer/BUILD b/pkg/sentry/fs/gofer/BUILD
index 6b993928c..2b71ca0e1 100644
--- a/pkg/sentry/fs/gofer/BUILD
+++ b/pkg/sentry/fs/gofer/BUILD
@@ -1,6 +1,8 @@
+load("@io_bazel_rules_go//go:def.bzl", "go_test")
+
package(licenses = ["notice"])
-load("//tools/go_stateify:defs.bzl", "go_library", "go_test")
+load("//tools/go_stateify:defs.bzl", "go_library")
go_library(
name = "gofer",
diff --git a/pkg/sentry/fs/host/BUILD b/pkg/sentry/fs/host/BUILD
index b1080fb1a..3e532332e 100644
--- a/pkg/sentry/fs/host/BUILD
+++ b/pkg/sentry/fs/host/BUILD
@@ -1,6 +1,8 @@
+load("@io_bazel_rules_go//go:def.bzl", "go_test")
+
package(licenses = ["notice"])
-load("//tools/go_stateify:defs.bzl", "go_library", "go_test")
+load("//tools/go_stateify:defs.bzl", "go_library")
go_library(
name = "host",
diff --git a/pkg/sentry/fs/inotify.go b/pkg/sentry/fs/inotify.go
index c7f4e2d13..ba3e0233d 100644
--- a/pkg/sentry/fs/inotify.go
+++ b/pkg/sentry/fs/inotify.go
@@ -15,6 +15,7 @@
package fs
import (
+ "io"
"sync"
"sync/atomic"
@@ -172,7 +173,7 @@ func (i *Inotify) Read(ctx context.Context, _ *File, dst usermem.IOSequence, _ i
}
// WriteTo implements FileOperations.WriteTo.
-func (*Inotify) WriteTo(context.Context, *File, *File, SpliceOpts) (int64, error) {
+func (*Inotify) WriteTo(context.Context, *File, io.Writer, int64, bool) (int64, error) {
return 0, syserror.ENOSYS
}
@@ -182,7 +183,7 @@ func (*Inotify) Fsync(context.Context, *File, int64, int64, SyncType) error {
}
// ReadFrom implements FileOperations.ReadFrom.
-func (*Inotify) ReadFrom(context.Context, *File, *File, SpliceOpts) (int64, error) {
+func (*Inotify) ReadFrom(context.Context, *File, io.Reader, int64) (int64, error) {
return 0, syserror.ENOSYS
}
diff --git a/pkg/sentry/fs/lock/BUILD b/pkg/sentry/fs/lock/BUILD
index 08d7c0c57..5a7a5b8cd 100644
--- a/pkg/sentry/fs/lock/BUILD
+++ b/pkg/sentry/fs/lock/BUILD
@@ -1,7 +1,9 @@
+load("@io_bazel_rules_go//go:def.bzl", "go_test")
+
package(licenses = ["notice"])
load("//tools/go_generics:defs.bzl", "go_template_instance")
-load("//tools/go_stateify:defs.bzl", "go_library", "go_test")
+load("//tools/go_stateify:defs.bzl", "go_library")
go_template_instance(
name = "lock_range",
diff --git a/pkg/sentry/fs/proc/BUILD b/pkg/sentry/fs/proc/BUILD
index c7599d1f6..1c93e8886 100644
--- a/pkg/sentry/fs/proc/BUILD
+++ b/pkg/sentry/fs/proc/BUILD
@@ -1,6 +1,8 @@
+load("@io_bazel_rules_go//go:def.bzl", "go_test")
+
package(licenses = ["notice"])
-load("//tools/go_stateify:defs.bzl", "go_library", "go_test")
+load("//tools/go_stateify:defs.bzl", "go_library")
go_library(
name = "proc",
diff --git a/pkg/sentry/fs/proc/seqfile/BUILD b/pkg/sentry/fs/proc/seqfile/BUILD
index 20c3eefc8..76433c7d0 100644
--- a/pkg/sentry/fs/proc/seqfile/BUILD
+++ b/pkg/sentry/fs/proc/seqfile/BUILD
@@ -1,6 +1,8 @@
+load("@io_bazel_rules_go//go:def.bzl", "go_test")
+
package(licenses = ["notice"])
-load("//tools/go_stateify:defs.bzl", "go_library", "go_test")
+load("//tools/go_stateify:defs.bzl", "go_library")
go_library(
name = "seqfile",
diff --git a/pkg/sentry/fs/ramfs/BUILD b/pkg/sentry/fs/ramfs/BUILD
index 516efcc4c..d0f351e5a 100644
--- a/pkg/sentry/fs/ramfs/BUILD
+++ b/pkg/sentry/fs/ramfs/BUILD
@@ -1,6 +1,8 @@
+load("@io_bazel_rules_go//go:def.bzl", "go_test")
+
package(licenses = ["notice"])
-load("//tools/go_stateify:defs.bzl", "go_library", "go_test")
+load("//tools/go_stateify:defs.bzl", "go_library")
go_library(
name = "ramfs",
diff --git a/pkg/sentry/fs/splice.go b/pkg/sentry/fs/splice.go
index eed1c2854..b03b7f836 100644
--- a/pkg/sentry/fs/splice.go
+++ b/pkg/sentry/fs/splice.go
@@ -18,7 +18,6 @@ import (
"io"
"sync/atomic"
- "gvisor.dev/gvisor/pkg/secio"
"gvisor.dev/gvisor/pkg/sentry/context"
"gvisor.dev/gvisor/pkg/syserror"
)
@@ -33,146 +32,131 @@ func Splice(ctx context.Context, dst *File, src *File, opts SpliceOpts) (int64,
}
// Check whether or not the objects being sliced are stream-oriented
- // (i.e. pipes or sockets). If yes, we elide checks and offset locks.
- srcPipe := IsPipe(src.Dirent.Inode.StableAttr) || IsSocket(src.Dirent.Inode.StableAttr)
- dstPipe := IsPipe(dst.Dirent.Inode.StableAttr) || IsSocket(dst.Dirent.Inode.StableAttr)
+ // (i.e. pipes or sockets). For all stream-oriented files and files
+ // where a specific offiset is not request, we acquire the file mutex.
+ // This has two important side effects. First, it provides the standard
+ // protection against concurrent writes that would mutate the offset.
+ // Second, it prevents Splice deadlocks. Only internal anonymous files
+ // implement the ReadFrom and WriteTo methods directly, and since such
+ // anonymous files are referred to by a unique fs.File object, we know
+ // that the file mutex takes strict precedence over internal locks.
+ // Since we enforce lock ordering here, we can't deadlock by using
+ // using a file in two different splice operations simultaneously.
+ srcPipe := !IsRegular(src.Dirent.Inode.StableAttr)
+ dstPipe := !IsRegular(dst.Dirent.Inode.StableAttr)
+ dstAppend := !dstPipe && dst.Flags().Append
+ srcLock := srcPipe || !opts.SrcOffset
+ dstLock := dstPipe || !opts.DstOffset || dstAppend
- if !dstPipe && !opts.DstOffset && !srcPipe && !opts.SrcOffset {
+ switch {
+ case srcLock && dstLock:
switch {
case dst.UniqueID < src.UniqueID:
// Acquire dst first.
if !dst.mu.Lock(ctx) {
return 0, syserror.ErrInterrupted
}
- defer dst.mu.Unlock()
if !src.mu.Lock(ctx) {
+ dst.mu.Unlock()
return 0, syserror.ErrInterrupted
}
- defer src.mu.Unlock()
case dst.UniqueID > src.UniqueID:
// Acquire src first.
if !src.mu.Lock(ctx) {
return 0, syserror.ErrInterrupted
}
- defer src.mu.Unlock()
if !dst.mu.Lock(ctx) {
+ src.mu.Unlock()
return 0, syserror.ErrInterrupted
}
- defer dst.mu.Unlock()
case dst.UniqueID == src.UniqueID:
// Acquire only one lock; it's the same file. This is a
// bit of a edge case, but presumably it's possible.
if !dst.mu.Lock(ctx) {
return 0, syserror.ErrInterrupted
}
- defer dst.mu.Unlock()
+ srcLock = false // Only need one unlock.
}
// Use both offsets (locked).
opts.DstStart = dst.offset
opts.SrcStart = src.offset
- } else if !dstPipe && !opts.DstOffset {
+ case dstLock:
// Acquire only dst.
if !dst.mu.Lock(ctx) {
return 0, syserror.ErrInterrupted
}
- defer dst.mu.Unlock()
opts.DstStart = dst.offset // Safe: locked.
- } else if !srcPipe && !opts.SrcOffset {
+ case srcLock:
// Acquire only src.
if !src.mu.Lock(ctx) {
return 0, syserror.ErrInterrupted
}
- defer src.mu.Unlock()
opts.SrcStart = src.offset // Safe: locked.
}
- // Check append-only mode and the limit.
- if !dstPipe {
+ var err error
+ if dstAppend {
unlock := dst.Dirent.Inode.lockAppendMu(dst.Flags().Append)
defer unlock()
- if dst.Flags().Append {
- if opts.DstOffset {
- // We need to acquire the lock.
- if !dst.mu.Lock(ctx) {
- return 0, syserror.ErrInterrupted
- }
- defer dst.mu.Unlock()
- }
- // Figure out the appropriate offset to use.
- if err := dst.offsetForAppend(ctx, &opts.DstStart); err != nil {
- return 0, err
- }
- }
+ // Figure out the appropriate offset to use.
+ err = dst.offsetForAppend(ctx, &opts.DstStart)
+ }
+ if err == nil && !dstPipe {
// Enforce file limits.
limit, ok := dst.checkLimit(ctx, opts.DstStart)
switch {
case ok && limit == 0:
- return 0, syserror.ErrExceedsFileSizeLimit
+ err = syserror.ErrExceedsFileSizeLimit
case ok && limit < opts.Length:
opts.Length = limit // Cap the write.
}
}
+ if err != nil {
+ if dstLock {
+ dst.mu.Unlock()
+ }
+ if srcLock {
+ src.mu.Unlock()
+ }
+ return 0, err
+ }
- // Attempt to do a WriteTo; this is likely the most efficient.
- //
- // The underlying implementation may be able to donate buffers.
- newOpts := SpliceOpts{
- Length: opts.Length,
- SrcStart: opts.SrcStart,
- SrcOffset: !srcPipe,
- Dup: opts.Dup,
- DstStart: opts.DstStart,
- DstOffset: !dstPipe,
+ // Construct readers and writers for the splice. This is used to
+ // provide a safer locking path for the WriteTo/ReadFrom operations
+ // (since they will otherwise go through public interface methods which
+ // conflict with locking done above), and simplifies the fallback path.
+ w := &lockedWriter{
+ Ctx: ctx,
+ File: dst,
+ Offset: opts.DstStart,
}
- n, err := src.FileOperations.WriteTo(ctx, src, dst, newOpts)
- if n == 0 && err != nil {
- // Attempt as a ReadFrom. If a WriteTo, a ReadFrom may also
- // be more efficient than a copy if buffers are cached or readily
- // available. (It's unlikely that they can actually be donate
- n, err = dst.FileOperations.ReadFrom(ctx, dst, src, newOpts)
+ r := &lockedReader{
+ Ctx: ctx,
+ File: src,
+ Offset: opts.SrcStart,
}
- if n == 0 && err != nil {
- // If we've failed up to here, and at least one of the sources
- // is a pipe or socket, then we can't properly support dup.
- // Return an error indicating that this operation is not
- // supported.
- if (srcPipe || dstPipe) && newOpts.Dup {
- return 0, syserror.EINVAL
- }
- // We failed to splice the files. But that's fine; we just fall
- // back to a slow path in this case. This copies without doing
- // any mode changes, so should still be more efficient.
- var (
- r io.Reader
- w io.Writer
- )
- fw := &lockedWriter{
- Ctx: ctx,
- File: dst,
- }
- if newOpts.DstOffset {
- // Use the provided offset.
- w = secio.NewOffsetWriter(fw, newOpts.DstStart)
- } else {
- // Writes will proceed with no offset.
- w = fw
- }
- fr := &lockedReader{
- Ctx: ctx,
- File: src,
- }
- if newOpts.SrcOffset {
- // Limit to the given offset and length.
- r = io.NewSectionReader(fr, opts.SrcStart, opts.Length)
- } else {
- // Limit just to the given length.
- r = &io.LimitedReader{fr, opts.Length}
- }
+ // Attempt to do a WriteTo; this is likely the most efficient.
+ n, err := src.FileOperations.WriteTo(ctx, src, w, opts.Length, opts.Dup)
+ if n == 0 && err != nil && err != syserror.ErrWouldBlock && !opts.Dup {
+ // Attempt as a ReadFrom. If a WriteTo, a ReadFrom may also be
+ // more efficient than a copy if buffers are cached or readily
+ // available. (It's unlikely that they can actually be donated).
+ n, err = dst.FileOperations.ReadFrom(ctx, dst, r, opts.Length)
+ }
- // Copy between the two.
- n, err = io.Copy(w, r)
+ // Support one last fallback option, but only if at least one of
+ // the source and destination are regular files. This is because
+ // if we block at some point, we could lose data. If the source is
+ // not a pipe then reading is not destructive; if the destination
+ // is a regular file, then it is guaranteed not to block writing.
+ if n == 0 && err != nil && err != syserror.ErrWouldBlock && !opts.Dup && (!dstPipe || !srcPipe) {
+ // Fallback to an in-kernel copy.
+ n, err = io.Copy(w, &io.LimitedReader{
+ R: r,
+ N: opts.Length,
+ })
}
// Update offsets, if required.
@@ -185,5 +169,13 @@ func Splice(ctx context.Context, dst *File, src *File, opts SpliceOpts) (int64,
}
}
+ // Drop locks.
+ if dstLock {
+ dst.mu.Unlock()
+ }
+ if srcLock {
+ src.mu.Unlock()
+ }
+
return n, err
}
diff --git a/pkg/sentry/fs/tmpfs/BUILD b/pkg/sentry/fs/tmpfs/BUILD
index 8f7eb5757..11b680929 100644
--- a/pkg/sentry/fs/tmpfs/BUILD
+++ b/pkg/sentry/fs/tmpfs/BUILD
@@ -1,6 +1,8 @@
+load("@io_bazel_rules_go//go:def.bzl", "go_test")
+
package(licenses = ["notice"])
-load("//tools/go_stateify:defs.bzl", "go_library", "go_test")
+load("//tools/go_stateify:defs.bzl", "go_library")
go_library(
name = "tmpfs",
diff --git a/pkg/sentry/fs/tty/BUILD b/pkg/sentry/fs/tty/BUILD
index 5e9327aec..25811f668 100644
--- a/pkg/sentry/fs/tty/BUILD
+++ b/pkg/sentry/fs/tty/BUILD
@@ -1,6 +1,8 @@
+load("@io_bazel_rules_go//go:def.bzl", "go_test")
+
package(licenses = ["notice"])
-load("//tools/go_stateify:defs.bzl", "go_library", "go_test")
+load("//tools/go_stateify:defs.bzl", "go_library")
go_library(
name = "tty",
@@ -23,6 +25,7 @@ go_library(
"//pkg/sentry/device",
"//pkg/sentry/fs",
"//pkg/sentry/fs/fsutil",
+ "//pkg/sentry/kernel",
"//pkg/sentry/kernel/auth",
"//pkg/sentry/safemem",
"//pkg/sentry/socket/unix/transport",
diff --git a/pkg/sentry/fs/tty/dir.go b/pkg/sentry/fs/tty/dir.go
index 1d128532b..2f639c823 100644
--- a/pkg/sentry/fs/tty/dir.go
+++ b/pkg/sentry/fs/tty/dir.go
@@ -129,6 +129,9 @@ func newDir(ctx context.Context, m *fs.MountSource) *fs.Inode {
// Release implements fs.InodeOperations.Release.
func (d *dirInodeOperations) Release(ctx context.Context) {
+ d.mu.Lock()
+ defer d.mu.Unlock()
+
d.master.DecRef()
if len(d.slaves) != 0 {
panic(fmt.Sprintf("devpts directory still contains active terminals: %+v", d))
diff --git a/pkg/sentry/fs/tty/master.go b/pkg/sentry/fs/tty/master.go
index 92ec1ca18..19b7557d5 100644
--- a/pkg/sentry/fs/tty/master.go
+++ b/pkg/sentry/fs/tty/master.go
@@ -172,6 +172,19 @@ func (mf *masterFileOperations) Ioctl(ctx context.Context, _ *fs.File, io userme
return 0, mf.t.ld.windowSize(ctx, io, args)
case linux.TIOCSWINSZ:
return 0, mf.t.ld.setWindowSize(ctx, io, args)
+ case linux.TIOCSCTTY:
+ // Make the given terminal the controlling terminal of the
+ // calling process.
+ return 0, mf.t.setControllingTTY(ctx, io, args, true /* isMaster */)
+ case linux.TIOCNOTTY:
+ // Release this process's controlling terminal.
+ return 0, mf.t.releaseControllingTTY(ctx, io, args, true /* isMaster */)
+ case linux.TIOCGPGRP:
+ // Get the foreground process group.
+ return mf.t.foregroundProcessGroup(ctx, io, args, true /* isMaster */)
+ case linux.TIOCSPGRP:
+ // Set the foreground process group.
+ return mf.t.setForegroundProcessGroup(ctx, io, args, true /* isMaster */)
default:
maybeEmitUnimplementedEvent(ctx, cmd)
return 0, syserror.ENOTTY
@@ -185,8 +198,6 @@ func maybeEmitUnimplementedEvent(ctx context.Context, cmd uint32) {
linux.TCSETS,
linux.TCSETSW,
linux.TCSETSF,
- linux.TIOCGPGRP,
- linux.TIOCSPGRP,
linux.TIOCGWINSZ,
linux.TIOCSWINSZ,
linux.TIOCSETD,
@@ -200,8 +211,6 @@ func maybeEmitUnimplementedEvent(ctx context.Context, cmd uint32) {
linux.TIOCEXCL,
linux.TIOCNXCL,
linux.TIOCGEXCL,
- linux.TIOCNOTTY,
- linux.TIOCSCTTY,
linux.TIOCGSID,
linux.TIOCGETD,
linux.TIOCVHANGUP,
diff --git a/pkg/sentry/fs/tty/slave.go b/pkg/sentry/fs/tty/slave.go
index e30266404..944c4ada1 100644
--- a/pkg/sentry/fs/tty/slave.go
+++ b/pkg/sentry/fs/tty/slave.go
@@ -152,9 +152,16 @@ func (sf *slaveFileOperations) Ioctl(ctx context.Context, _ *fs.File, io usermem
case linux.TIOCSCTTY:
// Make the given terminal the controlling terminal of the
// calling process.
- // TODO(b/129283598): Implement once we have support for job
- // control.
- return 0, nil
+ return 0, sf.si.t.setControllingTTY(ctx, io, args, false /* isMaster */)
+ case linux.TIOCNOTTY:
+ // Release this process's controlling terminal.
+ return 0, sf.si.t.releaseControllingTTY(ctx, io, args, false /* isMaster */)
+ case linux.TIOCGPGRP:
+ // Get the foreground process group.
+ return sf.si.t.foregroundProcessGroup(ctx, io, args, false /* isMaster */)
+ case linux.TIOCSPGRP:
+ // Set the foreground process group.
+ return sf.si.t.setForegroundProcessGroup(ctx, io, args, false /* isMaster */)
default:
maybeEmitUnimplementedEvent(ctx, cmd)
return 0, syserror.ENOTTY
diff --git a/pkg/sentry/fs/tty/terminal.go b/pkg/sentry/fs/tty/terminal.go
index b7cecb2ed..ff8138820 100644
--- a/pkg/sentry/fs/tty/terminal.go
+++ b/pkg/sentry/fs/tty/terminal.go
@@ -17,7 +17,10 @@ package tty
import (
"gvisor.dev/gvisor/pkg/abi/linux"
"gvisor.dev/gvisor/pkg/refs"
+ "gvisor.dev/gvisor/pkg/sentry/arch"
"gvisor.dev/gvisor/pkg/sentry/context"
+ "gvisor.dev/gvisor/pkg/sentry/kernel"
+ "gvisor.dev/gvisor/pkg/sentry/usermem"
)
// Terminal is a pseudoterminal.
@@ -26,23 +29,100 @@ import (
type Terminal struct {
refs.AtomicRefCount
- // n is the terminal index.
+ // n is the terminal index. It is immutable.
n uint32
- // d is the containing directory.
+ // d is the containing directory. It is immutable.
d *dirInodeOperations
- // ld is the line discipline of the terminal.
+ // ld is the line discipline of the terminal. It is immutable.
ld *lineDiscipline
+
+ // masterKTTY contains the controlling process of the master end of
+ // this terminal. This field is immutable.
+ masterKTTY *kernel.TTY
+
+ // slaveKTTY contains the controlling process of the slave end of this
+ // terminal. This field is immutable.
+ slaveKTTY *kernel.TTY
}
func newTerminal(ctx context.Context, d *dirInodeOperations, n uint32) *Terminal {
termios := linux.DefaultSlaveTermios
t := Terminal{
- d: d,
- n: n,
- ld: newLineDiscipline(termios),
+ d: d,
+ n: n,
+ ld: newLineDiscipline(termios),
+ masterKTTY: &kernel.TTY{},
+ slaveKTTY: &kernel.TTY{},
}
t.EnableLeakCheck("tty.Terminal")
return &t
}
+
+// setControllingTTY makes tm the controlling terminal of the calling thread
+// group.
+func (tm *Terminal) setControllingTTY(ctx context.Context, io usermem.IO, args arch.SyscallArguments, isMaster bool) error {
+ task := kernel.TaskFromContext(ctx)
+ if task == nil {
+ panic("setControllingTTY must be called from a task context")
+ }
+
+ return task.ThreadGroup().SetControllingTTY(tm.tty(isMaster), args[2].Int())
+}
+
+// releaseControllingTTY removes tm as the controlling terminal of the calling
+// thread group.
+func (tm *Terminal) releaseControllingTTY(ctx context.Context, io usermem.IO, args arch.SyscallArguments, isMaster bool) error {
+ task := kernel.TaskFromContext(ctx)
+ if task == nil {
+ panic("releaseControllingTTY must be called from a task context")
+ }
+
+ return task.ThreadGroup().ReleaseControllingTTY(tm.tty(isMaster))
+}
+
+// foregroundProcessGroup gets the process group ID of tm's foreground process.
+func (tm *Terminal) foregroundProcessGroup(ctx context.Context, io usermem.IO, args arch.SyscallArguments, isMaster bool) (uintptr, error) {
+ task := kernel.TaskFromContext(ctx)
+ if task == nil {
+ panic("foregroundProcessGroup must be called from a task context")
+ }
+
+ ret, err := task.ThreadGroup().ForegroundProcessGroup(tm.tty(isMaster))
+ if err != nil {
+ return 0, err
+ }
+
+ // Write it out to *arg.
+ _, err = usermem.CopyObjectOut(ctx, io, args[2].Pointer(), int32(ret), usermem.IOOpts{
+ AddressSpaceActive: true,
+ })
+ return 0, err
+}
+
+// foregroundProcessGroup sets tm's foreground process.
+func (tm *Terminal) setForegroundProcessGroup(ctx context.Context, io usermem.IO, args arch.SyscallArguments, isMaster bool) (uintptr, error) {
+ task := kernel.TaskFromContext(ctx)
+ if task == nil {
+ panic("setForegroundProcessGroup must be called from a task context")
+ }
+
+ // Read in the process group ID.
+ var pgid int32
+ if _, err := usermem.CopyObjectIn(ctx, io, args[2].Pointer(), &pgid, usermem.IOOpts{
+ AddressSpaceActive: true,
+ }); err != nil {
+ return 0, err
+ }
+
+ ret, err := task.ThreadGroup().SetForegroundProcessGroup(tm.tty(isMaster), kernel.ProcessGroupID(pgid))
+ return uintptr(ret), err
+}
+
+func (tm *Terminal) tty(isMaster bool) *kernel.TTY {
+ if isMaster {
+ return tm.masterKTTY
+ }
+ return tm.slaveKTTY
+}
diff --git a/pkg/sentry/fsimpl/ext/BUILD b/pkg/sentry/fsimpl/ext/BUILD
index 9e8ebb907..b0c286b7a 100644
--- a/pkg/sentry/fsimpl/ext/BUILD
+++ b/pkg/sentry/fsimpl/ext/BUILD
@@ -1,6 +1,8 @@
+load("@io_bazel_rules_go//go:def.bzl", "go_test")
+
package(licenses = ["notice"])
-load("//tools/go_stateify:defs.bzl", "go_library", "go_test")
+load("//tools/go_stateify:defs.bzl", "go_library")
load("//tools/go_generics:defs.bzl", "go_template_instance")
go_template_instance(
diff --git a/pkg/sentry/fsimpl/ext/benchmark/BUILD b/pkg/sentry/fsimpl/ext/benchmark/BUILD
index 9fddb4c4c..bfc46dfa6 100644
--- a/pkg/sentry/fsimpl/ext/benchmark/BUILD
+++ b/pkg/sentry/fsimpl/ext/benchmark/BUILD
@@ -1,4 +1,4 @@
-load("//tools/go_stateify:defs.bzl", "go_test")
+load("@io_bazel_rules_go//go:def.bzl", "go_test")
package(licenses = ["notice"])
diff --git a/pkg/sentry/fsimpl/ext/directory.go b/pkg/sentry/fsimpl/ext/directory.go
index b51f3e18d..0b471d121 100644
--- a/pkg/sentry/fsimpl/ext/directory.go
+++ b/pkg/sentry/fsimpl/ext/directory.go
@@ -190,10 +190,10 @@ func (fd *directoryFD) IterDirents(ctx context.Context, cb vfs.IterDirentsCallba
}
if !cb.Handle(vfs.Dirent{
- Name: child.diskDirent.FileName(),
- Type: fs.ToDirentType(childType),
- Ino: uint64(child.diskDirent.Inode()),
- Off: fd.off,
+ Name: child.diskDirent.FileName(),
+ Type: fs.ToDirentType(childType),
+ Ino: uint64(child.diskDirent.Inode()),
+ NextOff: fd.off + 1,
}) {
dir.childList.InsertBefore(child, fd.iter)
return nil
diff --git a/pkg/sentry/fsimpl/ext/disklayout/BUILD b/pkg/sentry/fsimpl/ext/disklayout/BUILD
index 907d35b7e..2d50e30aa 100644
--- a/pkg/sentry/fsimpl/ext/disklayout/BUILD
+++ b/pkg/sentry/fsimpl/ext/disklayout/BUILD
@@ -1,6 +1,8 @@
+load("@io_bazel_rules_go//go:def.bzl", "go_test")
+
package(licenses = ["notice"])
-load("//tools/go_stateify:defs.bzl", "go_library", "go_test")
+load("//tools/go_stateify:defs.bzl", "go_library")
go_library(
name = "disklayout",
diff --git a/pkg/sentry/fsimpl/ext/ext_test.go b/pkg/sentry/fsimpl/ext/ext_test.go
index 63cf7aeaf..1aa2bd6a4 100644
--- a/pkg/sentry/fsimpl/ext/ext_test.go
+++ b/pkg/sentry/fsimpl/ext/ext_test.go
@@ -584,7 +584,7 @@ func TestIterDirents(t *testing.T) {
// Ignore the inode number and offset of dirents because those are likely to
// change as the underlying image changes.
cmpIgnoreFields := cmp.FilterPath(func(p cmp.Path) bool {
- return p.String() == "Ino" || p.String() == "Off"
+ return p.String() == "Ino" || p.String() == "NextOff"
}, cmp.Ignore())
if diff := cmp.Diff(cb.dirents, test.want, cmpIgnoreFields); diff != "" {
t.Errorf("dirents mismatch (-want +got):\n%s", diff)
diff --git a/pkg/sentry/fsimpl/memfs/BUILD b/pkg/sentry/fsimpl/memfs/BUILD
index d2450e810..7e364c5fd 100644
--- a/pkg/sentry/fsimpl/memfs/BUILD
+++ b/pkg/sentry/fsimpl/memfs/BUILD
@@ -1,4 +1,5 @@
-load("//tools/go_stateify:defs.bzl", "go_library", "go_test")
+load("//tools/go_stateify:defs.bzl", "go_library")
+load("@io_bazel_rules_go//go:def.bzl", "go_test")
package(licenses = ["notice"])
diff --git a/pkg/sentry/fsimpl/memfs/directory.go b/pkg/sentry/fsimpl/memfs/directory.go
index c52dc781c..c620227c9 100644
--- a/pkg/sentry/fsimpl/memfs/directory.go
+++ b/pkg/sentry/fsimpl/memfs/directory.go
@@ -75,10 +75,10 @@ func (fd *directoryFD) IterDirents(ctx context.Context, cb vfs.IterDirentsCallba
if fd.off == 0 {
if !cb.Handle(vfs.Dirent{
- Name: ".",
- Type: linux.DT_DIR,
- Ino: vfsd.Impl().(*dentry).inode.ino,
- Off: 0,
+ Name: ".",
+ Type: linux.DT_DIR,
+ Ino: vfsd.Impl().(*dentry).inode.ino,
+ NextOff: 1,
}) {
return nil
}
@@ -87,10 +87,10 @@ func (fd *directoryFD) IterDirents(ctx context.Context, cb vfs.IterDirentsCallba
if fd.off == 1 {
parentInode := vfsd.ParentOrSelf().Impl().(*dentry).inode
if !cb.Handle(vfs.Dirent{
- Name: "..",
- Type: parentInode.direntType(),
- Ino: parentInode.ino,
- Off: 1,
+ Name: "..",
+ Type: parentInode.direntType(),
+ Ino: parentInode.ino,
+ NextOff: 2,
}) {
return nil
}
@@ -112,10 +112,10 @@ func (fd *directoryFD) IterDirents(ctx context.Context, cb vfs.IterDirentsCallba
// Skip other directoryFD iterators.
if child.inode != nil {
if !cb.Handle(vfs.Dirent{
- Name: child.vfsd.Name(),
- Type: child.inode.direntType(),
- Ino: child.inode.ino,
- Off: fd.off,
+ Name: child.vfsd.Name(),
+ Type: child.inode.direntType(),
+ Ino: child.inode.ino,
+ NextOff: fd.off + 1,
}) {
dir.childList.InsertBefore(child, fd.iter)
return nil
diff --git a/pkg/sentry/fsimpl/proc/BUILD b/pkg/sentry/fsimpl/proc/BUILD
index 3d8a4deaf..ade6ac946 100644
--- a/pkg/sentry/fsimpl/proc/BUILD
+++ b/pkg/sentry/fsimpl/proc/BUILD
@@ -1,4 +1,5 @@
-load("//tools/go_stateify:defs.bzl", "go_library", "go_test")
+load("//tools/go_stateify:defs.bzl", "go_library")
+load("@io_bazel_rules_go//go:def.bzl", "go_test")
package(licenses = ["notice"])
diff --git a/pkg/sentry/hostcpu/BUILD b/pkg/sentry/hostcpu/BUILD
index f989f2f8b..359468ccc 100644
--- a/pkg/sentry/hostcpu/BUILD
+++ b/pkg/sentry/hostcpu/BUILD
@@ -1,4 +1,5 @@
-load("//tools/go_stateify:defs.bzl", "go_library", "go_test")
+load("//tools/go_stateify:defs.bzl", "go_library")
+load("@io_bazel_rules_go//go:def.bzl", "go_test")
package(licenses = ["notice"])
@@ -6,6 +7,7 @@ go_library(
name = "hostcpu",
srcs = [
"getcpu_amd64.s",
+ "getcpu_arm64.s",
"hostcpu.go",
],
importpath = "gvisor.dev/gvisor/pkg/sentry/hostcpu",
diff --git a/pkg/sentry/hostcpu/getcpu_arm64.s b/pkg/sentry/hostcpu/getcpu_arm64.s
new file mode 100644
index 000000000..caf9abb89
--- /dev/null
+++ b/pkg/sentry/hostcpu/getcpu_arm64.s
@@ -0,0 +1,28 @@
+// Copyright 2019 The gVisor Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+#include "textflag.h"
+
+// GetCPU makes the getcpu(unsigned *cpu, unsigned *node, NULL) syscall for
+// the lack of an optimazed way of getting the current CPU number on arm64.
+
+// func GetCPU() (cpu uint32)
+TEXT ·GetCPU(SB), NOSPLIT, $0-4
+ MOVW ZR, cpu+0(FP)
+ MOVD $cpu+0(FP), R0
+ MOVD $0x0, R1 // unused
+ MOVD $0x0, R2 // unused
+ MOVD $0xA8, R8 // SYS_GETCPU
+ SVC
+ RET
diff --git a/pkg/sentry/kernel/BUILD b/pkg/sentry/kernel/BUILD
index e61d39c82..aba2414d4 100644
--- a/pkg/sentry/kernel/BUILD
+++ b/pkg/sentry/kernel/BUILD
@@ -1,9 +1,11 @@
load("@io_bazel_rules_go//proto:def.bzl", "go_proto_library")
+load("@io_bazel_rules_go//go:def.bzl", "go_test")
+load("@rules_cc//cc:defs.bzl", "cc_proto_library")
package(licenses = ["notice"])
load("//tools/go_generics:defs.bzl", "go_template_instance")
-load("//tools/go_stateify:defs.bzl", "go_library", "go_test")
+load("//tools/go_stateify:defs.bzl", "go_library")
go_template_instance(
name = "pending_signals_list",
@@ -83,6 +85,12 @@ proto_library(
deps = ["//pkg/sentry/arch:registers_proto"],
)
+cc_proto_library(
+ name = "uncaught_signal_cc_proto",
+ visibility = ["//visibility:public"],
+ deps = [":uncaught_signal_proto"],
+)
+
go_proto_library(
name = "uncaught_signal_go_proto",
importpath = "gvisor.dev/gvisor/pkg/sentry/kernel/uncaught_signal_go_proto",
@@ -144,6 +152,7 @@ go_library(
"threads.go",
"timekeeper.go",
"timekeeper_state.go",
+ "tty.go",
"uts_namespace.go",
"vdso.go",
"version.go",
diff --git a/pkg/sentry/kernel/epoll/BUILD b/pkg/sentry/kernel/epoll/BUILD
index f46c43128..65427b112 100644
--- a/pkg/sentry/kernel/epoll/BUILD
+++ b/pkg/sentry/kernel/epoll/BUILD
@@ -1,7 +1,9 @@
+load("@io_bazel_rules_go//go:def.bzl", "go_test")
+
package(licenses = ["notice"])
load("//tools/go_generics:defs.bzl", "go_template_instance")
-load("//tools/go_stateify:defs.bzl", "go_library", "go_test")
+load("//tools/go_stateify:defs.bzl", "go_library")
go_template_instance(
name = "epoll_list",
diff --git a/pkg/sentry/kernel/eventfd/BUILD b/pkg/sentry/kernel/eventfd/BUILD
index 1c5f979d4..983ca67ed 100644
--- a/pkg/sentry/kernel/eventfd/BUILD
+++ b/pkg/sentry/kernel/eventfd/BUILD
@@ -1,6 +1,8 @@
+load("@io_bazel_rules_go//go:def.bzl", "go_test")
+
package(licenses = ["notice"])
-load("//tools/go_stateify:defs.bzl", "go_library", "go_test")
+load("//tools/go_stateify:defs.bzl", "go_library")
go_library(
name = "eventfd",
diff --git a/pkg/sentry/kernel/futex/BUILD b/pkg/sentry/kernel/futex/BUILD
index 6a31dc044..41f44999c 100644
--- a/pkg/sentry/kernel/futex/BUILD
+++ b/pkg/sentry/kernel/futex/BUILD
@@ -1,7 +1,9 @@
+load("@io_bazel_rules_go//go:def.bzl", "go_test")
+
package(licenses = ["notice"])
load("//tools/go_generics:defs.bzl", "go_template_instance")
-load("//tools/go_stateify:defs.bzl", "go_library", "go_test")
+load("//tools/go_stateify:defs.bzl", "go_library")
go_template_instance(
name = "atomicptr_bucket",
diff --git a/pkg/sentry/kernel/memevent/BUILD b/pkg/sentry/kernel/memevent/BUILD
index ebcfaa619..d7a7d1169 100644
--- a/pkg/sentry/kernel/memevent/BUILD
+++ b/pkg/sentry/kernel/memevent/BUILD
@@ -1,5 +1,6 @@
load("//tools/go_stateify:defs.bzl", "go_library")
load("@io_bazel_rules_go//proto:def.bzl", "go_proto_library")
+load("@rules_cc//cc:defs.bzl", "cc_proto_library")
package(licenses = ["notice"])
@@ -24,6 +25,12 @@ proto_library(
visibility = ["//visibility:public"],
)
+cc_proto_library(
+ name = "memory_events_cc_proto",
+ visibility = ["//visibility:public"],
+ deps = [":memory_events_proto"],
+)
+
go_proto_library(
name = "memory_events_go_proto",
importpath = "gvisor.dev/gvisor/pkg/sentry/kernel/memevent/memory_events_go_proto",
diff --git a/pkg/sentry/kernel/pipe/BUILD b/pkg/sentry/kernel/pipe/BUILD
index 4d15cca85..2ce8952e2 100644
--- a/pkg/sentry/kernel/pipe/BUILD
+++ b/pkg/sentry/kernel/pipe/BUILD
@@ -1,7 +1,9 @@
+load("@io_bazel_rules_go//go:def.bzl", "go_test")
+
package(licenses = ["notice"])
load("//tools/go_generics:defs.bzl", "go_template_instance")
-load("//tools/go_stateify:defs.bzl", "go_library", "go_test")
+load("//tools/go_stateify:defs.bzl", "go_library")
go_template_instance(
name = "buffer_list",
diff --git a/pkg/sentry/kernel/pipe/buffer.go b/pkg/sentry/kernel/pipe/buffer.go
index 69ef2a720..95bee2d37 100644
--- a/pkg/sentry/kernel/pipe/buffer.go
+++ b/pkg/sentry/kernel/pipe/buffer.go
@@ -15,6 +15,7 @@
package pipe
import (
+ "io"
"sync"
"gvisor.dev/gvisor/pkg/sentry/safemem"
@@ -67,6 +68,17 @@ func (b *buffer) WriteFromBlocks(srcs safemem.BlockSeq) (uint64, error) {
return n, err
}
+// WriteFromReader writes to the buffer from an io.Reader.
+func (b *buffer) WriteFromReader(r io.Reader, count int64) (int64, error) {
+ dst := b.data[b.write:]
+ if count < int64(len(dst)) {
+ dst = b.data[b.write:][:count]
+ }
+ n, err := r.Read(dst)
+ b.write += n
+ return int64(n), err
+}
+
// ReadToBlocks implements safemem.Reader.ReadToBlocks.
func (b *buffer) ReadToBlocks(dsts safemem.BlockSeq) (uint64, error) {
src := safemem.BlockSeqOf(safemem.BlockFromSafeSlice(b.data[b.read:b.write]))
@@ -75,6 +87,19 @@ func (b *buffer) ReadToBlocks(dsts safemem.BlockSeq) (uint64, error) {
return n, err
}
+// ReadToWriter reads from the buffer into an io.Writer.
+func (b *buffer) ReadToWriter(w io.Writer, count int64, dup bool) (int64, error) {
+ src := b.data[b.read:b.write]
+ if count < int64(len(src)) {
+ src = b.data[b.read:][:count]
+ }
+ n, err := w.Write(src)
+ if !dup {
+ b.read += n
+ }
+ return int64(n), err
+}
+
// bufferPool is a pool for buffers.
var bufferPool = sync.Pool{
New: func() interface{} {
diff --git a/pkg/sentry/kernel/pipe/pipe.go b/pkg/sentry/kernel/pipe/pipe.go
index 247e2928e..93b50669f 100644
--- a/pkg/sentry/kernel/pipe/pipe.go
+++ b/pkg/sentry/kernel/pipe/pipe.go
@@ -23,7 +23,6 @@ import (
"gvisor.dev/gvisor/pkg/sentry/context"
"gvisor.dev/gvisor/pkg/sentry/fs"
- "gvisor.dev/gvisor/pkg/sentry/usermem"
"gvisor.dev/gvisor/pkg/syserror"
"gvisor.dev/gvisor/pkg/waiter"
)
@@ -173,13 +172,24 @@ func (p *Pipe) Open(ctx context.Context, d *fs.Dirent, flags fs.FileFlags) *fs.F
}
}
+type readOps struct {
+ // left returns the bytes remaining.
+ left func() int64
+
+ // limit limits subsequence reads.
+ limit func(int64)
+
+ // read performs the actual read operation.
+ read func(*buffer) (int64, error)
+}
+
// read reads data from the pipe into dst and returns the number of bytes
// read, or returns ErrWouldBlock if the pipe is empty.
//
// Precondition: this pipe must have readers.
-func (p *Pipe) read(ctx context.Context, dst usermem.IOSequence) (int64, error) {
+func (p *Pipe) read(ctx context.Context, ops readOps) (int64, error) {
// Don't block for a zero-length read even if the pipe is empty.
- if dst.NumBytes() == 0 {
+ if ops.left() == 0 {
return 0, nil
}
@@ -196,12 +206,12 @@ func (p *Pipe) read(ctx context.Context, dst usermem.IOSequence) (int64, error)
}
// Limit how much we consume.
- if dst.NumBytes() > p.size {
- dst = dst.TakeFirst64(p.size)
+ if ops.left() > p.size {
+ ops.limit(p.size)
}
done := int64(0)
- for dst.NumBytes() > 0 {
+ for ops.left() > 0 {
// Pop the first buffer.
first := p.data.Front()
if first == nil {
@@ -209,10 +219,9 @@ func (p *Pipe) read(ctx context.Context, dst usermem.IOSequence) (int64, error)
}
// Copy user data.
- n, err := dst.CopyOutFrom(ctx, first)
+ n, err := ops.read(first)
done += int64(n)
p.size -= n
- dst = dst.DropFirst64(n)
// Empty buffer?
if first.Empty() {
@@ -230,12 +239,57 @@ func (p *Pipe) read(ctx context.Context, dst usermem.IOSequence) (int64, error)
return done, nil
}
+// dup duplicates all data from this pipe into the given writer.
+//
+// There is no blocking behavior implemented here. The writer may propagate
+// some blocking error. All the writes must be complete writes.
+func (p *Pipe) dup(ctx context.Context, ops readOps) (int64, error) {
+ p.mu.Lock()
+ defer p.mu.Unlock()
+
+ // Is the pipe empty?
+ if p.size == 0 {
+ if !p.HasWriters() {
+ // See above.
+ return 0, nil
+ }
+ return 0, syserror.ErrWouldBlock
+ }
+
+ // Limit how much we consume.
+ if ops.left() > p.size {
+ ops.limit(p.size)
+ }
+
+ done := int64(0)
+ for buf := p.data.Front(); buf != nil; buf = buf.Next() {
+ n, err := ops.read(buf)
+ done += n
+ if err != nil {
+ return done, err
+ }
+ }
+
+ return done, nil
+}
+
+type writeOps struct {
+ // left returns the bytes remaining.
+ left func() int64
+
+ // limit should limit subsequent writes.
+ limit func(int64)
+
+ // write should write to the provided buffer.
+ write func(*buffer) (int64, error)
+}
+
// write writes data from sv into the pipe and returns the number of bytes
// written. If no bytes are written because the pipe is full (or has less than
// atomicIOBytes free capacity), write returns ErrWouldBlock.
//
// Precondition: this pipe must have writers.
-func (p *Pipe) write(ctx context.Context, src usermem.IOSequence) (int64, error) {
+func (p *Pipe) write(ctx context.Context, ops writeOps) (int64, error) {
p.mu.Lock()
defer p.mu.Unlock()
@@ -246,17 +300,16 @@ func (p *Pipe) write(ctx context.Context, src usermem.IOSequence) (int64, error)
// POSIX requires that a write smaller than atomicIOBytes (PIPE_BUF) be
// atomic, but requires no atomicity for writes larger than this.
- wanted := src.NumBytes()
+ wanted := ops.left()
if avail := p.max - p.size; wanted > avail {
if wanted <= p.atomicIOBytes {
return 0, syserror.ErrWouldBlock
}
- // Limit to the available capacity.
- src = src.TakeFirst64(avail)
+ ops.limit(avail)
}
done := int64(0)
- for src.NumBytes() > 0 {
+ for ops.left() > 0 {
// Need a new buffer?
last := p.data.Back()
if last == nil || last.Full() {
@@ -266,10 +319,9 @@ func (p *Pipe) write(ctx context.Context, src usermem.IOSequence) (int64, error)
}
// Copy user data.
- n, err := src.CopyInTo(ctx, last)
+ n, err := ops.write(last)
done += int64(n)
p.size += n
- src = src.DropFirst64(n)
// Handle errors.
if err != nil {
diff --git a/pkg/sentry/kernel/pipe/reader_writer.go b/pkg/sentry/kernel/pipe/reader_writer.go
index f69dbf27b..7c307f013 100644
--- a/pkg/sentry/kernel/pipe/reader_writer.go
+++ b/pkg/sentry/kernel/pipe/reader_writer.go
@@ -15,6 +15,7 @@
package pipe
import (
+ "io"
"math"
"syscall"
@@ -55,7 +56,45 @@ func (rw *ReaderWriter) Release() {
// Read implements fs.FileOperations.Read.
func (rw *ReaderWriter) Read(ctx context.Context, _ *fs.File, dst usermem.IOSequence, _ int64) (int64, error) {
- n, err := rw.Pipe.read(ctx, dst)
+ n, err := rw.Pipe.read(ctx, readOps{
+ left: func() int64 {
+ return dst.NumBytes()
+ },
+ limit: func(l int64) {
+ dst = dst.TakeFirst64(l)
+ },
+ read: func(buf *buffer) (int64, error) {
+ n, err := dst.CopyOutFrom(ctx, buf)
+ dst = dst.DropFirst64(n)
+ return n, err
+ },
+ })
+ if n > 0 {
+ rw.Pipe.Notify(waiter.EventOut)
+ }
+ return n, err
+}
+
+// WriteTo implements fs.FileOperations.WriteTo.
+func (rw *ReaderWriter) WriteTo(ctx context.Context, _ *fs.File, w io.Writer, count int64, dup bool) (int64, error) {
+ ops := readOps{
+ left: func() int64 {
+ return count
+ },
+ limit: func(l int64) {
+ count = l
+ },
+ read: func(buf *buffer) (int64, error) {
+ n, err := buf.ReadToWriter(w, count, dup)
+ count -= n
+ return n, err
+ },
+ }
+ if dup {
+ // There is no notification for dup operations.
+ return rw.Pipe.dup(ctx, ops)
+ }
+ n, err := rw.Pipe.read(ctx, ops)
if n > 0 {
rw.Pipe.Notify(waiter.EventOut)
}
@@ -64,7 +103,40 @@ func (rw *ReaderWriter) Read(ctx context.Context, _ *fs.File, dst usermem.IOSequ
// Write implements fs.FileOperations.Write.
func (rw *ReaderWriter) Write(ctx context.Context, _ *fs.File, src usermem.IOSequence, _ int64) (int64, error) {
- n, err := rw.Pipe.write(ctx, src)
+ n, err := rw.Pipe.write(ctx, writeOps{
+ left: func() int64 {
+ return src.NumBytes()
+ },
+ limit: func(l int64) {
+ src = src.TakeFirst64(l)
+ },
+ write: func(buf *buffer) (int64, error) {
+ n, err := src.CopyInTo(ctx, buf)
+ src = src.DropFirst64(n)
+ return n, err
+ },
+ })
+ if n > 0 {
+ rw.Pipe.Notify(waiter.EventIn)
+ }
+ return n, err
+}
+
+// ReadFrom implements fs.FileOperations.WriteTo.
+func (rw *ReaderWriter) ReadFrom(ctx context.Context, _ *fs.File, r io.Reader, count int64) (int64, error) {
+ n, err := rw.Pipe.write(ctx, writeOps{
+ left: func() int64 {
+ return count
+ },
+ limit: func(l int64) {
+ count = l
+ },
+ write: func(buf *buffer) (int64, error) {
+ n, err := buf.WriteFromReader(r, count)
+ count -= n
+ return n, err
+ },
+ })
if n > 0 {
rw.Pipe.Notify(waiter.EventIn)
}
diff --git a/pkg/sentry/kernel/sched/BUILD b/pkg/sentry/kernel/sched/BUILD
index 1725b8562..98ea7a0d8 100644
--- a/pkg/sentry/kernel/sched/BUILD
+++ b/pkg/sentry/kernel/sched/BUILD
@@ -1,4 +1,5 @@
-load("//tools/go_stateify:defs.bzl", "go_library", "go_test")
+load("//tools/go_stateify:defs.bzl", "go_library")
+load("@io_bazel_rules_go//go:def.bzl", "go_test")
package(licenses = ["notice"])
diff --git a/pkg/sentry/kernel/semaphore/BUILD b/pkg/sentry/kernel/semaphore/BUILD
index 36edf10f3..80e5e5da3 100644
--- a/pkg/sentry/kernel/semaphore/BUILD
+++ b/pkg/sentry/kernel/semaphore/BUILD
@@ -1,7 +1,9 @@
+load("@io_bazel_rules_go//go:def.bzl", "go_test")
+
package(licenses = ["notice"])
load("//tools/go_generics:defs.bzl", "go_template_instance")
-load("//tools/go_stateify:defs.bzl", "go_library", "go_test")
+load("//tools/go_stateify:defs.bzl", "go_library")
go_template_instance(
name = "waiter_list",
diff --git a/pkg/sentry/kernel/sessions.go b/pkg/sentry/kernel/sessions.go
index 81fcd8258..047b5214d 100644
--- a/pkg/sentry/kernel/sessions.go
+++ b/pkg/sentry/kernel/sessions.go
@@ -47,6 +47,11 @@ type Session struct {
// The id is immutable.
id SessionID
+ // foreground is the foreground process group.
+ //
+ // This is protected by TaskSet.mu.
+ foreground *ProcessGroup
+
// ProcessGroups is a list of process groups in this Session. This is
// protected by TaskSet.mu.
processGroups processGroupList
@@ -260,12 +265,14 @@ func (pg *ProcessGroup) SendSignal(info *arch.SignalInfo) error {
func (tg *ThreadGroup) CreateSession() error {
tg.pidns.owner.mu.Lock()
defer tg.pidns.owner.mu.Unlock()
+ tg.signalHandlers.mu.Lock()
+ defer tg.signalHandlers.mu.Unlock()
return tg.createSession()
}
// createSession creates a new session for a threadgroup.
//
-// Precondition: callers must hold TaskSet.mu for writing.
+// Precondition: callers must hold TaskSet.mu and the signal mutex for writing.
func (tg *ThreadGroup) createSession() error {
// Get the ID for this thread in the current namespace.
id := tg.pidns.tgids[tg]
@@ -321,8 +328,14 @@ func (tg *ThreadGroup) createSession() error {
childTG.processGroup.incRefWithParent(pg)
childTG.processGroup.decRefWithParent(oldParentPG)
})
- tg.processGroup.decRefWithParent(oldParentPG)
+ // If tg.processGroup is an orphan, decRefWithParent will lock
+ // the signal mutex of each thread group in tg.processGroup.
+ // However, tg's signal mutex may already be locked at this
+ // point. We change tg's process group before calling
+ // decRefWithParent to avoid locking tg's signal mutex twice.
+ oldPG := tg.processGroup
tg.processGroup = pg
+ oldPG.decRefWithParent(oldParentPG)
} else {
// The current process group may be nil only in the case of an
// unparented thread group (i.e. the init process). This would
@@ -346,6 +359,9 @@ func (tg *ThreadGroup) createSession() error {
ns.processGroups[ProcessGroupID(local)] = pg
}
+ // Disconnect from the controlling terminal.
+ tg.tty = nil
+
return nil
}
diff --git a/pkg/sentry/kernel/signalfd/BUILD b/pkg/sentry/kernel/signalfd/BUILD
new file mode 100644
index 000000000..50b69d154
--- /dev/null
+++ b/pkg/sentry/kernel/signalfd/BUILD
@@ -0,0 +1,22 @@
+package(licenses = ["notice"])
+
+load("//tools/go_stateify:defs.bzl", "go_library")
+
+go_library(
+ name = "signalfd",
+ srcs = ["signalfd.go"],
+ importpath = "gvisor.dev/gvisor/pkg/sentry/kernel/signalfd",
+ visibility = ["//pkg/sentry:internal"],
+ deps = [
+ "//pkg/abi/linux",
+ "//pkg/binary",
+ "//pkg/sentry/context",
+ "//pkg/sentry/fs",
+ "//pkg/sentry/fs/anon",
+ "//pkg/sentry/fs/fsutil",
+ "//pkg/sentry/kernel",
+ "//pkg/sentry/usermem",
+ "//pkg/syserror",
+ "//pkg/waiter",
+ ],
+)
diff --git a/pkg/sentry/kernel/signalfd/signalfd.go b/pkg/sentry/kernel/signalfd/signalfd.go
new file mode 100644
index 000000000..06fd5ec88
--- /dev/null
+++ b/pkg/sentry/kernel/signalfd/signalfd.go
@@ -0,0 +1,137 @@
+// Copyright 2019 The gVisor Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+// Package signalfd provides an implementation of signal file descriptors.
+package signalfd
+
+import (
+ "sync"
+
+ "gvisor.dev/gvisor/pkg/abi/linux"
+ "gvisor.dev/gvisor/pkg/binary"
+ "gvisor.dev/gvisor/pkg/sentry/context"
+ "gvisor.dev/gvisor/pkg/sentry/fs"
+ "gvisor.dev/gvisor/pkg/sentry/fs/anon"
+ "gvisor.dev/gvisor/pkg/sentry/fs/fsutil"
+ "gvisor.dev/gvisor/pkg/sentry/kernel"
+ "gvisor.dev/gvisor/pkg/sentry/usermem"
+ "gvisor.dev/gvisor/pkg/syserror"
+ "gvisor.dev/gvisor/pkg/waiter"
+)
+
+// SignalOperations represent a file with signalfd semantics.
+//
+// +stateify savable
+type SignalOperations struct {
+ fsutil.FileNoopRelease `state:"nosave"`
+ fsutil.FilePipeSeek `state:"nosave"`
+ fsutil.FileNotDirReaddir `state:"nosave"`
+ fsutil.FileNoIoctl `state:"nosave"`
+ fsutil.FileNoFsync `state:"nosave"`
+ fsutil.FileNoMMap `state:"nosave"`
+ fsutil.FileNoSplice `state:"nosave"`
+ fsutil.FileNoWrite `state:"nosave"`
+ fsutil.FileNoopFlush `state:"nosave"`
+ fsutil.FileUseInodeUnstableAttr `state:"nosave"`
+
+ // target is the original task target.
+ //
+ // The semantics here are a bit broken. Linux will always use current
+ // for all reads, regardless of where the signalfd originated. We can't
+ // do exactly that because we need to plumb the context through
+ // EventRegister in order to support proper blocking behavior. This
+ // will undoubtedly become very complicated quickly.
+ target *kernel.Task
+
+ // mu protects below.
+ mu sync.Mutex `state:"nosave"`
+
+ // mask is the signal mask. Protected by mu.
+ mask linux.SignalSet
+}
+
+// New creates a new signalfd object with the supplied mask.
+func New(ctx context.Context, mask linux.SignalSet) (*fs.File, error) {
+ t := kernel.TaskFromContext(ctx)
+ if t == nil {
+ // No task context? Not valid.
+ return nil, syserror.EINVAL
+ }
+ // name matches fs/signalfd.c:signalfd4.
+ dirent := fs.NewDirent(ctx, anon.NewInode(ctx), "anon_inode:[signalfd]")
+ return fs.NewFile(ctx, dirent, fs.FileFlags{Read: true, Write: true}, &SignalOperations{
+ target: t,
+ mask: mask,
+ }), nil
+}
+
+// Release implements fs.FileOperations.Release.
+func (s *SignalOperations) Release() {}
+
+// Mask returns the signal mask.
+func (s *SignalOperations) Mask() linux.SignalSet {
+ s.mu.Lock()
+ mask := s.mask
+ s.mu.Unlock()
+ return mask
+}
+
+// SetMask sets the signal mask.
+func (s *SignalOperations) SetMask(mask linux.SignalSet) {
+ s.mu.Lock()
+ s.mask = mask
+ s.mu.Unlock()
+}
+
+// Read implements fs.FileOperations.Read.
+func (s *SignalOperations) Read(ctx context.Context, _ *fs.File, dst usermem.IOSequence, _ int64) (int64, error) {
+ // Attempt to dequeue relevant signals.
+ info, err := s.target.Sigtimedwait(s.Mask(), 0)
+ if err != nil {
+ // There must be no signal available.
+ return 0, syserror.ErrWouldBlock
+ }
+
+ // Copy out the signal info using the specified format.
+ var buf [128]byte
+ binary.Marshal(buf[:0], usermem.ByteOrder, &linux.SignalfdSiginfo{
+ Signo: uint32(info.Signo),
+ Errno: info.Errno,
+ Code: info.Code,
+ PID: uint32(info.Pid()),
+ UID: uint32(info.Uid()),
+ Status: info.Status(),
+ Overrun: uint32(info.Overrun()),
+ Addr: info.Addr(),
+ })
+ n, err := dst.CopyOut(ctx, buf[:])
+ return int64(n), err
+}
+
+// Readiness implements waiter.Waitable.Readiness.
+func (s *SignalOperations) Readiness(mask waiter.EventMask) waiter.EventMask {
+ return mask & waiter.EventIn
+}
+
+// EventRegister implements waiter.Waitable.EventRegister.
+func (s *SignalOperations) EventRegister(entry *waiter.Entry, _ waiter.EventMask) {
+ // Register for the signal set; ignore the passed events.
+ s.target.SignalRegister(entry, waiter.EventMask(s.Mask()))
+}
+
+// EventUnregister implements waiter.Waitable.EventUnregister.
+func (s *SignalOperations) EventUnregister(entry *waiter.Entry) {
+ // Unregister the original entry.
+ s.target.SignalUnregister(entry)
+}
diff --git a/pkg/sentry/kernel/task.go b/pkg/sentry/kernel/task.go
index e91f82bb3..c82ef5486 100644
--- a/pkg/sentry/kernel/task.go
+++ b/pkg/sentry/kernel/task.go
@@ -35,6 +35,7 @@ import (
"gvisor.dev/gvisor/pkg/sentry/uniqueid"
"gvisor.dev/gvisor/pkg/sentry/usage"
"gvisor.dev/gvisor/pkg/sentry/usermem"
+ "gvisor.dev/gvisor/pkg/waiter"
"gvisor.dev/gvisor/third_party/gvsync"
)
@@ -133,6 +134,13 @@ type Task struct {
// signalStack is exclusive to the task goroutine.
signalStack arch.SignalStack
+ // signalQueue is a set of registered waiters for signal-related events.
+ //
+ // signalQueue is protected by the signalMutex. Note that the task does
+ // not implement all queue methods, specifically the readiness checks.
+ // The task only broadcast a notification on signal delivery.
+ signalQueue waiter.Queue `state:"zerovalue"`
+
// If groupStopPending is true, the task should participate in a group
// stop in the interrupt path.
//
diff --git a/pkg/sentry/kernel/task_signals.go b/pkg/sentry/kernel/task_signals.go
index 266959a07..39cd1340d 100644
--- a/pkg/sentry/kernel/task_signals.go
+++ b/pkg/sentry/kernel/task_signals.go
@@ -28,6 +28,7 @@ import (
ucspb "gvisor.dev/gvisor/pkg/sentry/kernel/uncaught_signal_go_proto"
"gvisor.dev/gvisor/pkg/sentry/usermem"
"gvisor.dev/gvisor/pkg/syserror"
+ "gvisor.dev/gvisor/pkg/waiter"
)
// SignalAction is an internal signal action.
@@ -497,6 +498,9 @@ func (tg *ThreadGroup) applySignalSideEffectsLocked(sig linux.Signal) {
//
// Preconditions: The signal mutex must be locked.
func (t *Task) canReceiveSignalLocked(sig linux.Signal) bool {
+ // Notify that the signal is queued.
+ t.signalQueue.Notify(waiter.EventMask(linux.MakeSignalSet(sig)))
+
// - Do not choose tasks that are blocking the signal.
if linux.SignalSetOf(sig)&t.signalMask != 0 {
return false
@@ -1108,3 +1112,17 @@ func (*runInterruptAfterSignalDeliveryStop) execute(t *Task) taskRunState {
t.tg.signalHandlers.mu.Unlock()
return t.deliverSignal(info, act)
}
+
+// SignalRegister registers a waiter for pending signals.
+func (t *Task) SignalRegister(e *waiter.Entry, mask waiter.EventMask) {
+ t.tg.signalHandlers.mu.Lock()
+ t.signalQueue.EventRegister(e, mask)
+ t.tg.signalHandlers.mu.Unlock()
+}
+
+// SignalUnregister unregisters a waiter for pending signals.
+func (t *Task) SignalUnregister(e *waiter.Entry) {
+ t.tg.signalHandlers.mu.Lock()
+ t.signalQueue.EventUnregister(e)
+ t.tg.signalHandlers.mu.Unlock()
+}
diff --git a/pkg/sentry/kernel/task_start.go b/pkg/sentry/kernel/task_start.go
index d60cd62c7..ae6fc4025 100644
--- a/pkg/sentry/kernel/task_start.go
+++ b/pkg/sentry/kernel/task_start.go
@@ -172,9 +172,10 @@ func (ts *TaskSet) newTask(cfg *TaskConfig) (*Task, error) {
if parentPG := tg.parentPG(); parentPG == nil {
tg.createSession()
} else {
- // Inherit the process group.
+ // Inherit the process group and terminal.
parentPG.incRefWithParent(parentPG)
tg.processGroup = parentPG
+ tg.tty = t.parent.tg.tty
}
}
tg.tasks.PushBack(t)
diff --git a/pkg/sentry/kernel/thread_group.go b/pkg/sentry/kernel/thread_group.go
index 2a97e3e8e..0eef24bfb 100644
--- a/pkg/sentry/kernel/thread_group.go
+++ b/pkg/sentry/kernel/thread_group.go
@@ -19,10 +19,13 @@ import (
"sync/atomic"
"gvisor.dev/gvisor/pkg/abi/linux"
+ "gvisor.dev/gvisor/pkg/sentry/arch"
"gvisor.dev/gvisor/pkg/sentry/fs"
+ "gvisor.dev/gvisor/pkg/sentry/kernel/auth"
ktime "gvisor.dev/gvisor/pkg/sentry/kernel/time"
"gvisor.dev/gvisor/pkg/sentry/limits"
"gvisor.dev/gvisor/pkg/sentry/usage"
+ "gvisor.dev/gvisor/pkg/syserror"
)
// A ThreadGroup is a logical grouping of tasks that has widespread
@@ -245,6 +248,12 @@ type ThreadGroup struct {
//
// mounts is immutable.
mounts *fs.MountNamespace
+
+ // tty is the thread group's controlling terminal. If nil, there is no
+ // controlling terminal.
+ //
+ // tty is protected by the signal mutex.
+ tty *TTY
}
// newThreadGroup returns a new, empty thread group in PID namespace ns. The
@@ -324,6 +333,176 @@ func (tg *ThreadGroup) forEachChildThreadGroupLocked(fn func(*ThreadGroup)) {
}
}
+// SetControllingTTY sets tty as the controlling terminal of tg.
+func (tg *ThreadGroup) SetControllingTTY(tty *TTY, arg int32) error {
+ tty.mu.Lock()
+ defer tty.mu.Unlock()
+
+ // We might be asked to set the controlling terminal of multiple
+ // processes, so we lock both the TaskSet and SignalHandlers.
+ tg.pidns.owner.mu.Lock()
+ defer tg.pidns.owner.mu.Unlock()
+ tg.signalHandlers.mu.Lock()
+ defer tg.signalHandlers.mu.Unlock()
+
+ // "The calling process must be a session leader and not have a
+ // controlling terminal already." - tty_ioctl(4)
+ if tg.processGroup.session.leader != tg || tg.tty != nil {
+ return syserror.EINVAL
+ }
+
+ // "If this terminal is already the controlling terminal of a different
+ // session group, then the ioctl fails with EPERM, unless the caller
+ // has the CAP_SYS_ADMIN capability and arg equals 1, in which case the
+ // terminal is stolen, and all processes that had it as controlling
+ // terminal lose it." - tty_ioctl(4)
+ if tty.tg != nil && tg.processGroup.session != tty.tg.processGroup.session {
+ if !auth.CredentialsFromContext(tg.leader).HasCapability(linux.CAP_SYS_ADMIN) || arg != 1 {
+ return syserror.EPERM
+ }
+ // Steal the TTY away. Unlike TIOCNOTTY, don't send signals.
+ for othertg := range tg.pidns.owner.Root.tgids {
+ // This won't deadlock by locking tg.signalHandlers
+ // because at this point:
+ // - We only lock signalHandlers if it's in the same
+ // session as the tty's controlling thread group.
+ // - We know that the calling thread group is not in
+ // the same session as the tty's controlling thread
+ // group.
+ if othertg.processGroup.session == tty.tg.processGroup.session {
+ othertg.signalHandlers.mu.Lock()
+ othertg.tty = nil
+ othertg.signalHandlers.mu.Unlock()
+ }
+ }
+ }
+
+ // Set the controlling terminal and foreground process group.
+ tg.tty = tty
+ tg.processGroup.session.foreground = tg.processGroup
+ // Set this as the controlling process of the terminal.
+ tty.tg = tg
+
+ return nil
+}
+
+// ReleaseControllingTTY gives up tty as the controlling tty of tg.
+func (tg *ThreadGroup) ReleaseControllingTTY(tty *TTY) error {
+ tty.mu.Lock()
+ defer tty.mu.Unlock()
+
+ // We might be asked to set the controlling terminal of multiple
+ // processes, so we lock both the TaskSet and SignalHandlers.
+ tg.pidns.owner.mu.RLock()
+ defer tg.pidns.owner.mu.RUnlock()
+
+ // Just below, we may re-lock signalHandlers in order to send signals.
+ // Thus we can't defer Unlock here.
+ tg.signalHandlers.mu.Lock()
+
+ if tg.tty == nil || tg.tty != tty {
+ tg.signalHandlers.mu.Unlock()
+ return syserror.ENOTTY
+ }
+
+ // "If the process was session leader, then send SIGHUP and SIGCONT to
+ // the foreground process group and all processes in the current
+ // session lose their controlling terminal." - tty_ioctl(4)
+ // Remove tty as the controlling tty for each process in the session,
+ // then send them SIGHUP and SIGCONT.
+
+ // If we're not the session leader, we don't have to do much.
+ if tty.tg != tg {
+ tg.tty = nil
+ tg.signalHandlers.mu.Unlock()
+ return nil
+ }
+
+ tg.signalHandlers.mu.Unlock()
+
+ // We're the session leader. SIGHUP and SIGCONT the foreground process
+ // group and remove all controlling terminals in the session.
+ var lastErr error
+ for othertg := range tg.pidns.owner.Root.tgids {
+ if othertg.processGroup.session == tg.processGroup.session {
+ othertg.signalHandlers.mu.Lock()
+ othertg.tty = nil
+ if othertg.processGroup == tg.processGroup.session.foreground {
+ if err := othertg.leader.sendSignalLocked(&arch.SignalInfo{Signo: int32(linux.SIGHUP)}, true /* group */); err != nil {
+ lastErr = err
+ }
+ if err := othertg.leader.sendSignalLocked(&arch.SignalInfo{Signo: int32(linux.SIGCONT)}, true /* group */); err != nil {
+ lastErr = err
+ }
+ }
+ othertg.signalHandlers.mu.Unlock()
+ }
+ }
+
+ return lastErr
+}
+
+// ForegroundProcessGroup returns the process group ID of the foreground
+// process group.
+func (tg *ThreadGroup) ForegroundProcessGroup(tty *TTY) (int32, error) {
+ tty.mu.Lock()
+ defer tty.mu.Unlock()
+
+ tg.pidns.owner.mu.Lock()
+ defer tg.pidns.owner.mu.Unlock()
+ tg.signalHandlers.mu.Lock()
+ defer tg.signalHandlers.mu.Unlock()
+
+ // "When fd does not refer to the controlling terminal of the calling
+ // process, -1 is returned" - tcgetpgrp(3)
+ if tg.tty != tty {
+ return -1, syserror.ENOTTY
+ }
+
+ return int32(tg.processGroup.session.foreground.id), nil
+}
+
+// SetForegroundProcessGroup sets the foreground process group of tty to pgid.
+func (tg *ThreadGroup) SetForegroundProcessGroup(tty *TTY, pgid ProcessGroupID) (int32, error) {
+ tty.mu.Lock()
+ defer tty.mu.Unlock()
+
+ tg.pidns.owner.mu.Lock()
+ defer tg.pidns.owner.mu.Unlock()
+ tg.signalHandlers.mu.Lock()
+ defer tg.signalHandlers.mu.Unlock()
+
+ // TODO(b/129283598): "If tcsetpgrp() is called by a member of a
+ // background process group in its session, and the calling process is
+ // not blocking or ignoring SIGTTOU, a SIGTTOU signal is sent to all
+ // members of this background process group."
+
+ // tty must be the controlling terminal.
+ if tg.tty != tty {
+ return -1, syserror.ENOTTY
+ }
+
+ // pgid must be positive.
+ if pgid < 0 {
+ return -1, syserror.EINVAL
+ }
+
+ // pg must not be empty. Empty process groups are removed from their
+ // pid namespaces.
+ pg, ok := tg.pidns.processGroups[pgid]
+ if !ok {
+ return -1, syserror.ESRCH
+ }
+
+ // pg must be part of this process's session.
+ if tg.processGroup.session != pg.session {
+ return -1, syserror.EPERM
+ }
+
+ tg.processGroup.session.foreground.id = pgid
+ return 0, nil
+}
+
// itimerRealListener implements ktime.Listener for ITIMER_REAL expirations.
//
// +stateify savable
diff --git a/pkg/sentry/kernel/tty.go b/pkg/sentry/kernel/tty.go
new file mode 100644
index 000000000..34f84487a
--- /dev/null
+++ b/pkg/sentry/kernel/tty.go
@@ -0,0 +1,28 @@
+// Copyright 2018 The gVisor Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+package kernel
+
+import "sync"
+
+// TTY defines the relationship between a thread group and its controlling
+// terminal.
+//
+// +stateify savable
+type TTY struct {
+ mu sync.Mutex `state:"nosave"`
+
+ // tg is protected by mu.
+ tg *ThreadGroup
+}
diff --git a/pkg/sentry/limits/BUILD b/pkg/sentry/limits/BUILD
index 40025d62d..59649c770 100644
--- a/pkg/sentry/limits/BUILD
+++ b/pkg/sentry/limits/BUILD
@@ -1,6 +1,8 @@
+load("@io_bazel_rules_go//go:def.bzl", "go_test")
+
package(licenses = ["notice"])
-load("//tools/go_stateify:defs.bzl", "go_library", "go_test")
+load("//tools/go_stateify:defs.bzl", "go_library")
go_library(
name = "limits",
diff --git a/pkg/sentry/memmap/BUILD b/pkg/sentry/memmap/BUILD
index 29c14ec56..9687e7e76 100644
--- a/pkg/sentry/memmap/BUILD
+++ b/pkg/sentry/memmap/BUILD
@@ -1,7 +1,9 @@
+load("@io_bazel_rules_go//go:def.bzl", "go_test")
+
package(licenses = ["notice"])
load("//tools/go_generics:defs.bzl", "go_template_instance")
-load("//tools/go_stateify:defs.bzl", "go_library", "go_test")
+load("//tools/go_stateify:defs.bzl", "go_library")
go_template_instance(
name = "mappable_range",
diff --git a/pkg/sentry/mm/BUILD b/pkg/sentry/mm/BUILD
index 072745a08..b35c8c673 100644
--- a/pkg/sentry/mm/BUILD
+++ b/pkg/sentry/mm/BUILD
@@ -1,7 +1,9 @@
+load("@io_bazel_rules_go//go:def.bzl", "go_test")
+
package(licenses = ["notice"])
load("//tools/go_generics:defs.bzl", "go_template_instance")
-load("//tools/go_stateify:defs.bzl", "go_library", "go_test")
+load("//tools/go_stateify:defs.bzl", "go_library")
go_template_instance(
name = "file_refcount_set",
diff --git a/pkg/sentry/pgalloc/BUILD b/pkg/sentry/pgalloc/BUILD
index 858f895f2..3fd904c67 100644
--- a/pkg/sentry/pgalloc/BUILD
+++ b/pkg/sentry/pgalloc/BUILD
@@ -1,7 +1,9 @@
+load("@io_bazel_rules_go//go:def.bzl", "go_test")
+
package(licenses = ["notice"])
load("//tools/go_generics:defs.bzl", "go_template_instance")
-load("//tools/go_stateify:defs.bzl", "go_library", "go_test")
+load("//tools/go_stateify:defs.bzl", "go_library")
go_template_instance(
name = "evictable_range",
diff --git a/pkg/sentry/platform/interrupt/BUILD b/pkg/sentry/platform/interrupt/BUILD
index eeb634644..b6d008dbe 100644
--- a/pkg/sentry/platform/interrupt/BUILD
+++ b/pkg/sentry/platform/interrupt/BUILD
@@ -1,4 +1,5 @@
-load("//tools/go_stateify:defs.bzl", "go_library", "go_test")
+load("//tools/go_stateify:defs.bzl", "go_library")
+load("@io_bazel_rules_go//go:def.bzl", "go_test")
package(licenses = ["notice"])
diff --git a/pkg/sentry/platform/kvm/BUILD b/pkg/sentry/platform/kvm/BUILD
index fe979dccf..31fa48ec5 100644
--- a/pkg/sentry/platform/kvm/BUILD
+++ b/pkg/sentry/platform/kvm/BUILD
@@ -1,4 +1,5 @@
-load("//tools/go_stateify:defs.bzl", "go_library", "go_test")
+load("//tools/go_stateify:defs.bzl", "go_library")
+load("@io_bazel_rules_go//go:def.bzl", "go_test")
package(licenses = ["notice"])
diff --git a/pkg/sentry/platform/ptrace/ptrace_unsafe.go b/pkg/sentry/platform/ptrace/ptrace_unsafe.go
index 47957bb3b..72c7ec564 100644
--- a/pkg/sentry/platform/ptrace/ptrace_unsafe.go
+++ b/pkg/sentry/platform/ptrace/ptrace_unsafe.go
@@ -154,3 +154,19 @@ func (t *thread) clone() (*thread, error) {
cpu: ^uint32(0),
}, nil
}
+
+// getEventMessage retrieves a message about the ptrace event that just happened.
+func (t *thread) getEventMessage() (uintptr, error) {
+ var msg uintptr
+ _, _, errno := syscall.RawSyscall6(
+ syscall.SYS_PTRACE,
+ syscall.PTRACE_GETEVENTMSG,
+ uintptr(t.tid),
+ 0,
+ uintptr(unsafe.Pointer(&msg)),
+ 0, 0)
+ if errno != 0 {
+ return msg, errno
+ }
+ return msg, nil
+}
diff --git a/pkg/sentry/platform/ptrace/subprocess.go b/pkg/sentry/platform/ptrace/subprocess.go
index 6bf7cd097..4f8f9c5d9 100644
--- a/pkg/sentry/platform/ptrace/subprocess.go
+++ b/pkg/sentry/platform/ptrace/subprocess.go
@@ -355,7 +355,8 @@ func (t *thread) wait(outcome waitOutcome) syscall.Signal {
}
if stopSig == syscall.SIGTRAP {
if status.TrapCause() == syscall.PTRACE_EVENT_EXIT {
- t.dumpAndPanic("wait failed: the process exited")
+ msg, err := t.getEventMessage()
+ t.dumpAndPanic(fmt.Sprintf("wait failed: the process %d:%d exited: %x (err %v)", t.tgid, t.tid, msg, err))
}
// Re-encode the trap cause the way it's expected.
return stopSig | syscall.Signal(status.TrapCause()<<8)
@@ -426,6 +427,9 @@ func (t *thread) syscall(regs *syscall.PtraceRegs) (uintptr, error) {
break
} else {
// Some other signal caused a thread stop; ignore.
+ if sig != syscall.SIGSTOP && sig != syscall.SIGCHLD {
+ log.Warningf("The thread %d:%d has been interrupted by %d", t.tgid, t.tid, sig)
+ }
continue
}
}
diff --git a/pkg/sentry/platform/ring0/pagetables/BUILD b/pkg/sentry/platform/ring0/pagetables/BUILD
index 3b95af617..ea090b686 100644
--- a/pkg/sentry/platform/ring0/pagetables/BUILD
+++ b/pkg/sentry/platform/ring0/pagetables/BUILD
@@ -1,4 +1,5 @@
-load("//tools/go_stateify:defs.bzl", "go_library", "go_test")
+load("//tools/go_stateify:defs.bzl", "go_library")
+load("@io_bazel_rules_go//go:def.bzl", "go_test")
package(licenses = ["notice"])
diff --git a/pkg/sentry/platform/safecopy/BUILD b/pkg/sentry/platform/safecopy/BUILD
index 924d8a6d6..6769cd0a5 100644
--- a/pkg/sentry/platform/safecopy/BUILD
+++ b/pkg/sentry/platform/safecopy/BUILD
@@ -1,4 +1,5 @@
-load("//tools/go_stateify:defs.bzl", "go_library", "go_test")
+load("//tools/go_stateify:defs.bzl", "go_library")
+load("@io_bazel_rules_go//go:def.bzl", "go_test")
package(licenses = ["notice"])
diff --git a/pkg/sentry/safemem/BUILD b/pkg/sentry/safemem/BUILD
index fd6dc8e6e..884020f7b 100644
--- a/pkg/sentry/safemem/BUILD
+++ b/pkg/sentry/safemem/BUILD
@@ -1,4 +1,5 @@
-load("//tools/go_stateify:defs.bzl", "go_library", "go_test")
+load("//tools/go_stateify:defs.bzl", "go_library")
+load("@io_bazel_rules_go//go:def.bzl", "go_test")
package(licenses = ["notice"])
diff --git a/pkg/sentry/socket/epsocket/epsocket.go b/pkg/sentry/socket/epsocket/epsocket.go
index 0e37ce61b..3e66f9cbb 100644
--- a/pkg/sentry/socket/epsocket/epsocket.go
+++ b/pkg/sentry/socket/epsocket/epsocket.go
@@ -26,6 +26,7 @@ package epsocket
import (
"bytes"
+ "io"
"math"
"reflect"
"sync"
@@ -208,6 +209,10 @@ type commonEndpoint interface {
// transport.Endpoint.SetSockOpt.
SetSockOpt(interface{}) *tcpip.Error
+ // SetSockOptInt implements tcpip.Endpoint.SetSockOptInt and
+ // transport.Endpoint.SetSockOptInt.
+ SetSockOptInt(opt tcpip.SockOpt, v int) *tcpip.Error
+
// GetSockOpt implements tcpip.Endpoint.GetSockOpt and
// transport.Endpoint.GetSockOpt.
GetSockOpt(interface{}) *tcpip.Error
@@ -227,7 +232,6 @@ type SocketOperations struct {
fsutil.FileNoopFlush `state:"nosave"`
fsutil.FileNoFsync `state:"nosave"`
fsutil.FileNoMMap `state:"nosave"`
- fsutil.FileNoSplice `state:"nosave"`
fsutil.FileUseInodeUnstableAttr `state:"nosave"`
socket.SendReceiveTimeout
*waiter.Queue
@@ -412,17 +416,60 @@ func (s *SocketOperations) Read(ctx context.Context, _ *fs.File, dst usermem.IOS
return int64(n), nil
}
-// ioSequencePayload implements tcpip.Payload. It copies user memory bytes on demand
-// based on the requested size.
+// WriteTo implements fs.FileOperations.WriteTo.
+func (s *SocketOperations) WriteTo(ctx context.Context, _ *fs.File, dst io.Writer, count int64, dup bool) (int64, error) {
+ s.readMu.Lock()
+
+ // Copy as much data as possible.
+ done := int64(0)
+ for count > 0 {
+ // This may return a blocking error.
+ if err := s.fetchReadView(); err != nil {
+ s.readMu.Unlock()
+ return done, err.ToError()
+ }
+
+ // Write to the underlying file.
+ n, err := dst.Write(s.readView)
+ done += int64(n)
+ count -= int64(n)
+ if dup {
+ // That's all we support for dup. This is generally
+ // supported by any Linux system calls, but the
+ // expectation is that now a caller will call read to
+ // actually remove these bytes from the socket.
+ break
+ }
+
+ // Drop that part of the view.
+ s.readView.TrimFront(n)
+ if err != nil {
+ s.readMu.Unlock()
+ return done, err
+ }
+ }
+
+ s.readMu.Unlock()
+ return done, nil
+}
+
+// ioSequencePayload implements tcpip.Payload.
+//
+// t copies user memory bytes on demand based on the requested size.
type ioSequencePayload struct {
ctx context.Context
src usermem.IOSequence
}
-// Get implements tcpip.Payload.
-func (i *ioSequencePayload) Get(size int) ([]byte, *tcpip.Error) {
- if size > i.Size() {
- size = i.Size()
+// FullPayload implements tcpip.Payloader.FullPayload
+func (i *ioSequencePayload) FullPayload() ([]byte, *tcpip.Error) {
+ return i.Payload(int(i.src.NumBytes()))
+}
+
+// Payload implements tcpip.Payloader.Payload.
+func (i *ioSequencePayload) Payload(size int) ([]byte, *tcpip.Error) {
+ if max := int(i.src.NumBytes()); size > max {
+ size = max
}
v := buffer.NewView(size)
if _, err := i.src.CopyIn(i.ctx, v); err != nil {
@@ -431,11 +478,6 @@ func (i *ioSequencePayload) Get(size int) ([]byte, *tcpip.Error) {
return v, nil
}
-// Size implements tcpip.Payload.
-func (i *ioSequencePayload) Size() int {
- return int(i.src.NumBytes())
-}
-
// DropFirst drops the first n bytes from underlying src.
func (i *ioSequencePayload) DropFirst(n int) {
i.src = i.src.DropFirst(int(n))
@@ -469,6 +511,78 @@ func (s *SocketOperations) Write(ctx context.Context, _ *fs.File, src usermem.IO
return int64(n), nil
}
+// readerPayload implements tcpip.Payloader.
+//
+// It allocates a view and reads from a reader on-demand, based on available
+// capacity in the endpoint.
+type readerPayload struct {
+ ctx context.Context
+ r io.Reader
+ count int64
+ err error
+}
+
+// FullPayload implements tcpip.Payloader.FullPayload.
+func (r *readerPayload) FullPayload() ([]byte, *tcpip.Error) {
+ return r.Payload(int(r.count))
+}
+
+// Payload implements tcpip.Payloader.Payload.
+func (r *readerPayload) Payload(size int) ([]byte, *tcpip.Error) {
+ if size > int(r.count) {
+ size = int(r.count)
+ }
+ v := buffer.NewView(size)
+ n, err := r.r.Read(v)
+ if n > 0 {
+ // We ignore the error here. It may re-occur on subsequent
+ // reads, but for now we can enqueue some amount of data.
+ r.count -= int64(n)
+ return v[:n], nil
+ }
+ if err == syserror.ErrWouldBlock {
+ return nil, tcpip.ErrWouldBlock
+ } else if err != nil {
+ r.err = err // Save for propation.
+ return nil, tcpip.ErrBadAddress
+ }
+
+ // There is no data and no error. Return an error, which will propagate
+ // r.err, which will be nil. This is the desired result: (0, nil).
+ return nil, tcpip.ErrBadAddress
+}
+
+// ReadFrom implements fs.FileOperations.ReadFrom.
+func (s *SocketOperations) ReadFrom(ctx context.Context, _ *fs.File, r io.Reader, count int64) (int64, error) {
+ f := &readerPayload{ctx: ctx, r: r, count: count}
+ n, resCh, err := s.Endpoint.Write(f, tcpip.WriteOptions{
+ // Reads may be destructive but should be very fast,
+ // so we can't release the lock while copying data.
+ Atomic: true,
+ })
+ if err == tcpip.ErrWouldBlock {
+ return 0, syserror.ErrWouldBlock
+ }
+
+ if resCh != nil {
+ t := ctx.(*kernel.Task)
+ if err := t.Block(resCh); err != nil {
+ return 0, syserr.FromError(err).ToError()
+ }
+
+ n, _, err = s.Endpoint.Write(f, tcpip.WriteOptions{
+ Atomic: true, // See above.
+ })
+ }
+ if err == tcpip.ErrWouldBlock {
+ return n, syserror.ErrWouldBlock
+ } else if err != nil {
+ return int64(n), f.err // Propagate error.
+ }
+
+ return int64(n), nil
+}
+
// Readiness returns a mask of ready events for socket s.
func (s *SocketOperations) Readiness(mask waiter.EventMask) waiter.EventMask {
r := s.Endpoint.Readiness(mask)
@@ -777,8 +891,8 @@ func getSockOptSocket(t *kernel.Task, s socket.Socket, ep commonEndpoint, family
return nil, syserr.ErrInvalidArgument
}
- var size tcpip.SendBufferSizeOption
- if err := ep.GetSockOpt(&size); err != nil {
+ size, err := ep.GetSockOptInt(tcpip.SendBufferSizeOption)
+ if err != nil {
return nil, syserr.TranslateNetstackError(err)
}
@@ -793,8 +907,8 @@ func getSockOptSocket(t *kernel.Task, s socket.Socket, ep commonEndpoint, family
return nil, syserr.ErrInvalidArgument
}
- var size tcpip.ReceiveBufferSizeOption
- if err := ep.GetSockOpt(&size); err != nil {
+ size, err := ep.GetSockOptInt(tcpip.ReceiveBufferSizeOption)
+ if err != nil {
return nil, syserr.TranslateNetstackError(err)
}
@@ -1165,7 +1279,7 @@ func setSockOptSocket(t *kernel.Task, s socket.Socket, ep commonEndpoint, name i
}
v := usermem.ByteOrder.Uint32(optVal)
- return syserr.TranslateNetstackError(ep.SetSockOpt(tcpip.SendBufferSizeOption(v)))
+ return syserr.TranslateNetstackError(ep.SetSockOptInt(tcpip.SendBufferSizeOption, int(v)))
case linux.SO_RCVBUF:
if len(optVal) < sizeOfInt32 {
@@ -1173,7 +1287,7 @@ func setSockOptSocket(t *kernel.Task, s socket.Socket, ep commonEndpoint, name i
}
v := usermem.ByteOrder.Uint32(optVal)
- return syserr.TranslateNetstackError(ep.SetSockOpt(tcpip.ReceiveBufferSizeOption(v)))
+ return syserr.TranslateNetstackError(ep.SetSockOptInt(tcpip.ReceiveBufferSizeOption, int(v)))
case linux.SO_REUSEADDR:
if len(optVal) < sizeOfInt32 {
@@ -2060,7 +2174,7 @@ func (s *SocketOperations) SendMsg(t *kernel.Task, src usermem.IOSequence, to []
n, _, err = s.Endpoint.Write(v, opts)
}
dontWait := flags&linux.MSG_DONTWAIT != 0
- if err == nil && (n >= int64(v.Size()) || dontWait) {
+ if err == nil && (n >= v.src.NumBytes() || dontWait) {
// Complete write.
return int(n), nil
}
@@ -2085,7 +2199,7 @@ func (s *SocketOperations) SendMsg(t *kernel.Task, src usermem.IOSequence, to []
return 0, syserr.TranslateNetstackError(err)
}
- if err == nil && v.Size() == 0 || err != nil && err != tcpip.ErrWouldBlock {
+ if err == nil && v.src.NumBytes() == 0 || err != nil && err != tcpip.ErrWouldBlock {
return int(total), nil
}
@@ -2207,9 +2321,9 @@ func Ioctl(ctx context.Context, ep commonEndpoint, io usermem.IO, args arch.Sysc
return 0, err
case linux.TIOCOUTQ:
- var v tcpip.SendQueueSizeOption
- if err := ep.GetSockOpt(&v); err != nil {
- return 0, syserr.TranslateNetstackError(err).ToError()
+ v, terr := ep.GetSockOptInt(tcpip.SendQueueSizeOption)
+ if terr != nil {
+ return 0, syserr.TranslateNetstackError(terr).ToError()
}
if v > math.MaxInt32 {
diff --git a/pkg/sentry/socket/netlink/port/BUILD b/pkg/sentry/socket/netlink/port/BUILD
index 9e2e12799..445080aa4 100644
--- a/pkg/sentry/socket/netlink/port/BUILD
+++ b/pkg/sentry/socket/netlink/port/BUILD
@@ -1,6 +1,8 @@
+load("@io_bazel_rules_go//go:def.bzl", "go_test")
+
package(licenses = ["notice"])
-load("//tools/go_stateify:defs.bzl", "go_library", "go_test")
+load("//tools/go_stateify:defs.bzl", "go_library")
go_library(
name = "port",
diff --git a/pkg/sentry/socket/rpcinet/BUILD b/pkg/sentry/socket/rpcinet/BUILD
index 5061dcbde..3a6baa308 100644
--- a/pkg/sentry/socket/rpcinet/BUILD
+++ b/pkg/sentry/socket/rpcinet/BUILD
@@ -1,5 +1,6 @@
load("//tools/go_stateify:defs.bzl", "go_library")
load("@io_bazel_rules_go//proto:def.bzl", "go_proto_library")
+load("@rules_cc//cc:defs.bzl", "cc_proto_library")
package(licenses = ["notice"])
@@ -49,6 +50,14 @@ proto_library(
],
)
+cc_proto_library(
+ name = "syscall_rpc_cc_proto",
+ visibility = [
+ "//visibility:public",
+ ],
+ deps = [":syscall_rpc_proto"],
+)
+
go_proto_library(
name = "syscall_rpc_go_proto",
importpath = "gvisor.dev/gvisor/pkg/sentry/socket/rpcinet/syscall_rpc_go_proto",
diff --git a/pkg/sentry/socket/unix/transport/unix.go b/pkg/sentry/socket/unix/transport/unix.go
index 2b0ad6395..1867b3a5c 100644
--- a/pkg/sentry/socket/unix/transport/unix.go
+++ b/pkg/sentry/socket/unix/transport/unix.go
@@ -175,6 +175,10 @@ type Endpoint interface {
// types.
SetSockOpt(opt interface{}) *tcpip.Error
+ // SetSockOptInt sets a socket option for simple cases when a value has
+ // the int type.
+ SetSockOptInt(opt tcpip.SockOpt, v int) *tcpip.Error
+
// GetSockOpt gets a socket option. opt should be a pointer to one of the
// tcpip.*Option types.
GetSockOpt(opt interface{}) *tcpip.Error
@@ -838,6 +842,10 @@ func (e *baseEndpoint) SetSockOpt(opt interface{}) *tcpip.Error {
return nil
}
+func (e *baseEndpoint) SetSockOptInt(opt tcpip.SockOpt, v int) *tcpip.Error {
+ return nil
+}
+
func (e *baseEndpoint) GetSockOptInt(opt tcpip.SockOpt) (int, *tcpip.Error) {
switch opt {
case tcpip.ReceiveQueueSizeOption:
@@ -853,65 +861,63 @@ func (e *baseEndpoint) GetSockOptInt(opt tcpip.SockOpt) (int, *tcpip.Error) {
return -1, tcpip.ErrQueueSizeNotSupported
}
return v, nil
- default:
- return -1, tcpip.ErrUnknownProtocolOption
- }
-}
-
-// GetSockOpt implements tcpip.Endpoint.GetSockOpt.
-func (e *baseEndpoint) GetSockOpt(opt interface{}) *tcpip.Error {
- switch o := opt.(type) {
- case tcpip.ErrorOption:
- return nil
- case *tcpip.SendQueueSizeOption:
+ case tcpip.SendQueueSizeOption:
e.Lock()
if !e.Connected() {
e.Unlock()
- return tcpip.ErrNotConnected
+ return -1, tcpip.ErrNotConnected
}
- qs := tcpip.SendQueueSizeOption(e.connected.SendQueuedSize())
+ v := e.connected.SendQueuedSize()
e.Unlock()
- if qs < 0 {
- return tcpip.ErrQueueSizeNotSupported
- }
- *o = qs
- return nil
-
- case *tcpip.PasscredOption:
- if e.Passcred() {
- *o = tcpip.PasscredOption(1)
- } else {
- *o = tcpip.PasscredOption(0)
+ if v < 0 {
+ return -1, tcpip.ErrQueueSizeNotSupported
}
- return nil
+ return int(v), nil
- case *tcpip.SendBufferSizeOption:
+ case tcpip.SendBufferSizeOption:
e.Lock()
if !e.Connected() {
e.Unlock()
- return tcpip.ErrNotConnected
+ return -1, tcpip.ErrNotConnected
}
- qs := tcpip.SendBufferSizeOption(e.connected.SendMaxQueueSize())
+ v := e.connected.SendMaxQueueSize()
e.Unlock()
- if qs < 0 {
- return tcpip.ErrQueueSizeNotSupported
+ if v < 0 {
+ return -1, tcpip.ErrQueueSizeNotSupported
}
- *o = qs
- return nil
+ return int(v), nil
- case *tcpip.ReceiveBufferSizeOption:
+ case tcpip.ReceiveBufferSizeOption:
e.Lock()
if e.receiver == nil {
e.Unlock()
- return tcpip.ErrNotConnected
+ return -1, tcpip.ErrNotConnected
}
- qs := tcpip.ReceiveBufferSizeOption(e.receiver.RecvMaxQueueSize())
+ v := e.receiver.RecvMaxQueueSize()
e.Unlock()
- if qs < 0 {
- return tcpip.ErrQueueSizeNotSupported
+ if v < 0 {
+ return -1, tcpip.ErrQueueSizeNotSupported
+ }
+ return int(v), nil
+
+ default:
+ return -1, tcpip.ErrUnknownProtocolOption
+ }
+}
+
+// GetSockOpt implements tcpip.Endpoint.GetSockOpt.
+func (e *baseEndpoint) GetSockOpt(opt interface{}) *tcpip.Error {
+ switch o := opt.(type) {
+ case tcpip.ErrorOption:
+ return nil
+
+ case *tcpip.PasscredOption:
+ if e.Passcred() {
+ *o = tcpip.PasscredOption(1)
+ } else {
+ *o = tcpip.PasscredOption(0)
}
- *o = qs
return nil
case *tcpip.KeepaliveEnabledOption:
diff --git a/pkg/sentry/strace/BUILD b/pkg/sentry/strace/BUILD
index 445d25010..7d7b42eba 100644
--- a/pkg/sentry/strace/BUILD
+++ b/pkg/sentry/strace/BUILD
@@ -1,5 +1,6 @@
load("//tools/go_stateify:defs.bzl", "go_library")
load("@io_bazel_rules_go//proto:def.bzl", "go_proto_library")
+load("@rules_cc//cc:defs.bzl", "cc_proto_library")
package(licenses = ["notice"])
@@ -44,6 +45,12 @@ proto_library(
visibility = ["//visibility:public"],
)
+cc_proto_library(
+ name = "strace_cc_proto",
+ visibility = ["//visibility:public"],
+ deps = [":strace_proto"],
+)
+
go_proto_library(
name = "strace_go_proto",
importpath = "gvisor.dev/gvisor/pkg/sentry/strace/strace_go_proto",
diff --git a/pkg/sentry/strace/linux64.go b/pkg/sentry/strace/linux64.go
index 3650fd6e1..5d57b75af 100644
--- a/pkg/sentry/strace/linux64.go
+++ b/pkg/sentry/strace/linux64.go
@@ -335,4 +335,5 @@ var linuxAMD64 = SyscallMap{
315: makeSyscallInfo("sched_getattr", Hex, Hex, Hex),
316: makeSyscallInfo("renameat2", FD, Path, Hex, Path, Hex),
317: makeSyscallInfo("seccomp", Hex, Hex, Hex),
+ 332: makeSyscallInfo("statx", FD, Path, Hex, Hex, Hex),
}
diff --git a/pkg/sentry/syscalls/linux/BUILD b/pkg/sentry/syscalls/linux/BUILD
index 33a40b9c6..e76ee27d2 100644
--- a/pkg/sentry/syscalls/linux/BUILD
+++ b/pkg/sentry/syscalls/linux/BUILD
@@ -74,6 +74,7 @@ go_library(
"//pkg/sentry/kernel/pipe",
"//pkg/sentry/kernel/sched",
"//pkg/sentry/kernel/shm",
+ "//pkg/sentry/kernel/signalfd",
"//pkg/sentry/kernel/time",
"//pkg/sentry/limits",
"//pkg/sentry/memmap",
diff --git a/pkg/sentry/syscalls/linux/linux64.go b/pkg/sentry/syscalls/linux/linux64.go
index ed996ba51..18d24ab61 100644
--- a/pkg/sentry/syscalls/linux/linux64.go
+++ b/pkg/sentry/syscalls/linux/linux64.go
@@ -320,21 +320,21 @@ var AMD64 = &kernel.SyscallTable{
272: syscalls.PartiallySupported("unshare", Unshare, "Mount, cgroup namespaces not supported. Network namespaces supported but must be empty.", nil),
273: syscalls.Error("set_robust_list", syserror.ENOSYS, "Obsolete.", nil),
274: syscalls.Error("get_robust_list", syserror.ENOSYS, "Obsolete.", nil),
- 275: syscalls.PartiallySupported("splice", Splice, "Stub implementation.", []string{"gvisor.dev/issue/138"}), // TODO(b/29354098)
- 276: syscalls.ErrorWithEvent("tee", syserror.ENOSYS, "", []string{"gvisor.dev/issue/138"}), // TODO(b/29354098)
+ 275: syscalls.Supported("splice", Splice),
+ 276: syscalls.Supported("tee", Tee),
277: syscalls.PartiallySupported("sync_file_range", SyncFileRange, "Full data flush is not guaranteed at this time.", nil),
278: syscalls.ErrorWithEvent("vmsplice", syserror.ENOSYS, "", []string{"gvisor.dev/issue/138"}), // TODO(b/29354098)
279: syscalls.CapError("move_pages", linux.CAP_SYS_NICE, "", nil), // requires cap_sys_nice (mostly)
280: syscalls.Supported("utimensat", Utimensat),
281: syscalls.Supported("epoll_pwait", EpollPwait),
- 282: syscalls.ErrorWithEvent("signalfd", syserror.ENOSYS, "", []string{"gvisor.dev/issue/139"}), // TODO(b/19846426)
+ 282: syscalls.PartiallySupported("signalfd", Signalfd, "Semantics are slightly different.", []string{"gvisor.dev/issue/139"}),
283: syscalls.Supported("timerfd_create", TimerfdCreate),
284: syscalls.Supported("eventfd", Eventfd),
285: syscalls.PartiallySupported("fallocate", Fallocate, "Not all options are supported.", nil),
286: syscalls.Supported("timerfd_settime", TimerfdSettime),
287: syscalls.Supported("timerfd_gettime", TimerfdGettime),
288: syscalls.Supported("accept4", Accept4),
- 289: syscalls.ErrorWithEvent("signalfd4", syserror.ENOSYS, "", []string{"gvisor.dev/issue/139"}), // TODO(b/19846426)
+ 289: syscalls.PartiallySupported("signalfd4", Signalfd4, "Semantics are slightly different.", []string{"gvisor.dev/issue/139"}),
290: syscalls.Supported("eventfd2", Eventfd2),
291: syscalls.Supported("epoll_create1", EpollCreate1),
292: syscalls.Supported("dup3", Dup3),
diff --git a/pkg/sentry/syscalls/linux/sys_file.go b/pkg/sentry/syscalls/linux/sys_file.go
index 2e00a91ce..b9a8e3e21 100644
--- a/pkg/sentry/syscalls/linux/sys_file.go
+++ b/pkg/sentry/syscalls/linux/sys_file.go
@@ -1423,9 +1423,6 @@ func unlinkAt(t *kernel.Task, dirFD int32, addr usermem.Addr) error {
if err != nil {
return err
}
- if dirPath {
- return syserror.ENOENT
- }
return fileOpAt(t, dirFD, path, func(root *fs.Dirent, d *fs.Dirent, name string, _ uint) error {
if !fs.IsDir(d.Inode.StableAttr) {
@@ -1436,7 +1433,7 @@ func unlinkAt(t *kernel.Task, dirFD int32, addr usermem.Addr) error {
return err
}
- return d.Remove(t, root, name)
+ return d.Remove(t, root, name, dirPath)
})
}
diff --git a/pkg/sentry/syscalls/linux/sys_signal.go b/pkg/sentry/syscalls/linux/sys_signal.go
index 0104a94c0..fb6efd5d8 100644
--- a/pkg/sentry/syscalls/linux/sys_signal.go
+++ b/pkg/sentry/syscalls/linux/sys_signal.go
@@ -20,7 +20,10 @@ import (
"gvisor.dev/gvisor/pkg/abi/linux"
"gvisor.dev/gvisor/pkg/sentry/arch"
+ "gvisor.dev/gvisor/pkg/sentry/fs"
"gvisor.dev/gvisor/pkg/sentry/kernel"
+ "gvisor.dev/gvisor/pkg/sentry/kernel/signalfd"
+ "gvisor.dev/gvisor/pkg/sentry/usermem"
"gvisor.dev/gvisor/pkg/syserror"
)
@@ -506,3 +509,77 @@ func RestartSyscall(t *kernel.Task, args arch.SyscallArguments) (uintptr, *kerne
t.Debugf("Restart block missing in restart_syscall(2). Did ptrace inject a return value of ERESTART_RESTARTBLOCK?")
return 0, nil, syserror.EINTR
}
+
+// sharedSignalfd is shared between the two calls.
+func sharedSignalfd(t *kernel.Task, fd int32, sigset usermem.Addr, sigsetsize uint, flags int32) (uintptr, *kernel.SyscallControl, error) {
+ // Copy in the signal mask.
+ mask, err := copyInSigSet(t, sigset, sigsetsize)
+ if err != nil {
+ return 0, nil, err
+ }
+
+ // Always check for valid flags, even if not creating.
+ if flags&^(linux.SFD_NONBLOCK|linux.SFD_CLOEXEC) != 0 {
+ return 0, nil, syserror.EINVAL
+ }
+
+ // Is this a change to an existing signalfd?
+ //
+ // The spec indicates that this should adjust the mask.
+ if fd != -1 {
+ file := t.GetFile(fd)
+ if file == nil {
+ return 0, nil, syserror.EBADF
+ }
+ defer file.DecRef()
+
+ // Is this a signalfd?
+ if s, ok := file.FileOperations.(*signalfd.SignalOperations); ok {
+ s.SetMask(mask)
+ return 0, nil, nil
+ }
+
+ // Not a signalfd.
+ return 0, nil, syserror.EINVAL
+ }
+
+ // Create a new file.
+ file, err := signalfd.New(t, mask)
+ if err != nil {
+ return 0, nil, err
+ }
+ defer file.DecRef()
+
+ // Set appropriate flags.
+ file.SetFlags(fs.SettableFileFlags{
+ NonBlocking: flags&linux.SFD_NONBLOCK != 0,
+ })
+
+ // Create a new descriptor.
+ fd, err = t.NewFDFrom(0, file, kernel.FDFlags{
+ CloseOnExec: flags&linux.SFD_CLOEXEC != 0,
+ })
+ if err != nil {
+ return 0, nil, err
+ }
+
+ // Done.
+ return uintptr(fd), nil, nil
+}
+
+// Signalfd implements the linux syscall signalfd(2).
+func Signalfd(t *kernel.Task, args arch.SyscallArguments) (uintptr, *kernel.SyscallControl, error) {
+ fd := args[0].Int()
+ sigset := args[1].Pointer()
+ sigsetsize := args[2].SizeT()
+ return sharedSignalfd(t, fd, sigset, sigsetsize, 0)
+}
+
+// Signalfd4 implements the linux syscall signalfd4(2).
+func Signalfd4(t *kernel.Task, args arch.SyscallArguments) (uintptr, *kernel.SyscallControl, error) {
+ fd := args[0].Int()
+ sigset := args[1].Pointer()
+ sigsetsize := args[2].SizeT()
+ flags := args[3].Int()
+ return sharedSignalfd(t, fd, sigset, sigsetsize, flags)
+}
diff --git a/pkg/sentry/syscalls/linux/sys_splice.go b/pkg/sentry/syscalls/linux/sys_splice.go
index 8a98fedcb..f0a292f2f 100644
--- a/pkg/sentry/syscalls/linux/sys_splice.go
+++ b/pkg/sentry/syscalls/linux/sys_splice.go
@@ -29,9 +29,8 @@ func doSplice(t *kernel.Task, outFile, inFile *fs.File, opts fs.SpliceOpts, nonB
total int64
n int64
err error
- ch chan struct{}
- inW bool
- outW bool
+ inCh chan struct{}
+ outCh chan struct{}
)
for opts.Length > 0 {
n, err = fs.Splice(t, outFile, inFile, opts)
@@ -43,35 +42,33 @@ func doSplice(t *kernel.Task, outFile, inFile *fs.File, opts fs.SpliceOpts, nonB
break
}
- // Are we a registered waiter?
- if ch == nil {
- ch = make(chan struct{}, 1)
- }
- if !inW && !inFile.Flags().NonBlocking {
- w, _ := waiter.NewChannelEntry(ch)
- inFile.EventRegister(&w, EventMaskRead)
- defer inFile.EventUnregister(&w)
- inW = true // Registered.
- } else if !outW && !outFile.Flags().NonBlocking {
- w, _ := waiter.NewChannelEntry(ch)
- outFile.EventRegister(&w, EventMaskWrite)
- defer outFile.EventUnregister(&w)
- outW = true // Registered.
- }
-
- // Was anything registered? If no, everything is non-blocking.
- if !inW && !outW {
- break
- }
-
- if (!inW || inFile.Readiness(EventMaskRead) != 0) && (!outW || outFile.Readiness(EventMaskWrite) != 0) {
- // Something became ready, try again without blocking.
- continue
+ // Note that the blocking behavior here is a bit different than the
+ // normal pattern. Because we need to have both data to read and data
+ // to write simultaneously, we actually explicitly block on both of
+ // these cases in turn before returning to the splice operation.
+ if inFile.Readiness(EventMaskRead) == 0 {
+ if inCh == nil {
+ inCh = make(chan struct{}, 1)
+ inW, _ := waiter.NewChannelEntry(inCh)
+ inFile.EventRegister(&inW, EventMaskRead)
+ defer inFile.EventUnregister(&inW)
+ continue // Need to refresh readiness.
+ }
+ if err = t.Block(inCh); err != nil {
+ break
+ }
}
-
- // Block until there's data.
- if err = t.Block(ch); err != nil {
- break
+ if outFile.Readiness(EventMaskWrite) == 0 {
+ if outCh == nil {
+ outCh = make(chan struct{}, 1)
+ outW, _ := waiter.NewChannelEntry(outCh)
+ outFile.EventRegister(&outW, EventMaskWrite)
+ defer outFile.EventUnregister(&outW)
+ continue // Need to refresh readiness.
+ }
+ if err = t.Block(outCh); err != nil {
+ break
+ }
}
}
@@ -149,7 +146,7 @@ func Sendfile(t *kernel.Task, args arch.SyscallArguments) (uintptr, *kernel.Sysc
Length: count,
SrcOffset: true,
SrcStart: offset,
- }, false)
+ }, outFile.Flags().NonBlocking)
// Copy out the new offset.
if _, err := t.CopyOut(offsetAddr, n+offset); err != nil {
@@ -159,7 +156,7 @@ func Sendfile(t *kernel.Task, args arch.SyscallArguments) (uintptr, *kernel.Sysc
// Send data using splice.
n, err = doSplice(t, outFile, inFile, fs.SpliceOpts{
Length: count,
- }, false)
+ }, outFile.Flags().NonBlocking)
}
// We can only pass a single file to handleIOError, so pick inFile
@@ -181,12 +178,6 @@ func Splice(t *kernel.Task, args arch.SyscallArguments) (uintptr, *kernel.Syscal
return 0, nil, syserror.EINVAL
}
- // Only non-blocking is meaningful. Note that unlike in Linux, this
- // flag is applied consistently. We will have either fully blocking or
- // non-blocking behavior below, regardless of the underlying files
- // being spliced to. It's unclear if this is a bug or not yet.
- nonBlocking := (flags & linux.SPLICE_F_NONBLOCK) != 0
-
// Get files.
outFile := t.GetFile(outFD)
if outFile == nil {
@@ -200,6 +191,13 @@ func Splice(t *kernel.Task, args arch.SyscallArguments) (uintptr, *kernel.Syscal
}
defer inFile.DecRef()
+ // The operation is non-blocking if anything is non-blocking.
+ //
+ // N.B. This is a rather simplistic heuristic that avoids some
+ // poor edge case behavior since the exact semantics here are
+ // underspecified and vary between versions of Linux itself.
+ nonBlock := inFile.Flags().NonBlocking || outFile.Flags().NonBlocking || (flags&linux.SPLICE_F_NONBLOCK != 0)
+
// Construct our options.
//
// Note that exactly one of the underlying buffers must be a pipe. We
@@ -257,7 +255,7 @@ func Splice(t *kernel.Task, args arch.SyscallArguments) (uintptr, *kernel.Syscal
}
// Splice data.
- n, err := doSplice(t, outFile, inFile, opts, nonBlocking)
+ n, err := doSplice(t, outFile, inFile, opts, nonBlock)
// See above; inFile is chosen arbitrarily here.
return uintptr(n), nil, handleIOError(t, n != 0, err, kernel.ERESTARTSYS, "splice", inFile)
@@ -275,9 +273,6 @@ func Tee(t *kernel.Task, args arch.SyscallArguments) (uintptr, *kernel.SyscallCo
return 0, nil, syserror.EINVAL
}
- // Only non-blocking is meaningful.
- nonBlocking := (flags & linux.SPLICE_F_NONBLOCK) != 0
-
// Get files.
outFile := t.GetFile(outFD)
if outFile == nil {
@@ -301,11 +296,14 @@ func Tee(t *kernel.Task, args arch.SyscallArguments) (uintptr, *kernel.SyscallCo
return 0, nil, syserror.EINVAL
}
+ // The operation is non-blocking if anything is non-blocking.
+ nonBlock := inFile.Flags().NonBlocking || outFile.Flags().NonBlocking || (flags&linux.SPLICE_F_NONBLOCK != 0)
+
// Splice data.
n, err := doSplice(t, outFile, inFile, fs.SpliceOpts{
Length: count,
Dup: true,
- }, nonBlocking)
+ }, nonBlock)
// See above; inFile is chosen arbitrarily here.
return uintptr(n), nil, handleIOError(t, n != 0, err, kernel.ERESTARTSYS, "tee", inFile)
diff --git a/pkg/sentry/time/BUILD b/pkg/sentry/time/BUILD
index 8aa6a3017..beb43ba13 100644
--- a/pkg/sentry/time/BUILD
+++ b/pkg/sentry/time/BUILD
@@ -1,4 +1,5 @@
-load("//tools/go_stateify:defs.bzl", "go_library", "go_test")
+load("//tools/go_stateify:defs.bzl", "go_library")
+load("@io_bazel_rules_go//go:def.bzl", "go_test")
package(licenses = ["notice"])
diff --git a/pkg/sentry/unimpl/BUILD b/pkg/sentry/unimpl/BUILD
index b69603da3..fc7614fff 100644
--- a/pkg/sentry/unimpl/BUILD
+++ b/pkg/sentry/unimpl/BUILD
@@ -1,5 +1,6 @@
load("//tools/go_stateify:defs.bzl", "go_library")
load("@io_bazel_rules_go//proto:def.bzl", "go_proto_library")
+load("@rules_cc//cc:defs.bzl", "cc_proto_library")
package(licenses = ["notice"])
@@ -10,6 +11,12 @@ proto_library(
deps = ["//pkg/sentry/arch:registers_proto"],
)
+cc_proto_library(
+ name = "unimplemented_syscall_cc_proto",
+ visibility = ["//visibility:public"],
+ deps = [":unimplemented_syscall_proto"],
+)
+
go_proto_library(
name = "unimplemented_syscall_go_proto",
importpath = "gvisor.dev/gvisor/pkg/sentry/unimpl/unimplemented_syscall_go_proto",
diff --git a/pkg/sentry/usermem/BUILD b/pkg/sentry/usermem/BUILD
index a5b4206bb..cc5d25762 100644
--- a/pkg/sentry/usermem/BUILD
+++ b/pkg/sentry/usermem/BUILD
@@ -1,7 +1,9 @@
+load("@io_bazel_rules_go//go:def.bzl", "go_test")
+
package(licenses = ["notice"])
load("//tools/go_generics:defs.bzl", "go_template_instance")
-load("//tools/go_stateify:defs.bzl", "go_library", "go_test")
+load("//tools/go_stateify:defs.bzl", "go_library")
go_template_instance(
name = "addr_range",
diff --git a/pkg/sentry/vfs/BUILD b/pkg/sentry/vfs/BUILD
index 0f247bf77..eff4b44f6 100644
--- a/pkg/sentry/vfs/BUILD
+++ b/pkg/sentry/vfs/BUILD
@@ -1,4 +1,5 @@
-load("//tools/go_stateify:defs.bzl", "go_library", "go_test")
+load("//tools/go_stateify:defs.bzl", "go_library")
+load("@io_bazel_rules_go//go:def.bzl", "go_test")
package(licenses = ["notice"])
diff --git a/pkg/sentry/vfs/file_description.go b/pkg/sentry/vfs/file_description.go
index 86bde7fb3..7eb2b2821 100644
--- a/pkg/sentry/vfs/file_description.go
+++ b/pkg/sentry/vfs/file_description.go
@@ -199,8 +199,11 @@ type Dirent struct {
// Ino is the inode number.
Ino uint64
- // Off is this Dirent's offset.
- Off int64
+ // NextOff is the offset of the *next* Dirent in the directory; that is,
+ // FileDescription.Seek(NextOff, SEEK_SET) (as called by seekdir(3)) will
+ // cause the next call to FileDescription.IterDirents() to yield the next
+ // Dirent. (The offset of the first Dirent in a directory is always 0.)
+ NextOff int64
}
// IterDirentsCallback receives Dirents from FileDescriptionImpl.IterDirents.
diff --git a/pkg/sleep/BUILD b/pkg/sleep/BUILD
index 00665c939..bdca80d37 100644
--- a/pkg/sleep/BUILD
+++ b/pkg/sleep/BUILD
@@ -1,4 +1,5 @@
-load("//tools/go_stateify:defs.bzl", "go_library", "go_test")
+load("//tools/go_stateify:defs.bzl", "go_library")
+load("@io_bazel_rules_go//go:def.bzl", "go_test")
package(licenses = ["notice"])
diff --git a/pkg/state/BUILD b/pkg/state/BUILD
index c0f3c658d..329904457 100644
--- a/pkg/state/BUILD
+++ b/pkg/state/BUILD
@@ -1,5 +1,6 @@
-load("//tools/go_stateify:defs.bzl", "go_library", "go_test")
+load("//tools/go_stateify:defs.bzl", "go_library")
load("@io_bazel_rules_go//proto:def.bzl", "go_proto_library")
+load("@io_bazel_rules_go//go:def.bzl", "go_test")
package(licenses = ["notice"])
diff --git a/pkg/state/statefile/BUILD b/pkg/state/statefile/BUILD
index e70f4a79f..8a865d229 100644
--- a/pkg/state/statefile/BUILD
+++ b/pkg/state/statefile/BUILD
@@ -1,4 +1,5 @@
-load("//tools/go_stateify:defs.bzl", "go_library", "go_test")
+load("//tools/go_stateify:defs.bzl", "go_library")
+load("@io_bazel_rules_go//go:def.bzl", "go_test")
package(licenses = ["notice"])
diff --git a/pkg/syserror/BUILD b/pkg/syserror/BUILD
index b149f9e02..bd3f9fd28 100644
--- a/pkg/syserror/BUILD
+++ b/pkg/syserror/BUILD
@@ -1,4 +1,5 @@
-load("//tools/go_stateify:defs.bzl", "go_library", "go_test")
+load("//tools/go_stateify:defs.bzl", "go_library")
+load("@io_bazel_rules_go//go:def.bzl", "go_test")
package(licenses = ["notice"])
diff --git a/pkg/tcpip/BUILD b/pkg/tcpip/BUILD
index df37c7d5a..3fd9e3134 100644
--- a/pkg/tcpip/BUILD
+++ b/pkg/tcpip/BUILD
@@ -1,6 +1,8 @@
+load("@io_bazel_rules_go//go:def.bzl", "go_test")
+
package(licenses = ["notice"])
-load("//tools/go_stateify:defs.bzl", "go_library", "go_test")
+load("//tools/go_stateify:defs.bzl", "go_library")
go_library(
name = "tcpip",
diff --git a/pkg/tcpip/adapters/gonet/BUILD b/pkg/tcpip/adapters/gonet/BUILD
index 0d2637ee4..78df5a0b1 100644
--- a/pkg/tcpip/adapters/gonet/BUILD
+++ b/pkg/tcpip/adapters/gonet/BUILD
@@ -1,4 +1,5 @@
-load("//tools/go_stateify:defs.bzl", "go_library", "go_test")
+load("//tools/go_stateify:defs.bzl", "go_library")
+load("@io_bazel_rules_go//go:def.bzl", "go_test")
package(licenses = ["notice"])
diff --git a/pkg/tcpip/buffer/BUILD b/pkg/tcpip/buffer/BUILD
index 3301967fb..b4e8d6810 100644
--- a/pkg/tcpip/buffer/BUILD
+++ b/pkg/tcpip/buffer/BUILD
@@ -1,6 +1,8 @@
+load("@io_bazel_rules_go//go:def.bzl", "go_test")
+
package(licenses = ["notice"])
-load("//tools/go_stateify:defs.bzl", "go_library", "go_test")
+load("//tools/go_stateify:defs.bzl", "go_library")
go_library(
name = "buffer",
diff --git a/pkg/tcpip/hash/jenkins/BUILD b/pkg/tcpip/hash/jenkins/BUILD
index 29b30be9c..0c5c20cea 100644
--- a/pkg/tcpip/hash/jenkins/BUILD
+++ b/pkg/tcpip/hash/jenkins/BUILD
@@ -1,4 +1,5 @@
-load("//tools/go_stateify:defs.bzl", "go_library", "go_test")
+load("//tools/go_stateify:defs.bzl", "go_library")
+load("@io_bazel_rules_go//go:def.bzl", "go_test")
package(licenses = ["notice"])
diff --git a/pkg/tcpip/header/BUILD b/pkg/tcpip/header/BUILD
index 76ef02f13..b558350c3 100644
--- a/pkg/tcpip/header/BUILD
+++ b/pkg/tcpip/header/BUILD
@@ -1,6 +1,8 @@
+load("@io_bazel_rules_go//go:def.bzl", "go_test")
+
package(licenses = ["notice"])
-load("//tools/go_stateify:defs.bzl", "go_library", "go_test")
+load("//tools/go_stateify:defs.bzl", "go_library")
go_library(
name = "header",
diff --git a/pkg/tcpip/header/ipv6.go b/pkg/tcpip/header/ipv6.go
index 093850e25..9d3abc0e4 100644
--- a/pkg/tcpip/header/ipv6.go
+++ b/pkg/tcpip/header/ipv6.go
@@ -76,6 +76,13 @@ const (
// IPv6Version is the version of the ipv6 protocol.
IPv6Version = 6
+ // IPv6AllNodesMulticastAddress is a link-local multicast group that
+ // all IPv6 nodes MUST join, as per RFC 4291, section 2.8. Packets
+ // destined to this address will reach all nodes on a link.
+ //
+ // The address is ff02::1.
+ IPv6AllNodesMulticastAddress tcpip.Address = "\xff\x02\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x01"
+
// IPv6MinimumMTU is the minimum MTU required by IPv6, per RFC 2460,
// section 5.
IPv6MinimumMTU = 1280
@@ -221,6 +228,24 @@ func IsV6MulticastAddress(addr tcpip.Address) bool {
return addr[0] == 0xff
}
+// IsV6UnicastAddress determines if the provided address is a valid IPv6
+// unicast (and specified) address. That is, IsV6UnicastAddress returns
+// true if addr contains IPv6AddressSize bytes, is not the unspecified
+// address and is not a multicast address.
+func IsV6UnicastAddress(addr tcpip.Address) bool {
+ if len(addr) != IPv6AddressSize {
+ return false
+ }
+
+ // Must not be unspecified
+ if addr == IPv6Any {
+ return false
+ }
+
+ // Return if not a multicast.
+ return addr[0] != 0xff
+}
+
// SolicitedNodeAddr computes the solicited-node multicast address. This is
// used for NDP. Described in RFC 4291. The argument must be a full-length IPv6
// address.
diff --git a/pkg/tcpip/header/udp.go b/pkg/tcpip/header/udp.go
index c1f454805..74412c894 100644
--- a/pkg/tcpip/header/udp.go
+++ b/pkg/tcpip/header/udp.go
@@ -27,6 +27,11 @@ const (
udpChecksum = 6
)
+const (
+ // UDPMaximumPacketSize is the largest possible UDP packet.
+ UDPMaximumPacketSize = 0xffff
+)
+
// UDPFields contains the fields of a UDP packet. It is used to describe the
// fields of a packet that needs to be encoded.
type UDPFields struct {
diff --git a/pkg/tcpip/link/channel/channel.go b/pkg/tcpip/link/channel/channel.go
index c40744b8e..18adb2085 100644
--- a/pkg/tcpip/link/channel/channel.go
+++ b/pkg/tcpip/link/channel/channel.go
@@ -44,14 +44,12 @@ type Endpoint struct {
}
// New creates a new channel endpoint.
-func New(size int, mtu uint32, linkAddr tcpip.LinkAddress) (tcpip.LinkEndpointID, *Endpoint) {
- e := &Endpoint{
+func New(size int, mtu uint32, linkAddr tcpip.LinkAddress) *Endpoint {
+ return &Endpoint{
C: make(chan PacketInfo, size),
mtu: mtu,
linkAddr: linkAddr,
}
-
- return stack.RegisterLinkEndpoint(e), e
}
// Drain removes all outbound packets from the channel and counts them.
@@ -135,3 +133,6 @@ func (e *Endpoint) WritePacket(_ *stack.Route, gso *stack.GSO, hdr buffer.Prepen
return nil
}
+
+// Wait implements stack.LinkEndpoint.Wait.
+func (*Endpoint) Wait() {}
diff --git a/pkg/tcpip/link/fdbased/BUILD b/pkg/tcpip/link/fdbased/BUILD
index 74fbbb896..8fa9e3984 100644
--- a/pkg/tcpip/link/fdbased/BUILD
+++ b/pkg/tcpip/link/fdbased/BUILD
@@ -1,4 +1,5 @@
-load("//tools/go_stateify:defs.bzl", "go_library", "go_test")
+load("//tools/go_stateify:defs.bzl", "go_library")
+load("@io_bazel_rules_go//go:def.bzl", "go_test")
package(licenses = ["notice"])
diff --git a/pkg/tcpip/link/fdbased/endpoint.go b/pkg/tcpip/link/fdbased/endpoint.go
index 77f988b9f..584db710e 100644
--- a/pkg/tcpip/link/fdbased/endpoint.go
+++ b/pkg/tcpip/link/fdbased/endpoint.go
@@ -41,6 +41,7 @@ package fdbased
import (
"fmt"
+ "sync"
"syscall"
"golang.org/x/sys/unix"
@@ -81,6 +82,7 @@ const (
PacketMMap
)
+// An endpoint implements the link-layer using a message-oriented file descriptor.
type endpoint struct {
// fds is the set of file descriptors each identifying one inbound/outbound
// channel. The endpoint will dispatch from all inbound channels as well as
@@ -114,6 +116,9 @@ type endpoint struct {
// gsoMaxSize is the maximum GSO packet size. It is zero if GSO is
// disabled.
gsoMaxSize uint32
+
+ // wg keeps track of running goroutines.
+ wg sync.WaitGroup
}
// Options specify the details about the fd-based endpoint to be created.
@@ -164,8 +169,9 @@ type Options struct {
// New creates a new fd-based endpoint.
//
// Makes fd non-blocking, but does not take ownership of fd, which must remain
-// open for the lifetime of the returned endpoint.
-func New(opts *Options) (tcpip.LinkEndpointID, error) {
+// open for the lifetime of the returned endpoint (until after the endpoint has
+// stopped being using and Wait returns).
+func New(opts *Options) (stack.LinkEndpoint, error) {
caps := stack.LinkEndpointCapabilities(0)
if opts.RXChecksumOffload {
caps |= stack.CapabilityRXChecksumOffload
@@ -190,7 +196,7 @@ func New(opts *Options) (tcpip.LinkEndpointID, error) {
}
if len(opts.FDs) == 0 {
- return 0, fmt.Errorf("opts.FD is empty, at least one FD must be specified")
+ return nil, fmt.Errorf("opts.FD is empty, at least one FD must be specified")
}
e := &endpoint{
@@ -207,12 +213,12 @@ func New(opts *Options) (tcpip.LinkEndpointID, error) {
for i := 0; i < len(e.fds); i++ {
fd := e.fds[i]
if err := syscall.SetNonblock(fd, true); err != nil {
- return 0, fmt.Errorf("syscall.SetNonblock(%v) failed: %v", fd, err)
+ return nil, fmt.Errorf("syscall.SetNonblock(%v) failed: %v", fd, err)
}
isSocket, err := isSocketFD(fd)
if err != nil {
- return 0, err
+ return nil, err
}
if isSocket {
if opts.GSOMaxSize != 0 {
@@ -222,12 +228,12 @@ func New(opts *Options) (tcpip.LinkEndpointID, error) {
}
inboundDispatcher, err := createInboundDispatcher(e, fd, isSocket)
if err != nil {
- return 0, fmt.Errorf("createInboundDispatcher(...) = %v", err)
+ return nil, fmt.Errorf("createInboundDispatcher(...) = %v", err)
}
e.inboundDispatchers = append(e.inboundDispatchers, inboundDispatcher)
}
- return stack.RegisterLinkEndpoint(e), nil
+ return e, nil
}
func createInboundDispatcher(e *endpoint, fd int, isSocket bool) (linkDispatcher, error) {
@@ -290,7 +296,11 @@ func (e *endpoint) Attach(dispatcher stack.NetworkDispatcher) {
// saved, they stop sending outgoing packets and all incoming packets
// are rejected.
for i := range e.inboundDispatchers {
- go e.dispatchLoop(e.inboundDispatchers[i]) // S/R-SAFE: See above.
+ e.wg.Add(1)
+ go func(i int) { // S/R-SAFE: See above.
+ e.dispatchLoop(e.inboundDispatchers[i])
+ e.wg.Done()
+ }(i)
}
}
@@ -320,6 +330,12 @@ func (e *endpoint) LinkAddress() tcpip.LinkAddress {
return e.addr
}
+// Wait implements stack.LinkEndpoint.Wait. It waits for the endpoint to stop
+// reading from its FD.
+func (e *endpoint) Wait() {
+ e.wg.Wait()
+}
+
// virtioNetHdr is declared in linux/virtio_net.h.
type virtioNetHdr struct {
flags uint8
@@ -435,14 +451,12 @@ func (e *InjectableEndpoint) Inject(protocol tcpip.NetworkProtocolNumber, vv buf
}
// NewInjectable creates a new fd-based InjectableEndpoint.
-func NewInjectable(fd int, mtu uint32, capabilities stack.LinkEndpointCapabilities) (tcpip.LinkEndpointID, *InjectableEndpoint) {
+func NewInjectable(fd int, mtu uint32, capabilities stack.LinkEndpointCapabilities) *InjectableEndpoint {
syscall.SetNonblock(fd, true)
- e := &InjectableEndpoint{endpoint: endpoint{
+ return &InjectableEndpoint{endpoint: endpoint{
fds: []int{fd},
mtu: mtu,
caps: capabilities,
}}
-
- return stack.RegisterLinkEndpoint(e), e
}
diff --git a/pkg/tcpip/link/fdbased/endpoint_test.go b/pkg/tcpip/link/fdbased/endpoint_test.go
index e305252d6..04406bc9a 100644
--- a/pkg/tcpip/link/fdbased/endpoint_test.go
+++ b/pkg/tcpip/link/fdbased/endpoint_test.go
@@ -68,11 +68,10 @@ func newContext(t *testing.T, opt *Options) *context {
}
opt.FDs = []int{fds[1]}
- epID, err := New(opt)
+ ep, err := New(opt)
if err != nil {
t.Fatalf("Failed to create FD endpoint: %v", err)
}
- ep := stack.FindLinkEndpoint(epID).(*endpoint)
c := &context{
t: t,
diff --git a/pkg/tcpip/link/loopback/loopback.go b/pkg/tcpip/link/loopback/loopback.go
index ab6a53988..b36629d2c 100644
--- a/pkg/tcpip/link/loopback/loopback.go
+++ b/pkg/tcpip/link/loopback/loopback.go
@@ -32,8 +32,8 @@ type endpoint struct {
// New creates a new loopback endpoint. This link-layer endpoint just turns
// outbound packets into inbound packets.
-func New() tcpip.LinkEndpointID {
- return stack.RegisterLinkEndpoint(&endpoint{})
+func New() stack.LinkEndpoint {
+ return &endpoint{}
}
// Attach implements stack.LinkEndpoint.Attach. It just saves the stack network-
@@ -85,3 +85,6 @@ func (e *endpoint) WritePacket(_ *stack.Route, _ *stack.GSO, hdr buffer.Prependa
return nil
}
+
+// Wait implements stack.LinkEndpoint.Wait.
+func (*endpoint) Wait() {}
diff --git a/pkg/tcpip/link/muxed/BUILD b/pkg/tcpip/link/muxed/BUILD
index ea12ef1ac..1bab380b0 100644
--- a/pkg/tcpip/link/muxed/BUILD
+++ b/pkg/tcpip/link/muxed/BUILD
@@ -1,4 +1,5 @@
-load("//tools/go_stateify:defs.bzl", "go_library", "go_test")
+load("//tools/go_stateify:defs.bzl", "go_library")
+load("@io_bazel_rules_go//go:def.bzl", "go_test")
package(licenses = ["notice"])
diff --git a/pkg/tcpip/link/muxed/injectable.go b/pkg/tcpip/link/muxed/injectable.go
index a577a3d52..7c946101d 100644
--- a/pkg/tcpip/link/muxed/injectable.go
+++ b/pkg/tcpip/link/muxed/injectable.go
@@ -104,10 +104,16 @@ func (m *InjectableEndpoint) WriteRawPacket(dest tcpip.Address, packet []byte) *
return endpoint.WriteRawPacket(dest, packet)
}
+// Wait implements stack.LinkEndpoint.Wait.
+func (m *InjectableEndpoint) Wait() {
+ for _, ep := range m.routes {
+ ep.Wait()
+ }
+}
+
// NewInjectableEndpoint creates a new multi-endpoint injectable endpoint.
-func NewInjectableEndpoint(routes map[tcpip.Address]stack.InjectableLinkEndpoint) (tcpip.LinkEndpointID, *InjectableEndpoint) {
- e := &InjectableEndpoint{
+func NewInjectableEndpoint(routes map[tcpip.Address]stack.InjectableLinkEndpoint) *InjectableEndpoint {
+ return &InjectableEndpoint{
routes: routes,
}
- return stack.RegisterLinkEndpoint(e), e
}
diff --git a/pkg/tcpip/link/muxed/injectable_test.go b/pkg/tcpip/link/muxed/injectable_test.go
index 174b9330f..3086fec00 100644
--- a/pkg/tcpip/link/muxed/injectable_test.go
+++ b/pkg/tcpip/link/muxed/injectable_test.go
@@ -87,8 +87,8 @@ func makeTestInjectableEndpoint(t *testing.T) (*InjectableEndpoint, *os.File, tc
if err != nil {
t.Fatal("Failed to create socket pair:", err)
}
- _, underlyingEndpoint := fdbased.NewInjectable(pair[1], 6500, stack.CapabilityNone)
+ underlyingEndpoint := fdbased.NewInjectable(pair[1], 6500, stack.CapabilityNone)
routes := map[tcpip.Address]stack.InjectableLinkEndpoint{dstIP: underlyingEndpoint}
- _, endpoint := NewInjectableEndpoint(routes)
+ endpoint := NewInjectableEndpoint(routes)
return endpoint, os.NewFile(uintptr(pair[0]), "test route end"), dstIP
}
diff --git a/pkg/tcpip/link/sharedmem/BUILD b/pkg/tcpip/link/sharedmem/BUILD
index f2998aa98..0a5ea3dc4 100644
--- a/pkg/tcpip/link/sharedmem/BUILD
+++ b/pkg/tcpip/link/sharedmem/BUILD
@@ -1,4 +1,5 @@
-load("//tools/go_stateify:defs.bzl", "go_library", "go_test")
+load("//tools/go_stateify:defs.bzl", "go_library")
+load("@io_bazel_rules_go//go:def.bzl", "go_test")
package(licenses = ["notice"])
diff --git a/pkg/tcpip/link/sharedmem/pipe/BUILD b/pkg/tcpip/link/sharedmem/pipe/BUILD
index 94725cb11..330ed5e94 100644
--- a/pkg/tcpip/link/sharedmem/pipe/BUILD
+++ b/pkg/tcpip/link/sharedmem/pipe/BUILD
@@ -1,4 +1,5 @@
-load("//tools/go_stateify:defs.bzl", "go_library", "go_test")
+load("//tools/go_stateify:defs.bzl", "go_library")
+load("@io_bazel_rules_go//go:def.bzl", "go_test")
package(licenses = ["notice"])
diff --git a/pkg/tcpip/link/sharedmem/queue/BUILD b/pkg/tcpip/link/sharedmem/queue/BUILD
index 160a8f864..de1ce043d 100644
--- a/pkg/tcpip/link/sharedmem/queue/BUILD
+++ b/pkg/tcpip/link/sharedmem/queue/BUILD
@@ -1,4 +1,5 @@
-load("//tools/go_stateify:defs.bzl", "go_library", "go_test")
+load("//tools/go_stateify:defs.bzl", "go_library")
+load("@io_bazel_rules_go//go:def.bzl", "go_test")
package(licenses = ["notice"])
diff --git a/pkg/tcpip/link/sharedmem/sharedmem.go b/pkg/tcpip/link/sharedmem/sharedmem.go
index 834ea5c40..9e71d4edf 100644
--- a/pkg/tcpip/link/sharedmem/sharedmem.go
+++ b/pkg/tcpip/link/sharedmem/sharedmem.go
@@ -94,7 +94,7 @@ type endpoint struct {
// New creates a new shared-memory-based endpoint. Buffers will be broken up
// into buffers of "bufferSize" bytes.
-func New(mtu, bufferSize uint32, addr tcpip.LinkAddress, tx, rx QueueConfig) (tcpip.LinkEndpointID, error) {
+func New(mtu, bufferSize uint32, addr tcpip.LinkAddress, tx, rx QueueConfig) (stack.LinkEndpoint, error) {
e := &endpoint{
mtu: mtu,
bufferSize: bufferSize,
@@ -102,15 +102,15 @@ func New(mtu, bufferSize uint32, addr tcpip.LinkAddress, tx, rx QueueConfig) (tc
}
if err := e.tx.init(bufferSize, &tx); err != nil {
- return 0, err
+ return nil, err
}
if err := e.rx.init(bufferSize, &rx); err != nil {
e.tx.cleanup()
- return 0, err
+ return nil, err
}
- return stack.RegisterLinkEndpoint(e), nil
+ return e, nil
}
// Close frees all resources associated with the endpoint.
@@ -132,7 +132,8 @@ func (e *endpoint) Close() {
}
}
-// Wait waits until all workers have stopped after a Close() call.
+// Wait implements stack.LinkEndpoint.Wait. It waits until all workers have
+// stopped after a Close() call.
func (e *endpoint) Wait() {
e.completed.Wait()
}
diff --git a/pkg/tcpip/link/sharedmem/sharedmem_test.go b/pkg/tcpip/link/sharedmem/sharedmem_test.go
index 98036f367..0e9ba0846 100644
--- a/pkg/tcpip/link/sharedmem/sharedmem_test.go
+++ b/pkg/tcpip/link/sharedmem/sharedmem_test.go
@@ -119,12 +119,12 @@ func newTestContext(t *testing.T, mtu, bufferSize uint32, addr tcpip.LinkAddress
initQueue(t, &c.txq, &c.txCfg)
initQueue(t, &c.rxq, &c.rxCfg)
- id, err := New(mtu, bufferSize, addr, c.txCfg, c.rxCfg)
+ ep, err := New(mtu, bufferSize, addr, c.txCfg, c.rxCfg)
if err != nil {
t.Fatalf("New failed: %v", err)
}
- c.ep = stack.FindLinkEndpoint(id).(*endpoint)
+ c.ep = ep.(*endpoint)
c.ep.Attach(c)
return c
diff --git a/pkg/tcpip/link/sniffer/sniffer.go b/pkg/tcpip/link/sniffer/sniffer.go
index 36c8c46fc..e401dce44 100644
--- a/pkg/tcpip/link/sniffer/sniffer.go
+++ b/pkg/tcpip/link/sniffer/sniffer.go
@@ -58,10 +58,10 @@ type endpoint struct {
// New creates a new sniffer link-layer endpoint. It wraps around another
// endpoint and logs packets and they traverse the endpoint.
-func New(lower tcpip.LinkEndpointID) tcpip.LinkEndpointID {
- return stack.RegisterLinkEndpoint(&endpoint{
- lower: stack.FindLinkEndpoint(lower),
- })
+func New(lower stack.LinkEndpoint) stack.LinkEndpoint {
+ return &endpoint{
+ lower: lower,
+ }
}
func zoneOffset() (int32, error) {
@@ -102,15 +102,15 @@ func writePCAPHeader(w io.Writer, maxLen uint32) error {
// snapLen is the maximum amount of a packet to be saved. Packets with a length
// less than or equal too snapLen will be saved in their entirety. Longer
// packets will be truncated to snapLen.
-func NewWithFile(lower tcpip.LinkEndpointID, file *os.File, snapLen uint32) (tcpip.LinkEndpointID, error) {
+func NewWithFile(lower stack.LinkEndpoint, file *os.File, snapLen uint32) (stack.LinkEndpoint, error) {
if err := writePCAPHeader(file, snapLen); err != nil {
- return 0, err
+ return nil, err
}
- return stack.RegisterLinkEndpoint(&endpoint{
- lower: stack.FindLinkEndpoint(lower),
+ return &endpoint{
+ lower: lower,
file: file,
maxPCAPLen: snapLen,
- }), nil
+ }, nil
}
// DeliverNetworkPacket implements the stack.NetworkDispatcher interface. It is
@@ -240,6 +240,9 @@ func (e *endpoint) WritePacket(r *stack.Route, gso *stack.GSO, hdr buffer.Prepen
return e.lower.WritePacket(r, gso, hdr, payload, protocol)
}
+// Wait implements stack.LinkEndpoint.Wait.
+func (*endpoint) Wait() {}
+
func logPacket(prefix string, protocol tcpip.NetworkProtocolNumber, b buffer.View, gso *stack.GSO) {
// Figure out the network layer info.
var transProto uint8
diff --git a/pkg/tcpip/link/waitable/BUILD b/pkg/tcpip/link/waitable/BUILD
index 2597d4b3e..0746dc8ec 100644
--- a/pkg/tcpip/link/waitable/BUILD
+++ b/pkg/tcpip/link/waitable/BUILD
@@ -1,4 +1,5 @@
-load("//tools/go_stateify:defs.bzl", "go_library", "go_test")
+load("//tools/go_stateify:defs.bzl", "go_library")
+load("@io_bazel_rules_go//go:def.bzl", "go_test")
package(licenses = ["notice"])
diff --git a/pkg/tcpip/link/waitable/waitable.go b/pkg/tcpip/link/waitable/waitable.go
index 3b6ac2ff7..5a1791cb5 100644
--- a/pkg/tcpip/link/waitable/waitable.go
+++ b/pkg/tcpip/link/waitable/waitable.go
@@ -40,11 +40,10 @@ type Endpoint struct {
// New creates a new waitable link-layer endpoint. It wraps around another
// endpoint and allows the caller to block new write/dispatch calls and wait for
// the inflight ones to finish before returning.
-func New(lower tcpip.LinkEndpointID) (tcpip.LinkEndpointID, *Endpoint) {
- e := &Endpoint{
- lower: stack.FindLinkEndpoint(lower),
+func New(lower stack.LinkEndpoint) *Endpoint {
+ return &Endpoint{
+ lower: lower,
}
- return stack.RegisterLinkEndpoint(e), e
}
// DeliverNetworkPacket implements stack.NetworkDispatcher.DeliverNetworkPacket.
@@ -121,3 +120,6 @@ func (e *Endpoint) WaitWrite() {
func (e *Endpoint) WaitDispatch() {
e.dispatchGate.Close()
}
+
+// Wait implements stack.LinkEndpoint.Wait.
+func (e *Endpoint) Wait() {}
diff --git a/pkg/tcpip/link/waitable/waitable_test.go b/pkg/tcpip/link/waitable/waitable_test.go
index 56e18ecb0..ae23c96b7 100644
--- a/pkg/tcpip/link/waitable/waitable_test.go
+++ b/pkg/tcpip/link/waitable/waitable_test.go
@@ -70,9 +70,12 @@ func (e *countedEndpoint) WritePacket(r *stack.Route, _ *stack.GSO, hdr buffer.P
return nil
}
+// Wait implements stack.LinkEndpoint.Wait.
+func (*countedEndpoint) Wait() {}
+
func TestWaitWrite(t *testing.T) {
ep := &countedEndpoint{}
- _, wep := New(stack.RegisterLinkEndpoint(ep))
+ wep := New(ep)
// Write and check that it goes through.
wep.WritePacket(nil, nil /* gso */, buffer.Prependable{}, buffer.VectorisedView{}, 0)
@@ -97,7 +100,7 @@ func TestWaitWrite(t *testing.T) {
func TestWaitDispatch(t *testing.T) {
ep := &countedEndpoint{}
- _, wep := New(stack.RegisterLinkEndpoint(ep))
+ wep := New(ep)
// Check that attach happens.
wep.Attach(ep)
@@ -139,7 +142,7 @@ func TestOtherMethods(t *testing.T) {
hdrLen: hdrLen,
linkAddr: linkAddr,
}
- _, wep := New(stack.RegisterLinkEndpoint(ep))
+ wep := New(ep)
if v := wep.MTU(); v != mtu {
t.Fatalf("Unexpected mtu: got=%v, want=%v", v, mtu)
diff --git a/pkg/tcpip/network/BUILD b/pkg/tcpip/network/BUILD
index f36f49453..9d16ff8c9 100644
--- a/pkg/tcpip/network/BUILD
+++ b/pkg/tcpip/network/BUILD
@@ -1,4 +1,4 @@
-load("//tools/go_stateify:defs.bzl", "go_test")
+load("@io_bazel_rules_go//go:def.bzl", "go_test")
package(licenses = ["notice"])
diff --git a/pkg/tcpip/network/arp/BUILD b/pkg/tcpip/network/arp/BUILD
index d95d44f56..df0d3a8c0 100644
--- a/pkg/tcpip/network/arp/BUILD
+++ b/pkg/tcpip/network/arp/BUILD
@@ -1,4 +1,5 @@
-load("//tools/go_stateify:defs.bzl", "go_library", "go_test")
+load("//tools/go_stateify:defs.bzl", "go_library")
+load("@io_bazel_rules_go//go:def.bzl", "go_test")
package(licenses = ["notice"])
diff --git a/pkg/tcpip/network/arp/arp_test.go b/pkg/tcpip/network/arp/arp_test.go
index 4c4b54469..387fca96e 100644
--- a/pkg/tcpip/network/arp/arp_test.go
+++ b/pkg/tcpip/network/arp/arp_test.go
@@ -47,11 +47,13 @@ func newTestContext(t *testing.T) *testContext {
s := stack.New([]string{ipv4.ProtocolName, arp.ProtocolName}, []string{icmp.ProtocolName4}, stack.Options{})
const defaultMTU = 65536
- id, linkEP := channel.New(256, defaultMTU, stackLinkAddr)
+ ep := channel.New(256, defaultMTU, stackLinkAddr)
+ wep := stack.LinkEndpoint(ep)
+
if testing.Verbose() {
- id = sniffer.New(id)
+ wep = sniffer.New(ep)
}
- if err := s.CreateNIC(1, id); err != nil {
+ if err := s.CreateNIC(1, wep); err != nil {
t.Fatalf("CreateNIC failed: %v", err)
}
@@ -73,7 +75,7 @@ func newTestContext(t *testing.T) *testContext {
return &testContext{
t: t,
s: s,
- linkEP: linkEP,
+ linkEP: ep,
}
}
diff --git a/pkg/tcpip/network/fragmentation/BUILD b/pkg/tcpip/network/fragmentation/BUILD
index 118bfc763..c5c7aad86 100644
--- a/pkg/tcpip/network/fragmentation/BUILD
+++ b/pkg/tcpip/network/fragmentation/BUILD
@@ -1,7 +1,9 @@
+load("@io_bazel_rules_go//go:def.bzl", "go_test")
+
package(licenses = ["notice"])
load("//tools/go_generics:defs.bzl", "go_template_instance")
-load("//tools/go_stateify:defs.bzl", "go_library", "go_test")
+load("//tools/go_stateify:defs.bzl", "go_library")
go_template_instance(
name = "reassembler_list",
diff --git a/pkg/tcpip/network/ip_test.go b/pkg/tcpip/network/ip_test.go
index 4b3bd74fa..6a40e7ee3 100644
--- a/pkg/tcpip/network/ip_test.go
+++ b/pkg/tcpip/network/ip_test.go
@@ -144,6 +144,9 @@ func (*testObject) LinkAddress() tcpip.LinkAddress {
return ""
}
+// Wait implements stack.LinkEndpoint.Wait.
+func (*testObject) Wait() {}
+
// WritePacket is called by network endpoints after producing a packet and
// writing it to the link endpoint. This is used by the test object to verify
// that the produced packet is as expected.
diff --git a/pkg/tcpip/network/ipv4/BUILD b/pkg/tcpip/network/ipv4/BUILD
index be84fa63d..58e537aad 100644
--- a/pkg/tcpip/network/ipv4/BUILD
+++ b/pkg/tcpip/network/ipv4/BUILD
@@ -1,4 +1,5 @@
-load("//tools/go_stateify:defs.bzl", "go_library", "go_test")
+load("//tools/go_stateify:defs.bzl", "go_library")
+load("@io_bazel_rules_go//go:def.bzl", "go_test")
package(licenses = ["notice"])
diff --git a/pkg/tcpip/network/ipv4/ipv4_test.go b/pkg/tcpip/network/ipv4/ipv4_test.go
index 1b5a55bea..ae827ca27 100644
--- a/pkg/tcpip/network/ipv4/ipv4_test.go
+++ b/pkg/tcpip/network/ipv4/ipv4_test.go
@@ -36,11 +36,11 @@ func TestExcludeBroadcast(t *testing.T) {
s := stack.New([]string{ipv4.ProtocolName}, []string{udp.ProtocolName}, stack.Options{})
const defaultMTU = 65536
- id, _ := channel.New(256, defaultMTU, "")
+ ep := stack.LinkEndpoint(channel.New(256, defaultMTU, ""))
if testing.Verbose() {
- id = sniffer.New(id)
+ ep = sniffer.New(ep)
}
- if err := s.CreateNIC(1, id); err != nil {
+ if err := s.CreateNIC(1, ep); err != nil {
t.Fatalf("CreateNIC failed: %v", err)
}
@@ -184,15 +184,12 @@ type errorChannel struct {
// newErrorChannel creates a new errorChannel endpoint. Each call to WritePacket
// will return successive errors from packetCollectorErrors until the list is
// empty and then return nil each time.
-func newErrorChannel(size int, mtu uint32, linkAddr tcpip.LinkAddress, packetCollectorErrors []*tcpip.Error) (tcpip.LinkEndpointID, *errorChannel) {
- _, e := channel.New(size, mtu, linkAddr)
- ec := errorChannel{
- Endpoint: e,
+func newErrorChannel(size int, mtu uint32, linkAddr tcpip.LinkAddress, packetCollectorErrors []*tcpip.Error) *errorChannel {
+ return &errorChannel{
+ Endpoint: channel.New(size, mtu, linkAddr),
Ch: make(chan packetInfo, size),
packetCollectorErrors: packetCollectorErrors,
}
-
- return stack.RegisterLinkEndpoint(e), &ec
}
// packetInfo holds all the information about an outbound packet.
@@ -242,9 +239,8 @@ type context struct {
func buildContext(t *testing.T, packetCollectorErrors []*tcpip.Error, mtu uint32) context {
// Make the packet and write it.
s := stack.New([]string{ipv4.ProtocolName}, []string{}, stack.Options{})
- _, linkEP := newErrorChannel(100 /* Enough for all tests. */, mtu, "", packetCollectorErrors)
- linkEPId := stack.RegisterLinkEndpoint(linkEP)
- s.CreateNIC(1, linkEPId)
+ ep := newErrorChannel(100 /* Enough for all tests. */, mtu, "", packetCollectorErrors)
+ s.CreateNIC(1, ep)
const (
src = "\x10\x00\x00\x01"
dst = "\x10\x00\x00\x02"
@@ -266,7 +262,7 @@ func buildContext(t *testing.T, packetCollectorErrors []*tcpip.Error, mtu uint32
}
return context{
Route: r,
- linkEP: linkEP,
+ linkEP: ep,
}
}
diff --git a/pkg/tcpip/network/ipv6/BUILD b/pkg/tcpip/network/ipv6/BUILD
index c71b69123..f06622a8b 100644
--- a/pkg/tcpip/network/ipv6/BUILD
+++ b/pkg/tcpip/network/ipv6/BUILD
@@ -1,4 +1,5 @@
-load("//tools/go_stateify:defs.bzl", "go_library", "go_test")
+load("//tools/go_stateify:defs.bzl", "go_library")
+load("@io_bazel_rules_go//go:def.bzl", "go_test")
package(licenses = ["notice"])
@@ -25,6 +26,7 @@ go_test(
size = "small",
srcs = [
"icmp_test.go",
+ "ipv6_test.go",
"ndp_test.go",
],
embed = [":ipv6"],
@@ -36,6 +38,7 @@ go_test(
"//pkg/tcpip/link/sniffer",
"//pkg/tcpip/stack",
"//pkg/tcpip/transport/icmp",
+ "//pkg/tcpip/transport/udp",
"//pkg/waiter",
],
)
diff --git a/pkg/tcpip/network/ipv6/icmp_test.go b/pkg/tcpip/network/ipv6/icmp_test.go
index 227a65cf2..653d984e9 100644
--- a/pkg/tcpip/network/ipv6/icmp_test.go
+++ b/pkg/tcpip/network/ipv6/icmp_test.go
@@ -83,8 +83,7 @@ func (*stubLinkAddressCache) AddLinkAddress(tcpip.NICID, tcpip.Address, tcpip.Li
func TestICMPCounts(t *testing.T) {
s := stack.New([]string{ProtocolName}, []string{icmp.ProtocolName6}, stack.Options{})
{
- id := stack.RegisterLinkEndpoint(&stubLinkEndpoint{})
- if err := s.CreateNIC(1, id); err != nil {
+ if err := s.CreateNIC(1, &stubLinkEndpoint{}); err != nil {
t.Fatalf("CreateNIC(_) = %s", err)
}
if err := s.AddAddress(1, ProtocolNumber, lladdr0); err != nil {
@@ -211,36 +210,27 @@ func newTestContext(t *testing.T) *testContext {
}
const defaultMTU = 65536
- _, linkEP0 := channel.New(256, defaultMTU, linkAddr0)
- c.linkEP0 = linkEP0
- wrappedEP0 := endpointWithResolutionCapability{LinkEndpoint: linkEP0}
- id0 := stack.RegisterLinkEndpoint(wrappedEP0)
+ c.linkEP0 = channel.New(256, defaultMTU, linkAddr0)
+
+ wrappedEP0 := stack.LinkEndpoint(endpointWithResolutionCapability{LinkEndpoint: c.linkEP0})
if testing.Verbose() {
- id0 = sniffer.New(id0)
+ wrappedEP0 = sniffer.New(wrappedEP0)
}
- if err := c.s0.CreateNIC(1, id0); err != nil {
+ if err := c.s0.CreateNIC(1, wrappedEP0); err != nil {
t.Fatalf("CreateNIC s0: %v", err)
}
if err := c.s0.AddAddress(1, ProtocolNumber, lladdr0); err != nil {
t.Fatalf("AddAddress lladdr0: %v", err)
}
- if err := c.s0.AddAddress(1, ProtocolNumber, header.SolicitedNodeAddr(lladdr0)); err != nil {
- t.Fatalf("AddAddress sn lladdr0: %v", err)
- }
- _, linkEP1 := channel.New(256, defaultMTU, linkAddr1)
- c.linkEP1 = linkEP1
- wrappedEP1 := endpointWithResolutionCapability{LinkEndpoint: linkEP1}
- id1 := stack.RegisterLinkEndpoint(wrappedEP1)
- if err := c.s1.CreateNIC(1, id1); err != nil {
+ c.linkEP1 = channel.New(256, defaultMTU, linkAddr1)
+ wrappedEP1 := stack.LinkEndpoint(endpointWithResolutionCapability{LinkEndpoint: c.linkEP1})
+ if err := c.s1.CreateNIC(1, wrappedEP1); err != nil {
t.Fatalf("CreateNIC failed: %v", err)
}
if err := c.s1.AddAddress(1, ProtocolNumber, lladdr1); err != nil {
t.Fatalf("AddAddress lladdr1: %v", err)
}
- if err := c.s1.AddAddress(1, ProtocolNumber, header.SolicitedNodeAddr(lladdr1)); err != nil {
- t.Fatalf("AddAddress sn lladdr1: %v", err)
- }
subnet0, err := tcpip.NewSubnet(lladdr1, tcpip.AddressMask(strings.Repeat("\xff", len(lladdr1))))
if err != nil {
diff --git a/pkg/tcpip/network/ipv6/ipv6_test.go b/pkg/tcpip/network/ipv6/ipv6_test.go
new file mode 100644
index 000000000..57bcd5455
--- /dev/null
+++ b/pkg/tcpip/network/ipv6/ipv6_test.go
@@ -0,0 +1,258 @@
+// Copyright 2019 The gVisor Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+package ipv6
+
+import (
+ "testing"
+
+ "gvisor.dev/gvisor/pkg/tcpip"
+ "gvisor.dev/gvisor/pkg/tcpip/buffer"
+ "gvisor.dev/gvisor/pkg/tcpip/header"
+ "gvisor.dev/gvisor/pkg/tcpip/link/channel"
+ "gvisor.dev/gvisor/pkg/tcpip/stack"
+ "gvisor.dev/gvisor/pkg/tcpip/transport/icmp"
+ "gvisor.dev/gvisor/pkg/tcpip/transport/udp"
+ "gvisor.dev/gvisor/pkg/waiter"
+)
+
+const (
+ addr1 = "\x0a\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x01"
+ addr2 = "\x0a\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x02"
+ // The least significant 3 bytes are the same as addr2 so both addr2 and
+ // addr3 will have the same solicited-node address.
+ addr3 = "\x0a\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x01\x00\x00\x02"
+)
+
+// testReceiveICMP tests receiving an ICMP packet from src to dst. want is the
+// expected Neighbor Advertisement received count after receiving the packet.
+func testReceiveICMP(t *testing.T, s *stack.Stack, e *channel.Endpoint, src, dst tcpip.Address, want uint64) {
+ t.Helper()
+
+ // Receive ICMP packet.
+ hdr := buffer.NewPrependable(header.IPv6MinimumSize + header.ICMPv6NeighborAdvertSize)
+ pkt := header.ICMPv6(hdr.Prepend(header.ICMPv6NeighborAdvertSize))
+ pkt.SetType(header.ICMPv6NeighborAdvert)
+ pkt.SetChecksum(header.ICMPv6Checksum(pkt, src, dst, buffer.VectorisedView{}))
+ payloadLength := hdr.UsedLength()
+ ip := header.IPv6(hdr.Prepend(header.IPv6MinimumSize))
+ ip.Encode(&header.IPv6Fields{
+ PayloadLength: uint16(payloadLength),
+ NextHeader: uint8(header.ICMPv6ProtocolNumber),
+ HopLimit: 255,
+ SrcAddr: src,
+ DstAddr: dst,
+ })
+
+ e.Inject(ProtocolNumber, hdr.View().ToVectorisedView())
+
+ stats := s.Stats().ICMP.V6PacketsReceived
+
+ if got := stats.NeighborAdvert.Value(); got != want {
+ t.Fatalf("got NeighborAdvert = %d, want = %d", got, want)
+ }
+}
+
+// testReceiveICMP tests receiving a UDP packet from src to dst. want is the
+// expected UDP received count after receiving the packet.
+func testReceiveUDP(t *testing.T, s *stack.Stack, e *channel.Endpoint, src, dst tcpip.Address, want uint64) {
+ t.Helper()
+
+ wq := waiter.Queue{}
+ we, ch := waiter.NewChannelEntry(nil)
+ wq.EventRegister(&we, waiter.EventIn)
+ defer wq.EventUnregister(&we)
+ defer close(ch)
+
+ ep, err := s.NewEndpoint(udp.ProtocolNumber, ProtocolNumber, &wq)
+ if err != nil {
+ t.Fatalf("NewEndpoint failed: %v", err)
+ }
+ defer ep.Close()
+
+ if err := ep.Bind(tcpip.FullAddress{Addr: dst, Port: 80}); err != nil {
+ t.Fatalf("ep.Bind(...) failed: %v", err)
+ }
+
+ // Receive UDP Packet.
+ hdr := buffer.NewPrependable(header.IPv6MinimumSize + header.UDPMinimumSize)
+ u := header.UDP(hdr.Prepend(header.UDPMinimumSize))
+ u.Encode(&header.UDPFields{
+ SrcPort: 5555,
+ DstPort: 80,
+ Length: header.UDPMinimumSize,
+ })
+
+ // UDP pseudo-header checksum.
+ sum := header.PseudoHeaderChecksum(udp.ProtocolNumber, src, dst, header.UDPMinimumSize)
+
+ // UDP checksum
+ sum = header.Checksum(header.UDP([]byte{}), sum)
+ u.SetChecksum(^u.CalculateChecksum(sum))
+
+ payloadLength := hdr.UsedLength()
+ ip := header.IPv6(hdr.Prepend(header.IPv6MinimumSize))
+ ip.Encode(&header.IPv6Fields{
+ PayloadLength: uint16(payloadLength),
+ NextHeader: uint8(udp.ProtocolNumber),
+ HopLimit: 255,
+ SrcAddr: src,
+ DstAddr: dst,
+ })
+
+ e.Inject(ProtocolNumber, hdr.View().ToVectorisedView())
+
+ stat := s.Stats().UDP.PacketsReceived
+
+ if got := stat.Value(); got != want {
+ t.Fatalf("got UDPPacketsReceived = %d, want = %d", got, want)
+ }
+}
+
+// TestReceiveOnAllNodesMulticastAddr tests that IPv6 endpoints receive ICMP and
+// UDP packets destined to the IPv6 link-local all-nodes multicast address.
+func TestReceiveOnAllNodesMulticastAddr(t *testing.T) {
+ tests := []struct {
+ name string
+ protocolName string
+ rxf func(t *testing.T, s *stack.Stack, e *channel.Endpoint, src, dst tcpip.Address, want uint64)
+ }{
+ {"ICMP", icmp.ProtocolName6, testReceiveICMP},
+ {"UDP", udp.ProtocolName, testReceiveUDP},
+ }
+
+ for _, test := range tests {
+ t.Run(test.name, func(t *testing.T) {
+ s := stack.New([]string{ProtocolName}, []string{test.protocolName}, stack.Options{})
+ e := channel.New(10, 1280, linkAddr1)
+ if err := s.CreateNIC(1, e); err != nil {
+ t.Fatalf("CreateNIC(_) = %s", err)
+ }
+
+ // Should receive a packet destined to the all-nodes
+ // multicast address.
+ test.rxf(t, s, e, addr1, header.IPv6AllNodesMulticastAddress, 1)
+ })
+ }
+}
+
+// TestReceiveOnSolicitedNodeAddr tests that IPv6 endpoints receive ICMP and UDP
+// packets destined to the IPv6 solicited-node address of an assigned IPv6
+// address.
+func TestReceiveOnSolicitedNodeAddr(t *testing.T) {
+ tests := []struct {
+ name string
+ protocolName string
+ rxf func(t *testing.T, s *stack.Stack, e *channel.Endpoint, src, dst tcpip.Address, want uint64)
+ }{
+ {"ICMP", icmp.ProtocolName6, testReceiveICMP},
+ {"UDP", udp.ProtocolName, testReceiveUDP},
+ }
+
+ snmc := header.SolicitedNodeAddr(addr2)
+
+ for _, test := range tests {
+ t.Run(test.name, func(t *testing.T) {
+ s := stack.New([]string{ProtocolName}, []string{test.protocolName}, stack.Options{})
+ e := channel.New(10, 1280, linkAddr1)
+ if err := s.CreateNIC(1, e); err != nil {
+ t.Fatalf("CreateNIC(_) = %s", err)
+ }
+
+ // Should not receive a packet destined to the solicited
+ // node address of addr2/addr3 yet as we haven't added
+ // those addresses.
+ test.rxf(t, s, e, addr1, snmc, 0)
+
+ if err := s.AddAddress(1, ProtocolNumber, addr2); err != nil {
+ t.Fatalf("AddAddress(_, %d, %s) = %s", ProtocolNumber, addr2, err)
+ }
+
+ // Should receive a packet destined to the solicited
+ // node address of addr2/addr3 now that we have added
+ // added addr2.
+ test.rxf(t, s, e, addr1, snmc, 1)
+
+ if err := s.AddAddress(1, ProtocolNumber, addr3); err != nil {
+ t.Fatalf("AddAddress(_, %d, %s) = %s", ProtocolNumber, addr3, err)
+ }
+
+ // Should still receive a packet destined to the
+ // solicited node address of addr2/addr3 now that we
+ // have added addr3.
+ test.rxf(t, s, e, addr1, snmc, 2)
+
+ if err := s.RemoveAddress(1, addr2); err != nil {
+ t.Fatalf("RemoveAddress(_, %s) = %s", addr2, err)
+ }
+
+ // Should still receive a packet destined to the
+ // solicited node address of addr2/addr3 now that we
+ // have removed addr2.
+ test.rxf(t, s, e, addr1, snmc, 3)
+
+ if err := s.RemoveAddress(1, addr3); err != nil {
+ t.Fatalf("RemoveAddress(_, %s) = %s", addr3, err)
+ }
+
+ // Should not receive a packet destined to the solicited
+ // node address of addr2/addr3 yet as both of them got
+ // removed.
+ test.rxf(t, s, e, addr1, snmc, 3)
+ })
+ }
+}
+
+// TestAddIpv6Address tests adding IPv6 addresses.
+func TestAddIpv6Address(t *testing.T) {
+ tests := []struct {
+ name string
+ addr tcpip.Address
+ }{
+ // This test is in response to b/140943433.
+ {
+ "Nil",
+ tcpip.Address([]byte(nil)),
+ },
+ {
+ "ValidUnicast",
+ addr1,
+ },
+ {
+ "ValidLinkLocalUnicast",
+ lladdr0,
+ },
+ }
+
+ for _, test := range tests {
+ t.Run(test.name, func(t *testing.T) {
+ s := stack.New([]string{ProtocolName}, nil, stack.Options{})
+ if err := s.CreateNIC(1, &stubLinkEndpoint{}); err != nil {
+ t.Fatalf("CreateNIC(_) = %s", err)
+ }
+
+ if err := s.AddAddress(1, ProtocolNumber, test.addr); err != nil {
+ t.Fatalf("AddAddress(_, %d, nil) = %s", ProtocolNumber, err)
+ }
+
+ addr, err := s.GetMainNICAddress(1, header.IPv6ProtocolNumber)
+ if err != nil {
+ t.Fatalf("stack.GetMainNICAddress(_, _) err = %s", err)
+ }
+ if addr.Address != test.addr {
+ t.Fatalf("got stack.GetMainNICAddress(_, _) = %s, want = %s", addr.Address, test.addr)
+ }
+ })
+ }
+}
diff --git a/pkg/tcpip/network/ipv6/ndp_test.go b/pkg/tcpip/network/ipv6/ndp_test.go
index 8e4cf0e74..571915d3f 100644
--- a/pkg/tcpip/network/ipv6/ndp_test.go
+++ b/pkg/tcpip/network/ipv6/ndp_test.go
@@ -32,15 +32,14 @@ func setupStackAndEndpoint(t *testing.T, llladdr, rlladdr tcpip.Address) (*stack
t.Helper()
s := stack.New([]string{ProtocolName}, []string{icmp.ProtocolName6}, stack.Options{})
- {
- id := stack.RegisterLinkEndpoint(&stubLinkEndpoint{})
- if err := s.CreateNIC(1, id); err != nil {
- t.Fatalf("CreateNIC(_) = %s", err)
- }
- if err := s.AddAddress(1, ProtocolNumber, llladdr); err != nil {
- t.Fatalf("AddAddress(_, %d, %s) = %s", ProtocolNumber, llladdr, err)
- }
+
+ if err := s.CreateNIC(1, &stubLinkEndpoint{}); err != nil {
+ t.Fatalf("CreateNIC(_) = %s", err)
+ }
+ if err := s.AddAddress(1, ProtocolNumber, llladdr); err != nil {
+ t.Fatalf("AddAddress(_, %d, %s) = %s", ProtocolNumber, llladdr, err)
}
+
{
subnet, err := tcpip.NewSubnet(rlladdr, tcpip.AddressMask(strings.Repeat("\xff", len(rlladdr))))
if err != nil {
diff --git a/pkg/tcpip/ports/BUILD b/pkg/tcpip/ports/BUILD
index 989058413..11efb4e44 100644
--- a/pkg/tcpip/ports/BUILD
+++ b/pkg/tcpip/ports/BUILD
@@ -1,4 +1,5 @@
-load("//tools/go_stateify:defs.bzl", "go_library", "go_test")
+load("//tools/go_stateify:defs.bzl", "go_library")
+load("@io_bazel_rules_go//go:def.bzl", "go_test")
package(licenses = ["notice"])
diff --git a/pkg/tcpip/sample/tun_tcp_connect/main.go b/pkg/tcpip/sample/tun_tcp_connect/main.go
index e2021cd15..f12189580 100644
--- a/pkg/tcpip/sample/tun_tcp_connect/main.go
+++ b/pkg/tcpip/sample/tun_tcp_connect/main.go
@@ -138,11 +138,11 @@ func main() {
log.Fatal(err)
}
- linkID, err := fdbased.New(&fdbased.Options{FDs: []int{fd}, MTU: mtu})
+ linkEP, err := fdbased.New(&fdbased.Options{FDs: []int{fd}, MTU: mtu})
if err != nil {
log.Fatal(err)
}
- if err := s.CreateNIC(1, sniffer.New(linkID)); err != nil {
+ if err := s.CreateNIC(1, sniffer.New(linkEP)); err != nil {
log.Fatal(err)
}
diff --git a/pkg/tcpip/sample/tun_tcp_echo/main.go b/pkg/tcpip/sample/tun_tcp_echo/main.go
index 1716be285..329941775 100644
--- a/pkg/tcpip/sample/tun_tcp_echo/main.go
+++ b/pkg/tcpip/sample/tun_tcp_echo/main.go
@@ -128,7 +128,7 @@ func main() {
log.Fatal(err)
}
- linkID, err := fdbased.New(&fdbased.Options{
+ linkEP, err := fdbased.New(&fdbased.Options{
FDs: []int{fd},
MTU: mtu,
EthernetHeader: *tap,
@@ -137,7 +137,7 @@ func main() {
if err != nil {
log.Fatal(err)
}
- if err := s.CreateNIC(1, linkID); err != nil {
+ if err := s.CreateNIC(1, linkEP); err != nil {
log.Fatal(err)
}
diff --git a/pkg/tcpip/stack/BUILD b/pkg/tcpip/stack/BUILD
index 788de3dfe..28c49e8ff 100644
--- a/pkg/tcpip/stack/BUILD
+++ b/pkg/tcpip/stack/BUILD
@@ -1,7 +1,9 @@
+load("@io_bazel_rules_go//go:def.bzl", "go_test")
+
package(licenses = ["notice"])
load("//tools/go_generics:defs.bzl", "go_template_instance")
-load("//tools/go_stateify:defs.bzl", "go_library", "go_test")
+load("//tools/go_stateify:defs.bzl", "go_library")
go_template_instance(
name = "linkaddrentry_list",
diff --git a/pkg/tcpip/stack/icmp_rate_limit.go b/pkg/tcpip/stack/icmp_rate_limit.go
index f8156be47..3a20839da 100644
--- a/pkg/tcpip/stack/icmp_rate_limit.go
+++ b/pkg/tcpip/stack/icmp_rate_limit.go
@@ -15,8 +15,6 @@
package stack
import (
- "sync"
-
"golang.org/x/time/rate"
)
@@ -33,54 +31,11 @@ const (
// ICMPRateLimiter is a global rate limiter that controls the generation of
// ICMP messages generated by the stack.
type ICMPRateLimiter struct {
- mu sync.RWMutex
- l *rate.Limiter
+ *rate.Limiter
}
// NewICMPRateLimiter returns a global rate limiter for controlling the rate
// at which ICMP messages are generated by the stack.
func NewICMPRateLimiter() *ICMPRateLimiter {
- return &ICMPRateLimiter{l: rate.NewLimiter(icmpLimit, icmpBurst)}
-}
-
-// Allow returns true if we are allowed to send at least 1 message at the
-// moment.
-func (i *ICMPRateLimiter) Allow() bool {
- i.mu.RLock()
- allow := i.l.Allow()
- i.mu.RUnlock()
- return allow
-}
-
-// Limit returns the maximum number of ICMP messages that can be sent in one
-// second.
-func (i *ICMPRateLimiter) Limit() rate.Limit {
- i.mu.RLock()
- defer i.mu.RUnlock()
- return i.l.Limit()
-}
-
-// SetLimit sets the maximum number of ICMP messages that can be sent in one
-// second.
-func (i *ICMPRateLimiter) SetLimit(newLimit rate.Limit) {
- i.mu.RLock()
- defer i.mu.RUnlock()
- i.l.SetLimit(newLimit)
-}
-
-// Burst returns how many ICMP messages can be sent at any single instant.
-func (i *ICMPRateLimiter) Burst() int {
- i.mu.RLock()
- defer i.mu.RUnlock()
- return i.l.Burst()
-}
-
-// SetBurst sets the maximum number of ICMP messages allowed at any single
-// instant.
-//
-// NOTE: Changing Burst causes the underlying rate limiter to be recreated.
-func (i *ICMPRateLimiter) SetBurst(burst int) {
- i.mu.Lock()
- i.l = rate.NewLimiter(i.l.Limit(), burst)
- i.mu.Unlock()
+ return &ICMPRateLimiter{Limiter: rate.NewLimiter(icmpLimit, icmpBurst)}
}
diff --git a/pkg/tcpip/stack/nic.go b/pkg/tcpip/stack/nic.go
index 43719085e..a719058b4 100644
--- a/pkg/tcpip/stack/nic.go
+++ b/pkg/tcpip/stack/nic.go
@@ -102,6 +102,25 @@ func newNIC(stack *Stack, id tcpip.NICID, name string, ep LinkEndpoint, loopback
}
}
+// enable enables the NIC. enable will attach the link to its LinkEndpoint and
+// join the IPv6 All-Nodes Multicast address (ff02::1).
+func (n *NIC) enable() *tcpip.Error {
+ n.attachLinkEndpoint()
+
+ // Join the IPv6 All-Nodes Multicast group if the stack is configured to
+ // use IPv6. This is required to ensure that this node properly receives
+ // and responds to the various NDP messages that are destined to the
+ // all-nodes multicast address. An example is the Neighbor Advertisement
+ // when we perform Duplicate Address Detection, or Router Advertisement
+ // when we do Router Discovery. See RFC 4862, section 5.4.2 and RFC 4861
+ // section 4.2 for more information.
+ if _, ok := n.stack.networkProtocols[header.IPv6ProtocolNumber]; ok {
+ return n.joinGroup(header.IPv6ProtocolNumber, header.IPv6AllNodesMulticastAddress)
+ }
+
+ return nil
+}
+
// attachLinkEndpoint attaches the NIC to the endpoint, which will enable it
// to start delivering packets.
func (n *NIC) attachLinkEndpoint() {
@@ -307,6 +326,8 @@ func (n *NIC) addPermanentAddressLocked(protocolAddress tcpip.ProtocolAddress, p
}
func (n *NIC) addAddressLocked(protocolAddress tcpip.ProtocolAddress, peb PrimaryEndpointBehavior, kind networkEndpointKind) (*referencedNetworkEndpoint, *tcpip.Error) {
+ // TODO(b/141022673): Validate IP address before adding them.
+
// Sanity check.
id := NetworkEndpointID{protocolAddress.AddressWithPrefix.Address}
if _, ok := n.endpoints[id]; ok {
@@ -339,6 +360,15 @@ func (n *NIC) addAddressLocked(protocolAddress tcpip.ProtocolAddress, peb Primar
}
}
+ // If we are adding an IPv6 unicast address, join the solicited-node
+ // multicast address.
+ if protocolAddress.Protocol == header.IPv6ProtocolNumber && header.IsV6UnicastAddress(protocolAddress.AddressWithPrefix.Address) {
+ snmc := header.SolicitedNodeAddr(protocolAddress.AddressWithPrefix.Address)
+ if err := n.joinGroupLocked(protocolAddress.Protocol, snmc); err != nil {
+ return nil, err
+ }
+ }
+
n.endpoints[id] = ref
l, ok := n.primary[protocolAddress.Protocol]
@@ -467,13 +497,27 @@ func (n *NIC) removeEndpoint(r *referencedNetworkEndpoint) {
}
func (n *NIC) removePermanentAddressLocked(addr tcpip.Address) *tcpip.Error {
- r := n.endpoints[NetworkEndpointID{addr}]
- if r == nil || r.getKind() != permanent {
+ r, ok := n.endpoints[NetworkEndpointID{addr}]
+ if !ok || r.getKind() != permanent {
return tcpip.ErrBadLocalAddress
}
r.setKind(permanentExpired)
- r.decRefLocked()
+ if !r.decRefLocked() {
+ // The endpoint still has references to it.
+ return nil
+ }
+
+ // At this point the endpoint is deleted.
+
+ // If we are removing an IPv6 unicast address, leave the solicited-node
+ // multicast address.
+ if r.protocol == header.IPv6ProtocolNumber && header.IsV6UnicastAddress(addr) {
+ snmc := header.SolicitedNodeAddr(addr)
+ if err := n.leaveGroupLocked(snmc); err != nil {
+ return err
+ }
+ }
return nil
}
@@ -491,6 +535,13 @@ func (n *NIC) joinGroup(protocol tcpip.NetworkProtocolNumber, addr tcpip.Address
n.mu.Lock()
defer n.mu.Unlock()
+ return n.joinGroupLocked(protocol, addr)
+}
+
+// joinGroupLocked adds a new endpoint for the given multicast address, if none
+// exists yet. Otherwise it just increments its count. n MUST be locked before
+// joinGroupLocked is called.
+func (n *NIC) joinGroupLocked(protocol tcpip.NetworkProtocolNumber, addr tcpip.Address) *tcpip.Error {
id := NetworkEndpointID{addr}
joins := n.mcastJoins[id]
if joins == 0 {
@@ -518,6 +569,13 @@ func (n *NIC) leaveGroup(addr tcpip.Address) *tcpip.Error {
n.mu.Lock()
defer n.mu.Unlock()
+ return n.leaveGroupLocked(addr)
+}
+
+// leaveGroupLocked decrements the count for the given multicast address, and
+// when it reaches zero removes the endpoint for this address. n MUST be locked
+// before leaveGroupLocked is called.
+func (n *NIC) leaveGroupLocked(addr tcpip.Address) *tcpip.Error {
id := NetworkEndpointID{addr}
joins := n.mcastJoins[id]
switch joins {
@@ -802,11 +860,14 @@ func (r *referencedNetworkEndpoint) decRef() {
}
// decRefLocked is the same as decRef but assumes that the NIC.mu mutex is
-// locked.
-func (r *referencedNetworkEndpoint) decRefLocked() {
+// locked. Returns true if the endpoint was removed.
+func (r *referencedNetworkEndpoint) decRefLocked() bool {
if atomic.AddInt32(&r.refs, -1) == 0 {
r.nic.removeEndpointLocked(r)
+ return true
}
+
+ return false
}
// incRef increments the ref count. It must only be called when the caller is
diff --git a/pkg/tcpip/stack/registration.go b/pkg/tcpip/stack/registration.go
index 67b70b2ee..07e4c770d 100644
--- a/pkg/tcpip/stack/registration.go
+++ b/pkg/tcpip/stack/registration.go
@@ -15,8 +15,6 @@
package stack
import (
- "sync"
-
"gvisor.dev/gvisor/pkg/sleep"
"gvisor.dev/gvisor/pkg/tcpip"
"gvisor.dev/gvisor/pkg/tcpip/buffer"
@@ -297,6 +295,15 @@ type LinkEndpoint interface {
// IsAttached returns whether a NetworkDispatcher is attached to the
// endpoint.
IsAttached() bool
+
+ // Wait waits for any worker goroutines owned by the endpoint to stop.
+ //
+ // For now, requesting that an endpoint's worker goroutine(s) stop is
+ // implementation specific.
+ //
+ // Wait will not block if the endpoint hasn't started any goroutines
+ // yet, even if it might later.
+ Wait()
}
// InjectableLinkEndpoint is a LinkEndpoint where inbound packets are
@@ -379,10 +386,6 @@ var (
networkProtocols = make(map[string]NetworkProtocolFactory)
unassociatedFactory UnassociatedEndpointFactory
-
- linkEPMu sync.RWMutex
- nextLinkEndpointID tcpip.LinkEndpointID = 1
- linkEndpoints = make(map[tcpip.LinkEndpointID]LinkEndpoint)
)
// RegisterTransportProtocolFactory registers a new transport protocol factory
@@ -406,28 +409,6 @@ func RegisterUnassociatedFactory(f UnassociatedEndpointFactory) {
unassociatedFactory = f
}
-// RegisterLinkEndpoint register a link-layer protocol endpoint and returns an
-// ID that can be used to refer to it.
-func RegisterLinkEndpoint(linkEP LinkEndpoint) tcpip.LinkEndpointID {
- linkEPMu.Lock()
- defer linkEPMu.Unlock()
-
- v := nextLinkEndpointID
- nextLinkEndpointID++
-
- linkEndpoints[v] = linkEP
-
- return v
-}
-
-// FindLinkEndpoint finds the link endpoint associated with the given ID.
-func FindLinkEndpoint(id tcpip.LinkEndpointID) LinkEndpoint {
- linkEPMu.RLock()
- defer linkEPMu.RUnlock()
-
- return linkEndpoints[id]
-}
-
// GSOType is the type of GSO segments.
//
// +stateify savable
diff --git a/pkg/tcpip/stack/stack.go b/pkg/tcpip/stack/stack.go
index 6beca6ae8..1fe21b68e 100644
--- a/pkg/tcpip/stack/stack.go
+++ b/pkg/tcpip/stack/stack.go
@@ -620,12 +620,7 @@ func (s *Stack) NewRawEndpoint(transport tcpip.TransportProtocolNumber, network
// createNIC creates a NIC with the provided id and link-layer endpoint, and
// optionally enable it.
-func (s *Stack) createNIC(id tcpip.NICID, name string, linkEP tcpip.LinkEndpointID, enabled, loopback bool) *tcpip.Error {
- ep := FindLinkEndpoint(linkEP)
- if ep == nil {
- return tcpip.ErrBadLinkEndpoint
- }
-
+func (s *Stack) createNIC(id tcpip.NICID, name string, ep LinkEndpoint, enabled, loopback bool) *tcpip.Error {
s.mu.Lock()
defer s.mu.Unlock()
@@ -638,40 +633,40 @@ func (s *Stack) createNIC(id tcpip.NICID, name string, linkEP tcpip.LinkEndpoint
s.nics[id] = n
if enabled {
- n.attachLinkEndpoint()
+ return n.enable()
}
return nil
}
// CreateNIC creates a NIC with the provided id and link-layer endpoint.
-func (s *Stack) CreateNIC(id tcpip.NICID, linkEP tcpip.LinkEndpointID) *tcpip.Error {
- return s.createNIC(id, "", linkEP, true, false)
+func (s *Stack) CreateNIC(id tcpip.NICID, ep LinkEndpoint) *tcpip.Error {
+ return s.createNIC(id, "", ep, true, false)
}
// CreateNamedNIC creates a NIC with the provided id and link-layer endpoint,
// and a human-readable name.
-func (s *Stack) CreateNamedNIC(id tcpip.NICID, name string, linkEP tcpip.LinkEndpointID) *tcpip.Error {
- return s.createNIC(id, name, linkEP, true, false)
+func (s *Stack) CreateNamedNIC(id tcpip.NICID, name string, ep LinkEndpoint) *tcpip.Error {
+ return s.createNIC(id, name, ep, true, false)
}
// CreateNamedLoopbackNIC creates a NIC with the provided id and link-layer
// endpoint, and a human-readable name.
-func (s *Stack) CreateNamedLoopbackNIC(id tcpip.NICID, name string, linkEP tcpip.LinkEndpointID) *tcpip.Error {
- return s.createNIC(id, name, linkEP, true, true)
+func (s *Stack) CreateNamedLoopbackNIC(id tcpip.NICID, name string, ep LinkEndpoint) *tcpip.Error {
+ return s.createNIC(id, name, ep, true, true)
}
// CreateDisabledNIC creates a NIC with the provided id and link-layer endpoint,
// but leave it disable. Stack.EnableNIC must be called before the link-layer
// endpoint starts delivering packets to it.
-func (s *Stack) CreateDisabledNIC(id tcpip.NICID, linkEP tcpip.LinkEndpointID) *tcpip.Error {
- return s.createNIC(id, "", linkEP, false, false)
+func (s *Stack) CreateDisabledNIC(id tcpip.NICID, ep LinkEndpoint) *tcpip.Error {
+ return s.createNIC(id, "", ep, false, false)
}
// CreateDisabledNamedNIC is a combination of CreateNamedNIC and
// CreateDisabledNIC.
-func (s *Stack) CreateDisabledNamedNIC(id tcpip.NICID, name string, linkEP tcpip.LinkEndpointID) *tcpip.Error {
- return s.createNIC(id, name, linkEP, false, false)
+func (s *Stack) CreateDisabledNamedNIC(id tcpip.NICID, name string, ep LinkEndpoint) *tcpip.Error {
+ return s.createNIC(id, name, ep, false, false)
}
// EnableNIC enables the given NIC so that the link-layer endpoint can start
@@ -685,9 +680,7 @@ func (s *Stack) EnableNIC(id tcpip.NICID) *tcpip.Error {
return tcpip.ErrUnknownNICID
}
- nic.attachLinkEndpoint()
-
- return nil
+ return nic.enable()
}
// CheckNIC checks if a NIC is usable.
diff --git a/pkg/tcpip/stack/stack_test.go b/pkg/tcpip/stack/stack_test.go
index c6a8160af..0c26c9911 100644
--- a/pkg/tcpip/stack/stack_test.go
+++ b/pkg/tcpip/stack/stack_test.go
@@ -60,11 +60,11 @@ type fakeNetworkEndpoint struct {
prefixLen int
proto *fakeNetworkProtocol
dispatcher stack.TransportDispatcher
- linkEP stack.LinkEndpoint
+ ep stack.LinkEndpoint
}
func (f *fakeNetworkEndpoint) MTU() uint32 {
- return f.linkEP.MTU() - uint32(f.MaxHeaderLength())
+ return f.ep.MTU() - uint32(f.MaxHeaderLength())
}
func (f *fakeNetworkEndpoint) NICID() tcpip.NICID {
@@ -108,7 +108,7 @@ func (f *fakeNetworkEndpoint) HandlePacket(r *stack.Route, vv buffer.VectorisedV
}
func (f *fakeNetworkEndpoint) MaxHeaderLength() uint16 {
- return f.linkEP.MaxHeaderLength() + fakeNetHeaderLen
+ return f.ep.MaxHeaderLength() + fakeNetHeaderLen
}
func (f *fakeNetworkEndpoint) PseudoHeaderChecksum(protocol tcpip.TransportProtocolNumber, dstAddr tcpip.Address) uint16 {
@@ -116,7 +116,7 @@ func (f *fakeNetworkEndpoint) PseudoHeaderChecksum(protocol tcpip.TransportProto
}
func (f *fakeNetworkEndpoint) Capabilities() stack.LinkEndpointCapabilities {
- return f.linkEP.Capabilities()
+ return f.ep.Capabilities()
}
func (f *fakeNetworkEndpoint) WritePacket(r *stack.Route, gso *stack.GSO, hdr buffer.Prependable, payload buffer.VectorisedView, protocol tcpip.TransportProtocolNumber, _ uint8, loop stack.PacketLooping) *tcpip.Error {
@@ -141,7 +141,7 @@ func (f *fakeNetworkEndpoint) WritePacket(r *stack.Route, gso *stack.GSO, hdr bu
return nil
}
- return f.linkEP.WritePacket(r, gso, hdr, payload, fakeNetNumber)
+ return f.ep.WritePacket(r, gso, hdr, payload, fakeNetNumber)
}
func (*fakeNetworkEndpoint) WriteHeaderIncludedPacket(r *stack.Route, payload buffer.VectorisedView, loop stack.PacketLooping) *tcpip.Error {
@@ -189,14 +189,14 @@ func (*fakeNetworkProtocol) ParseAddresses(v buffer.View) (src, dst tcpip.Addres
return tcpip.Address(v[1:2]), tcpip.Address(v[0:1])
}
-func (f *fakeNetworkProtocol) NewEndpoint(nicid tcpip.NICID, addrWithPrefix tcpip.AddressWithPrefix, linkAddrCache stack.LinkAddressCache, dispatcher stack.TransportDispatcher, linkEP stack.LinkEndpoint) (stack.NetworkEndpoint, *tcpip.Error) {
+func (f *fakeNetworkProtocol) NewEndpoint(nicid tcpip.NICID, addrWithPrefix tcpip.AddressWithPrefix, linkAddrCache stack.LinkAddressCache, dispatcher stack.TransportDispatcher, ep stack.LinkEndpoint) (stack.NetworkEndpoint, *tcpip.Error) {
return &fakeNetworkEndpoint{
nicid: nicid,
id: stack.NetworkEndpointID{LocalAddress: addrWithPrefix.Address},
prefixLen: addrWithPrefix.PrefixLen,
proto: f,
dispatcher: dispatcher,
- linkEP: linkEP,
+ ep: ep,
}, nil
}
@@ -225,9 +225,9 @@ func (f *fakeNetworkProtocol) Option(option interface{}) *tcpip.Error {
func TestNetworkReceive(t *testing.T) {
// Create a stack with the fake network protocol, one nic, and two
// addresses attached to it: 1 & 2.
- id, linkEP := channel.New(10, defaultMTU, "")
+ ep := channel.New(10, defaultMTU, "")
s := stack.New([]string{"fakeNet"}, nil, stack.Options{})
- if err := s.CreateNIC(1, id); err != nil {
+ if err := s.CreateNIC(1, ep); err != nil {
t.Fatal("CreateNIC failed:", err)
}
@@ -245,7 +245,7 @@ func TestNetworkReceive(t *testing.T) {
// Make sure packet with wrong address is not delivered.
buf[0] = 3
- linkEP.Inject(fakeNetNumber, buf.ToVectorisedView())
+ ep.Inject(fakeNetNumber, buf.ToVectorisedView())
if fakeNet.packetCount[1] != 0 {
t.Errorf("packetCount[1] = %d, want %d", fakeNet.packetCount[1], 0)
}
@@ -255,7 +255,7 @@ func TestNetworkReceive(t *testing.T) {
// Make sure packet is delivered to first endpoint.
buf[0] = 1
- linkEP.Inject(fakeNetNumber, buf.ToVectorisedView())
+ ep.Inject(fakeNetNumber, buf.ToVectorisedView())
if fakeNet.packetCount[1] != 1 {
t.Errorf("packetCount[1] = %d, want %d", fakeNet.packetCount[1], 1)
}
@@ -265,7 +265,7 @@ func TestNetworkReceive(t *testing.T) {
// Make sure packet is delivered to second endpoint.
buf[0] = 2
- linkEP.Inject(fakeNetNumber, buf.ToVectorisedView())
+ ep.Inject(fakeNetNumber, buf.ToVectorisedView())
if fakeNet.packetCount[1] != 1 {
t.Errorf("packetCount[1] = %d, want %d", fakeNet.packetCount[1], 1)
}
@@ -274,7 +274,7 @@ func TestNetworkReceive(t *testing.T) {
}
// Make sure packet is not delivered if protocol number is wrong.
- linkEP.Inject(fakeNetNumber-1, buf.ToVectorisedView())
+ ep.Inject(fakeNetNumber-1, buf.ToVectorisedView())
if fakeNet.packetCount[1] != 1 {
t.Errorf("packetCount[1] = %d, want %d", fakeNet.packetCount[1], 1)
}
@@ -284,7 +284,7 @@ func TestNetworkReceive(t *testing.T) {
// Make sure packet that is too small is dropped.
buf.CapLength(2)
- linkEP.Inject(fakeNetNumber, buf.ToVectorisedView())
+ ep.Inject(fakeNetNumber, buf.ToVectorisedView())
if fakeNet.packetCount[1] != 1 {
t.Errorf("packetCount[1] = %d, want %d", fakeNet.packetCount[1], 1)
}
@@ -307,59 +307,59 @@ func send(r stack.Route, payload buffer.View) *tcpip.Error {
return r.WritePacket(nil /* gso */, hdr, payload.ToVectorisedView(), fakeTransNumber, 123)
}
-func testSendTo(t *testing.T, s *stack.Stack, addr tcpip.Address, linkEP *channel.Endpoint, payload buffer.View) {
+func testSendTo(t *testing.T, s *stack.Stack, addr tcpip.Address, ep *channel.Endpoint, payload buffer.View) {
t.Helper()
- linkEP.Drain()
+ ep.Drain()
if err := sendTo(s, addr, payload); err != nil {
t.Error("sendTo failed:", err)
}
- if got, want := linkEP.Drain(), 1; got != want {
+ if got, want := ep.Drain(), 1; got != want {
t.Errorf("sendTo packet count: got = %d, want %d", got, want)
}
}
-func testSend(t *testing.T, r stack.Route, linkEP *channel.Endpoint, payload buffer.View) {
+func testSend(t *testing.T, r stack.Route, ep *channel.Endpoint, payload buffer.View) {
t.Helper()
- linkEP.Drain()
+ ep.Drain()
if err := send(r, payload); err != nil {
t.Error("send failed:", err)
}
- if got, want := linkEP.Drain(), 1; got != want {
+ if got, want := ep.Drain(), 1; got != want {
t.Errorf("send packet count: got = %d, want %d", got, want)
}
}
-func testFailingSend(t *testing.T, r stack.Route, linkEP *channel.Endpoint, payload buffer.View, wantErr *tcpip.Error) {
+func testFailingSend(t *testing.T, r stack.Route, ep *channel.Endpoint, payload buffer.View, wantErr *tcpip.Error) {
t.Helper()
if gotErr := send(r, payload); gotErr != wantErr {
t.Errorf("send failed: got = %s, want = %s ", gotErr, wantErr)
}
}
-func testFailingSendTo(t *testing.T, s *stack.Stack, addr tcpip.Address, linkEP *channel.Endpoint, payload buffer.View, wantErr *tcpip.Error) {
+func testFailingSendTo(t *testing.T, s *stack.Stack, addr tcpip.Address, ep *channel.Endpoint, payload buffer.View, wantErr *tcpip.Error) {
t.Helper()
if gotErr := sendTo(s, addr, payload); gotErr != wantErr {
t.Errorf("sendto failed: got = %s, want = %s ", gotErr, wantErr)
}
}
-func testRecv(t *testing.T, fakeNet *fakeNetworkProtocol, localAddrByte byte, linkEP *channel.Endpoint, buf buffer.View) {
+func testRecv(t *testing.T, fakeNet *fakeNetworkProtocol, localAddrByte byte, ep *channel.Endpoint, buf buffer.View) {
t.Helper()
// testRecvInternal injects one packet, and we expect to receive it.
want := fakeNet.PacketCount(localAddrByte) + 1
- testRecvInternal(t, fakeNet, localAddrByte, linkEP, buf, want)
+ testRecvInternal(t, fakeNet, localAddrByte, ep, buf, want)
}
-func testFailingRecv(t *testing.T, fakeNet *fakeNetworkProtocol, localAddrByte byte, linkEP *channel.Endpoint, buf buffer.View) {
+func testFailingRecv(t *testing.T, fakeNet *fakeNetworkProtocol, localAddrByte byte, ep *channel.Endpoint, buf buffer.View) {
t.Helper()
// testRecvInternal injects one packet, and we do NOT expect to receive it.
want := fakeNet.PacketCount(localAddrByte)
- testRecvInternal(t, fakeNet, localAddrByte, linkEP, buf, want)
+ testRecvInternal(t, fakeNet, localAddrByte, ep, buf, want)
}
-func testRecvInternal(t *testing.T, fakeNet *fakeNetworkProtocol, localAddrByte byte, linkEP *channel.Endpoint, buf buffer.View, want int) {
+func testRecvInternal(t *testing.T, fakeNet *fakeNetworkProtocol, localAddrByte byte, ep *channel.Endpoint, buf buffer.View, want int) {
t.Helper()
- linkEP.Inject(fakeNetNumber, buf.ToVectorisedView())
+ ep.Inject(fakeNetNumber, buf.ToVectorisedView())
if got := fakeNet.PacketCount(localAddrByte); got != want {
t.Errorf("receive packet count: got = %d, want %d", got, want)
}
@@ -369,9 +369,9 @@ func TestNetworkSend(t *testing.T) {
// Create a stack with the fake network protocol, one nic, and one
// address: 1. The route table sends all packets through the only
// existing nic.
- id, linkEP := channel.New(10, defaultMTU, "")
+ ep := channel.New(10, defaultMTU, "")
s := stack.New([]string{"fakeNet"}, nil, stack.Options{})
- if err := s.CreateNIC(1, id); err != nil {
+ if err := s.CreateNIC(1, ep); err != nil {
t.Fatal("NewNIC failed:", err)
}
@@ -388,7 +388,7 @@ func TestNetworkSend(t *testing.T) {
}
// Make sure that the link-layer endpoint received the outbound packet.
- testSendTo(t, s, "\x03", linkEP, nil)
+ testSendTo(t, s, "\x03", ep, nil)
}
func TestNetworkSendMultiRoute(t *testing.T) {
@@ -397,8 +397,8 @@ func TestNetworkSendMultiRoute(t *testing.T) {
// even addresses.
s := stack.New([]string{"fakeNet"}, nil, stack.Options{})
- id1, linkEP1 := channel.New(10, defaultMTU, "")
- if err := s.CreateNIC(1, id1); err != nil {
+ ep1 := channel.New(10, defaultMTU, "")
+ if err := s.CreateNIC(1, ep1); err != nil {
t.Fatal("CreateNIC failed:", err)
}
@@ -410,8 +410,8 @@ func TestNetworkSendMultiRoute(t *testing.T) {
t.Fatal("AddAddress failed:", err)
}
- id2, linkEP2 := channel.New(10, defaultMTU, "")
- if err := s.CreateNIC(2, id2); err != nil {
+ ep2 := channel.New(10, defaultMTU, "")
+ if err := s.CreateNIC(2, ep2); err != nil {
t.Fatal("CreateNIC failed:", err)
}
@@ -442,10 +442,10 @@ func TestNetworkSendMultiRoute(t *testing.T) {
}
// Send a packet to an odd destination.
- testSendTo(t, s, "\x05", linkEP1, nil)
+ testSendTo(t, s, "\x05", ep1, nil)
// Send a packet to an even destination.
- testSendTo(t, s, "\x06", linkEP2, nil)
+ testSendTo(t, s, "\x06", ep2, nil)
}
func testRoute(t *testing.T, s *stack.Stack, nic tcpip.NICID, srcAddr, dstAddr, expectedSrcAddr tcpip.Address) {
@@ -478,8 +478,8 @@ func TestRoutes(t *testing.T) {
// even addresses.
s := stack.New([]string{"fakeNet"}, nil, stack.Options{})
- id1, _ := channel.New(10, defaultMTU, "")
- if err := s.CreateNIC(1, id1); err != nil {
+ ep1 := channel.New(10, defaultMTU, "")
+ if err := s.CreateNIC(1, ep1); err != nil {
t.Fatal("CreateNIC failed:", err)
}
@@ -491,8 +491,8 @@ func TestRoutes(t *testing.T) {
t.Fatal("AddAddress failed:", err)
}
- id2, _ := channel.New(10, defaultMTU, "")
- if err := s.CreateNIC(2, id2); err != nil {
+ ep2 := channel.New(10, defaultMTU, "")
+ if err := s.CreateNIC(2, ep2); err != nil {
t.Fatal("CreateNIC failed:", err)
}
@@ -556,8 +556,8 @@ func TestAddressRemoval(t *testing.T) {
s := stack.New([]string{"fakeNet"}, nil, stack.Options{})
- id, linkEP := channel.New(10, defaultMTU, "")
- if err := s.CreateNIC(1, id); err != nil {
+ ep := channel.New(10, defaultMTU, "")
+ if err := s.CreateNIC(1, ep); err != nil {
t.Fatal("CreateNIC failed:", err)
}
@@ -578,15 +578,15 @@ func TestAddressRemoval(t *testing.T) {
// Send and receive packets, and verify they are received.
buf[0] = localAddrByte
- testRecv(t, fakeNet, localAddrByte, linkEP, buf)
- testSendTo(t, s, remoteAddr, linkEP, nil)
+ testRecv(t, fakeNet, localAddrByte, ep, buf)
+ testSendTo(t, s, remoteAddr, ep, nil)
// Remove the address, then check that send/receive doesn't work anymore.
if err := s.RemoveAddress(1, localAddr); err != nil {
t.Fatal("RemoveAddress failed:", err)
}
- testFailingRecv(t, fakeNet, localAddrByte, linkEP, buf)
- testFailingSendTo(t, s, remoteAddr, linkEP, nil, tcpip.ErrNoRoute)
+ testFailingRecv(t, fakeNet, localAddrByte, ep, buf)
+ testFailingSendTo(t, s, remoteAddr, ep, nil, tcpip.ErrNoRoute)
// Check that removing the same address fails.
if err := s.RemoveAddress(1, localAddr); err != tcpip.ErrBadLocalAddress {
@@ -601,9 +601,9 @@ func TestAddressRemovalWithRouteHeld(t *testing.T) {
s := stack.New([]string{"fakeNet"}, nil, stack.Options{})
- id, linkEP := channel.New(10, defaultMTU, "")
- if err := s.CreateNIC(1, id); err != nil {
- t.Fatal("CreateNIC failed:", err)
+ ep := channel.New(10, defaultMTU, "")
+ if err := s.CreateNIC(1, ep); err != nil {
+ t.Fatalf("CreateNIC failed: %v", err)
}
fakeNet := s.NetworkProtocolInstance(fakeNetNumber).(*fakeNetworkProtocol)
buf := buffer.NewView(30)
@@ -626,17 +626,17 @@ func TestAddressRemovalWithRouteHeld(t *testing.T) {
// Send and receive packets, and verify they are received.
buf[0] = localAddrByte
- testRecv(t, fakeNet, localAddrByte, linkEP, buf)
- testSend(t, r, linkEP, nil)
- testSendTo(t, s, remoteAddr, linkEP, nil)
+ testRecv(t, fakeNet, localAddrByte, ep, buf)
+ testSend(t, r, ep, nil)
+ testSendTo(t, s, remoteAddr, ep, nil)
// Remove the address, then check that send/receive doesn't work anymore.
if err := s.RemoveAddress(1, localAddr); err != nil {
t.Fatal("RemoveAddress failed:", err)
}
- testFailingRecv(t, fakeNet, localAddrByte, linkEP, buf)
- testFailingSend(t, r, linkEP, nil, tcpip.ErrInvalidEndpointState)
- testFailingSendTo(t, s, remoteAddr, linkEP, nil, tcpip.ErrNoRoute)
+ testFailingRecv(t, fakeNet, localAddrByte, ep, buf)
+ testFailingSend(t, r, ep, nil, tcpip.ErrInvalidEndpointState)
+ testFailingSendTo(t, s, remoteAddr, ep, nil, tcpip.ErrNoRoute)
// Check that removing the same address fails.
if err := s.RemoveAddress(1, localAddr); err != tcpip.ErrBadLocalAddress {
@@ -690,8 +690,8 @@ func TestEndpointExpiration(t *testing.T) {
t.Run(fmt.Sprintf("promiscuous=%t spoofing=%t", promiscuous, spoofing), func(t *testing.T) {
s := stack.New([]string{"fakeNet"}, nil, stack.Options{})
- id, linkEP := channel.New(10, defaultMTU, "")
- if err := s.CreateNIC(nicid, id); err != nil {
+ ep := channel.New(10, defaultMTU, "")
+ if err := s.CreateNIC(nicid, ep); err != nil {
t.Fatal("CreateNIC failed:", err)
}
@@ -724,15 +724,15 @@ func TestEndpointExpiration(t *testing.T) {
//-----------------------
verifyAddress(t, s, nicid, noAddr)
if promiscuous {
- testRecv(t, fakeNet, localAddrByte, linkEP, buf)
+ testRecv(t, fakeNet, localAddrByte, ep, buf)
} else {
- testFailingRecv(t, fakeNet, localAddrByte, linkEP, buf)
+ testFailingRecv(t, fakeNet, localAddrByte, ep, buf)
}
if spoofing {
// FIXME(b/139841518):Spoofing doesn't work if there is no primary address.
- // testSendTo(t, s, remoteAddr, linkEP, nil)
+ // testSendTo(t, s, remoteAddr, ep, nil)
} else {
- testFailingSendTo(t, s, remoteAddr, linkEP, nil, tcpip.ErrNoRoute)
+ testFailingSendTo(t, s, remoteAddr, ep, nil, tcpip.ErrNoRoute)
}
// 2. Add Address, everything should work.
@@ -741,8 +741,8 @@ func TestEndpointExpiration(t *testing.T) {
t.Fatal("AddAddress failed:", err)
}
verifyAddress(t, s, nicid, localAddr)
- testRecv(t, fakeNet, localAddrByte, linkEP, buf)
- testSendTo(t, s, remoteAddr, linkEP, nil)
+ testRecv(t, fakeNet, localAddrByte, ep, buf)
+ testSendTo(t, s, remoteAddr, ep, nil)
// 3. Remove the address, send should only work for spoofing, receive
// for promiscuous mode.
@@ -752,15 +752,15 @@ func TestEndpointExpiration(t *testing.T) {
}
verifyAddress(t, s, nicid, noAddr)
if promiscuous {
- testRecv(t, fakeNet, localAddrByte, linkEP, buf)
+ testRecv(t, fakeNet, localAddrByte, ep, buf)
} else {
- testFailingRecv(t, fakeNet, localAddrByte, linkEP, buf)
+ testFailingRecv(t, fakeNet, localAddrByte, ep, buf)
}
if spoofing {
// FIXME(b/139841518):Spoofing doesn't work if there is no primary address.
- // testSendTo(t, s, remoteAddr, linkEP, nil)
+ // testSendTo(t, s, remoteAddr, ep, nil)
} else {
- testFailingSendTo(t, s, remoteAddr, linkEP, nil, tcpip.ErrNoRoute)
+ testFailingSendTo(t, s, remoteAddr, ep, nil, tcpip.ErrNoRoute)
}
// 4. Add Address back, everything should work again.
@@ -769,8 +769,8 @@ func TestEndpointExpiration(t *testing.T) {
t.Fatal("AddAddress failed:", err)
}
verifyAddress(t, s, nicid, localAddr)
- testRecv(t, fakeNet, localAddrByte, linkEP, buf)
- testSendTo(t, s, remoteAddr, linkEP, nil)
+ testRecv(t, fakeNet, localAddrByte, ep, buf)
+ testSendTo(t, s, remoteAddr, ep, nil)
// 5. Take a reference to the endpoint by getting a route. Verify that
// we can still send/receive, including sending using the route.
@@ -779,9 +779,9 @@ func TestEndpointExpiration(t *testing.T) {
if err != nil {
t.Fatal("FindRoute failed:", err)
}
- testRecv(t, fakeNet, localAddrByte, linkEP, buf)
- testSendTo(t, s, remoteAddr, linkEP, nil)
- testSend(t, r, linkEP, nil)
+ testRecv(t, fakeNet, localAddrByte, ep, buf)
+ testSendTo(t, s, remoteAddr, ep, nil)
+ testSend(t, r, ep, nil)
// 6. Remove the address. Send should only work for spoofing, receive
// for promiscuous mode.
@@ -791,16 +791,16 @@ func TestEndpointExpiration(t *testing.T) {
}
verifyAddress(t, s, nicid, noAddr)
if promiscuous {
- testRecv(t, fakeNet, localAddrByte, linkEP, buf)
+ testRecv(t, fakeNet, localAddrByte, ep, buf)
} else {
- testFailingRecv(t, fakeNet, localAddrByte, linkEP, buf)
+ testFailingRecv(t, fakeNet, localAddrByte, ep, buf)
}
if spoofing {
- testSend(t, r, linkEP, nil)
- testSendTo(t, s, remoteAddr, linkEP, nil)
+ testSend(t, r, ep, nil)
+ testSendTo(t, s, remoteAddr, ep, nil)
} else {
- testFailingSend(t, r, linkEP, nil, tcpip.ErrInvalidEndpointState)
- testFailingSendTo(t, s, remoteAddr, linkEP, nil, tcpip.ErrNoRoute)
+ testFailingSend(t, r, ep, nil, tcpip.ErrInvalidEndpointState)
+ testFailingSendTo(t, s, remoteAddr, ep, nil, tcpip.ErrNoRoute)
}
// 7. Add Address back, everything should work again.
@@ -809,16 +809,16 @@ func TestEndpointExpiration(t *testing.T) {
t.Fatal("AddAddress failed:", err)
}
verifyAddress(t, s, nicid, localAddr)
- testRecv(t, fakeNet, localAddrByte, linkEP, buf)
- testSendTo(t, s, remoteAddr, linkEP, nil)
- testSend(t, r, linkEP, nil)
+ testRecv(t, fakeNet, localAddrByte, ep, buf)
+ testSendTo(t, s, remoteAddr, ep, nil)
+ testSend(t, r, ep, nil)
// 8. Remove the route, sendTo/recv should still work.
//-----------------------
r.Release()
verifyAddress(t, s, nicid, localAddr)
- testRecv(t, fakeNet, localAddrByte, linkEP, buf)
- testSendTo(t, s, remoteAddr, linkEP, nil)
+ testRecv(t, fakeNet, localAddrByte, ep, buf)
+ testSendTo(t, s, remoteAddr, ep, nil)
// 9. Remove the address. Send should only work for spoofing, receive
// for promiscuous mode.
@@ -828,15 +828,15 @@ func TestEndpointExpiration(t *testing.T) {
}
verifyAddress(t, s, nicid, noAddr)
if promiscuous {
- testRecv(t, fakeNet, localAddrByte, linkEP, buf)
+ testRecv(t, fakeNet, localAddrByte, ep, buf)
} else {
- testFailingRecv(t, fakeNet, localAddrByte, linkEP, buf)
+ testFailingRecv(t, fakeNet, localAddrByte, ep, buf)
}
if spoofing {
// FIXME(b/139841518):Spoofing doesn't work if there is no primary address.
- // testSendTo(t, s, remoteAddr, linkEP, nil)
+ // testSendTo(t, s, remoteAddr, ep, nil)
} else {
- testFailingSendTo(t, s, remoteAddr, linkEP, nil, tcpip.ErrNoRoute)
+ testFailingSendTo(t, s, remoteAddr, ep, nil, tcpip.ErrNoRoute)
}
})
}
@@ -846,8 +846,8 @@ func TestEndpointExpiration(t *testing.T) {
func TestPromiscuousMode(t *testing.T) {
s := stack.New([]string{"fakeNet"}, nil, stack.Options{})
- id, linkEP := channel.New(10, defaultMTU, "")
- if err := s.CreateNIC(1, id); err != nil {
+ ep := channel.New(10, defaultMTU, "")
+ if err := s.CreateNIC(1, ep); err != nil {
t.Fatal("CreateNIC failed:", err)
}
@@ -867,13 +867,13 @@ func TestPromiscuousMode(t *testing.T) {
// have a matching endpoint.
const localAddrByte byte = 0x01
buf[0] = localAddrByte
- testFailingRecv(t, fakeNet, localAddrByte, linkEP, buf)
+ testFailingRecv(t, fakeNet, localAddrByte, ep, buf)
// Set promiscuous mode, then check that packet is delivered.
if err := s.SetPromiscuousMode(1, true); err != nil {
t.Fatal("SetPromiscuousMode failed:", err)
}
- testRecv(t, fakeNet, localAddrByte, linkEP, buf)
+ testRecv(t, fakeNet, localAddrByte, ep, buf)
// Check that we can't get a route as there is no local address.
_, err := s.FindRoute(0, "", "\x02", fakeNetNumber, false /* multicastLoop */)
@@ -886,7 +886,7 @@ func TestPromiscuousMode(t *testing.T) {
if err := s.SetPromiscuousMode(1, false); err != nil {
t.Fatal("SetPromiscuousMode failed:", err)
}
- testFailingRecv(t, fakeNet, localAddrByte, linkEP, buf)
+ testFailingRecv(t, fakeNet, localAddrByte, ep, buf)
}
func TestSpoofingWithAddress(t *testing.T) {
@@ -896,8 +896,8 @@ func TestSpoofingWithAddress(t *testing.T) {
s := stack.New([]string{"fakeNet"}, nil, stack.Options{})
- id, linkEP := channel.New(10, defaultMTU, "")
- if err := s.CreateNIC(1, id); err != nil {
+ ep := channel.New(10, defaultMTU, "")
+ if err := s.CreateNIC(1, ep); err != nil {
t.Fatal("CreateNIC failed:", err)
}
@@ -936,8 +936,8 @@ func TestSpoofingWithAddress(t *testing.T) {
t.Errorf("Route has wrong remote address: got %v, wanted %v", r.RemoteAddress, dstAddr)
}
// Sending a packet works.
- testSendTo(t, s, dstAddr, linkEP, nil)
- testSend(t, r, linkEP, nil)
+ testSendTo(t, s, dstAddr, ep, nil)
+ testSend(t, r, ep, nil)
// FindRoute should also work with a local address that exists on the NIC.
r, err = s.FindRoute(0, localAddr, dstAddr, fakeNetNumber, false /* multicastLoop */)
@@ -951,7 +951,7 @@ func TestSpoofingWithAddress(t *testing.T) {
t.Errorf("Route has wrong remote address: got %v, wanted %v", r.RemoteAddress, dstAddr)
}
// Sending a packet using the route works.
- testSend(t, r, linkEP, nil)
+ testSend(t, r, ep, nil)
}
func TestSpoofingNoAddress(t *testing.T) {
@@ -960,8 +960,8 @@ func TestSpoofingNoAddress(t *testing.T) {
s := stack.New([]string{"fakeNet"}, nil, stack.Options{})
- id, linkEP := channel.New(10, defaultMTU, "")
- if err := s.CreateNIC(1, id); err != nil {
+ ep := channel.New(10, defaultMTU, "")
+ if err := s.CreateNIC(1, ep); err != nil {
t.Fatal("CreateNIC failed:", err)
}
@@ -980,7 +980,7 @@ func TestSpoofingNoAddress(t *testing.T) {
t.Errorf("FindRoute succeeded with route %+v when it should have failed", r)
}
// Sending a packet fails.
- testFailingSendTo(t, s, dstAddr, linkEP, nil, tcpip.ErrNoRoute)
+ testFailingSendTo(t, s, dstAddr, ep, nil, tcpip.ErrNoRoute)
// With address spoofing enabled, FindRoute permits any address to be used
// as the source.
@@ -999,14 +999,14 @@ func TestSpoofingNoAddress(t *testing.T) {
}
// Sending a packet works.
// FIXME(b/139841518):Spoofing doesn't work if there is no primary address.
- // testSendTo(t, s, remoteAddr, linkEP, nil)
+ // testSendTo(t, s, remoteAddr, ep, nil)
}
func TestBroadcastNeedsNoRoute(t *testing.T) {
s := stack.New([]string{"fakeNet"}, nil, stack.Options{})
- id, _ := channel.New(10, defaultMTU, "")
- if err := s.CreateNIC(1, id); err != nil {
+ ep := channel.New(10, defaultMTU, "")
+ if err := s.CreateNIC(1, ep); err != nil {
t.Fatal("CreateNIC failed:", err)
}
s.SetRouteTable([]tcpip.Route{})
@@ -1076,8 +1076,8 @@ func TestMulticastOrIPv6LinkLocalNeedsNoRoute(t *testing.T) {
t.Run(tc.name, func(t *testing.T) {
s := stack.New([]string{"fakeNet"}, nil, stack.Options{})
- id, _ := channel.New(10, defaultMTU, "")
- if err := s.CreateNIC(1, id); err != nil {
+ ep := channel.New(10, defaultMTU, "")
+ if err := s.CreateNIC(1, ep); err != nil {
t.Fatal("CreateNIC failed:", err)
}
@@ -1132,8 +1132,8 @@ func TestMulticastOrIPv6LinkLocalNeedsNoRoute(t *testing.T) {
func TestAddressRangeAcceptsMatchingPacket(t *testing.T) {
s := stack.New([]string{"fakeNet"}, nil, stack.Options{})
- id, linkEP := channel.New(10, defaultMTU, "")
- if err := s.CreateNIC(1, id); err != nil {
+ ep := channel.New(10, defaultMTU, "")
+ if err := s.CreateNIC(1, ep); err != nil {
t.Fatal("CreateNIC failed:", err)
}
@@ -1159,7 +1159,7 @@ func TestAddressRangeAcceptsMatchingPacket(t *testing.T) {
t.Fatal("AddAddressRange failed:", err)
}
- testRecv(t, fakeNet, localAddrByte, linkEP, buf)
+ testRecv(t, fakeNet, localAddrByte, ep, buf)
}
func testNicForAddressRange(t *testing.T, nicID tcpip.NICID, s *stack.Stack, subnet tcpip.Subnet, rangeExists bool) {
@@ -1198,8 +1198,8 @@ func TestCheckLocalAddressForSubnet(t *testing.T) {
const nicID tcpip.NICID = 1
s := stack.New([]string{"fakeNet"}, nil, stack.Options{})
- id, _ := channel.New(10, defaultMTU, "")
- if err := s.CreateNIC(nicID, id); err != nil {
+ ep := channel.New(10, defaultMTU, "")
+ if err := s.CreateNIC(nicID, ep); err != nil {
t.Fatal("CreateNIC failed:", err)
}
@@ -1236,8 +1236,8 @@ func TestCheckLocalAddressForSubnet(t *testing.T) {
func TestAddressRangeRejectsNonmatchingPacket(t *testing.T) {
s := stack.New([]string{"fakeNet"}, nil, stack.Options{})
- id, linkEP := channel.New(10, defaultMTU, "")
- if err := s.CreateNIC(1, id); err != nil {
+ ep := channel.New(10, defaultMTU, "")
+ if err := s.CreateNIC(1, ep); err != nil {
t.Fatal("CreateNIC failed:", err)
}
@@ -1262,7 +1262,7 @@ func TestAddressRangeRejectsNonmatchingPacket(t *testing.T) {
if err := s.AddAddressRange(1, fakeNetNumber, subnet); err != nil {
t.Fatal("AddAddressRange failed:", err)
}
- testFailingRecv(t, fakeNet, localAddrByte, linkEP, buf)
+ testFailingRecv(t, fakeNet, localAddrByte, ep, buf)
}
func TestNetworkOptions(t *testing.T) {
@@ -1320,8 +1320,8 @@ func stackContainsAddressRange(s *stack.Stack, id tcpip.NICID, addrRange tcpip.S
func TestAddresRangeAddRemove(t *testing.T) {
s := stack.New([]string{"fakeNet"}, nil, stack.Options{})
- id, _ := channel.New(10, defaultMTU, "")
- if err := s.CreateNIC(1, id); err != nil {
+ ep := channel.New(10, defaultMTU, "")
+ if err := s.CreateNIC(1, ep); err != nil {
t.Fatal("CreateNIC failed:", err)
}
@@ -1361,8 +1361,8 @@ func TestGetMainNICAddressAddPrimaryNonPrimary(t *testing.T) {
for never := 0; never < 3; never++ {
t.Run(fmt.Sprintf("never=%d", never), func(t *testing.T) {
s := stack.New([]string{"fakeNet"}, nil, stack.Options{})
- id, _ := channel.New(10, defaultMTU, "")
- if err := s.CreateNIC(1, id); err != nil {
+ ep := channel.New(10, defaultMTU, "")
+ if err := s.CreateNIC(1, ep); err != nil {
t.Fatal("CreateNIC failed:", err)
}
// Insert <canBe> primary and <never> never-primary addresses.
@@ -1426,8 +1426,8 @@ func TestGetMainNICAddressAddPrimaryNonPrimary(t *testing.T) {
func TestGetMainNICAddressAddRemove(t *testing.T) {
s := stack.New([]string{"fakeNet"}, nil, stack.Options{})
- id, _ := channel.New(10, defaultMTU, "")
- if err := s.CreateNIC(1, id); err != nil {
+ ep := channel.New(10, defaultMTU, "")
+ if err := s.CreateNIC(1, ep); err != nil {
t.Fatal("CreateNIC failed:", err)
}
@@ -1501,8 +1501,8 @@ func verifyAddresses(t *testing.T, expectedAddresses, gotAddresses []tcpip.Proto
func TestAddAddress(t *testing.T) {
const nicid = 1
s := stack.New([]string{"fakeNet"}, nil, stack.Options{})
- id, _ := channel.New(10, defaultMTU, "")
- if err := s.CreateNIC(nicid, id); err != nil {
+ ep := channel.New(10, defaultMTU, "")
+ if err := s.CreateNIC(nicid, ep); err != nil {
t.Fatal("CreateNIC failed:", err)
}
@@ -1526,8 +1526,8 @@ func TestAddAddress(t *testing.T) {
func TestAddProtocolAddress(t *testing.T) {
const nicid = 1
s := stack.New([]string{"fakeNet"}, nil, stack.Options{})
- id, _ := channel.New(10, defaultMTU, "")
- if err := s.CreateNIC(nicid, id); err != nil {
+ ep := channel.New(10, defaultMTU, "")
+ if err := s.CreateNIC(nicid, ep); err != nil {
t.Fatal("CreateNIC failed:", err)
}
@@ -1558,8 +1558,8 @@ func TestAddProtocolAddress(t *testing.T) {
func TestAddAddressWithOptions(t *testing.T) {
const nicid = 1
s := stack.New([]string{"fakeNet"}, nil, stack.Options{})
- id, _ := channel.New(10, defaultMTU, "")
- if err := s.CreateNIC(nicid, id); err != nil {
+ ep := channel.New(10, defaultMTU, "")
+ if err := s.CreateNIC(nicid, ep); err != nil {
t.Fatal("CreateNIC failed:", err)
}
@@ -1587,8 +1587,8 @@ func TestAddAddressWithOptions(t *testing.T) {
func TestAddProtocolAddressWithOptions(t *testing.T) {
const nicid = 1
s := stack.New([]string{"fakeNet"}, nil, stack.Options{})
- id, _ := channel.New(10, defaultMTU, "")
- if err := s.CreateNIC(nicid, id); err != nil {
+ ep := channel.New(10, defaultMTU, "")
+ if err := s.CreateNIC(nicid, ep); err != nil {
t.Fatal("CreateNIC failed:", err)
}
@@ -1621,8 +1621,8 @@ func TestAddProtocolAddressWithOptions(t *testing.T) {
func TestNICStats(t *testing.T) {
s := stack.New([]string{"fakeNet"}, nil, stack.Options{})
- id1, linkEP1 := channel.New(10, defaultMTU, "")
- if err := s.CreateNIC(1, id1); err != nil {
+ ep1 := channel.New(10, defaultMTU, "")
+ if err := s.CreateNIC(1, ep1); err != nil {
t.Fatal("CreateNIC failed: ", err)
}
if err := s.AddAddress(1, fakeNetNumber, "\x01"); err != nil {
@@ -1639,7 +1639,7 @@ func TestNICStats(t *testing.T) {
// Send a packet to address 1.
buf := buffer.NewView(30)
- linkEP1.Inject(fakeNetNumber, buf.ToVectorisedView())
+ ep1.Inject(fakeNetNumber, buf.ToVectorisedView())
if got, want := s.NICInfo()[1].Stats.Rx.Packets.Value(), uint64(1); got != want {
t.Errorf("got Rx.Packets.Value() = %d, want = %d", got, want)
}
@@ -1653,9 +1653,9 @@ func TestNICStats(t *testing.T) {
if err := sendTo(s, "\x01", payload); err != nil {
t.Fatal("sendTo failed: ", err)
}
- want := uint64(linkEP1.Drain())
+ want := uint64(ep1.Drain())
if got := s.NICInfo()[1].Stats.Tx.Packets.Value(); got != want {
- t.Errorf("got Tx.Packets.Value() = %d, linkEP1.Drain() = %d", got, want)
+ t.Errorf("got Tx.Packets.Value() = %d, ep1.Drain() = %d", got, want)
}
if got, want := s.NICInfo()[1].Stats.Tx.Bytes.Value(), uint64(len(payload)); got != want {
@@ -1669,16 +1669,16 @@ func TestNICForwarding(t *testing.T) {
s := stack.New([]string{"fakeNet"}, nil, stack.Options{})
s.SetForwarding(true)
- id1, linkEP1 := channel.New(10, defaultMTU, "")
- if err := s.CreateNIC(1, id1); err != nil {
+ ep1 := channel.New(10, defaultMTU, "")
+ if err := s.CreateNIC(1, ep1); err != nil {
t.Fatal("CreateNIC #1 failed:", err)
}
if err := s.AddAddress(1, fakeNetNumber, "\x01"); err != nil {
t.Fatal("AddAddress #1 failed:", err)
}
- id2, linkEP2 := channel.New(10, defaultMTU, "")
- if err := s.CreateNIC(2, id2); err != nil {
+ ep2 := channel.New(10, defaultMTU, "")
+ if err := s.CreateNIC(2, ep2); err != nil {
t.Fatal("CreateNIC #2 failed:", err)
}
if err := s.AddAddress(2, fakeNetNumber, "\x02"); err != nil {
@@ -1697,10 +1697,10 @@ func TestNICForwarding(t *testing.T) {
// Send a packet to address 3.
buf := buffer.NewView(30)
buf[0] = 3
- linkEP1.Inject(fakeNetNumber, buf.ToVectorisedView())
+ ep1.Inject(fakeNetNumber, buf.ToVectorisedView())
select {
- case <-linkEP2.C:
+ case <-ep2.C:
default:
t.Fatal("Packet not forwarded")
}
diff --git a/pkg/tcpip/stack/transport_test.go b/pkg/tcpip/stack/transport_test.go
index ca185279e..0e69ac7c8 100644
--- a/pkg/tcpip/stack/transport_test.go
+++ b/pkg/tcpip/stack/transport_test.go
@@ -65,13 +65,13 @@ func (*fakeTransportEndpoint) Read(*tcpip.FullAddress) (buffer.View, tcpip.Contr
return buffer.View{}, tcpip.ControlMessages{}, nil
}
-func (f *fakeTransportEndpoint) Write(p tcpip.Payload, opts tcpip.WriteOptions) (int64, <-chan struct{}, *tcpip.Error) {
+func (f *fakeTransportEndpoint) Write(p tcpip.Payloader, opts tcpip.WriteOptions) (int64, <-chan struct{}, *tcpip.Error) {
if len(f.route.RemoteAddress) == 0 {
return 0, nil, tcpip.ErrNoRoute
}
hdr := buffer.NewPrependable(int(f.route.MaxHeaderLength()))
- v, err := p.Get(p.Size())
+ v, err := p.FullPayload()
if err != nil {
return 0, nil, err
}
@@ -91,6 +91,11 @@ func (*fakeTransportEndpoint) SetSockOpt(interface{}) *tcpip.Error {
return tcpip.ErrInvalidEndpointState
}
+// SetSockOptInt sets a socket option. Currently not supported.
+func (*fakeTransportEndpoint) SetSockOptInt(tcpip.SockOpt, int) *tcpip.Error {
+ return tcpip.ErrInvalidEndpointState
+}
+
// GetSockOptInt implements tcpip.Endpoint.GetSockOptInt.
func (*fakeTransportEndpoint) GetSockOptInt(opt tcpip.SockOpt) (int, *tcpip.Error) {
return -1, tcpip.ErrUnknownProtocolOption
@@ -278,9 +283,9 @@ func (f *fakeTransportProtocol) Option(option interface{}) *tcpip.Error {
}
func TestTransportReceive(t *testing.T) {
- id, linkEP := channel.New(10, defaultMTU, "")
+ linkEP := channel.New(10, defaultMTU, "")
s := stack.New([]string{"fakeNet"}, []string{"fakeTrans"}, stack.Options{})
- if err := s.CreateNIC(1, id); err != nil {
+ if err := s.CreateNIC(1, linkEP); err != nil {
t.Fatalf("CreateNIC failed: %v", err)
}
@@ -340,9 +345,9 @@ func TestTransportReceive(t *testing.T) {
}
func TestTransportControlReceive(t *testing.T) {
- id, linkEP := channel.New(10, defaultMTU, "")
+ linkEP := channel.New(10, defaultMTU, "")
s := stack.New([]string{"fakeNet"}, []string{"fakeTrans"}, stack.Options{})
- if err := s.CreateNIC(1, id); err != nil {
+ if err := s.CreateNIC(1, linkEP); err != nil {
t.Fatalf("CreateNIC failed: %v", err)
}
@@ -408,9 +413,9 @@ func TestTransportControlReceive(t *testing.T) {
}
func TestTransportSend(t *testing.T) {
- id, _ := channel.New(10, defaultMTU, "")
+ linkEP := channel.New(10, defaultMTU, "")
s := stack.New([]string{"fakeNet"}, []string{"fakeTrans"}, stack.Options{})
- if err := s.CreateNIC(1, id); err != nil {
+ if err := s.CreateNIC(1, linkEP); err != nil {
t.Fatalf("CreateNIC failed: %v", err)
}
@@ -497,16 +502,16 @@ func TestTransportForwarding(t *testing.T) {
s.SetForwarding(true)
// TODO(b/123449044): Change this to a channel NIC.
- id1 := loopback.New()
- if err := s.CreateNIC(1, id1); err != nil {
+ ep1 := loopback.New()
+ if err := s.CreateNIC(1, ep1); err != nil {
t.Fatalf("CreateNIC #1 failed: %v", err)
}
if err := s.AddAddress(1, fakeNetNumber, "\x01"); err != nil {
t.Fatalf("AddAddress #1 failed: %v", err)
}
- id2, linkEP2 := channel.New(10, defaultMTU, "")
- if err := s.CreateNIC(2, id2); err != nil {
+ ep2 := channel.New(10, defaultMTU, "")
+ if err := s.CreateNIC(2, ep2); err != nil {
t.Fatalf("CreateNIC #2 failed: %v", err)
}
if err := s.AddAddress(2, fakeNetNumber, "\x02"); err != nil {
@@ -545,7 +550,7 @@ func TestTransportForwarding(t *testing.T) {
req[0] = 1
req[1] = 3
req[2] = byte(fakeTransNumber)
- linkEP2.Inject(fakeNetNumber, req.ToVectorisedView())
+ ep2.Inject(fakeNetNumber, req.ToVectorisedView())
aep, _, err := ep.Accept()
if err != nil || aep == nil {
@@ -559,7 +564,7 @@ func TestTransportForwarding(t *testing.T) {
var p channel.PacketInfo
select {
- case p = <-linkEP2.C:
+ case p = <-ep2.C:
default:
t.Fatal("Response packet not forwarded")
}
diff --git a/pkg/tcpip/tcpip.go b/pkg/tcpip/tcpip.go
index 418e771d2..c021c67ac 100644
--- a/pkg/tcpip/tcpip.go
+++ b/pkg/tcpip/tcpip.go
@@ -261,31 +261,34 @@ type FullAddress struct {
Port uint16
}
-// Payload provides an interface around data that is being sent to an endpoint.
-// This allows the endpoint to request the amount of data it needs based on
-// internal buffers without exposing them. 'p.Get(p.Size())' reads all the data.
-type Payload interface {
- // Get returns a slice containing exactly 'min(size, p.Size())' bytes.
- Get(size int) ([]byte, *Error)
-
- // Size returns the payload size.
- Size() int
+// Payloader is an interface that provides data.
+//
+// This interface allows the endpoint to request the amount of data it needs
+// based on internal buffers without exposing them.
+type Payloader interface {
+ // FullPayload returns all available bytes.
+ FullPayload() ([]byte, *Error)
+
+ // Payload returns a slice containing at most size bytes.
+ Payload(size int) ([]byte, *Error)
}
-// SlicePayload implements Payload on top of slices for convenience.
+// SlicePayload implements Payloader for slices.
+//
+// This is typically used for tests.
type SlicePayload []byte
-// Get implements Payload.
-func (s SlicePayload) Get(size int) ([]byte, *Error) {
- if size > s.Size() {
- size = s.Size()
- }
- return s[:size], nil
+// FullPayload implements Payloader.FullPayload.
+func (s SlicePayload) FullPayload() ([]byte, *Error) {
+ return s, nil
}
-// Size implements Payload.
-func (s SlicePayload) Size() int {
- return len(s)
+// Payload implements Payloader.Payload.
+func (s SlicePayload) Payload(size int) ([]byte, *Error) {
+ if size > len(s) {
+ size = len(s)
+ }
+ return s[:size], nil
}
// A ControlMessages contains socket control messages for IP sockets.
@@ -338,7 +341,7 @@ type Endpoint interface {
// ErrNoLinkAddress and a notification channel is returned for the caller to
// block. Channel is closed once address resolution is complete (success or
// not). The channel is only non-nil in this case.
- Write(Payload, WriteOptions) (int64, <-chan struct{}, *Error)
+ Write(Payloader, WriteOptions) (int64, <-chan struct{}, *Error)
// Peek reads data without consuming it from the endpoint.
//
@@ -398,6 +401,10 @@ type Endpoint interface {
// SetSockOpt sets a socket option. opt should be one of the *Option types.
SetSockOpt(opt interface{}) *Error
+ // SetSockOptInt sets a socket option, for simple cases where a value
+ // has the int type.
+ SetSockOptInt(opt SockOpt, v int) *Error
+
// GetSockOpt gets a socket option. opt should be a pointer to one of the
// *Option types.
GetSockOpt(opt interface{}) *Error
@@ -432,16 +439,33 @@ type WriteOptions struct {
// EndOfRecord has the same semantics as Linux's MSG_EOR.
EndOfRecord bool
+
+ // Atomic means that all data fetched from Payloader must be written to the
+ // endpoint. If Atomic is false, then data fetched from the Payloader may be
+ // discarded if available endpoint buffer space is unsufficient.
+ Atomic bool
}
// SockOpt represents socket options which values have the int type.
type SockOpt int
const (
- // ReceiveQueueSizeOption is used in GetSockOpt to specify that the number of
- // unread bytes in the input buffer should be returned.
+ // ReceiveQueueSizeOption is used in GetSockOptInt to specify that the
+ // number of unread bytes in the input buffer should be returned.
ReceiveQueueSizeOption SockOpt = iota
+ // SendBufferSizeOption is used by SetSockOptInt/GetSockOptInt to
+ // specify the send buffer size option.
+ SendBufferSizeOption
+
+ // ReceiveBufferSizeOption is used by SetSockOptInt/GetSockOptInt to
+ // specify the receive buffer size option.
+ ReceiveBufferSizeOption
+
+ // SendQueueSizeOption is used in GetSockOptInt to specify that the
+ // number of unread bytes in the output buffer should be returned.
+ SendQueueSizeOption
+
// TODO(b/137664753): convert all int socket options to be handled via
// GetSockOptInt.
)
@@ -450,18 +474,6 @@ const (
// the endpoint should be cleared and returned.
type ErrorOption struct{}
-// SendBufferSizeOption is used by SetSockOpt/GetSockOpt to specify the send
-// buffer size option.
-type SendBufferSizeOption int
-
-// ReceiveBufferSizeOption is used by SetSockOpt/GetSockOpt to specify the
-// receive buffer size option.
-type ReceiveBufferSizeOption int
-
-// SendQueueSizeOption is used in GetSockOpt to specify that the number of
-// unread bytes in the output buffer should be returned.
-type SendQueueSizeOption int
-
// V6OnlyOption is used by SetSockOpt/GetSockOpt to specify whether an IPv6
// socket is to be restricted to sending and receiving IPv6 packets only.
type V6OnlyOption int
@@ -600,9 +612,6 @@ func (r Route) String() string {
return out.String()
}
-// LinkEndpointID represents a data link layer endpoint.
-type LinkEndpointID uint64
-
// TransportProtocolNumber is the number of a transport protocol.
type TransportProtocolNumber uint32
diff --git a/pkg/tcpip/transport/icmp/endpoint.go b/pkg/tcpip/transport/icmp/endpoint.go
index e1f622af6..a111fdb2a 100644
--- a/pkg/tcpip/transport/icmp/endpoint.go
+++ b/pkg/tcpip/transport/icmp/endpoint.go
@@ -204,7 +204,7 @@ func (e *endpoint) prepareForWrite(to *tcpip.FullAddress) (retry bool, err *tcpi
// Write writes data to the endpoint's peer. This method does not block
// if the data cannot be written.
-func (e *endpoint) Write(p tcpip.Payload, opts tcpip.WriteOptions) (int64, <-chan struct{}, *tcpip.Error) {
+func (e *endpoint) Write(p tcpip.Payloader, opts tcpip.WriteOptions) (int64, <-chan struct{}, *tcpip.Error) {
// MSG_MORE is unimplemented. (This also means that MSG_EOR is a no-op.)
if opts.More {
return 0, nil, tcpip.ErrInvalidOptionValue
@@ -289,7 +289,7 @@ func (e *endpoint) Write(p tcpip.Payload, opts tcpip.WriteOptions) (int64, <-cha
}
}
- v, err := p.Get(p.Size())
+ v, err := p.FullPayload()
if err != nil {
return 0, nil, err
}
@@ -319,6 +319,11 @@ func (e *endpoint) SetSockOpt(opt interface{}) *tcpip.Error {
return nil
}
+// SetSockOptInt sets a socket option. Currently not supported.
+func (e *endpoint) SetSockOptInt(opt tcpip.SockOpt, v int) *tcpip.Error {
+ return nil
+}
+
// GetSockOptInt implements tcpip.Endpoint.GetSockOptInt.
func (e *endpoint) GetSockOptInt(opt tcpip.SockOpt) (int, *tcpip.Error) {
switch opt {
@@ -331,6 +336,18 @@ func (e *endpoint) GetSockOptInt(opt tcpip.SockOpt) (int, *tcpip.Error) {
}
e.rcvMu.Unlock()
return v, nil
+ case tcpip.SendBufferSizeOption:
+ e.mu.Lock()
+ v := e.sndBufSize
+ e.mu.Unlock()
+ return v, nil
+
+ case tcpip.ReceiveBufferSizeOption:
+ e.rcvMu.Lock()
+ v := e.rcvBufSizeMax
+ e.rcvMu.Unlock()
+ return v, nil
+
}
return -1, tcpip.ErrUnknownProtocolOption
}
@@ -341,18 +358,6 @@ func (e *endpoint) GetSockOpt(opt interface{}) *tcpip.Error {
case tcpip.ErrorOption:
return nil
- case *tcpip.SendBufferSizeOption:
- e.mu.Lock()
- *o = tcpip.SendBufferSizeOption(e.sndBufSize)
- e.mu.Unlock()
- return nil
-
- case *tcpip.ReceiveBufferSizeOption:
- e.rcvMu.Lock()
- *o = tcpip.ReceiveBufferSizeOption(e.rcvBufSizeMax)
- e.rcvMu.Unlock()
- return nil
-
case *tcpip.KeepaliveEnabledOption:
*o = 0
return nil
diff --git a/pkg/tcpip/transport/raw/endpoint.go b/pkg/tcpip/transport/raw/endpoint.go
index 13e17e2a6..a02731a5d 100644
--- a/pkg/tcpip/transport/raw/endpoint.go
+++ b/pkg/tcpip/transport/raw/endpoint.go
@@ -207,7 +207,7 @@ func (ep *endpoint) Read(addr *tcpip.FullAddress) (buffer.View, tcpip.ControlMes
}
// Write implements tcpip.Endpoint.Write.
-func (ep *endpoint) Write(payload tcpip.Payload, opts tcpip.WriteOptions) (int64, <-chan struct{}, *tcpip.Error) {
+func (ep *endpoint) Write(p tcpip.Payloader, opts tcpip.WriteOptions) (int64, <-chan struct{}, *tcpip.Error) {
// MSG_MORE is unimplemented. This also means that MSG_EOR is a no-op.
if opts.More {
return 0, nil, tcpip.ErrInvalidOptionValue
@@ -220,9 +220,8 @@ func (ep *endpoint) Write(payload tcpip.Payload, opts tcpip.WriteOptions) (int64
return 0, nil, tcpip.ErrInvalidEndpointState
}
- payloadBytes, err := payload.Get(payload.Size())
+ payloadBytes, err := p.FullPayload()
if err != nil {
- ep.mu.RUnlock()
return 0, nil, err
}
@@ -230,7 +229,7 @@ func (ep *endpoint) Write(payload tcpip.Payload, opts tcpip.WriteOptions) (int64
// destination address, route using that address.
if !ep.associated {
ip := header.IPv4(payloadBytes)
- if !ip.IsValid(payload.Size()) {
+ if !ip.IsValid(len(payloadBytes)) {
ep.mu.RUnlock()
return 0, nil, tcpip.ErrInvalidOptionValue
}
@@ -493,6 +492,11 @@ func (ep *endpoint) SetSockOpt(opt interface{}) *tcpip.Error {
return tcpip.ErrUnknownProtocolOption
}
+// SetSockOptInt implements tcpip.Endpoint.SetSockOptInt.
+func (ep *endpoint) SetSockOptInt(opt tcpip.SockOpt, v int) *tcpip.Error {
+ return tcpip.ErrUnknownProtocolOption
+}
+
// GetSockOptInt implements tcpip.Endpoint.GetSockOptInt.
func (ep *endpoint) GetSockOptInt(opt tcpip.SockOpt) (int, *tcpip.Error) {
switch opt {
@@ -505,6 +509,19 @@ func (ep *endpoint) GetSockOptInt(opt tcpip.SockOpt) (int, *tcpip.Error) {
}
ep.rcvMu.Unlock()
return v, nil
+
+ case tcpip.SendBufferSizeOption:
+ ep.mu.Lock()
+ v := ep.sndBufSize
+ ep.mu.Unlock()
+ return v, nil
+
+ case tcpip.ReceiveBufferSizeOption:
+ ep.rcvMu.Lock()
+ v := ep.rcvBufSizeMax
+ ep.rcvMu.Unlock()
+ return v, nil
+
}
return -1, tcpip.ErrUnknownProtocolOption
@@ -516,18 +533,6 @@ func (ep *endpoint) GetSockOpt(opt interface{}) *tcpip.Error {
case tcpip.ErrorOption:
return nil
- case *tcpip.SendBufferSizeOption:
- ep.mu.Lock()
- *o = tcpip.SendBufferSizeOption(ep.sndBufSize)
- ep.mu.Unlock()
- return nil
-
- case *tcpip.ReceiveBufferSizeOption:
- ep.rcvMu.Lock()
- *o = tcpip.ReceiveBufferSizeOption(ep.rcvBufSizeMax)
- ep.rcvMu.Unlock()
- return nil
-
case *tcpip.KeepaliveEnabledOption:
*o = 0
return nil
diff --git a/pkg/tcpip/transport/tcp/BUILD b/pkg/tcpip/transport/tcp/BUILD
index 1ee1a53f8..39a839ab7 100644
--- a/pkg/tcpip/transport/tcp/BUILD
+++ b/pkg/tcpip/transport/tcp/BUILD
@@ -1,7 +1,9 @@
+load("@io_bazel_rules_go//go:def.bzl", "go_test")
+
package(licenses = ["notice"])
load("//tools/go_generics:defs.bzl", "go_template_instance")
-load("//tools/go_stateify:defs.bzl", "go_library", "go_test")
+load("//tools/go_stateify:defs.bzl", "go_library")
go_template_instance(
name = "tcp_segment_list",
diff --git a/pkg/tcpip/transport/tcp/endpoint.go b/pkg/tcpip/transport/tcp/endpoint.go
index ac927569a..35b489c68 100644
--- a/pkg/tcpip/transport/tcp/endpoint.go
+++ b/pkg/tcpip/transport/tcp/endpoint.go
@@ -806,7 +806,7 @@ func (e *endpoint) isEndpointWritableLocked() (int, *tcpip.Error) {
}
// Write writes data to the endpoint's peer.
-func (e *endpoint) Write(p tcpip.Payload, opts tcpip.WriteOptions) (int64, <-chan struct{}, *tcpip.Error) {
+func (e *endpoint) Write(p tcpip.Payloader, opts tcpip.WriteOptions) (int64, <-chan struct{}, *tcpip.Error) {
// Linux completely ignores any address passed to sendto(2) for TCP sockets
// (without the MSG_FASTOPEN flag). Corking is unimplemented, so opts.More
// and opts.EndOfRecord are also ignored.
@@ -821,47 +821,52 @@ func (e *endpoint) Write(p tcpip.Payload, opts tcpip.WriteOptions) (int64, <-cha
return 0, nil, err
}
- e.sndBufMu.Unlock()
- e.mu.RUnlock()
-
- // Nothing to do if the buffer is empty.
- if p.Size() == 0 {
- return 0, nil, nil
+ // We can release locks while copying data.
+ //
+ // This is not possible if atomic is set, because we can't allow the
+ // available buffer space to be consumed by some other caller while we
+ // are copying data in.
+ if !opts.Atomic {
+ e.sndBufMu.Unlock()
+ e.mu.RUnlock()
}
- // Copy in memory without holding sndBufMu so that worker goroutine can
- // make progress independent of this operation.
- v, perr := p.Get(avail)
- if perr != nil {
+ // Fetch data.
+ v, perr := p.Payload(avail)
+ if perr != nil || len(v) == 0 {
+ if opts.Atomic { // See above.
+ e.sndBufMu.Unlock()
+ e.mu.RUnlock()
+ }
+ // Note that perr may be nil if len(v) == 0.
return 0, nil, perr
}
- e.mu.RLock()
- e.sndBufMu.Lock()
+ if !opts.Atomic { // See above.
+ e.mu.RLock()
+ e.sndBufMu.Lock()
- // Because we released the lock before copying, check state again
- // to make sure the endpoint is still in a valid state for a
- // write.
- avail, err = e.isEndpointWritableLocked()
- if err != nil {
- e.sndBufMu.Unlock()
- e.mu.RUnlock()
- return 0, nil, err
- }
+ // Because we released the lock before copying, check state again
+ // to make sure the endpoint is still in a valid state for a write.
+ avail, err = e.isEndpointWritableLocked()
+ if err != nil {
+ e.sndBufMu.Unlock()
+ e.mu.RUnlock()
+ return 0, nil, err
+ }
- // Discard any excess data copied in due to avail being reduced due to a
- // simultaneous write call to the socket.
- if avail < len(v) {
- v = v[:avail]
+ // Discard any excess data copied in due to avail being reduced due
+ // to a simultaneous write call to the socket.
+ if avail < len(v) {
+ v = v[:avail]
+ }
}
// Add data to the send queue.
- l := len(v)
s := newSegmentFromView(&e.route, e.id, v)
- e.sndBufUsed += l
- e.sndBufInQueue += seqnum.Size(l)
+ e.sndBufUsed += len(v)
+ e.sndBufInQueue += seqnum.Size(len(v))
e.sndQueue.PushBack(s)
-
e.sndBufMu.Unlock()
// Release the endpoint lock to prevent deadlocks due to lock
// order inversion when acquiring workMu.
@@ -875,7 +880,8 @@ func (e *endpoint) Write(p tcpip.Payload, opts tcpip.WriteOptions) (int64, <-cha
// Let the protocol goroutine do the work.
e.sndWaker.Assert()
}
- return int64(l), nil, nil
+
+ return int64(len(v)), nil, nil
}
// Peek reads data without consuming it from the endpoint.
@@ -946,62 +952,9 @@ func (e *endpoint) zeroReceiveWindow(scale uint8) bool {
return ((e.rcvBufSize - e.rcvBufUsed) >> scale) == 0
}
-// SetSockOpt sets a socket option.
-func (e *endpoint) SetSockOpt(opt interface{}) *tcpip.Error {
- switch v := opt.(type) {
- case tcpip.DelayOption:
- if v == 0 {
- atomic.StoreUint32(&e.delay, 0)
-
- // Handle delayed data.
- e.sndWaker.Assert()
- } else {
- atomic.StoreUint32(&e.delay, 1)
- }
- return nil
-
- case tcpip.CorkOption:
- if v == 0 {
- atomic.StoreUint32(&e.cork, 0)
-
- // Handle the corked data.
- e.sndWaker.Assert()
- } else {
- atomic.StoreUint32(&e.cork, 1)
- }
- return nil
-
- case tcpip.ReuseAddressOption:
- e.mu.Lock()
- e.reuseAddr = v != 0
- e.mu.Unlock()
- return nil
-
- case tcpip.ReusePortOption:
- e.mu.Lock()
- e.reusePort = v != 0
- e.mu.Unlock()
- return nil
-
- case tcpip.QuickAckOption:
- if v == 0 {
- atomic.StoreUint32(&e.slowAck, 1)
- } else {
- atomic.StoreUint32(&e.slowAck, 0)
- }
- return nil
-
- case tcpip.MaxSegOption:
- userMSS := v
- if userMSS < header.TCPMinimumMSS || userMSS > header.TCPMaximumMSS {
- return tcpip.ErrInvalidOptionValue
- }
- e.mu.Lock()
- e.userMSS = int(userMSS)
- e.mu.Unlock()
- e.notifyProtocolGoroutine(notifyMSSChanged)
- return nil
-
+// SetSockOptInt sets a socket option.
+func (e *endpoint) SetSockOptInt(opt tcpip.SockOpt, v int) *tcpip.Error {
+ switch opt {
case tcpip.ReceiveBufferSizeOption:
// Make sure the receive buffer size is within the min and max
// allowed.
@@ -1065,6 +1018,67 @@ func (e *endpoint) SetSockOpt(opt interface{}) *tcpip.Error {
e.sndBufMu.Unlock()
return nil
+ default:
+ return nil
+ }
+}
+
+// SetSockOpt sets a socket option.
+func (e *endpoint) SetSockOpt(opt interface{}) *tcpip.Error {
+ switch v := opt.(type) {
+ case tcpip.DelayOption:
+ if v == 0 {
+ atomic.StoreUint32(&e.delay, 0)
+
+ // Handle delayed data.
+ e.sndWaker.Assert()
+ } else {
+ atomic.StoreUint32(&e.delay, 1)
+ }
+ return nil
+
+ case tcpip.CorkOption:
+ if v == 0 {
+ atomic.StoreUint32(&e.cork, 0)
+
+ // Handle the corked data.
+ e.sndWaker.Assert()
+ } else {
+ atomic.StoreUint32(&e.cork, 1)
+ }
+ return nil
+
+ case tcpip.ReuseAddressOption:
+ e.mu.Lock()
+ e.reuseAddr = v != 0
+ e.mu.Unlock()
+ return nil
+
+ case tcpip.ReusePortOption:
+ e.mu.Lock()
+ e.reusePort = v != 0
+ e.mu.Unlock()
+ return nil
+
+ case tcpip.QuickAckOption:
+ if v == 0 {
+ atomic.StoreUint32(&e.slowAck, 1)
+ } else {
+ atomic.StoreUint32(&e.slowAck, 0)
+ }
+ return nil
+
+ case tcpip.MaxSegOption:
+ userMSS := v
+ if userMSS < header.TCPMinimumMSS || userMSS > header.TCPMaximumMSS {
+ return tcpip.ErrInvalidOptionValue
+ }
+ e.mu.Lock()
+ e.userMSS = int(userMSS)
+ e.mu.Unlock()
+ e.notifyProtocolGoroutine(notifyMSSChanged)
+ return nil
+
case tcpip.V6OnlyOption:
// We only recognize this option on v6 endpoints.
if e.netProto != header.IPv6ProtocolNumber {
@@ -1176,6 +1190,18 @@ func (e *endpoint) GetSockOptInt(opt tcpip.SockOpt) (int, *tcpip.Error) {
switch opt {
case tcpip.ReceiveQueueSizeOption:
return e.readyReceiveSize()
+ case tcpip.SendBufferSizeOption:
+ e.sndBufMu.Lock()
+ v := e.sndBufSize
+ e.sndBufMu.Unlock()
+ return v, nil
+
+ case tcpip.ReceiveBufferSizeOption:
+ e.rcvListMu.Lock()
+ v := e.rcvBufSize
+ e.rcvListMu.Unlock()
+ return v, nil
+
}
return -1, tcpip.ErrUnknownProtocolOption
}
@@ -1198,18 +1224,6 @@ func (e *endpoint) GetSockOpt(opt interface{}) *tcpip.Error {
*o = header.TCPDefaultMSS
return nil
- case *tcpip.SendBufferSizeOption:
- e.sndBufMu.Lock()
- *o = tcpip.SendBufferSizeOption(e.sndBufSize)
- e.sndBufMu.Unlock()
- return nil
-
- case *tcpip.ReceiveBufferSizeOption:
- e.rcvListMu.Lock()
- *o = tcpip.ReceiveBufferSizeOption(e.rcvBufSize)
- e.rcvListMu.Unlock()
- return nil
-
case *tcpip.DelayOption:
*o = 0
if v := atomic.LoadUint32(&e.delay); v != 0 {
diff --git a/pkg/tcpip/transport/tcp/tcp_noracedetector_test.go b/pkg/tcpip/transport/tcp/tcp_noracedetector_test.go
index 272bbcdbd..9fa97528b 100644
--- a/pkg/tcpip/transport/tcp/tcp_noracedetector_test.go
+++ b/pkg/tcpip/transport/tcp/tcp_noracedetector_test.go
@@ -38,7 +38,7 @@ func TestFastRecovery(t *testing.T) {
c := context.New(t, uint32(header.TCPMinimumSize+header.IPv4MinimumSize+maxPayload))
defer c.Cleanup()
- c.CreateConnected(789, 30000, nil)
+ c.CreateConnected(789, 30000, -1 /* epRcvBuf */)
const iterations = 7
data := buffer.NewView(2 * maxPayload * (tcp.InitialCwnd << (iterations + 1)))
@@ -190,7 +190,7 @@ func TestExponentialIncreaseDuringSlowStart(t *testing.T) {
c := context.New(t, uint32(header.TCPMinimumSize+header.IPv4MinimumSize+maxPayload))
defer c.Cleanup()
- c.CreateConnected(789, 30000, nil)
+ c.CreateConnected(789, 30000, -1 /* epRcvBuf */)
const iterations = 7
data := buffer.NewView(maxPayload * (tcp.InitialCwnd << (iterations + 1)))
@@ -232,7 +232,7 @@ func TestCongestionAvoidance(t *testing.T) {
c := context.New(t, uint32(header.TCPMinimumSize+header.IPv4MinimumSize+maxPayload))
defer c.Cleanup()
- c.CreateConnected(789, 30000, nil)
+ c.CreateConnected(789, 30000, -1 /* epRcvBuf */)
const iterations = 7
data := buffer.NewView(2 * maxPayload * (tcp.InitialCwnd << (iterations + 1)))
@@ -336,7 +336,7 @@ func TestCubicCongestionAvoidance(t *testing.T) {
enableCUBIC(t, c)
- c.CreateConnected(789, 30000, nil)
+ c.CreateConnected(789, 30000, -1 /* epRcvBuf */)
const iterations = 7
data := buffer.NewView(2 * maxPayload * (tcp.InitialCwnd << (iterations + 1)))
@@ -445,7 +445,7 @@ func TestRetransmit(t *testing.T) {
c := context.New(t, uint32(header.TCPMinimumSize+header.IPv4MinimumSize+maxPayload))
defer c.Cleanup()
- c.CreateConnected(789, 30000, nil)
+ c.CreateConnected(789, 30000, -1 /* epRcvBuf */)
const iterations = 7
data := buffer.NewView(maxPayload * (tcp.InitialCwnd << (iterations + 1)))
diff --git a/pkg/tcpip/transport/tcp/tcp_test.go b/pkg/tcpip/transport/tcp/tcp_test.go
index 32bb45224..7fa5cfb6e 100644
--- a/pkg/tcpip/transport/tcp/tcp_test.go
+++ b/pkg/tcpip/transport/tcp/tcp_test.go
@@ -84,7 +84,7 @@ func TestConnectIncrementActiveConnection(t *testing.T) {
stats := c.Stack().Stats()
want := stats.TCP.ActiveConnectionOpenings.Value() + 1
- c.CreateConnected(789, 30000, nil)
+ c.CreateConnected(789, 30000, -1 /* epRcvBuf */)
if got := stats.TCP.ActiveConnectionOpenings.Value(); got != want {
t.Errorf("got stats.TCP.ActtiveConnectionOpenings.Value() = %v, want = %v", got, want)
}
@@ -97,7 +97,7 @@ func TestConnectDoesNotIncrementFailedConnectionAttempts(t *testing.T) {
stats := c.Stack().Stats()
want := stats.TCP.FailedConnectionAttempts.Value()
- c.CreateConnected(789, 30000, nil)
+ c.CreateConnected(789, 30000, -1 /* epRcvBuf */)
if got := stats.TCP.FailedConnectionAttempts.Value(); got != want {
t.Errorf("got stats.TCP.FailedConnectionOpenings.Value() = %v, want = %v", got, want)
}
@@ -131,7 +131,7 @@ func TestTCPSegmentsSentIncrement(t *testing.T) {
stats := c.Stack().Stats()
// SYN and ACK
want := stats.TCP.SegmentsSent.Value() + 2
- c.CreateConnected(789, 30000, nil)
+ c.CreateConnected(789, 30000, -1 /* epRcvBuf */)
if got := stats.TCP.SegmentsSent.Value(); got != want {
t.Errorf("got stats.TCP.SegmentsSent.Value() = %v, want = %v", got, want)
@@ -299,7 +299,7 @@ func TestTCPResetsReceivedIncrement(t *testing.T) {
want := stats.TCP.ResetsReceived.Value() + 1
iss := seqnum.Value(789)
rcvWnd := seqnum.Size(30000)
- c.CreateConnected(iss, rcvWnd, nil)
+ c.CreateConnected(iss, rcvWnd, -1 /* epRcvBuf */)
c.SendPacket(nil, &context.Headers{
SrcPort: context.TestPort,
@@ -323,7 +323,7 @@ func TestTCPResetsDoNotGenerateResets(t *testing.T) {
want := stats.TCP.ResetsReceived.Value() + 1
iss := seqnum.Value(789)
rcvWnd := seqnum.Size(30000)
- c.CreateConnected(iss, rcvWnd, nil)
+ c.CreateConnected(iss, rcvWnd, -1 /* epRcvBuf */)
c.SendPacket(nil, &context.Headers{
SrcPort: context.TestPort,
@@ -344,14 +344,14 @@ func TestActiveHandshake(t *testing.T) {
c := context.New(t, defaultMTU)
defer c.Cleanup()
- c.CreateConnected(789, 30000, nil)
+ c.CreateConnected(789, 30000, -1 /* epRcvBuf */)
}
func TestNonBlockingClose(t *testing.T) {
c := context.New(t, defaultMTU)
defer c.Cleanup()
- c.CreateConnected(789, 30000, nil)
+ c.CreateConnected(789, 30000, -1 /* epRcvBuf */)
ep := c.EP
c.EP = nil
@@ -367,7 +367,7 @@ func TestConnectResetAfterClose(t *testing.T) {
c := context.New(t, defaultMTU)
defer c.Cleanup()
- c.CreateConnected(789, 30000, nil)
+ c.CreateConnected(789, 30000, -1 /* epRcvBuf */)
ep := c.EP
c.EP = nil
@@ -417,7 +417,7 @@ func TestSimpleReceive(t *testing.T) {
c := context.New(t, defaultMTU)
defer c.Cleanup()
- c.CreateConnected(789, 30000, nil)
+ c.CreateConnected(789, 30000, -1 /* epRcvBuf */)
we, ch := waiter.NewChannelEntry(nil)
c.WQ.EventRegister(&we, waiter.EventIn)
@@ -469,7 +469,7 @@ func TestOutOfOrderReceive(t *testing.T) {
c := context.New(t, defaultMTU)
defer c.Cleanup()
- c.CreateConnected(789, 30000, nil)
+ c.CreateConnected(789, 30000, -1 /* epRcvBuf */)
we, ch := waiter.NewChannelEntry(nil)
c.WQ.EventRegister(&we, waiter.EventIn)
@@ -557,8 +557,7 @@ func TestOutOfOrderFlood(t *testing.T) {
defer c.Cleanup()
// Create a new connection with initial window size of 10.
- opt := tcpip.ReceiveBufferSizeOption(10)
- c.CreateConnected(789, 30000, &opt)
+ c.CreateConnected(789, 30000, 10)
if _, _, err := c.EP.Read(nil); err != tcpip.ErrWouldBlock {
t.Fatalf("got c.EP.Read(nil) = %v, want = %v", err, tcpip.ErrWouldBlock)
@@ -631,7 +630,7 @@ func TestRstOnCloseWithUnreadData(t *testing.T) {
c := context.New(t, defaultMTU)
defer c.Cleanup()
- c.CreateConnected(789, 30000, nil)
+ c.CreateConnected(789, 30000, -1 /* epRcvBuf */)
we, ch := waiter.NewChannelEntry(nil)
c.WQ.EventRegister(&we, waiter.EventIn)
@@ -700,7 +699,7 @@ func TestRstOnCloseWithUnreadDataFinConvertRst(t *testing.T) {
c := context.New(t, defaultMTU)
defer c.Cleanup()
- c.CreateConnected(789, 30000, nil)
+ c.CreateConnected(789, 30000, -1 /* epRcvBuf */)
we, ch := waiter.NewChannelEntry(nil)
c.WQ.EventRegister(&we, waiter.EventIn)
@@ -785,7 +784,7 @@ func TestShutdownRead(t *testing.T) {
c := context.New(t, defaultMTU)
defer c.Cleanup()
- c.CreateConnected(789, 30000, nil)
+ c.CreateConnected(789, 30000, -1 /* epRcvBuf */)
if _, _, err := c.EP.Read(nil); err != tcpip.ErrWouldBlock {
t.Fatalf("got c.EP.Read(nil) = %v, want = %v", err, tcpip.ErrWouldBlock)
@@ -804,8 +803,7 @@ func TestFullWindowReceive(t *testing.T) {
c := context.New(t, defaultMTU)
defer c.Cleanup()
- opt := tcpip.ReceiveBufferSizeOption(10)
- c.CreateConnected(789, 30000, &opt)
+ c.CreateConnected(789, 30000, 10)
we, ch := waiter.NewChannelEntry(nil)
c.WQ.EventRegister(&we, waiter.EventIn)
@@ -872,11 +870,9 @@ func TestNoWindowShrinking(t *testing.T) {
defer c.Cleanup()
// Start off with a window size of 10, then shrink it to 5.
- opt := tcpip.ReceiveBufferSizeOption(10)
- c.CreateConnected(789, 30000, &opt)
+ c.CreateConnected(789, 30000, 10)
- opt = 5
- if err := c.EP.SetSockOpt(opt); err != nil {
+ if err := c.EP.SetSockOptInt(tcpip.ReceiveBufferSizeOption, 5); err != nil {
t.Fatalf("SetSockOpt failed: %v", err)
}
@@ -976,7 +972,7 @@ func TestSimpleSend(t *testing.T) {
c := context.New(t, defaultMTU)
defer c.Cleanup()
- c.CreateConnected(789, 30000, nil)
+ c.CreateConnected(789, 30000, -1 /* epRcvBuf */)
data := []byte{1, 2, 3}
view := buffer.NewView(len(data))
@@ -1017,7 +1013,7 @@ func TestZeroWindowSend(t *testing.T) {
c := context.New(t, defaultMTU)
defer c.Cleanup()
- c.CreateConnected(789, 0, nil)
+ c.CreateConnected(789, 0, -1 /* epRcvBuf */)
data := []byte{1, 2, 3}
view := buffer.NewView(len(data))
@@ -1075,8 +1071,7 @@ func TestScaledWindowConnect(t *testing.T) {
defer c.Cleanup()
// Set the window size greater than the maximum non-scaled window.
- opt := tcpip.ReceiveBufferSizeOption(65535 * 3)
- c.CreateConnectedWithRawOptions(789, 30000, &opt, []byte{
+ c.CreateConnectedWithRawOptions(789, 30000, 65535*3, []byte{
header.TCPOptionWS, 3, 0, header.TCPOptionNOP,
})
@@ -1110,8 +1105,7 @@ func TestNonScaledWindowConnect(t *testing.T) {
defer c.Cleanup()
// Set the window size greater than the maximum non-scaled window.
- opt := tcpip.ReceiveBufferSizeOption(65535 * 3)
- c.CreateConnected(789, 30000, &opt)
+ c.CreateConnected(789, 30000, 65535*3)
data := []byte{1, 2, 3}
view := buffer.NewView(len(data))
@@ -1151,7 +1145,7 @@ func TestScaledWindowAccept(t *testing.T) {
defer ep.Close()
// Set the window size greater than the maximum non-scaled window.
- if err := ep.SetSockOpt(tcpip.ReceiveBufferSizeOption(65535 * 3)); err != nil {
+ if err := ep.SetSockOptInt(tcpip.ReceiveBufferSizeOption, 65535*3); err != nil {
t.Fatalf("SetSockOpt failed failed: %v", err)
}
@@ -1224,7 +1218,7 @@ func TestNonScaledWindowAccept(t *testing.T) {
defer ep.Close()
// Set the window size greater than the maximum non-scaled window.
- if err := ep.SetSockOpt(tcpip.ReceiveBufferSizeOption(65535 * 3)); err != nil {
+ if err := ep.SetSockOptInt(tcpip.ReceiveBufferSizeOption, 65535*3); err != nil {
t.Fatalf("SetSockOpt failed failed: %v", err)
}
@@ -1293,8 +1287,7 @@ func TestZeroScaledWindowReceive(t *testing.T) {
// Set the window size such that a window scale of 4 will be used.
const wnd = 65535 * 10
const ws = uint32(4)
- opt := tcpip.ReceiveBufferSizeOption(wnd)
- c.CreateConnectedWithRawOptions(789, 30000, &opt, []byte{
+ c.CreateConnectedWithRawOptions(789, 30000, wnd, []byte{
header.TCPOptionWS, 3, 0, header.TCPOptionNOP,
})
@@ -1399,7 +1392,7 @@ func TestSegmentMerging(t *testing.T) {
c := context.New(t, defaultMTU)
defer c.Cleanup()
- c.CreateConnected(789, 30000, nil)
+ c.CreateConnected(789, 30000, -1 /* epRcvBuf */)
// Prevent the endpoint from processing packets.
test.stop(c.EP)
@@ -1449,7 +1442,7 @@ func TestDelay(t *testing.T) {
c := context.New(t, defaultMTU)
defer c.Cleanup()
- c.CreateConnected(789, 30000, nil)
+ c.CreateConnected(789, 30000, -1 /* epRcvBuf */)
c.EP.SetSockOpt(tcpip.DelayOption(1))
@@ -1497,7 +1490,7 @@ func TestUndelay(t *testing.T) {
c := context.New(t, defaultMTU)
defer c.Cleanup()
- c.CreateConnected(789, 30000, nil)
+ c.CreateConnected(789, 30000, -1 /* epRcvBuf */)
c.EP.SetSockOpt(tcpip.DelayOption(1))
@@ -1579,7 +1572,7 @@ func TestMSSNotDelayed(t *testing.T) {
c := context.New(t, defaultMTU)
defer c.Cleanup()
- c.CreateConnectedWithRawOptions(789, 30000, nil, []byte{
+ c.CreateConnectedWithRawOptions(789, 30000, -1 /* epRcvBuf */, []byte{
header.TCPOptionMSS, 4, byte(maxPayload / 256), byte(maxPayload % 256),
})
@@ -1695,7 +1688,7 @@ func TestSendGreaterThanMTU(t *testing.T) {
c := context.New(t, uint32(header.TCPMinimumSize+header.IPv4MinimumSize+maxPayload))
defer c.Cleanup()
- c.CreateConnected(789, 30000, nil)
+ c.CreateConnected(789, 30000, -1 /* epRcvBuf */)
testBrokenUpWrite(t, c, maxPayload)
}
@@ -1704,7 +1697,7 @@ func TestActiveSendMSSLessThanMTU(t *testing.T) {
c := context.New(t, 65535)
defer c.Cleanup()
- c.CreateConnectedWithRawOptions(789, 30000, nil, []byte{
+ c.CreateConnectedWithRawOptions(789, 30000, -1 /* epRcvBuf */, []byte{
header.TCPOptionMSS, 4, byte(maxPayload / 256), byte(maxPayload % 256),
})
testBrokenUpWrite(t, c, maxPayload)
@@ -1727,7 +1720,7 @@ func TestPassiveSendMSSLessThanMTU(t *testing.T) {
// Set the buffer size to a deterministic size so that we can check the
// window scaling option.
const rcvBufferSize = 0x20000
- if err := ep.SetSockOpt(tcpip.ReceiveBufferSizeOption(rcvBufferSize)); err != nil {
+ if err := ep.SetSockOptInt(tcpip.ReceiveBufferSizeOption, rcvBufferSize); err != nil {
t.Fatalf("SetSockOpt failed failed: %v", err)
}
@@ -1871,7 +1864,7 @@ func TestSynOptionsOnActiveConnect(t *testing.T) {
// window scaling option.
const rcvBufferSize = 0x20000
const wndScale = 2
- if err := c.EP.SetSockOpt(tcpip.ReceiveBufferSizeOption(rcvBufferSize)); err != nil {
+ if err := c.EP.SetSockOptInt(tcpip.ReceiveBufferSizeOption, rcvBufferSize); err != nil {
t.Fatalf("SetSockOpt failed failed: %v", err)
}
@@ -1973,7 +1966,7 @@ func TestReceiveOnResetConnection(t *testing.T) {
c := context.New(t, defaultMTU)
defer c.Cleanup()
- c.CreateConnected(789, 30000, nil)
+ c.CreateConnected(789, 30000, -1 /* epRcvBuf */)
// Send RST segment.
c.SendPacket(nil, &context.Headers{
@@ -2010,7 +2003,7 @@ func TestSendOnResetConnection(t *testing.T) {
c := context.New(t, defaultMTU)
defer c.Cleanup()
- c.CreateConnected(789, 30000, nil)
+ c.CreateConnected(789, 30000, -1 /* epRcvBuf */)
// Send RST segment.
c.SendPacket(nil, &context.Headers{
@@ -2035,7 +2028,7 @@ func TestFinImmediately(t *testing.T) {
c := context.New(t, defaultMTU)
defer c.Cleanup()
- c.CreateConnected(789, 30000, nil)
+ c.CreateConnected(789, 30000, -1 /* epRcvBuf */)
// Shutdown immediately, check that we get a FIN.
if err := c.EP.Shutdown(tcpip.ShutdownWrite); err != nil {
@@ -2078,7 +2071,7 @@ func TestFinRetransmit(t *testing.T) {
c := context.New(t, defaultMTU)
defer c.Cleanup()
- c.CreateConnected(789, 30000, nil)
+ c.CreateConnected(789, 30000, -1 /* epRcvBuf */)
// Shutdown immediately, check that we get a FIN.
if err := c.EP.Shutdown(tcpip.ShutdownWrite); err != nil {
@@ -2132,7 +2125,7 @@ func TestFinWithNoPendingData(t *testing.T) {
c := context.New(t, defaultMTU)
defer c.Cleanup()
- c.CreateConnected(789, 30000, nil)
+ c.CreateConnected(789, 30000, -1 /* epRcvBuf */)
// Write something out, and have it acknowledged.
view := buffer.NewView(10)
@@ -2203,7 +2196,7 @@ func TestFinWithPendingDataCwndFull(t *testing.T) {
c := context.New(t, defaultMTU)
defer c.Cleanup()
- c.CreateConnected(789, 30000, nil)
+ c.CreateConnected(789, 30000, -1 /* epRcvBuf */)
// Write enough segments to fill the congestion window before ACK'ing
// any of them.
@@ -2291,7 +2284,7 @@ func TestFinWithPendingData(t *testing.T) {
c := context.New(t, defaultMTU)
defer c.Cleanup()
- c.CreateConnected(789, 30000, nil)
+ c.CreateConnected(789, 30000, -1 /* epRcvBuf */)
// Write something out, and acknowledge it to get cwnd to 2.
view := buffer.NewView(10)
@@ -2377,7 +2370,7 @@ func TestFinWithPartialAck(t *testing.T) {
c := context.New(t, defaultMTU)
defer c.Cleanup()
- c.CreateConnected(789, 30000, nil)
+ c.CreateConnected(789, 30000, -1 /* epRcvBuf */)
// Write something out, and acknowledge it to get cwnd to 2. Also send
// FIN from the test side.
@@ -2509,7 +2502,7 @@ func scaledSendWindow(t *testing.T, scale uint8) {
defer c.Cleanup()
maxPayload := defaultMTU - header.IPv4MinimumSize - header.TCPMinimumSize
- c.CreateConnectedWithRawOptions(789, 0, nil, []byte{
+ c.CreateConnectedWithRawOptions(789, 0, -1 /* epRcvBuf */, []byte{
header.TCPOptionMSS, 4, byte(maxPayload / 256), byte(maxPayload % 256),
header.TCPOptionWS, 3, scale, header.TCPOptionNOP,
})
@@ -2559,7 +2552,7 @@ func TestScaledSendWindow(t *testing.T) {
func TestReceivedValidSegmentCountIncrement(t *testing.T) {
c := context.New(t, defaultMTU)
defer c.Cleanup()
- c.CreateConnected(789, 30000, nil)
+ c.CreateConnected(789, 30000, -1 /* epRcvBuf */)
stats := c.Stack().Stats()
want := stats.TCP.ValidSegmentsReceived.Value() + 1
@@ -2580,7 +2573,7 @@ func TestReceivedValidSegmentCountIncrement(t *testing.T) {
func TestReceivedInvalidSegmentCountIncrement(t *testing.T) {
c := context.New(t, defaultMTU)
defer c.Cleanup()
- c.CreateConnected(789, 30000, nil)
+ c.CreateConnected(789, 30000, -1 /* epRcvBuf */)
stats := c.Stack().Stats()
want := stats.TCP.InvalidSegmentsReceived.Value() + 1
vv := c.BuildSegment(nil, &context.Headers{
@@ -2604,7 +2597,7 @@ func TestReceivedInvalidSegmentCountIncrement(t *testing.T) {
func TestReceivedIncorrectChecksumIncrement(t *testing.T) {
c := context.New(t, defaultMTU)
defer c.Cleanup()
- c.CreateConnected(789, 30000, nil)
+ c.CreateConnected(789, 30000, -1 /* epRcvBuf */)
stats := c.Stack().Stats()
want := stats.TCP.ChecksumErrors.Value() + 1
vv := c.BuildSegment([]byte{0x1, 0x2, 0x3}, &context.Headers{
@@ -2635,7 +2628,7 @@ func TestReceivedSegmentQueuing(t *testing.T) {
c := context.New(t, defaultMTU)
defer c.Cleanup()
- c.CreateConnected(789, 30000, nil)
+ c.CreateConnected(789, 30000, -1 /* epRcvBuf */)
// Send 200 segments.
data := []byte{1, 2, 3}
@@ -2681,7 +2674,7 @@ func TestReadAfterClosedState(t *testing.T) {
c := context.New(t, defaultMTU)
defer c.Cleanup()
- c.CreateConnected(789, 30000, nil)
+ c.CreateConnected(789, 30000, -1 /* epRcvBuf */)
we, ch := waiter.NewChannelEntry(nil)
c.WQ.EventRegister(&we, waiter.EventIn)
@@ -2856,8 +2849,8 @@ func TestReusePort(t *testing.T) {
func checkRecvBufferSize(t *testing.T, ep tcpip.Endpoint, v int) {
t.Helper()
- var s tcpip.ReceiveBufferSizeOption
- if err := ep.GetSockOpt(&s); err != nil {
+ s, err := ep.GetSockOptInt(tcpip.ReceiveBufferSizeOption)
+ if err != nil {
t.Fatalf("GetSockOpt failed: %v", err)
}
@@ -2869,8 +2862,8 @@ func checkRecvBufferSize(t *testing.T, ep tcpip.Endpoint, v int) {
func checkSendBufferSize(t *testing.T, ep tcpip.Endpoint, v int) {
t.Helper()
- var s tcpip.SendBufferSizeOption
- if err := ep.GetSockOpt(&s); err != nil {
+ s, err := ep.GetSockOptInt(tcpip.SendBufferSizeOption)
+ if err != nil {
t.Fatalf("GetSockOpt failed: %v", err)
}
@@ -2945,26 +2938,26 @@ func TestMinMaxBufferSizes(t *testing.T) {
}
// Set values below the min.
- if err := ep.SetSockOpt(tcpip.ReceiveBufferSizeOption(199)); err != nil {
+ if err := ep.SetSockOptInt(tcpip.ReceiveBufferSizeOption, 199); err != nil {
t.Fatalf("GetSockOpt failed: %v", err)
}
checkRecvBufferSize(t, ep, 200)
- if err := ep.SetSockOpt(tcpip.SendBufferSizeOption(299)); err != nil {
+ if err := ep.SetSockOptInt(tcpip.SendBufferSizeOption, 299); err != nil {
t.Fatalf("GetSockOpt failed: %v", err)
}
checkSendBufferSize(t, ep, 300)
// Set values above the max.
- if err := ep.SetSockOpt(tcpip.ReceiveBufferSizeOption(1 + tcp.DefaultReceiveBufferSize*20)); err != nil {
+ if err := ep.SetSockOptInt(tcpip.ReceiveBufferSizeOption, 1+tcp.DefaultReceiveBufferSize*20); err != nil {
t.Fatalf("GetSockOpt failed: %v", err)
}
checkRecvBufferSize(t, ep, tcp.DefaultReceiveBufferSize*20)
- if err := ep.SetSockOpt(tcpip.SendBufferSizeOption(1 + tcp.DefaultSendBufferSize*30)); err != nil {
+ if err := ep.SetSockOptInt(tcpip.SendBufferSizeOption, 1+tcp.DefaultSendBufferSize*30); err != nil {
t.Fatalf("GetSockOpt failed: %v", err)
}
@@ -3231,7 +3224,7 @@ func TestPathMTUDiscovery(t *testing.T) {
// Create new connection with MSS of 1460.
const maxPayload = 1500 - header.TCPMinimumSize - header.IPv4MinimumSize
- c.CreateConnectedWithRawOptions(789, 30000, nil, []byte{
+ c.CreateConnectedWithRawOptions(789, 30000, -1 /* epRcvBuf */, []byte{
header.TCPOptionMSS, 4, byte(maxPayload / 256), byte(maxPayload % 256),
})
@@ -3308,7 +3301,7 @@ func TestTCPEndpointProbe(t *testing.T) {
invoked <- struct{}{}
})
- c.CreateConnected(789, 30000, nil)
+ c.CreateConnected(789, 30000, -1 /* epRcvBuf */)
data := []byte{1, 2, 3}
c.SendPacket(data, &context.Headers{
@@ -3482,7 +3475,7 @@ func TestKeepalive(t *testing.T) {
c := context.New(t, defaultMTU)
defer c.Cleanup()
- c.CreateConnected(789, 30000, nil)
+ c.CreateConnected(789, 30000, -1 /* epRcvBuf */)
c.EP.SetSockOpt(tcpip.KeepaliveIdleOption(10 * time.Millisecond))
c.EP.SetSockOpt(tcpip.KeepaliveIntervalOption(10 * time.Millisecond))
diff --git a/pkg/tcpip/transport/tcp/testing/context/context.go b/pkg/tcpip/transport/tcp/testing/context/context.go
index 18c707a57..78eff5c3a 100644
--- a/pkg/tcpip/transport/tcp/testing/context/context.go
+++ b/pkg/tcpip/transport/tcp/testing/context/context.go
@@ -150,11 +150,12 @@ func New(t *testing.T, mtu uint32) *Context {
// Some of the congestion control tests send up to 640 packets, we so
// set the channel size to 1000.
- id, linkEP := channel.New(1000, mtu, "")
+ ep := channel.New(1000, mtu, "")
+ wep := stack.LinkEndpoint(ep)
if testing.Verbose() {
- id = sniffer.New(id)
+ wep = sniffer.New(ep)
}
- if err := s.CreateNIC(1, id); err != nil {
+ if err := s.CreateNIC(1, wep); err != nil {
t.Fatalf("CreateNIC failed: %v", err)
}
@@ -180,7 +181,7 @@ func New(t *testing.T, mtu uint32) *Context {
return &Context{
t: t,
s: s,
- linkEP: linkEP,
+ linkEP: ep,
WindowScale: uint8(tcp.FindWndScale(tcp.DefaultReceiveBufferSize)),
}
}
@@ -511,7 +512,7 @@ func (c *Context) SendV6Packet(payload []byte, h *Headers) {
}
// CreateConnected creates a connected TCP endpoint.
-func (c *Context) CreateConnected(iss seqnum.Value, rcvWnd seqnum.Size, epRcvBuf *tcpip.ReceiveBufferSizeOption) {
+func (c *Context) CreateConnected(iss seqnum.Value, rcvWnd seqnum.Size, epRcvBuf int) {
c.CreateConnectedWithRawOptions(iss, rcvWnd, epRcvBuf, nil)
}
@@ -589,7 +590,7 @@ func (c *Context) Connect(iss seqnum.Value, rcvWnd seqnum.Size, options []byte)
//
// It also sets the receive buffer for the endpoint to the specified
// value in epRcvBuf.
-func (c *Context) CreateConnectedWithRawOptions(iss seqnum.Value, rcvWnd seqnum.Size, epRcvBuf *tcpip.ReceiveBufferSizeOption, options []byte) {
+func (c *Context) CreateConnectedWithRawOptions(iss seqnum.Value, rcvWnd seqnum.Size, epRcvBuf int, options []byte) {
// Create TCP endpoint.
var err *tcpip.Error
c.EP, err = c.s.NewEndpoint(tcp.ProtocolNumber, ipv4.ProtocolNumber, &c.WQ)
@@ -597,8 +598,8 @@ func (c *Context) CreateConnectedWithRawOptions(iss seqnum.Value, rcvWnd seqnum.
c.t.Fatalf("NewEndpoint failed: %v", err)
}
- if epRcvBuf != nil {
- if err := c.EP.SetSockOpt(*epRcvBuf); err != nil {
+ if epRcvBuf != -1 {
+ if err := c.EP.SetSockOptInt(tcpip.ReceiveBufferSizeOption, epRcvBuf); err != nil {
c.t.Fatalf("SetSockOpt failed failed: %v", err)
}
}
diff --git a/pkg/tcpip/transport/tcpconntrack/BUILD b/pkg/tcpip/transport/tcpconntrack/BUILD
index 4bec48c0f..43fcc27f0 100644
--- a/pkg/tcpip/transport/tcpconntrack/BUILD
+++ b/pkg/tcpip/transport/tcpconntrack/BUILD
@@ -1,4 +1,5 @@
-load("//tools/go_stateify:defs.bzl", "go_library", "go_test")
+load("//tools/go_stateify:defs.bzl", "go_library")
+load("@io_bazel_rules_go//go:def.bzl", "go_test")
package(licenses = ["notice"])
diff --git a/pkg/tcpip/transport/udp/BUILD b/pkg/tcpip/transport/udp/BUILD
index ac2666f69..c1ca22b35 100644
--- a/pkg/tcpip/transport/udp/BUILD
+++ b/pkg/tcpip/transport/udp/BUILD
@@ -1,7 +1,9 @@
+load("@io_bazel_rules_go//go:def.bzl", "go_test")
+
package(licenses = ["notice"])
load("//tools/go_generics:defs.bzl", "go_template_instance")
-load("//tools/go_stateify:defs.bzl", "go_library", "go_test")
+load("//tools/go_stateify:defs.bzl", "go_library")
go_template_instance(
name = "udp_packet_list",
diff --git a/pkg/tcpip/transport/udp/endpoint.go b/pkg/tcpip/transport/udp/endpoint.go
index 66455ef46..0bec7e62d 100644
--- a/pkg/tcpip/transport/udp/endpoint.go
+++ b/pkg/tcpip/transport/udp/endpoint.go
@@ -15,7 +15,6 @@
package udp
import (
- "math"
"sync"
"gvisor.dev/gvisor/pkg/tcpip"
@@ -277,17 +276,12 @@ func (e *endpoint) connectRoute(nicid tcpip.NICID, addr tcpip.FullAddress, netPr
// Write writes data to the endpoint's peer. This method does not block
// if the data cannot be written.
-func (e *endpoint) Write(p tcpip.Payload, opts tcpip.WriteOptions) (int64, <-chan struct{}, *tcpip.Error) {
+func (e *endpoint) Write(p tcpip.Payloader, opts tcpip.WriteOptions) (int64, <-chan struct{}, *tcpip.Error) {
// MSG_MORE is unimplemented. (This also means that MSG_EOR is a no-op.)
if opts.More {
return 0, nil, tcpip.ErrInvalidOptionValue
}
- if p.Size() > math.MaxUint16 {
- // Payload can't possibly fit in a packet.
- return 0, nil, tcpip.ErrMessageTooLong
- }
-
to := opts.To
e.mu.RLock()
@@ -370,10 +364,14 @@ func (e *endpoint) Write(p tcpip.Payload, opts tcpip.WriteOptions) (int64, <-cha
}
}
- v, err := p.Get(p.Size())
+ v, err := p.FullPayload()
if err != nil {
return 0, nil, err
}
+ if len(v) > header.UDPMaximumPacketSize {
+ // Payload can't possibly fit in a packet.
+ return 0, nil, tcpip.ErrMessageTooLong
+ }
ttl := route.DefaultTTL()
if header.IsV4MulticastAddress(route.RemoteAddress) || header.IsV6MulticastAddress(route.RemoteAddress) {
@@ -391,7 +389,12 @@ func (e *endpoint) Peek([][]byte) (int64, tcpip.ControlMessages, *tcpip.Error) {
return 0, tcpip.ControlMessages{}, nil
}
-// SetSockOpt sets a socket option. Currently not supported.
+// SetSockOptInt implements tcpip.Endpoint.SetSockOptInt.
+func (e *endpoint) SetSockOptInt(opt tcpip.SockOpt, v int) *tcpip.Error {
+ return nil
+}
+
+// SetSockOpt implements tcpip.Endpoint.SetSockOpt.
func (e *endpoint) SetSockOpt(opt interface{}) *tcpip.Error {
switch v := opt.(type) {
case tcpip.V6OnlyOption:
@@ -570,7 +573,20 @@ func (e *endpoint) GetSockOptInt(opt tcpip.SockOpt) (int, *tcpip.Error) {
}
e.rcvMu.Unlock()
return v, nil
+
+ case tcpip.SendBufferSizeOption:
+ e.mu.Lock()
+ v := e.sndBufSize
+ e.mu.Unlock()
+ return v, nil
+
+ case tcpip.ReceiveBufferSizeOption:
+ e.rcvMu.Lock()
+ v := e.rcvBufSizeMax
+ e.rcvMu.Unlock()
+ return v, nil
}
+
return -1, tcpip.ErrUnknownProtocolOption
}
@@ -580,18 +596,6 @@ func (e *endpoint) GetSockOpt(opt interface{}) *tcpip.Error {
case tcpip.ErrorOption:
return nil
- case *tcpip.SendBufferSizeOption:
- e.mu.Lock()
- *o = tcpip.SendBufferSizeOption(e.sndBufSize)
- e.mu.Unlock()
- return nil
-
- case *tcpip.ReceiveBufferSizeOption:
- e.rcvMu.Lock()
- *o = tcpip.ReceiveBufferSizeOption(e.rcvBufSizeMax)
- e.rcvMu.Unlock()
- return nil
-
case *tcpip.V6OnlyOption:
// We only recognize this option on v6 endpoints.
if e.netProto != header.IPv6ProtocolNumber {
@@ -747,6 +751,10 @@ func (e *endpoint) Disconnect() *tcpip.Error {
}
e.state = StateBound
} else {
+ if e.id.LocalPort != 0 {
+ // Release the ephemeral port.
+ e.stack.ReleasePort(e.effectiveNetProtos, ProtocolNumber, e.id.LocalAddress, e.id.LocalPort)
+ }
e.state = StateInitial
}
diff --git a/pkg/tcpip/transport/udp/udp_test.go b/pkg/tcpip/transport/udp/udp_test.go
index 995d6e8a1..c6deab892 100644
--- a/pkg/tcpip/transport/udp/udp_test.go
+++ b/pkg/tcpip/transport/udp/udp_test.go
@@ -275,12 +275,13 @@ func newDualTestContext(t *testing.T, mtu uint32) *testContext {
t.Helper()
s := stack.New([]string{ipv4.ProtocolName, ipv6.ProtocolName}, []string{udp.ProtocolName}, stack.Options{})
+ ep := channel.New(256, mtu, "")
+ wep := stack.LinkEndpoint(ep)
- id, linkEP := channel.New(256, mtu, "")
if testing.Verbose() {
- id = sniffer.New(id)
+ wep = sniffer.New(ep)
}
- if err := s.CreateNIC(1, id); err != nil {
+ if err := s.CreateNIC(1, wep); err != nil {
t.Fatalf("CreateNIC failed: %v", err)
}
@@ -306,7 +307,7 @@ func newDualTestContext(t *testing.T, mtu uint32) *testContext {
return &testContext{
t: t,
s: s,
- linkEP: linkEP,
+ linkEP: ep,
}
}
diff --git a/pkg/tmutex/BUILD b/pkg/tmutex/BUILD
index 98d51cc69..6afdb29b7 100644
--- a/pkg/tmutex/BUILD
+++ b/pkg/tmutex/BUILD
@@ -1,4 +1,5 @@
-load("//tools/go_stateify:defs.bzl", "go_library", "go_test")
+load("//tools/go_stateify:defs.bzl", "go_library")
+load("@io_bazel_rules_go//go:def.bzl", "go_test")
package(licenses = ["notice"])
diff --git a/pkg/unet/BUILD b/pkg/unet/BUILD
index cbd92fc05..8f6f180e5 100644
--- a/pkg/unet/BUILD
+++ b/pkg/unet/BUILD
@@ -1,4 +1,5 @@
-load("//tools/go_stateify:defs.bzl", "go_library", "go_test")
+load("//tools/go_stateify:defs.bzl", "go_library")
+load("@io_bazel_rules_go//go:def.bzl", "go_test")
package(licenses = ["notice"])
diff --git a/pkg/urpc/BUILD b/pkg/urpc/BUILD
index b7f505a84..b6bbb0ea2 100644
--- a/pkg/urpc/BUILD
+++ b/pkg/urpc/BUILD
@@ -1,4 +1,5 @@
-load("//tools/go_stateify:defs.bzl", "go_library", "go_test")
+load("//tools/go_stateify:defs.bzl", "go_library")
+load("@io_bazel_rules_go//go:def.bzl", "go_test")
package(licenses = ["notice"])
diff --git a/pkg/waiter/BUILD b/pkg/waiter/BUILD
index 9173dfd0f..8dc88becb 100644
--- a/pkg/waiter/BUILD
+++ b/pkg/waiter/BUILD
@@ -1,7 +1,9 @@
+load("@io_bazel_rules_go//go:def.bzl", "go_test")
+
package(licenses = ["notice"])
load("//tools/go_generics:defs.bzl", "go_template_instance")
-load("//tools/go_stateify:defs.bzl", "go_library", "go_test")
+load("//tools/go_stateify:defs.bzl", "go_library")
go_template_instance(
name = "waiter_list",
diff --git a/runsc/BUILD b/runsc/BUILD
index a2a465e1e..5e7dacb87 100644
--- a/runsc/BUILD
+++ b/runsc/BUILD
@@ -13,7 +13,7 @@ go_binary(
visibility = [
"//visibility:public",
],
- x_defs = {"main.version": "{VERSION}"},
+ x_defs = {"main.version": "{STABLE_VERSION}"},
deps = [
"//pkg/log",
"//pkg/refs",
@@ -46,7 +46,7 @@ go_binary(
visibility = [
"//visibility:public",
],
- x_defs = {"main.version": "{VERSION}"},
+ x_defs = {"main.version": "{STABLE_VERSION}"},
deps = [
"//pkg/log",
"//pkg/refs",
@@ -101,3 +101,10 @@ pkg_deb(
"//visibility:public",
],
)
+
+sh_test(
+ name = "version_test",
+ size = "small",
+ srcs = ["version_test.sh"],
+ data = [":runsc"],
+)
diff --git a/runsc/boot/BUILD b/runsc/boot/BUILD
index 588bb8851..54d1ab129 100644
--- a/runsc/boot/BUILD
+++ b/runsc/boot/BUILD
@@ -109,6 +109,7 @@ go_test(
"//pkg/sentry/arch:registers_go_proto",
"//pkg/sentry/context/contexttest",
"//pkg/sentry/fs",
+ "//pkg/sentry/kernel/auth",
"//pkg/unet",
"//runsc/fsgofer",
"@com_github_opencontainers_runtime-spec//specs-go:go_default_library",
diff --git a/runsc/boot/config.go b/runsc/boot/config.go
index 05b8f8761..31103367d 100644
--- a/runsc/boot/config.go
+++ b/runsc/boot/config.go
@@ -211,12 +211,6 @@ type Config struct {
// RestoreFile is the path to the saved container image
RestoreFile string
- // TestOnlyAllowRunAsCurrentUserWithoutChroot should only be used in
- // tests. It allows runsc to start the sandbox process as the current
- // user, and without chrooting the sandbox process. This can be
- // necessary in test environments that have limited capabilities.
- TestOnlyAllowRunAsCurrentUserWithoutChroot bool
-
// NumNetworkChannels controls the number of AF_PACKET sockets that map
// to the same underlying network device. This allows netstack to better
// scale for high throughput use cases.
@@ -233,6 +227,19 @@ type Config struct {
// ReferenceLeakMode sets reference leak check mode
ReferenceLeakMode refs.LeakMode
+
+ // TestOnlyAllowRunAsCurrentUserWithoutChroot should only be used in
+ // tests. It allows runsc to start the sandbox process as the current
+ // user, and without chrooting the sandbox process. This can be
+ // necessary in test environments that have limited capabilities.
+ TestOnlyAllowRunAsCurrentUserWithoutChroot bool
+
+ // TestOnlyTestNameEnv should only be used in tests. It looks up for the
+ // test name in the container environment variables and adds it to the debug
+ // log file name. This is done to help identify the log with the test when
+ // multiple tests are run in parallel, since there is no way to pass
+ // parameters to the runtime from docker.
+ TestOnlyTestNameEnv string
}
// ToFlags returns a slice of flags that correspond to the given Config.
@@ -261,9 +268,12 @@ func (c *Config) ToFlags() []string {
"--alsologtostderr=" + strconv.FormatBool(c.AlsoLogToStderr),
"--ref-leak-mode=" + refsLeakModeToString(c.ReferenceLeakMode),
}
+ // Only include these if set since it is never to be used by users.
if c.TestOnlyAllowRunAsCurrentUserWithoutChroot {
- // Only include if set since it is never to be used by users.
- f = append(f, "-TESTONLY-unsafe-nonroot=true")
+ f = append(f, "--TESTONLY-unsafe-nonroot=true")
+ }
+ if len(c.TestOnlyTestNameEnv) != 0 {
+ f = append(f, "--TESTONLY-test-name-env="+c.TestOnlyTestNameEnv)
}
return f
}
diff --git a/runsc/boot/filter/config.go b/runsc/boot/filter/config.go
index 7ca776b3a..a2ecc6bcb 100644
--- a/runsc/boot/filter/config.go
+++ b/runsc/boot/filter/config.go
@@ -88,14 +88,24 @@ var allowedSyscalls = seccomp.SyscallRules{
seccomp.AllowValue(linux.FUTEX_WAIT | linux.FUTEX_PRIVATE_FLAG),
seccomp.AllowAny{},
seccomp.AllowAny{},
- seccomp.AllowValue(0),
},
{
seccomp.AllowAny{},
seccomp.AllowValue(linux.FUTEX_WAKE | linux.FUTEX_PRIVATE_FLAG),
seccomp.AllowAny{},
+ },
+ // Non-private variants are included for flipcall support. They are otherwise
+ // unncessary, as the sentry will use only private futexes internally.
+ {
+ seccomp.AllowAny{},
+ seccomp.AllowValue(linux.FUTEX_WAIT),
+ seccomp.AllowAny{},
+ seccomp.AllowAny{},
+ },
+ {
+ seccomp.AllowAny{},
+ seccomp.AllowValue(linux.FUTEX_WAKE),
seccomp.AllowAny{},
- seccomp.AllowValue(0),
},
},
syscall.SYS_GETPID: {},
diff --git a/runsc/boot/loader.go b/runsc/boot/loader.go
index 823a34619..d824d7dc5 100644
--- a/runsc/boot/loader.go
+++ b/runsc/boot/loader.go
@@ -20,7 +20,6 @@ import (
mrand "math/rand"
"os"
"runtime"
- "strings"
"sync"
"sync/atomic"
"syscall"
@@ -535,23 +534,12 @@ func (l *Loader) run() error {
return err
}
- // Read /etc/passwd for the user's HOME directory and set the HOME
- // environment variable as required by POSIX if it is not overridden by
- // the user.
- hasHomeEnvv := false
- for _, envv := range l.rootProcArgs.Envv {
- if strings.HasPrefix(envv, "HOME=") {
- hasHomeEnvv = true
- }
- }
- if !hasHomeEnvv {
- homeDir, err := getExecUserHome(ctx, l.rootProcArgs.MountNamespace, uint32(l.rootProcArgs.Credentials.RealKUID))
- if err != nil {
- return fmt.Errorf("error reading exec user: %v", err)
- }
-
- l.rootProcArgs.Envv = append(l.rootProcArgs.Envv, "HOME="+homeDir)
+ // Add the HOME enviroment variable if it is not already set.
+ envv, err := maybeAddExecUserHome(ctx, l.rootProcArgs.MountNamespace, l.rootProcArgs.Credentials.RealKUID, l.rootProcArgs.Envv)
+ if err != nil {
+ return err
}
+ l.rootProcArgs.Envv = envv
// Create the root container init task. It will begin running
// when the kernel is started.
@@ -815,6 +803,16 @@ func (l *Loader) executeAsync(args *control.ExecArgs) (kernel.ThreadID, error) {
})
defer args.MountNamespace.DecRef()
+ // Add the HOME enviroment varible if it is not already set.
+ root := args.MountNamespace.Root()
+ defer root.DecRef()
+ ctx := fs.WithRoot(l.k.SupervisorContext(), root)
+ envv, err := maybeAddExecUserHome(ctx, args.MountNamespace, args.KUID, args.Envv)
+ if err != nil {
+ return 0, err
+ }
+ args.Envv = envv
+
// Start the process.
proc := control.Proc{Kernel: l.k}
args.PIDNamespace = tg.PIDNamespace()
diff --git a/runsc/boot/network.go b/runsc/boot/network.go
index ea0d9f790..32cba5ac1 100644
--- a/runsc/boot/network.go
+++ b/runsc/boot/network.go
@@ -121,10 +121,10 @@ func (n *Network) CreateLinksAndRoutes(args *CreateLinksAndRoutesArgs, _ *struct
nicID++
nicids[link.Name] = nicID
- linkEP := loopback.New()
+ ep := loopback.New()
log.Infof("Enabling loopback interface %q with id %d on addresses %+v", link.Name, nicID, link.Addresses)
- if err := n.createNICWithAddrs(nicID, link.Name, linkEP, link.Addresses, true /* loopback */); err != nil {
+ if err := n.createNICWithAddrs(nicID, link.Name, ep, link.Addresses, true /* loopback */); err != nil {
return err
}
@@ -156,7 +156,7 @@ func (n *Network) CreateLinksAndRoutes(args *CreateLinksAndRoutesArgs, _ *struct
}
mac := tcpip.LinkAddress(link.LinkAddress)
- linkEP, err := fdbased.New(&fdbased.Options{
+ ep, err := fdbased.New(&fdbased.Options{
FDs: FDs,
MTU: uint32(link.MTU),
EthernetHeader: true,
@@ -170,7 +170,7 @@ func (n *Network) CreateLinksAndRoutes(args *CreateLinksAndRoutesArgs, _ *struct
}
log.Infof("Enabling interface %q with id %d on addresses %+v (%v) w/ %d channels", link.Name, nicID, link.Addresses, mac, link.NumChannels)
- if err := n.createNICWithAddrs(nicID, link.Name, linkEP, link.Addresses, false /* loopback */); err != nil {
+ if err := n.createNICWithAddrs(nicID, link.Name, ep, link.Addresses, false /* loopback */); err != nil {
return err
}
@@ -203,14 +203,14 @@ func (n *Network) CreateLinksAndRoutes(args *CreateLinksAndRoutesArgs, _ *struct
// createNICWithAddrs creates a NIC in the network stack and adds the given
// addresses.
-func (n *Network) createNICWithAddrs(id tcpip.NICID, name string, linkEP tcpip.LinkEndpointID, addrs []net.IP, loopback bool) error {
+func (n *Network) createNICWithAddrs(id tcpip.NICID, name string, ep stack.LinkEndpoint, addrs []net.IP, loopback bool) error {
if loopback {
- if err := n.Stack.CreateNamedLoopbackNIC(id, name, sniffer.New(linkEP)); err != nil {
- return fmt.Errorf("CreateNamedLoopbackNIC(%v, %v, %v) failed: %v", id, name, linkEP, err)
+ if err := n.Stack.CreateNamedLoopbackNIC(id, name, sniffer.New(ep)); err != nil {
+ return fmt.Errorf("CreateNamedLoopbackNIC(%v, %v) failed: %v", id, name, err)
}
} else {
- if err := n.Stack.CreateNamedNIC(id, name, sniffer.New(linkEP)); err != nil {
- return fmt.Errorf("CreateNamedNIC(%v, %v, %v) failed: %v", id, name, linkEP, err)
+ if err := n.Stack.CreateNamedNIC(id, name, sniffer.New(ep)); err != nil {
+ return fmt.Errorf("CreateNamedNIC(%v, %v) failed: %v", id, name, err)
}
}
diff --git a/runsc/boot/user.go b/runsc/boot/user.go
index d1d423a5c..56cc12ee0 100644
--- a/runsc/boot/user.go
+++ b/runsc/boot/user.go
@@ -16,6 +16,7 @@ package boot
import (
"bufio"
+ "fmt"
"io"
"strconv"
"strings"
@@ -23,6 +24,7 @@ import (
"gvisor.dev/gvisor/pkg/abi/linux"
"gvisor.dev/gvisor/pkg/sentry/context"
"gvisor.dev/gvisor/pkg/sentry/fs"
+ "gvisor.dev/gvisor/pkg/sentry/kernel/auth"
"gvisor.dev/gvisor/pkg/sentry/usermem"
)
@@ -42,7 +44,7 @@ func (r *fileReader) Read(buf []byte) (int, error) {
// getExecUserHome returns the home directory of the executing user read from
// /etc/passwd as read from the container filesystem.
-func getExecUserHome(ctx context.Context, rootMns *fs.MountNamespace, uid uint32) (string, error) {
+func getExecUserHome(ctx context.Context, rootMns *fs.MountNamespace, uid auth.KUID) (string, error) {
// The default user home directory to return if no user matching the user
// if found in the /etc/passwd found in the image.
const defaultHome = "/"
@@ -82,7 +84,7 @@ func getExecUserHome(ctx context.Context, rootMns *fs.MountNamespace, uid uint32
File: f,
}
- homeDir, err := findHomeInPasswd(uid, r, defaultHome)
+ homeDir, err := findHomeInPasswd(uint32(uid), r, defaultHome)
if err != nil {
return "", err
}
@@ -90,6 +92,28 @@ func getExecUserHome(ctx context.Context, rootMns *fs.MountNamespace, uid uint32
return homeDir, nil
}
+// maybeAddExecUserHome returns a new slice with the HOME enviroment variable
+// set if the slice does not already contain it, otherwise it returns the
+// original slice unmodified.
+func maybeAddExecUserHome(ctx context.Context, mns *fs.MountNamespace, uid auth.KUID, envv []string) ([]string, error) {
+ // Check if the envv already contains HOME.
+ for _, env := range envv {
+ if strings.HasPrefix(env, "HOME=") {
+ // We have it. Return the original slice unmodified.
+ return envv, nil
+ }
+ }
+
+ // Read /etc/passwd for the user's HOME directory and set the HOME
+ // environment variable as required by POSIX if it is not overridden by
+ // the user.
+ homeDir, err := getExecUserHome(ctx, mns, uid)
+ if err != nil {
+ return nil, fmt.Errorf("error reading exec user: %v", err)
+ }
+ return append(envv, "HOME="+homeDir), nil
+}
+
// findHomeInPasswd parses a passwd file and returns the given user's home
// directory. This function does it's best to replicate the runc's behavior.
func findHomeInPasswd(uid uint32, passwd io.Reader, defaultHome string) (string, error) {
diff --git a/runsc/boot/user_test.go b/runsc/boot/user_test.go
index 906baf3e5..9aee2ad07 100644
--- a/runsc/boot/user_test.go
+++ b/runsc/boot/user_test.go
@@ -25,6 +25,7 @@ import (
specs "github.com/opencontainers/runtime-spec/specs-go"
"gvisor.dev/gvisor/pkg/sentry/context/contexttest"
"gvisor.dev/gvisor/pkg/sentry/fs"
+ "gvisor.dev/gvisor/pkg/sentry/kernel/auth"
)
func setupTempDir() (string, error) {
@@ -68,7 +69,7 @@ func setupPasswd(contents string, perms os.FileMode) func() (string, error) {
// TestGetExecUserHome tests the getExecUserHome function.
func TestGetExecUserHome(t *testing.T) {
tests := map[string]struct {
- uid uint32
+ uid auth.KUID
createRoot func() (string, error)
expected string
}{
diff --git a/runsc/cmd/exec.go b/runsc/cmd/exec.go
index e817eff77..bf1225e1c 100644
--- a/runsc/cmd/exec.go
+++ b/runsc/cmd/exec.go
@@ -127,6 +127,7 @@ func (ex *Exec) Execute(_ context.Context, f *flag.FlagSet, args ...interface{})
Fatalf("getting environment variables: %v", err)
}
}
+
if e.Capabilities == nil {
// enableRaw is set to true to prevent the filtering out of
// CAP_NET_RAW. This is the opposite of Create() because exec
diff --git a/runsc/container/container.go b/runsc/container/container.go
index bbb364214..a721c1c31 100644
--- a/runsc/container/container.go
+++ b/runsc/container/container.go
@@ -513,9 +513,16 @@ func (c *Container) Start(conf *boot.Config) error {
return err
}
- // Adjust the oom_score_adj for sandbox and gofers. This must be done after
+ // Adjust the oom_score_adj for sandbox. This must be done after
// save().
- return c.adjustOOMScoreAdj(conf)
+ err = adjustSandboxOOMScoreAdj(c.Sandbox, c.RootContainerDir, false)
+ if err != nil {
+ return err
+ }
+
+ // Set container's oom_score_adj to the gofer since it is dedicated to
+ // the container, in case the gofer uses up too much memory.
+ return c.adjustGoferOOMScoreAdj()
}
// Restore takes a container and replaces its kernel and file system
@@ -782,6 +789,9 @@ func (c *Container) Destroy() error {
}
defer unlock()
+ // Stored for later use as stop() sets c.Sandbox to nil.
+ sb := c.Sandbox
+
if err := c.stop(); err != nil {
err = fmt.Errorf("stopping container: %v", err)
log.Warningf("%v", err)
@@ -796,6 +806,16 @@ func (c *Container) Destroy() error {
c.changeStatus(Stopped)
+ // Adjust oom_score_adj for the sandbox. This must be done after the
+ // container is stopped and the directory at c.Root is removed.
+ // We must test if the sandbox is nil because Destroy should be
+ // idempotent.
+ if sb != nil {
+ if err := adjustSandboxOOMScoreAdj(sb, c.RootContainerDir, true); err != nil {
+ errs = append(errs, err.Error())
+ }
+ }
+
// "If any poststop hook fails, the runtime MUST log a warning, but the
// remaining hooks and lifecycle continue as if the hook had succeeded" -OCI spec.
// Based on the OCI, "The post-stop hooks MUST be called after the container is
@@ -926,7 +946,14 @@ func (c *Container) createGoferProcess(spec *specs.Spec, conf *boot.Config, bund
}
if conf.DebugLog != "" {
- debugLogFile, err := specutils.DebugLogFile(conf.DebugLog, "gofer")
+ test := ""
+ if len(conf.TestOnlyTestNameEnv) != 0 {
+ // Fetch test name if one is provided and the test only flag was set.
+ if t, ok := specutils.EnvVar(spec.Process.Env, conf.TestOnlyTestNameEnv); ok {
+ test = t
+ }
+ }
+ debugLogFile, err := specutils.DebugLogFile(conf.DebugLog, "gofer", test)
if err != nil {
return nil, nil, fmt.Errorf("opening debug log file in %q: %v", conf.DebugLog, err)
}
@@ -1139,35 +1166,82 @@ func runInCgroup(cg *cgroup.Cgroup, fn func() error) error {
return fn()
}
-// adjustOOMScoreAdj sets the oom_score_adj for the sandbox and all gofers.
+// adjustGoferOOMScoreAdj sets the oom_store_adj for the container's gofer.
+func (c *Container) adjustGoferOOMScoreAdj() error {
+ if c.GoferPid != 0 && c.Spec.Process.OOMScoreAdj != nil {
+ if err := setOOMScoreAdj(c.GoferPid, *c.Spec.Process.OOMScoreAdj); err != nil {
+ return fmt.Errorf("setting gofer oom_score_adj for container %q: %v", c.ID, err)
+ }
+ }
+
+ return nil
+}
+
+// adjustSandboxOOMScoreAdj sets the oom_score_adj for the sandbox.
// oom_score_adj is set to the lowest oom_score_adj among the containers
// running in the sandbox.
//
// TODO(gvisor.dev/issue/512): This call could race with other containers being
// created at the same time and end up setting the wrong oom_score_adj to the
// sandbox.
-func (c *Container) adjustOOMScoreAdj(conf *boot.Config) error {
- // If this container's OOMScoreAdj is nil then we can exit early as no
- // change should be made to oom_score_adj for the sandbox.
- if c.Spec.Process.OOMScoreAdj == nil {
- return nil
- }
-
- containers, err := loadSandbox(conf.RootDir, c.Sandbox.ID)
+func adjustSandboxOOMScoreAdj(s *sandbox.Sandbox, rootDir string, destroy bool) error {
+ containers, err := loadSandbox(rootDir, s.ID)
if err != nil {
return fmt.Errorf("loading sandbox containers: %v", err)
}
+ // Do nothing if the sandbox has been terminated.
+ if len(containers) == 0 {
+ return nil
+ }
+
// Get the lowest score for all containers.
var lowScore int
scoreFound := false
- for _, container := range containers {
- if container.Spec.Process.OOMScoreAdj != nil && (!scoreFound || *container.Spec.Process.OOMScoreAdj < lowScore) {
+ if len(containers) == 1 && len(containers[0].Spec.Annotations[specutils.ContainerdContainerTypeAnnotation]) == 0 {
+ // This is a single-container sandbox. Set the oom_score_adj to
+ // the value specified in the OCI bundle.
+ if containers[0].Spec.Process.OOMScoreAdj != nil {
scoreFound = true
- lowScore = *container.Spec.Process.OOMScoreAdj
+ lowScore = *containers[0].Spec.Process.OOMScoreAdj
+ }
+ } else {
+ for _, container := range containers {
+ // Special multi-container support for CRI. Ignore the root
+ // container when calculating oom_score_adj for the sandbox because
+ // it is the infrastructure (pause) container and always has a very
+ // low oom_score_adj.
+ //
+ // We will use OOMScoreAdj in the single-container case where the
+ // containerd container-type annotation is not present.
+ if container.Spec.Annotations[specutils.ContainerdContainerTypeAnnotation] == specutils.ContainerdContainerTypeSandbox {
+ continue
+ }
+
+ if container.Spec.Process.OOMScoreAdj != nil && (!scoreFound || *container.Spec.Process.OOMScoreAdj < lowScore) {
+ scoreFound = true
+ lowScore = *container.Spec.Process.OOMScoreAdj
+ }
}
}
+ // If the container is destroyed and remaining containers have no
+ // oomScoreAdj specified then we must revert to the oom_score_adj of the
+ // parent process.
+ if !scoreFound && destroy {
+ ppid, err := specutils.GetParentPid(s.Pid)
+ if err != nil {
+ return fmt.Errorf("getting parent pid of sandbox pid %d: %v", s.Pid, err)
+ }
+ pScore, err := specutils.GetOOMScoreAdj(ppid)
+ if err != nil {
+ return fmt.Errorf("getting oom_score_adj of parent %d: %v", ppid, err)
+ }
+
+ scoreFound = true
+ lowScore = pScore
+ }
+
// Only set oom_score_adj if one of the containers has oom_score_adj set
// in the OCI bundle. If not, we need to inherit the parent process's
// oom_score_adj.
@@ -1177,15 +1251,10 @@ func (c *Container) adjustOOMScoreAdj(conf *boot.Config) error {
}
// Set the lowest of all containers oom_score_adj to the sandbox.
- if err := setOOMScoreAdj(c.Sandbox.Pid, lowScore); err != nil {
- return fmt.Errorf("setting oom_score_adj for sandbox %q: %v", c.Sandbox.ID, err)
+ if err := setOOMScoreAdj(s.Pid, lowScore); err != nil {
+ return fmt.Errorf("setting oom_score_adj for sandbox %q: %v", s.ID, err)
}
- // Set container's oom_score_adj to the gofer since it is dedicated to the
- // container, in case the gofer uses up too much memory.
- if err := setOOMScoreAdj(c.GoferPid, *c.Spec.Process.OOMScoreAdj); err != nil {
- return fmt.Errorf("setting gofer oom_score_adj for container %q: %v", c.ID, err)
- }
return nil
}
diff --git a/runsc/dockerutil/dockerutil.go b/runsc/dockerutil/dockerutil.go
index 41f5fe1e8..e37ec0ffd 100644
--- a/runsc/dockerutil/dockerutil.go
+++ b/runsc/dockerutil/dockerutil.go
@@ -240,7 +240,7 @@ func (d *Docker) Stop() error {
// Run calls 'docker run' with the arguments provided. The container starts
// running in the background and the call returns immediately.
func (d *Docker) Run(args ...string) error {
- a := []string{"run", "--runtime", d.Runtime, "--name", d.Name, "-d"}
+ a := d.runArgs("-d")
a = append(a, args...)
_, err := do(a...)
if err == nil {
@@ -251,7 +251,7 @@ func (d *Docker) Run(args ...string) error {
// RunWithPty is like Run but with an attached pty.
func (d *Docker) RunWithPty(args ...string) (*exec.Cmd, *os.File, error) {
- a := []string{"run", "--runtime", d.Runtime, "--name", d.Name, "-it"}
+ a := d.runArgs("-it")
a = append(a, args...)
return doWithPty(a...)
}
@@ -259,8 +259,7 @@ func (d *Docker) RunWithPty(args ...string) (*exec.Cmd, *os.File, error) {
// RunFg calls 'docker run' with the arguments provided in the foreground. It
// blocks until the container exits and returns the output.
func (d *Docker) RunFg(args ...string) (string, error) {
- a := []string{"run", "--runtime", d.Runtime, "--name", d.Name}
- a = append(a, args...)
+ a := d.runArgs(args...)
out, err := do(a...)
if err == nil {
d.logDockerID()
@@ -268,6 +267,14 @@ func (d *Docker) RunFg(args ...string) (string, error) {
return string(out), err
}
+func (d *Docker) runArgs(args ...string) []string {
+ // Environment variable RUNSC_TEST_NAME is picked up by the runtime and added
+ // to the log name, so one can easily identify the corresponding logs for
+ // this test.
+ rv := []string{"run", "--runtime", d.Runtime, "--name", d.Name, "-e", "RUNSC_TEST_NAME=" + d.Name}
+ return append(rv, args...)
+}
+
// Logs calls 'docker logs'.
func (d *Docker) Logs() (string, error) {
return do("logs", d.Name)
@@ -280,6 +287,14 @@ func (d *Docker) Exec(args ...string) (string, error) {
return do(a...)
}
+// ExecAsUser calls 'docker exec' as the given user with the arguments
+// provided.
+func (d *Docker) ExecAsUser(user string, args ...string) (string, error) {
+ a := []string{"exec", "--user", user, d.Name}
+ a = append(a, args...)
+ return do(a...)
+}
+
// ExecWithTerminal calls 'docker exec -it' with the arguments provided and
// attaches a pty to stdio.
func (d *Docker) ExecWithTerminal(args ...string) (*exec.Cmd, *os.File, error) {
diff --git a/runsc/fsgofer/filter/BUILD b/runsc/fsgofer/filter/BUILD
index e2318a978..02168ad1b 100644
--- a/runsc/fsgofer/filter/BUILD
+++ b/runsc/fsgofer/filter/BUILD
@@ -17,6 +17,7 @@ go_library(
],
deps = [
"//pkg/abi/linux",
+ "//pkg/flipcall",
"//pkg/log",
"//pkg/seccomp",
"@org_golang_x_sys//unix:go_default_library",
diff --git a/runsc/fsgofer/filter/config.go b/runsc/fsgofer/filter/config.go
index 8ddfa77d6..2f3f2039a 100644
--- a/runsc/fsgofer/filter/config.go
+++ b/runsc/fsgofer/filter/config.go
@@ -83,6 +83,11 @@ var allowedSyscalls = seccomp.SyscallRules{
seccomp.AllowAny{},
seccomp.AllowValue(syscall.F_GETFD),
},
+ // Used by flipcall.PacketWindowAllocator.Init().
+ {
+ seccomp.AllowAny{},
+ seccomp.AllowValue(unix.F_ADD_SEALS),
+ },
},
syscall.SYS_FSTAT: {},
syscall.SYS_FSTATFS: {},
@@ -103,6 +108,19 @@ var allowedSyscalls = seccomp.SyscallRules{
seccomp.AllowAny{},
seccomp.AllowValue(0),
},
+ // Non-private futex used for flipcall.
+ seccomp.Rule{
+ seccomp.AllowAny{},
+ seccomp.AllowValue(linux.FUTEX_WAIT),
+ seccomp.AllowAny{},
+ seccomp.AllowAny{},
+ },
+ seccomp.Rule{
+ seccomp.AllowAny{},
+ seccomp.AllowValue(linux.FUTEX_WAKE),
+ seccomp.AllowAny{},
+ seccomp.AllowAny{},
+ },
},
syscall.SYS_GETDENTS64: {},
syscall.SYS_GETPID: {},
@@ -112,6 +130,7 @@ var allowedSyscalls = seccomp.SyscallRules{
syscall.SYS_LINKAT: {},
syscall.SYS_LSEEK: {},
syscall.SYS_MADVISE: {},
+ unix.SYS_MEMFD_CREATE: {}, /// Used by flipcall.PacketWindowAllocator.Init().
syscall.SYS_MKDIRAT: {},
syscall.SYS_MMAP: []seccomp.Rule{
{
@@ -160,6 +179,13 @@ var allowedSyscalls = seccomp.SyscallRules{
syscall.SYS_RT_SIGPROCMASK: {},
syscall.SYS_SCHED_YIELD: {},
syscall.SYS_SENDMSG: []seccomp.Rule{
+ // Used by fdchannel.Endpoint.SendFD().
+ {
+ seccomp.AllowAny{},
+ seccomp.AllowAny{},
+ seccomp.AllowValue(0),
+ },
+ // Used by unet.SocketWriter.WriteVec().
{
seccomp.AllowAny{},
seccomp.AllowAny{},
@@ -170,7 +196,15 @@ var allowedSyscalls = seccomp.SyscallRules{
{seccomp.AllowAny{}, seccomp.AllowValue(syscall.SHUT_RDWR)},
},
syscall.SYS_SIGALTSTACK: {},
- syscall.SYS_SYMLINKAT: {},
+ // Used by fdchannel.NewConnectedSockets().
+ syscall.SYS_SOCKETPAIR: {
+ {
+ seccomp.AllowValue(syscall.AF_UNIX),
+ seccomp.AllowValue(syscall.SOCK_SEQPACKET | syscall.SOCK_CLOEXEC),
+ seccomp.AllowValue(0),
+ },
+ },
+ syscall.SYS_SYMLINKAT: {},
syscall.SYS_TGKILL: []seccomp.Rule{
{
seccomp.AllowValue(uint64(os.Getpid())),
diff --git a/runsc/main.go b/runsc/main.go
index b6546717c..304d771c2 100644
--- a/runsc/main.go
+++ b/runsc/main.go
@@ -79,6 +79,7 @@ var (
// Test flags, not to be used outside tests, ever.
testOnlyAllowRunAsCurrentUserWithoutChroot = flag.Bool("TESTONLY-unsafe-nonroot", false, "TEST ONLY; do not ever use! This skips many security measures that isolate the host from the sandbox.")
+ testOnlyTestNameEnv = flag.String("TESTONLY-test-name-env", "", "TEST ONLY; do not ever use! Used for automated tests to improve logging.")
)
func main() {
@@ -211,6 +212,7 @@ func main() {
ReferenceLeakMode: refsLeakMode,
TestOnlyAllowRunAsCurrentUserWithoutChroot: *testOnlyAllowRunAsCurrentUserWithoutChroot,
+ TestOnlyTestNameEnv: *testOnlyTestNameEnv,
}
if len(*straceSyscalls) != 0 {
conf.StraceSyscalls = strings.Split(*straceSyscalls, ",")
@@ -244,7 +246,7 @@ func main() {
e = newEmitter(*debugLogFormat, f)
} else if *debugLog != "" {
- f, err := specutils.DebugLogFile(*debugLog, subcommand)
+ f, err := specutils.DebugLogFile(*debugLog, subcommand, "" /* name */)
if err != nil {
cmd.Fatalf("error opening debug log file in %q: %v", *debugLog, err)
}
diff --git a/runsc/sandbox/sandbox.go b/runsc/sandbox/sandbox.go
index df3c0c5ef..4c6c83fbd 100644
--- a/runsc/sandbox/sandbox.go
+++ b/runsc/sandbox/sandbox.go
@@ -351,7 +351,15 @@ func (s *Sandbox) createSandboxProcess(conf *boot.Config, args *Args, startSyncF
nextFD++
}
if conf.DebugLog != "" {
- debugLogFile, err := specutils.DebugLogFile(conf.DebugLog, "boot")
+ test := ""
+ if len(conf.TestOnlyTestNameEnv) == 0 {
+ // Fetch test name if one is provided and the test only flag was set.
+ if t, ok := specutils.EnvVar(args.Spec.Process.Env, conf.TestOnlyTestNameEnv); ok {
+ test = t
+ }
+ }
+
+ debugLogFile, err := specutils.DebugLogFile(conf.DebugLog, "boot", test)
if err != nil {
return fmt.Errorf("opening debug log file in %q: %v", conf.DebugLog, err)
}
diff --git a/runsc/specutils/specutils.go b/runsc/specutils/specutils.go
index 2eec92349..cb9e58dfb 100644
--- a/runsc/specutils/specutils.go
+++ b/runsc/specutils/specutils.go
@@ -23,6 +23,7 @@ import (
"os"
"path"
"path/filepath"
+ "strconv"
"strings"
"syscall"
"time"
@@ -398,13 +399,15 @@ func WaitForReady(pid int, timeout time.Duration, ready func() (bool, error)) er
// - %TIMESTAMP%: is replaced with a timestamp using the following format:
// <yyyymmdd-hhmmss.uuuuuu>
// - %COMMAND%: is replaced with 'command'
-func DebugLogFile(logPattern, command string) (*os.File, error) {
+// - %TEST%: is replaced with 'test' (omitted by default)
+func DebugLogFile(logPattern, command, test string) (*os.File, error) {
if strings.HasSuffix(logPattern, "/") {
// Default format: <debug-log>/runsc.log.<yyyymmdd-hhmmss.uuuuuu>.<command>
logPattern += "runsc.log.%TIMESTAMP%.%COMMAND%"
}
logPattern = strings.Replace(logPattern, "%TIMESTAMP%", time.Now().Format("20060102-150405.000000"), -1)
logPattern = strings.Replace(logPattern, "%COMMAND%", command, -1)
+ logPattern = strings.Replace(logPattern, "%TEST%", test, -1)
dir := filepath.Dir(logPattern)
if err := os.MkdirAll(dir, 0775); err != nil {
@@ -503,3 +506,53 @@ func RetryEintr(f func() (uintptr, uintptr, error)) (uintptr, uintptr, error) {
}
}
}
+
+// GetOOMScoreAdj reads the given process' oom_score_adj
+func GetOOMScoreAdj(pid int) (int, error) {
+ data, err := ioutil.ReadFile(fmt.Sprintf("/proc/%d/oom_score_adj", pid))
+ if err != nil {
+ return 0, err
+ }
+ return strconv.Atoi(strings.TrimSpace(string(data)))
+}
+
+// GetParentPid gets the parent process ID of the specified PID.
+func GetParentPid(pid int) (int, error) {
+ data, err := ioutil.ReadFile(fmt.Sprintf("/proc/%d/stat", pid))
+ if err != nil {
+ return 0, err
+ }
+
+ var cpid string
+ var name string
+ var state string
+ var ppid int
+ // Parse after the binary name.
+ _, err = fmt.Sscanf(string(data),
+ "%v %v %v %d",
+ // cpid is ignored.
+ &cpid,
+ // name is ignored.
+ &name,
+ // state is ignored.
+ &state,
+ &ppid)
+
+ if err != nil {
+ return 0, err
+ }
+
+ return ppid, nil
+}
+
+// EnvVar looks for a varible value in the env slice assuming the following
+// format: "NAME=VALUE".
+func EnvVar(env []string, name string) (string, bool) {
+ prefix := name + "="
+ for _, e := range env {
+ if strings.HasPrefix(e, prefix) {
+ return strings.TrimPrefix(e, prefix), true
+ }
+ }
+ return "", false
+}
diff --git a/runsc/testutil/testutil.go b/runsc/testutil/testutil.go
index 57ab73d97..edf8b126c 100644
--- a/runsc/testutil/testutil.go
+++ b/runsc/testutil/testutil.go
@@ -26,12 +26,14 @@ import (
"io"
"io/ioutil"
"log"
+ "math"
"math/rand"
"net/http"
"os"
"os/exec"
"os/signal"
"path/filepath"
+ "strconv"
"strings"
"sync"
"sync/atomic"
@@ -438,3 +440,44 @@ func IsStatic(filename string) (bool, error) {
}
return true, nil
}
+
+// TestBoundsForShard calculates the beginning and end indices for the test
+// based on the TEST_SHARD_INDEX and TEST_TOTAL_SHARDS environment vars. The
+// returned ints are the beginning (inclusive) and end (exclusive) of the
+// subslice corresponding to the shard. If either of the env vars are not
+// present, then the function will return bounds that include all tests. If
+// there are more shards than there are tests, then the returned list may be
+// empty.
+func TestBoundsForShard(numTests int) (int, int, error) {
+ var (
+ begin = 0
+ end = numTests
+ )
+ indexStr, totalStr := os.Getenv("TEST_SHARD_INDEX"), os.Getenv("TEST_TOTAL_SHARDS")
+ if indexStr == "" || totalStr == "" {
+ return begin, end, nil
+ }
+
+ // Parse index and total to ints.
+ shardIndex, err := strconv.Atoi(indexStr)
+ if err != nil {
+ return 0, 0, fmt.Errorf("invalid TEST_SHARD_INDEX %q: %v", indexStr, err)
+ }
+ shardTotal, err := strconv.Atoi(totalStr)
+ if err != nil {
+ return 0, 0, fmt.Errorf("invalid TEST_TOTAL_SHARDS %q: %v", totalStr, err)
+ }
+
+ // Calculate!
+ shardSize := int(math.Ceil(float64(numTests) / float64(shardTotal)))
+ begin = shardIndex * shardSize
+ end = ((shardIndex + 1) * shardSize)
+ if begin > numTests {
+ // Nothing to run.
+ return 0, 0, nil
+ }
+ if end > numTests {
+ end = numTests
+ }
+ return begin, end, nil
+}
diff --git a/runsc/version.go b/runsc/version.go
index ce0573a9b..ab9194b9d 100644
--- a/runsc/version.go
+++ b/runsc/version.go
@@ -15,4 +15,4 @@
package main
// version is set during linking.
-var version = ""
+var version = "VERSION_MISSING"
diff --git a/runsc/version_test.sh b/runsc/version_test.sh
new file mode 100755
index 000000000..cc0ca3f05
--- /dev/null
+++ b/runsc/version_test.sh
@@ -0,0 +1,36 @@
+#!/bin/bash
+
+# Copyright 2018 The gVisor Authors.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+set -euf -x -o pipefail
+
+readonly runsc="${TEST_SRCDIR}/__main__/runsc/linux_amd64_pure_stripped/runsc"
+readonly version=$($runsc --version)
+
+# Version should should not match VERSION, which is the default and which will
+# also appear if something is wrong with workspace_status.sh script.
+if [[ $version =~ "VERSION" ]]; then
+ echo "FAIL: Got bad version $version"
+ exit 1
+fi
+
+# Version should contain at least one number.
+if [[ ! $version =~ [0-9] ]]; then
+ echo "FAIL: Got bad version $version"
+ exit 1
+fi
+
+echo "PASS: Got OK version $version"
+exit 0
diff --git a/scripts/build.sh b/scripts/build.sh
index 293d87093..b3a6e4e7a 100755
--- a/scripts/build.sh
+++ b/scripts/build.sh
@@ -16,6 +16,9 @@
source $(dirname $0)/common.sh
+# Install required packages for make_repository.sh et al.
+sudo apt-get update && sudo apt-get install -y dpkg-sig coreutils apt-utils
+
# Build runsc.
runsc=$(build -c opt //runsc)
@@ -24,16 +27,19 @@ pkg=$(build -c opt --host_force_python=py2 //runsc:runsc-debian)
# Build a repository, if the key is available.
if [[ -v KOKORO_REPO_KEY ]]; then
- repo=$(tools/make_repository.sh "${KOKORO_REPO_KEY}" gvisor-bot@google.com)
+ repo=$(tools/make_repository.sh "${KOKORO_KEYSTORE_DIR}/${KOKORO_REPO_KEY}" gvisor-bot@google.com main ${pkg})
fi
# Install installs artifacts.
install() {
- mkdir -p $1
- cp "${runsc}" "$1"/runsc
- sha512sum "$1"/runsc | awk '{print $1 " runsc"}' > "$1"/runsc.sha512
+ local -r binaries_dir="$1"
+ local -r repo_dir="$2"
+ mkdir -p "${binaries_dir}"
+ cp -f "${runsc}" "${binaries_dir}"/runsc
+ sha512sum "${binaries_dir}"/runsc | awk '{print $1 " runsc"}' > "${binaries_dir}"/runsc.sha512
if [[ -v repo ]]; then
- cp -a "${repo}" "${latest_dir}"/repo
+ rm -rf "${repo_dir}" && mkdir -p "$(dirname "${repo_dir}")"
+ cp -a "${repo}" "${repo_dir}"
fi
}
@@ -41,22 +47,33 @@ install() {
# current date. If the current commit happens to correpond to a tag, then we
# will also move everything into a directory named after the given tag.
if [[ -v KOKORO_ARTIFACTS_DIR ]]; then
- if [[ "${KOKORO_BUILD_NIGHTLY}" == "true" ]]; then
+ if [[ "${KOKORO_BUILD_NIGHTLY:-false}" == "true" ]]; then
# The "latest" directory and current date.
- install "${KOKORO_ARTIFACTS_DIR}/nightly/latest"
- install "${KOKORO_ARTIFACTS_DIR}/nightly/$(date -Idate)"
+ stamp="$(date -Idate)"
+ install "${KOKORO_ARTIFACTS_DIR}/nightly/latest" \
+ "${KOKORO_ARTIFACTS_DIR}/dists/nightly/latest"
+ install "${KOKORO_ARTIFACTS_DIR}/nightly/${stamp}" \
+ "${KOKORO_ARTIFACTS_DIR}/dists/nightly/${stamp}"
else
# Is it a tagged release? Build that instead. In that case, we also try to
# update the base release directory, in case this is an update. Finally, we
# update the "release" directory, which has the last released version.
- tag="$(git describe --exact-match --tags HEAD)"
- if ! [[ -z "${tag}" ]]; then
- install "${KOKORO_ARTIFACTS_DIR}/${tag}"
- base=$(echo "${tag}" | cut -d'.' -f1)
- if [[ "${base}" != "${tag}" ]]; then
- install "${KOKORO_ARTIFACTS_DIR}/${base}"
- fi
- install "${KOKORO_ARTIFACTS_DIR}/release"
+ tags="$(git tag --points-at HEAD)"
+ if ! [[ -z "${tags}" ]]; then
+ # Note that a given commit can match any number of tags. We have to
+ # iterate through all possible tags and produce associated artifacts.
+ for tag in ${tags}; do
+ name=$(echo "${tag}" | cut -d'-' -f2)
+ base=$(echo "${name}" | cut -d'.' -f1)
+ install "${KOKORO_ARTIFACTS_DIR}/release/${name}" \
+ "${KOKORO_ARTIFACTS_DIR}/dists/${name}"
+ if [[ "${base}" != "${tag}" ]]; then
+ install "${KOKORO_ARTIFACTS_DIR}/release/${base}" \
+ "${KOKORO_ARTIFACTS_DIR}/dists/${base}"
+ fi
+ install "${KOKORO_ARTIFACTS_DIR}/release/latest" \
+ "${KOKORO_ARTIFACTS_DIR}/dists/latest"
+ done
fi
fi
fi
diff --git a/scripts/common.sh b/scripts/common.sh
index f2b9e24d8..6dabad141 100755
--- a/scripts/common.sh
+++ b/scripts/common.sh
@@ -14,10 +14,67 @@
# See the License for the specific language governing permissions and
# limitations under the License.
-set -xeo pipefail
+set -xeou pipefail
if [[ -f $(dirname $0)/common_google.sh ]]; then
source $(dirname $0)/common_google.sh
else
source $(dirname $0)/common_bazel.sh
fi
+
+# Ensure it attempts to collect logs in all cases.
+trap collect_logs EXIT
+
+function set_runtime() {
+ RUNTIME=${1:-runsc}
+ RUNSC_BIN=/tmp/"${RUNTIME}"/runsc
+ RUNSC_LOGS_DIR="$(dirname ${RUNSC_BIN})"/logs
+ RUNSC_LOGS="${RUNSC_LOGS_DIR}"/runsc.log.%TEST%.%TIMESTAMP%.%COMMAND%
+}
+
+function test_runsc() {
+ test --test_arg=--runtime=${RUNTIME} "$@"
+}
+
+function install_runsc_for_test() {
+ local -r test_name=$1
+ shift
+ if [[ -z "${test_name}" ]]; then
+ echo "Missing mandatory test name"
+ exit 1
+ fi
+
+ # Add test to the name, so it doesn't conflict with other runtimes.
+ set_runtime $(find_branch_name)_"${test_name}"
+
+ # ${RUNSC_TEST_NAME} is set by tests (see dockerutil) to pass the test name
+ # down to the runtime.
+ install_runsc "${RUNTIME}" \
+ --TESTONLY-test-name-env=RUNSC_TEST_NAME \
+ --debug \
+ --strace \
+ --log-packets \
+ "$@"
+}
+
+# Installs the runsc with given runtime name. set_runtime must have been called
+# to set runtime and logs location.
+function install_runsc() {
+ local -r runtime=$1
+ shift
+
+ # Prepare the runtime binary.
+ local -r output=$(build //runsc)
+ mkdir -p "$(dirname ${RUNSC_BIN})"
+ cp -f "${output}" "${RUNSC_BIN}"
+ chmod 0755 "${RUNSC_BIN}"
+
+ # Install the runtime.
+ sudo "${RUNSC_BIN}" install --experimental=true --runtime="${runtime}" -- --debug-log "${RUNSC_LOGS}" "$@"
+
+ # Clear old logs files that may exist.
+ sudo rm -f "${RUNSC_LOGS_DIR}"/*
+
+ # Restart docker to pick up the new runtime configuration.
+ sudo systemctl restart docker
+}
diff --git a/scripts/common_bazel.sh b/scripts/common_bazel.sh
index 42248cb25..dde0b51ed 100755
--- a/scripts/common_bazel.sh
+++ b/scripts/common_bazel.sh
@@ -48,20 +48,12 @@ fi
# Wrap bazel.
function build() {
- bazel build "${BAZEL_RBE_FLAGS[@]}" "${BAZEL_RBE_AUTH_FLAGS[@]}" "${BAZEL_FLAGS[@]}" "$@"
+ bazel build "${BAZEL_RBE_FLAGS[@]}" "${BAZEL_RBE_AUTH_FLAGS[@]}" "${BAZEL_FLAGS[@]}" "$@" 2>&1 |
+ tee /dev/fd/2 | grep -E '^ bazel-bin/' | awk '{ print $1; }'
}
function test() {
- (bazel test "${BAZEL_RBE_FLAGS[@]}" "${BAZEL_RBE_AUTH_FLAGS[@]}" "${BAZEL_FLAGS[@]}" "$@" && rc=0) || rc=$?
-
- # Zip out everything into a convenient form.
- if [[ -v KOKORO_ARTIFACTS_DIR ]]; then
- find -L "bazel-testlogs" -name "test.xml" -o -name "test.log" -o -name "outputs.zip" |
- tar --create --files-from - --transform 's/test\./sponge_log./' |
- tar --extract --directory ${KOKORO_ARTIFACTS_DIR}
- fi
-
- return $rc
+ bazel test "${BAZEL_RBE_FLAGS[@]}" "${BAZEL_RBE_AUTH_FLAGS[@]}" "${BAZEL_FLAGS[@]}" "$@"
}
function run() {
@@ -75,3 +67,26 @@ function run_as_root() {
shift
bazel run --run_under="sudo" "${binary}" -- "$@"
}
+
+function collect_logs() {
+ # Zip out everything into a convenient form.
+ if [[ -v KOKORO_ARTIFACTS_DIR ]] && [[ -e bazel-testlogs ]]; then
+ # Move test logs to Kokoro directory. tar is used to conveniently perform
+ # renames while moving files.
+ find -L "bazel-testlogs" -name "test.xml" -o -name "test.log" -o -name "outputs.zip" |
+ tar --create --files-from - --transform 's/test\./sponge_log./' |
+ tar --extract --directory ${KOKORO_ARTIFACTS_DIR}
+
+ # Collect sentry logs, if any.
+ if [[ -v RUNSC_LOGS_DIR ]] && [[ -d "${RUNSC_LOGS_DIR}" ]]; then
+ local -r logs=$(ls "${RUNSC_LOGS_DIR}")
+ if [[ -z "${logs}" ]]; then
+ tar --create --gzip --file="${KOKORO_ARTIFACTS_DIR}/${RUNTIME}.tar.gz" -C "${RUNSC_LOGS_DIR}" .
+ fi
+ fi
+ fi
+}
+
+function find_branch_name() {
+ git branch --show-current || git rev-parse HEAD || bazel info workspace | xargs basename
+}
diff --git a/scripts/dev.sh b/scripts/dev.sh
new file mode 100755
index 000000000..ee74dcb72
--- /dev/null
+++ b/scripts/dev.sh
@@ -0,0 +1,73 @@
+#!/bin/bash
+
+# Copyright 2019 The gVisor Authors.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+source $(dirname $0)/common.sh
+
+# common.sh sets '-x', but it's annoying to see so much output.
+set +x
+
+# Defaults
+declare -i REFRESH=0
+declare NAME=$(find_branch_name)
+
+while [[ $# -gt 0 ]]; do
+ case "$1" in
+ --refresh)
+ REFRESH=1
+ ;;
+ --help)
+ echo "Use this script to build and install runsc with Docker."
+ echo
+ echo "usage: $0 [--refresh] [runtime_name]"
+ exit 1
+ ;;
+ *)
+ NAME=$1
+ ;;
+ esac
+ shift
+done
+
+set_runtime "${NAME}"
+echo
+echo "Using runtime=${RUNTIME}"
+echo
+
+echo Building runsc...
+# Build first and fail on error. $() prevents "set -e" from reporting errors.
+build //runsc
+declare OUTPUT="$(build //runsc)"
+
+if [[ ${REFRESH} -eq 0 ]]; then
+ install_runsc "${RUNTIME}" --net-raw
+ install_runsc "${RUNTIME}-d" --net-raw --debug --strace --log-packets
+
+ echo
+ echo "Runtimes ${RUNTIME} and ${RUNTIME}-d (debug enabled) setup."
+ echo "Use --runtime="${RUNTIME}" with your Docker command."
+ echo " docker run --rm --runtime="${RUNTIME}" --rm hello-world"
+ echo
+ echo "If you rebuild, use $0 --refresh."
+
+else
+ mkdir -p "$(dirname ${RUNSC_BIN})"
+ cp -f ${OUTPUT} "${RUNSC_BIN}"
+
+ echo
+ echo "Runtime ${RUNTIME} refreshed."
+fi
+
+echo "Logs are in: ${RUNSC_LOGS_DIR}"
diff --git a/scripts/docker_tests.sh b/scripts/docker_tests.sh
index d6b18a35b..72ba05260 100755
--- a/scripts/docker_tests.sh
+++ b/scripts/docker_tests.sh
@@ -16,7 +16,5 @@
source $(dirname $0)/common.sh
-# Install the runtime and perform basic tests.
-run_as_root //runsc install --experimental=true -- --debug --strace --log-packets
-sudo systemctl restart docker
-test //test/image:image_test //test/e2e:integration_test
+install_runsc_for_test docker
+test_runsc //test/image:image_test //test/e2e:integration_test
diff --git a/scripts/go.sh b/scripts/go.sh
index e49d76c6d..0dbfb7747 100755
--- a/scripts/go.sh
+++ b/scripts/go.sh
@@ -29,6 +29,15 @@ git checkout go && git clean -f
go build ./...
# Push, if required.
-if [[ "${KOKORO_GO_PUSH}" == "true" ]]; then
+if [[ -v KOKORO_GO_PUSH ]] && [[ "${KOKORO_GO_PUSH}" == "true" ]]; then
+ if [[ -v KOKORO_GITHUB_ACCESS_TOKEN ]]; then
+ git config --global credential.helper cache
+ git credential approve <<EOF
+protocol=https
+host=github.com
+username=$(cat "${KOKORO_KEYSTORE_DIR}/${KOKORO_GITHUB_ACCESS_TOKEN}")
+password=x-oauth-basic
+EOF
+ fi
git push origin go:go
fi
diff --git a/scripts/hostnet_tests.sh b/scripts/hostnet_tests.sh
index 0631c5510..41298293d 100755
--- a/scripts/hostnet_tests.sh
+++ b/scripts/hostnet_tests.sh
@@ -17,6 +17,5 @@
source $(dirname $0)/common.sh
# Install the runtime and perform basic tests.
-run_as_root //runsc install --experimental=true -- --debug --strace --log-packets --network=host
-sudo systemctl restart docker
-test --test_arg=-checkpoint=false //test/image:image_test //test/e2e:integration_test
+install_runsc_for_test hostnet --network=host
+test_runsc --test_arg=-checkpoint=false //test/image:image_test //test/e2e:integration_test
diff --git a/scripts/kvm_tests.sh b/scripts/kvm_tests.sh
index 5cb7aa007..5662401df 100755
--- a/scripts/kvm_tests.sh
+++ b/scripts/kvm_tests.sh
@@ -20,11 +20,9 @@ source $(dirname $0)/common.sh
(lsmod | grep -E '^(kvm_intel|kvm_amd)') || sudo modprobe kvm
sudo chmod a+rw /dev/kvm
-# Run all KVM-tagged tests (locally).
-test --test_strategy=standalone --test_tag_filters=requires-kvm //...
-test --test_strategy=standalone //pkg/sentry/platform/kvm:kvm_test
+# Run all KVM platform tests (locally).
+run_as_root //pkg/sentry/platform/kvm:kvm_test
# Install the KVM runtime and run all integration tests.
-run_as_root //runsc install --experimental=true -- --debug --strace --log-packets --platform=kvm
-sudo systemctl restart docker
-test --test_strategy=standalone //test/image:image_test //test/e2e:integration_test
+install_runsc_for_test kvm --platform=kvm
+test_runsc //test/image:image_test //test/e2e:integration_test
diff --git a/scripts/overlay_tests.sh b/scripts/overlay_tests.sh
index 651a51f70..2a1f12c0b 100755
--- a/scripts/overlay_tests.sh
+++ b/scripts/overlay_tests.sh
@@ -17,6 +17,5 @@
source $(dirname $0)/common.sh
# Install the runtime and perform basic tests.
-run_as_root //runsc install --experimental=true -- --debug --strace --log-packets --overlay
-sudo systemctl restart docker
-test //test/image:image_test //test/e2e:integration_test
+install_runsc_for_test overlay --overlay
+test_runsc //test/image:image_test //test/e2e:integration_test
diff --git a/scripts/release.sh b/scripts/release.sh
index 422319500..b936bcc77 100755
--- a/scripts/release.sh
+++ b/scripts/release.sh
@@ -26,9 +26,13 @@ if ! [[ -v KOKORO_RELEASE_TAG ]]; then
exit 1
fi
+# Unless an explicit releaser is provided, use the bot e-mail.
+declare -r KOKORO_RELEASE_AUTHOR=${KOKORO_RELEASE_AUTHOR:-gvisor-bot}
+declare -r EMAIL=${EMAIL:-${KOKORO_RELEASE_AUTHOR}@google.com}
+
# Ensure we have an appropriate configuration for the tag.
git config --get user.name || git config user.name "gVisor-bot"
-git config --get user.email || git config user.email "gvisor-bot@google.com"
+git config --get user.email || git config user.email "${EMAIL}"
# Run the release tool, which pushes to the origin repository.
tools/tag_release.sh "${KOKORO_RELEASE_COMMIT}" "${KOKORO_RELEASE_TAG}"
diff --git a/scripts/root_tests.sh b/scripts/root_tests.sh
index e42c0e3ec..4e4fcc76b 100755
--- a/scripts/root_tests.sh
+++ b/scripts/root_tests.sh
@@ -26,6 +26,6 @@ chmod +x ${shim_path}
sudo mv ${shim_path} /usr/local/bin/gvisor-containerd-shim
# Run the tests that require root.
-run_as_root //runsc install --experimental=true -- --debug --strace --log-packets
-sudo systemctl restart docker
-run_as_root //test/root:root_test
+install_runsc_for_test root
+run_as_root //test/root:root_test --runtime=${RUNTIME}
+
diff --git a/test/e2e/exec_test.go b/test/e2e/exec_test.go
index ce2c4f689..7238c2afe 100644
--- a/test/e2e/exec_test.go
+++ b/test/e2e/exec_test.go
@@ -12,8 +12,8 @@
// See the License for the specific language governing permissions and
// limitations under the License.
-// Package image provides end-to-end integration tests for runsc. These tests
-// require docker and runsc to be installed on the machine.
+// Package integration provides end-to-end integration tests for runsc. These
+// tests require docker and runsc to be installed on the machine.
//
// Each test calls docker commands to start up a container, and tests that it
// is behaving properly, with various runsc commands. The container is killed
@@ -154,3 +154,68 @@ func TestExecError(t *testing.T) {
t.Fatalf("docker exec wrong error, got: %s, want: .*%s.*", err.Error(), want)
}
}
+
+// Test that exec inherits environment from run.
+func TestExecEnv(t *testing.T) {
+ if err := dockerutil.Pull("alpine"); err != nil {
+ t.Fatalf("docker pull failed: %v", err)
+ }
+ d := dockerutil.MakeDocker("exec-env-test")
+
+ // Start the container with env FOO=BAR.
+ if err := d.Run("-e", "FOO=BAR", "alpine", "sleep", "1000"); err != nil {
+ t.Fatalf("docker run failed: %v", err)
+ }
+ defer d.CleanUp()
+
+ // Exec "echo $FOO".
+ got, err := d.Exec("/bin/sh", "-c", "echo $FOO")
+ if err != nil {
+ t.Fatalf("docker exec failed: %v", err)
+ }
+ if want := "BAR"; !strings.Contains(got, want) {
+ t.Errorf("wanted exec output to contain %q, got %q", want, got)
+ }
+}
+
+// Test that exec always has HOME environment set, even when not set in run.
+func TestExecEnvHasHome(t *testing.T) {
+ // Base alpine image does not have any environment variables set.
+ if err := dockerutil.Pull("alpine"); err != nil {
+ t.Fatalf("docker pull failed: %v", err)
+ }
+ d := dockerutil.MakeDocker("exec-env-test")
+
+ // We will check that HOME is set for root user, and also for a new
+ // non-root user we will create.
+ newUID := 1234
+ newHome := "/foo/bar"
+
+ // Create a new user with a home directory, and then sleep.
+ script := fmt.Sprintf(`
+ mkdir -p -m 777 %s && \
+ adduser foo -D -u %d -h %s && \
+ sleep 1000`, newHome, newUID, newHome)
+ if err := d.Run("alpine", "/bin/sh", "-c", script); err != nil {
+ t.Fatalf("docker run failed: %v", err)
+ }
+ defer d.CleanUp()
+
+ // Exec "echo $HOME", and expect to see "/root".
+ got, err := d.Exec("/bin/sh", "-c", "echo $HOME")
+ if err != nil {
+ t.Fatalf("docker exec failed: %v", err)
+ }
+ if want := "/root"; !strings.Contains(got, want) {
+ t.Errorf("wanted exec output to contain %q, got %q", want, got)
+ }
+
+ // Execute the same as uid 123 and expect newHome.
+ got, err = d.ExecAsUser(strconv.Itoa(newUID), "/bin/sh", "-c", "echo $HOME")
+ if err != nil {
+ t.Fatalf("docker exec failed: %v", err)
+ }
+ if want := newHome; !strings.Contains(got, want) {
+ t.Errorf("wanted exec output to contain %q, got %q", want, got)
+ }
+}
diff --git a/test/root/BUILD b/test/root/BUILD
index f130df2c7..d5dd9bca2 100644
--- a/test/root/BUILD
+++ b/test/root/BUILD
@@ -15,6 +15,11 @@ go_test(
"cgroup_test.go",
"chroot_test.go",
"crictl_test.go",
+ "main_test.go",
+ "oom_score_adj_test.go",
+ ],
+ data = [
+ "//runsc",
],
embed = [":root"],
tags = [
@@ -25,12 +30,15 @@ go_test(
],
visibility = ["//:sandbox"],
deps = [
+ "//runsc/boot",
"//runsc/cgroup",
+ "//runsc/container",
"//runsc/criutil",
"//runsc/dockerutil",
"//runsc/specutils",
"//runsc/testutil",
"//test/root/testdata",
+ "@com_github_opencontainers_runtime-spec//specs-go:go_default_library",
"@com_github_syndtr_gocapability//capability:go_default_library",
],
)
diff --git a/test/root/cgroup_test.go b/test/root/cgroup_test.go
index cc7e8583e..76f1e4f2a 100644
--- a/test/root/cgroup_test.go
+++ b/test/root/cgroup_test.go
@@ -62,6 +62,12 @@ func TestCgroup(t *testing.T) {
}
d := dockerutil.MakeDocker("cgroup-test")
+ // This is not a comprehensive list of attributes.
+ //
+ // Note that we are specifically missing cpusets, which fail if specified.
+ // In any case, it's unclear if cpusets can be reliably tested here: these
+ // are often run on a single core virtual machine, and there is only a single
+ // CPU available in our current set, and every container's set.
attrs := []struct {
arg string
ctrl string
@@ -88,18 +94,6 @@ func TestCgroup(t *testing.T) {
want: "3000",
},
{
- arg: "--cpuset-cpus=0",
- ctrl: "cpuset",
- file: "cpuset.cpus",
- want: "0",
- },
- {
- arg: "--cpuset-mems=0",
- ctrl: "cpuset",
- file: "cpuset.mems",
- want: "0",
- },
- {
arg: "--kernel-memory=100MB",
ctrl: "memory",
file: "memory.kmem.limit_in_bytes",
diff --git a/test/root/chroot_test.go b/test/root/chroot_test.go
index f47f8e2c2..be0f63d18 100644
--- a/test/root/chroot_test.go
+++ b/test/root/chroot_test.go
@@ -16,19 +16,15 @@
package root
import (
- "flag"
"fmt"
"io/ioutil"
- "os"
"os/exec"
"path/filepath"
"strconv"
"strings"
"testing"
- "github.com/syndtr/gocapability/capability"
"gvisor.dev/gvisor/runsc/dockerutil"
- "gvisor.dev/gvisor/runsc/specutils"
)
// TestChroot verifies that the sandbox is chroot'd and that mounts are cleaned
@@ -144,15 +140,3 @@ func TestChrootGofer(t *testing.T) {
}
}
}
-
-func TestMain(m *testing.M) {
- dockerutil.EnsureSupportedDockerVersion()
-
- if !specutils.HasCapabilities(capability.CAP_SYS_ADMIN, capability.CAP_DAC_OVERRIDE) {
- fmt.Println("Test requires sysadmin privileges to run. Try again with sudo.")
- os.Exit(1)
- }
-
- flag.Parse()
- os.Exit(m.Run())
-}
diff --git a/test/root/main_test.go b/test/root/main_test.go
new file mode 100644
index 000000000..d74dec85f
--- /dev/null
+++ b/test/root/main_test.go
@@ -0,0 +1,49 @@
+// Copyright 2018 The gVisor Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+package root
+
+import (
+ "flag"
+ "fmt"
+ "os"
+ "testing"
+
+ "github.com/syndtr/gocapability/capability"
+ "gvisor.dev/gvisor/runsc/dockerutil"
+ "gvisor.dev/gvisor/runsc/specutils"
+)
+
+// TestMain is the main function for root tests. This function checks the
+// supported docker version, required capabilities, and configures the executable
+// path for runsc.
+func TestMain(m *testing.M) {
+ flag.Parse()
+
+ if !specutils.HasCapabilities(capability.CAP_SYS_ADMIN, capability.CAP_DAC_OVERRIDE) {
+ fmt.Println("Test requires sysadmin privileges to run. Try again with sudo.")
+ os.Exit(1)
+ }
+
+ dockerutil.EnsureSupportedDockerVersion()
+
+ // Configure exe for tests.
+ path, err := dockerutil.RuntimePath()
+ if err != nil {
+ panic(err.Error())
+ }
+ specutils.ExePath = path
+
+ os.Exit(m.Run())
+}
diff --git a/test/root/oom_score_adj_test.go b/test/root/oom_score_adj_test.go
new file mode 100644
index 000000000..6cd378a1b
--- /dev/null
+++ b/test/root/oom_score_adj_test.go
@@ -0,0 +1,376 @@
+// Copyright 2018 The gVisor Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+package root
+
+import (
+ "fmt"
+ "os"
+ "testing"
+
+ specs "github.com/opencontainers/runtime-spec/specs-go"
+ "gvisor.dev/gvisor/runsc/boot"
+ "gvisor.dev/gvisor/runsc/container"
+ "gvisor.dev/gvisor/runsc/specutils"
+ "gvisor.dev/gvisor/runsc/testutil"
+)
+
+var (
+ maxOOMScoreAdj = 1000
+ highOOMScoreAdj = 500
+ lowOOMScoreAdj = -500
+ minOOMScoreAdj = -1000
+)
+
+// Tests for oom_score_adj have to be run as root (rather than in a user
+// namespace) because we need to adjust oom_score_adj for PIDs other than our
+// own and test values below 0.
+
+// TestOOMScoreAdjSingle tests that oom_score_adj is set properly in a
+// single container sandbox.
+func TestOOMScoreAdjSingle(t *testing.T) {
+ ppid, err := specutils.GetParentPid(os.Getpid())
+ if err != nil {
+ t.Fatalf("getting parent pid: %v", err)
+ }
+ parentOOMScoreAdj, err := specutils.GetOOMScoreAdj(ppid)
+ if err != nil {
+ t.Fatalf("getting parent oom_score_adj: %v", err)
+ }
+
+ testCases := []struct {
+ Name string
+
+ // OOMScoreAdj is the oom_score_adj set to the OCI spec. If nil then
+ // no value is set.
+ OOMScoreAdj *int
+ }{
+ {
+ Name: "max",
+ OOMScoreAdj: &maxOOMScoreAdj,
+ },
+ {
+ Name: "high",
+ OOMScoreAdj: &highOOMScoreAdj,
+ },
+ {
+ Name: "low",
+ OOMScoreAdj: &lowOOMScoreAdj,
+ },
+ {
+ Name: "min",
+ OOMScoreAdj: &minOOMScoreAdj,
+ },
+ {
+ Name: "nil",
+ OOMScoreAdj: &parentOOMScoreAdj,
+ },
+ }
+
+ for _, testCase := range testCases {
+ t.Run(testCase.Name, func(t *testing.T) {
+ id := testutil.UniqueContainerID()
+ s := testutil.NewSpecWithArgs("sleep", "1000")
+ s.Process.OOMScoreAdj = testCase.OOMScoreAdj
+
+ conf := testutil.TestConfig()
+ containers, cleanup, err := startContainers(conf, []*specs.Spec{s}, []string{id})
+ if err != nil {
+ t.Fatalf("error starting containers: %v", err)
+ }
+ defer cleanup()
+
+ c := containers[0]
+
+ // Verify the gofer's oom_score_adj
+ if testCase.OOMScoreAdj != nil {
+ goferScore, err := specutils.GetOOMScoreAdj(c.GoferPid)
+ if err != nil {
+ t.Fatalf("error reading gofer oom_score_adj: %v", err)
+ }
+ if goferScore != *testCase.OOMScoreAdj {
+ t.Errorf("gofer oom_score_adj got: %d, want: %d", goferScore, *testCase.OOMScoreAdj)
+ }
+
+ // Verify the sandbox's oom_score_adj.
+ //
+ // The sandbox should be the same for all containers so just use
+ // the first one.
+ sandboxPid := c.Sandbox.Pid
+ sandboxScore, err := specutils.GetOOMScoreAdj(sandboxPid)
+ if err != nil {
+ t.Fatalf("error reading sandbox oom_score_adj: %v", err)
+ }
+ if sandboxScore != *testCase.OOMScoreAdj {
+ t.Errorf("sandbox oom_score_adj got: %d, want: %d", sandboxScore, *testCase.OOMScoreAdj)
+ }
+ }
+ })
+ }
+}
+
+// TestOOMScoreAdjMulti tests that oom_score_adj is set properly in a
+// multi-container sandbox.
+func TestOOMScoreAdjMulti(t *testing.T) {
+ ppid, err := specutils.GetParentPid(os.Getpid())
+ if err != nil {
+ t.Fatalf("getting parent pid: %v", err)
+ }
+ parentOOMScoreAdj, err := specutils.GetOOMScoreAdj(ppid)
+ if err != nil {
+ t.Fatalf("getting parent oom_score_adj: %v", err)
+ }
+
+ testCases := []struct {
+ Name string
+
+ // OOMScoreAdj is the oom_score_adj set to the OCI spec. If nil then
+ // no value is set. One value for each container. The first value is the
+ // root container.
+ OOMScoreAdj []*int
+
+ // Expected is the expected oom_score_adj of the sandbox. If nil, then
+ // this value is ignored.
+ Expected *int
+
+ // Remove is a set of container indexes to remove from the sandbox.
+ Remove []int
+
+ // ExpectedAfterRemove is the expected oom_score_adj of the sandbox
+ // after containers are removed. Ignored if nil.
+ ExpectedAfterRemove *int
+ }{
+ // A single container CRI test case. This should not happen in
+ // practice as there should be at least one container besides the pause
+ // container. However, we include a test case to ensure sane behavior.
+ {
+ Name: "single",
+ OOMScoreAdj: []*int{&highOOMScoreAdj},
+ Expected: &parentOOMScoreAdj,
+ },
+ {
+ Name: "multi_no_value",
+ OOMScoreAdj: []*int{nil, nil, nil},
+ Expected: &parentOOMScoreAdj,
+ },
+ {
+ Name: "multi_non_nil_root",
+ OOMScoreAdj: []*int{&minOOMScoreAdj, nil, nil},
+ Expected: &parentOOMScoreAdj,
+ },
+ {
+ Name: "multi_value",
+ OOMScoreAdj: []*int{&minOOMScoreAdj, &highOOMScoreAdj, &lowOOMScoreAdj},
+ // The lowest value excluding the root container is expected.
+ Expected: &lowOOMScoreAdj,
+ },
+ {
+ Name: "multi_min_value",
+ OOMScoreAdj: []*int{&minOOMScoreAdj, &lowOOMScoreAdj},
+ // The lowest value excluding the root container is expected.
+ Expected: &lowOOMScoreAdj,
+ },
+ {
+ Name: "multi_max_value",
+ OOMScoreAdj: []*int{&minOOMScoreAdj, &maxOOMScoreAdj, &highOOMScoreAdj},
+ // The lowest value excluding the root container is expected.
+ Expected: &highOOMScoreAdj,
+ },
+ {
+ Name: "remove_adjusted",
+ OOMScoreAdj: []*int{&minOOMScoreAdj, &maxOOMScoreAdj, &highOOMScoreAdj},
+ // The lowest value excluding the root container is expected.
+ Expected: &highOOMScoreAdj,
+ // Remove highOOMScoreAdj container.
+ Remove: []int{2},
+ ExpectedAfterRemove: &maxOOMScoreAdj,
+ },
+ {
+ // This test removes all non-root sandboxes with a specified oomScoreAdj.
+ Name: "remove_to_nil",
+ OOMScoreAdj: []*int{&minOOMScoreAdj, nil, &lowOOMScoreAdj},
+ Expected: &lowOOMScoreAdj,
+ // Remove lowOOMScoreAdj container.
+ Remove: []int{2},
+ // The oom_score_adj expected after remove is that of the parent process.
+ ExpectedAfterRemove: &parentOOMScoreAdj,
+ },
+ {
+ Name: "remove_no_effect",
+ OOMScoreAdj: []*int{&minOOMScoreAdj, &maxOOMScoreAdj, &highOOMScoreAdj},
+ // The lowest value excluding the root container is expected.
+ Expected: &highOOMScoreAdj,
+ // Remove the maxOOMScoreAdj container.
+ Remove: []int{1},
+ ExpectedAfterRemove: &highOOMScoreAdj,
+ },
+ }
+
+ for _, testCase := range testCases {
+ t.Run(testCase.Name, func(t *testing.T) {
+ var cmds [][]string
+ var oomScoreAdj []*int
+ var toRemove []string
+
+ for _, oomScore := range testCase.OOMScoreAdj {
+ oomScoreAdj = append(oomScoreAdj, oomScore)
+ cmds = append(cmds, []string{"sleep", "100"})
+ }
+
+ specs, ids := createSpecs(cmds...)
+ for i, spec := range specs {
+ // Ensure the correct value is set, including no value.
+ spec.Process.OOMScoreAdj = oomScoreAdj[i]
+
+ for _, j := range testCase.Remove {
+ if i == j {
+ toRemove = append(toRemove, ids[i])
+ }
+ }
+ }
+
+ conf := testutil.TestConfig()
+ containers, cleanup, err := startContainers(conf, specs, ids)
+ if err != nil {
+ t.Fatalf("error starting containers: %v", err)
+ }
+ defer cleanup()
+
+ for i, c := range containers {
+ if oomScoreAdj[i] != nil {
+ // Verify the gofer's oom_score_adj
+ score, err := specutils.GetOOMScoreAdj(c.GoferPid)
+ if err != nil {
+ t.Fatalf("error reading gofer oom_score_adj: %v", err)
+ }
+ if score != *oomScoreAdj[i] {
+ t.Errorf("gofer oom_score_adj got: %d, want: %d", score, *oomScoreAdj[i])
+ }
+ }
+ }
+
+ // Verify the sandbox's oom_score_adj.
+ //
+ // The sandbox should be the same for all containers so just use
+ // the first one.
+ sandboxPid := containers[0].Sandbox.Pid
+ if testCase.Expected != nil {
+ score, err := specutils.GetOOMScoreAdj(sandboxPid)
+ if err != nil {
+ t.Fatalf("error reading sandbox oom_score_adj: %v", err)
+ }
+ if score != *testCase.Expected {
+ t.Errorf("sandbox oom_score_adj got: %d, want: %d", score, *testCase.Expected)
+ }
+ }
+
+ if len(toRemove) == 0 {
+ return
+ }
+
+ // Remove containers.
+ for _, removeID := range toRemove {
+ for _, c := range containers {
+ if c.ID == removeID {
+ c.Destroy()
+ }
+ }
+ }
+
+ // Check the new adjusted oom_score_adj.
+ if testCase.ExpectedAfterRemove != nil {
+ scoreAfterRemove, err := specutils.GetOOMScoreAdj(sandboxPid)
+ if err != nil {
+ t.Fatalf("error reading sandbox oom_score_adj: %v", err)
+ }
+ if scoreAfterRemove != *testCase.ExpectedAfterRemove {
+ t.Errorf("sandbox oom_score_adj got: %d, want: %d", scoreAfterRemove, *testCase.ExpectedAfterRemove)
+ }
+ }
+ })
+ }
+}
+
+func createSpecs(cmds ...[]string) ([]*specs.Spec, []string) {
+ var specs []*specs.Spec
+ var ids []string
+ rootID := testutil.UniqueContainerID()
+
+ for i, cmd := range cmds {
+ spec := testutil.NewSpecWithArgs(cmd...)
+ if i == 0 {
+ spec.Annotations = map[string]string{
+ specutils.ContainerdContainerTypeAnnotation: specutils.ContainerdContainerTypeSandbox,
+ }
+ ids = append(ids, rootID)
+ } else {
+ spec.Annotations = map[string]string{
+ specutils.ContainerdContainerTypeAnnotation: specutils.ContainerdContainerTypeContainer,
+ specutils.ContainerdSandboxIDAnnotation: rootID,
+ }
+ ids = append(ids, testutil.UniqueContainerID())
+ }
+ specs = append(specs, spec)
+ }
+ return specs, ids
+}
+
+func startContainers(conf *boot.Config, specs []*specs.Spec, ids []string) ([]*container.Container, func(), error) {
+ // Setup root dir if one hasn't been provided.
+ if len(conf.RootDir) == 0 {
+ rootDir, err := testutil.SetupRootDir()
+ if err != nil {
+ return nil, nil, fmt.Errorf("error creating root dir: %v", err)
+ }
+ conf.RootDir = rootDir
+ }
+
+ var containers []*container.Container
+ var bundles []string
+ cleanup := func() {
+ for _, c := range containers {
+ c.Destroy()
+ }
+ for _, b := range bundles {
+ os.RemoveAll(b)
+ }
+ os.RemoveAll(conf.RootDir)
+ }
+ for i, spec := range specs {
+ bundleDir, err := testutil.SetupBundleDir(spec)
+ if err != nil {
+ cleanup()
+ return nil, nil, fmt.Errorf("error setting up container: %v", err)
+ }
+ bundles = append(bundles, bundleDir)
+
+ args := container.Args{
+ ID: ids[i],
+ Spec: spec,
+ BundleDir: bundleDir,
+ }
+ cont, err := container.New(conf, args)
+ if err != nil {
+ cleanup()
+ return nil, nil, fmt.Errorf("error creating container: %v", err)
+ }
+ containers = append(containers, cont)
+
+ if err := cont.Start(conf); err != nil {
+ cleanup()
+ return nil, nil, fmt.Errorf("error starting container: %v", err)
+ }
+ }
+ return containers, cleanup, nil
+}
diff --git a/test/root/root.go b/test/root/root.go
index 349c752cc..0f1d29faf 100644
--- a/test/root/root.go
+++ b/test/root/root.go
@@ -12,5 +12,10 @@
// See the License for the specific language governing permissions and
// limitations under the License.
-// Package root is empty. See chroot_test.go for description.
+// Package root is used for tests that requires sysadmin privileges run. First,
+// follow the setup instruction in runsc/test/README.md. You should also have
+// docker, containerd, and crictl installed. To run these tests from the
+// project root directory:
+//
+// ./scripts/root_tests.sh
package root
diff --git a/test/runtimes/BUILD b/test/runtimes/BUILD
index 5616a8b7b..dfb4e2a97 100644
--- a/test/runtimes/BUILD
+++ b/test/runtimes/BUILD
@@ -1,25 +1,41 @@
# These packages are used to run language runtime tests inside gVisor sandboxes.
-load("@io_bazel_rules_go//go:def.bzl", "go_library")
+load("@io_bazel_rules_go//go:def.bzl", "go_binary")
load("//test/runtimes:build_defs.bzl", "runtime_test")
package(licenses = ["notice"])
-go_library(
- name = "runtimes",
- srcs = ["runtimes.go"],
- importpath = "gvisor.dev/gvisor/test/runtimes",
+go_binary(
+ name = "runner",
+ testonly = 1,
+ srcs = ["runner.go"],
+ deps = [
+ "//runsc/dockerutil",
+ "//runsc/testutil",
+ ],
)
runtime_test(
- name = "runtimes_test",
- size = "small",
- srcs = ["runtimes_test.go"],
- embed = [":runtimes"],
- tags = [
- # Requires docker and runsc to be configured before the test runs.
- "manual",
- "local",
- ],
- deps = ["//runsc/testutil"],
+ image = "gcr.io/gvisor-presubmit/go1.12",
+ lang = "go",
+)
+
+runtime_test(
+ image = "gcr.io/gvisor-presubmit/java11",
+ lang = "java",
+)
+
+runtime_test(
+ image = "gcr.io/gvisor-presubmit/nodejs12.4.0",
+ lang = "nodejs",
+)
+
+runtime_test(
+ image = "gcr.io/gvisor-presubmit/php7.3.6",
+ lang = "php",
+)
+
+runtime_test(
+ image = "gcr.io/gvisor-presubmit/python3.7.3",
+ lang = "python",
)
diff --git a/test/runtimes/README.md b/test/runtimes/README.md
index 34d3507be..e41e78f77 100644
--- a/test/runtimes/README.md
+++ b/test/runtimes/README.md
@@ -16,10 +16,11 @@ The following runtimes are currently supported:
1) [Install and configure Docker](https://docs.docker.com/install/)
-2) Build each Docker container from the runtimes directory:
+2) Build each Docker container from the runtimes/images directory:
```bash
-$ docker build -f $LANG/Dockerfile [-t $NAME] .
+$ cd images
+$ docker build -f Dockerfile_$LANG [-t $NAME] .
```
### Testing:
diff --git a/test/runtimes/build_defs.bzl b/test/runtimes/build_defs.bzl
index ac28cc037..5e3065342 100644
--- a/test/runtimes/build_defs.bzl
+++ b/test/runtimes/build_defs.bzl
@@ -1,19 +1,35 @@
"""Defines a rule for runsc test targets."""
-load("@io_bazel_rules_go//go:def.bzl", _go_test = "go_test")
-
# runtime_test is a macro that will create targets to run the given test target
# with different runtime options.
-def runtime_test(**kwargs):
- """Runs the given test target with different runtime options."""
- name = kwargs["name"]
- _go_test(**kwargs)
- kwargs["name"] = name + "_hostnet"
- kwargs["args"] = ["--runtime-type=hostnet"]
- _go_test(**kwargs)
- kwargs["name"] = name + "_kvm"
- kwargs["args"] = ["--runtime-type=kvm"]
- _go_test(**kwargs)
- kwargs["name"] = name + "_overlay"
- kwargs["args"] = ["--runtime-type=overlay"]
- _go_test(**kwargs)
+def runtime_test(
+ lang,
+ image,
+ shard_count = 50,
+ size = "enormous"):
+ sh_test(
+ name = lang + "_test",
+ srcs = ["runner.sh"],
+ args = [
+ "--lang",
+ lang,
+ "--image",
+ image,
+ ],
+ data = [
+ ":runner",
+ ],
+ size = size,
+ shard_count = shard_count,
+ tags = [
+ # Requires docker and runsc to be configured before the test runs.
+ "manual",
+ "local",
+ ],
+ )
+
+def sh_test(**kwargs):
+ """Wraps the standard sh_test."""
+ native.sh_test(
+ **kwargs
+ )
diff --git a/test/runtimes/common/BUILD b/test/runtimes/common/BUILD
deleted file mode 100644
index b4740bb97..000000000
--- a/test/runtimes/common/BUILD
+++ /dev/null
@@ -1,20 +0,0 @@
-load("@io_bazel_rules_go//go:def.bzl", "go_library", "go_test")
-
-package(licenses = ["notice"])
-
-go_library(
- name = "common",
- srcs = ["common.go"],
- importpath = "gvisor.dev/gvisor/test/runtimes/common",
- visibility = ["//:sandbox"],
-)
-
-go_test(
- name = "common_test",
- size = "small",
- srcs = ["common_test.go"],
- deps = [
- ":common",
- "//runsc/testutil",
- ],
-)
diff --git a/test/runtimes/common/common.go b/test/runtimes/common/common.go
deleted file mode 100644
index 0ff87fa8b..000000000
--- a/test/runtimes/common/common.go
+++ /dev/null
@@ -1,114 +0,0 @@
-// Copyright 2019 The gVisor Authors.
-//
-// Licensed under the Apache License, Version 2.0 (the "License");
-// you may not use this file except in compliance with the License.
-// You may obtain a copy of the License at
-//
-// http://www.apache.org/licenses/LICENSE-2.0
-//
-// Unless required by applicable law or agreed to in writing, software
-// distributed under the License is distributed on an "AS IS" BASIS,
-// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-// See the License for the specific language governing permissions and
-// limitations under the License.
-
-// Package common executes functions for proctor binaries.
-package common
-
-import (
- "flag"
- "fmt"
- "os"
- "path/filepath"
- "regexp"
-)
-
-var (
- list = flag.Bool("list", false, "list all available tests")
- test = flag.String("test", "", "run a single test from the list of available tests")
- version = flag.Bool("v", false, "print out the version of node that is installed")
-)
-
-// TestRunner is an interface to be implemented in each proctor binary.
-type TestRunner interface {
- // ListTests returns a string slice of tests available to run.
- ListTests() ([]string, error)
-
- // RunTest runs a single test.
- RunTest(test string) error
-}
-
-// LaunchFunc parses flags passed by a proctor binary and calls the requested behavior.
-func LaunchFunc(tr TestRunner) error {
- flag.Parse()
-
- if *list && *test != "" {
- flag.PrintDefaults()
- return fmt.Errorf("cannot specify 'list' and 'test' flags simultaneously")
- }
- if *list {
- tests, err := tr.ListTests()
- if err != nil {
- return fmt.Errorf("failed to list tests: %v", err)
- }
- for _, test := range tests {
- fmt.Println(test)
- }
- return nil
- }
- if *version {
- fmt.Println(os.Getenv("LANG_NAME"), "version:", os.Getenv("LANG_VER"), "is installed.")
- return nil
- }
- if *test != "" {
- if err := tr.RunTest(*test); err != nil {
- return fmt.Errorf("test %q failed to run: %v", *test, err)
- }
- return nil
- }
-
- if err := runAllTests(tr); err != nil {
- return fmt.Errorf("error running all tests: %v", err)
- }
- return nil
-}
-
-// Search uses filepath.Walk to perform a search of the disk for test files
-// and returns a string slice of tests.
-func Search(root string, testFilter *regexp.Regexp) ([]string, error) {
- var testSlice []string
-
- err := filepath.Walk(root, func(path string, info os.FileInfo, err error) error {
- name := filepath.Base(path)
-
- if info.IsDir() || !testFilter.MatchString(name) {
- return nil
- }
-
- relPath, err := filepath.Rel(root, path)
- if err != nil {
- return err
- }
- testSlice = append(testSlice, relPath)
- return nil
- })
-
- if err != nil {
- return nil, fmt.Errorf("walking %q: %v", root, err)
- }
-
- return testSlice, nil
-}
-
-func runAllTests(tr TestRunner) error {
- tests, err := tr.ListTests()
- if err != nil {
- return fmt.Errorf("failed to list tests: %v", err)
- }
- for _, test := range tests {
- if err := tr.RunTest(test); err != nil {
- return fmt.Errorf("test %q failed to run: %v", test, err)
- }
- }
- return nil
-}
diff --git a/test/runtimes/go/BUILD b/test/runtimes/go/BUILD
deleted file mode 100644
index ce971ee9d..000000000
--- a/test/runtimes/go/BUILD
+++ /dev/null
@@ -1,9 +0,0 @@
-load("@io_bazel_rules_go//go:def.bzl", "go_binary")
-
-package(licenses = ["notice"])
-
-go_binary(
- name = "proctor-go",
- srcs = ["proctor-go.go"],
- deps = ["//test/runtimes/common"],
-)
diff --git a/test/runtimes/go/Dockerfile b/test/runtimes/go/Dockerfile
deleted file mode 100644
index 2d3477392..000000000
--- a/test/runtimes/go/Dockerfile
+++ /dev/null
@@ -1,35 +0,0 @@
-FROM ubuntu:bionic
-ENV LANG_VER=1.12.5
-ENV LANG_NAME=Go
-
-RUN apt-get update && apt-get install -y \
- curl \
- gcc \
- git
-
-WORKDIR /root
-
-# Download Go 1.4 to use as a bootstrap for building Go from the source.
-RUN curl -o go1.4.linux-amd64.tar.gz https://dl.google.com/go/go1.4.linux-amd64.tar.gz
-RUN curl -LJO https://github.com/golang/go/archive/go${LANG_VER}.tar.gz
-RUN mkdir bootstr
-RUN tar -C bootstr -xzf go1.4.linux-amd64.tar.gz
-RUN tar -xzf go-go${LANG_VER}.tar.gz
-RUN mv go-go${LANG_VER} go
-
-ENV GOROOT=/root/go
-ENV GOROOT_BOOTSTRAP=/root/bootstr/go
-ENV LANG_DIR=${GOROOT}
-
-WORKDIR ${LANG_DIR}/src
-RUN ./make.bash
-# Pre-compile the tests for faster execution
-RUN ["/root/go/bin/go", "tool", "dist", "test", "-compile-only"]
-
-WORKDIR ${LANG_DIR}
-
-COPY common /root/go/src/gvisor.dev/gvisor/test/runtimes/common/common
-COPY go/proctor-go.go ${LANG_DIR}
-RUN ["/root/go/bin/go", "build", "-o", "/root/go/bin/proctor", "proctor-go.go"]
-
-ENTRYPOINT ["/root/go/bin/proctor"]
diff --git a/test/runtimes/images/Dockerfile_go1.12 b/test/runtimes/images/Dockerfile_go1.12
new file mode 100644
index 000000000..ab9d6abf3
--- /dev/null
+++ b/test/runtimes/images/Dockerfile_go1.12
@@ -0,0 +1,10 @@
+# Go is easy, since we already have everything we need to compile the proctor
+# binary and run the tests in the golang Docker image.
+FROM golang:1.12
+ADD ["proctor/", "/go/src/proctor/"]
+RUN ["go", "build", "-o", "/proctor", "/go/src/proctor"]
+
+# Pre-compile the tests so we don't need to do so in each test run.
+RUN ["go", "tool", "dist", "test", "-compile-only"]
+
+ENTRYPOINT ["/proctor", "--runtime=go"]
diff --git a/test/runtimes/images/Dockerfile_java11 b/test/runtimes/images/Dockerfile_java11
new file mode 100644
index 000000000..9b7c3d5a3
--- /dev/null
+++ b/test/runtimes/images/Dockerfile_java11
@@ -0,0 +1,30 @@
+# Compile the proctor binary.
+FROM golang:1.12 AS golang
+ADD ["proctor/", "/go/src/proctor/"]
+RUN ["go", "build", "-o", "/proctor", "/go/src/proctor"]
+
+FROM ubuntu:bionic
+RUN apt-get update && apt-get install -y \
+ autoconf \
+ build-essential \
+ curl \
+ make \
+ openjdk-11-jdk \
+ unzip \
+ zip
+
+# Download the JDK test library.
+WORKDIR /root
+RUN set -ex \
+ && curl -fsSL --retry 10 -o /tmp/jdktests.tar.gz http://hg.openjdk.java.net/jdk/jdk11/archive/76072a077ee1.tar.gz/test \
+ && tar -xzf /tmp/jdktests.tar.gz \
+ && mv jdk11-76072a077ee1/test test \
+ && rm -f /tmp/jdktests.tar.gz
+
+# Install jtreg and add to PATH.
+RUN curl -o jtreg.tar.gz https://ci.adoptopenjdk.net/view/Dependencies/job/jtreg/lastSuccessfulBuild/artifact/jtreg-4.2.0-tip.tar.gz
+RUN tar -xzf jtreg.tar.gz
+ENV PATH="/root/jtreg/bin:$PATH"
+
+COPY --from=golang /proctor /proctor
+ENTRYPOINT ["/proctor", "--runtime=java"]
diff --git a/test/runtimes/images/Dockerfile_nodejs12.4.0 b/test/runtimes/images/Dockerfile_nodejs12.4.0
new file mode 100644
index 000000000..26f68b487
--- /dev/null
+++ b/test/runtimes/images/Dockerfile_nodejs12.4.0
@@ -0,0 +1,28 @@
+# Compile the proctor binary.
+FROM golang:1.12 AS golang
+ADD ["proctor/", "/go/src/proctor/"]
+RUN ["go", "build", "-o", "/proctor", "/go/src/proctor"]
+
+FROM ubuntu:bionic
+RUN apt-get update && apt-get install -y \
+ curl \
+ dumb-init \
+ g++ \
+ make \
+ python
+
+WORKDIR /root
+ARG VERSION=v12.4.0
+RUN curl -o node-${VERSION}.tar.gz https://nodejs.org/dist/${VERSION}/node-${VERSION}.tar.gz
+RUN tar -zxf node-${VERSION}.tar.gz
+
+WORKDIR /root/node-${VERSION}
+RUN ./configure
+RUN make
+RUN make test-build
+
+COPY --from=golang /proctor /proctor
+
+# Including dumb-init emulates the Linux "init" process, preventing the failure
+# of tests involving worker processes.
+ENTRYPOINT ["/usr/bin/dumb-init", "/proctor", "--runtime=nodejs"]
diff --git a/test/runtimes/images/Dockerfile_php7.3.6 b/test/runtimes/images/Dockerfile_php7.3.6
new file mode 100644
index 000000000..e6b4c6329
--- /dev/null
+++ b/test/runtimes/images/Dockerfile_php7.3.6
@@ -0,0 +1,27 @@
+# Compile the proctor binary.
+FROM golang:1.12 AS golang
+ADD ["proctor/", "/go/src/proctor/"]
+RUN ["go", "build", "-o", "/proctor", "/go/src/proctor"]
+
+FROM ubuntu:bionic
+RUN apt-get update && apt-get install -y \
+ autoconf \
+ automake \
+ bison \
+ build-essential \
+ curl \
+ libtool \
+ libxml2-dev \
+ re2c
+
+WORKDIR /root
+ARG VERSION=7.3.6
+RUN curl -o php-${VERSION}.tar.gz https://www.php.net/distributions/php-${VERSION}.tar.gz
+RUN tar -zxf php-${VERSION}.tar.gz
+
+WORKDIR /root/php-${VERSION}
+RUN ./configure
+RUN make
+
+COPY --from=golang /proctor /proctor
+ENTRYPOINT ["/proctor", "--runtime=php"]
diff --git a/test/runtimes/images/Dockerfile_python3.7.3 b/test/runtimes/images/Dockerfile_python3.7.3
new file mode 100644
index 000000000..905cd22d7
--- /dev/null
+++ b/test/runtimes/images/Dockerfile_python3.7.3
@@ -0,0 +1,30 @@
+# Compile the proctor binary.
+FROM golang:1.12 AS golang
+ADD ["proctor/", "/go/src/proctor/"]
+RUN ["go", "build", "-o", "/proctor", "/go/src/proctor"]
+
+FROM ubuntu:bionic
+
+RUN apt-get update && apt-get install -y \
+ curl \
+ gcc \
+ libbz2-dev \
+ libffi-dev \
+ liblzma-dev \
+ libreadline-dev \
+ libssl-dev \
+ make \
+ zlib1g-dev
+
+# Use flags -LJO to follow the html redirect and download .tar.gz.
+WORKDIR /root
+ARG VERSION=3.7.3
+RUN curl -LJO https://github.com/python/cpython/archive/v${VERSION}.tar.gz
+RUN tar -zxf cpython-${VERSION}.tar.gz
+
+WORKDIR /root/cpython-${VERSION}
+RUN ./configure --with-pydebug
+RUN make -s -j2
+
+COPY --from=golang /proctor /proctor
+ENTRYPOINT ["/proctor", "--runtime=python"]
diff --git a/test/runtimes/images/proctor/BUILD b/test/runtimes/images/proctor/BUILD
new file mode 100644
index 000000000..09dc6c42f
--- /dev/null
+++ b/test/runtimes/images/proctor/BUILD
@@ -0,0 +1,26 @@
+load("@io_bazel_rules_go//go:def.bzl", "go_binary", "go_test")
+
+package(licenses = ["notice"])
+
+go_binary(
+ name = "proctor",
+ srcs = [
+ "go.go",
+ "java.go",
+ "nodejs.go",
+ "php.go",
+ "proctor.go",
+ "python.go",
+ ],
+ visibility = ["//test/runtimes/images:__subpackages__"],
+)
+
+go_test(
+ name = "proctor_test",
+ size = "small",
+ srcs = ["proctor_test.go"],
+ embed = [":proctor"],
+ deps = [
+ "//runsc/testutil",
+ ],
+)
diff --git a/test/runtimes/go/proctor-go.go b/test/runtimes/images/proctor/go.go
index 3eb24576e..3e2d5d8db 100644
--- a/test/runtimes/go/proctor-go.go
+++ b/test/runtimes/images/proctor/go.go
@@ -12,50 +12,42 @@
// See the License for the specific language governing permissions and
// limitations under the License.
-// Binary proctor-go is a utility that facilitates language testing for Go.
-
-// There are two types of Go tests: "Go tool tests" and "Go tests on disk".
-// "Go tool tests" are found and executed using `go tool dist test`.
-// "Go tests on disk" are found in the /test directory and are
-// executed using `go run run.go`.
package main
import (
"fmt"
- "log"
"os"
"os/exec"
- "path/filepath"
"regexp"
"strings"
-
- "gvisor.dev/gvisor/test/runtimes/common"
)
var (
- dir = os.Getenv("LANG_DIR")
- goBin = filepath.Join(dir, "bin/go")
- testDir = filepath.Join(dir, "test")
- testRegEx = regexp.MustCompile(`^.+\.go$`)
+ goTestRegEx = regexp.MustCompile(`^.+\.go$`)
// Directories with .dir contain helper files for tests.
// Exclude benchmarks and stress tests.
- dirFilter = regexp.MustCompile(`^(bench|stress)\/.+$|^.+\.dir.+$`)
+ goDirFilter = regexp.MustCompile(`^(bench|stress)\/.+$|^.+\.dir.+$`)
)
-type goRunner struct {
-}
+// Location of Go tests on disk.
+const goTestDir = "/usr/local/go/test"
-func main() {
- if err := common.LaunchFunc(goRunner{}); err != nil {
- log.Fatalf("Failed to start: %v", err)
- }
-}
+// goRunner implements TestRunner for Go.
+//
+// There are two types of Go tests: "Go tool tests" and "Go tests on disk".
+// "Go tool tests" are found and executed using `go tool dist test`. "Go tests
+// on disk" are found in the /usr/local/go/test directory and are executed
+// using `go run run.go`.
+type goRunner struct{}
+
+var _ TestRunner = goRunner{}
-func (g goRunner) ListTests() ([]string, error) {
+// ListTests implements TestRunner.ListTests.
+func (goRunner) ListTests() ([]string, error) {
// Go tool dist test tests.
args := []string{"tool", "dist", "test", "-list"}
- cmd := exec.Command(filepath.Join(dir, "bin/go"), args...)
+ cmd := exec.Command("go", args...)
cmd.Stderr = os.Stderr
out, err := cmd.Output()
if err != nil {
@@ -67,14 +59,14 @@ func (g goRunner) ListTests() ([]string, error) {
}
// Go tests on disk.
- diskSlice, err := common.Search(testDir, testRegEx)
+ diskSlice, err := search(goTestDir, goTestRegEx)
if err != nil {
return nil, err
}
// Remove items from /bench/, /stress/ and .dir files
diskFiltered := diskSlice[:0]
for _, file := range diskSlice {
- if !dirFilter.MatchString(file) {
+ if !goDirFilter.MatchString(file) {
diskFiltered = append(diskFiltered, file)
}
}
@@ -82,24 +74,17 @@ func (g goRunner) ListTests() ([]string, error) {
return append(toolSlice, diskFiltered...), nil
}
-func (g goRunner) RunTest(test string) error {
+// TestCmd implements TestRunner.TestCmd.
+func (goRunner) TestCmd(test string) *exec.Cmd {
// Check if test exists on disk by searching for file of the same name.
// This will determine whether or not it is a Go test on disk.
if strings.HasSuffix(test, ".go") {
// Test has suffix ".go" which indicates a disk test, run it as such.
- cmd := exec.Command(goBin, "run", "run.go", "-v", "--", test)
- cmd.Dir = testDir
- cmd.Stdout, cmd.Stderr = os.Stdout, os.Stderr
- if err := cmd.Run(); err != nil {
- return fmt.Errorf("failed to run test: %v", err)
- }
- } else {
- // No ".go" suffix, run as a tool test.
- cmd := exec.Command(goBin, "tool", "dist", "test", "-run", test)
- cmd.Stdout, cmd.Stderr = os.Stdout, os.Stderr
- if err := cmd.Run(); err != nil {
- return fmt.Errorf("failed to run test: %v", err)
- }
+ cmd := exec.Command("go", "run", "run.go", "-v", "--", test)
+ cmd.Dir = goTestDir
+ return cmd
}
- return nil
+
+ // No ".go" suffix, run as a tool test.
+ return exec.Command("go", "tool", "dist", "test", "-run", test)
}
diff --git a/test/runtimes/java/proctor-java.go b/test/runtimes/images/proctor/java.go
index 7f6a66f4f..8b362029d 100644
--- a/test/runtimes/java/proctor-java.go
+++ b/test/runtimes/images/proctor/java.go
@@ -12,40 +12,31 @@
// See the License for the specific language governing permissions and
// limitations under the License.
-// Binary proctor-java is a utility that facilitates language testing for Java.
package main
import (
"fmt"
- "log"
"os"
"os/exec"
- "path/filepath"
"regexp"
"strings"
-
- "gvisor.dev/gvisor/test/runtimes/common"
)
-var (
- dir = os.Getenv("LANG_DIR")
- hash = os.Getenv("LANG_HASH")
- jtreg = filepath.Join(dir, "jtreg/bin/jtreg")
- exclDirs = regexp.MustCompile(`(^(sun\/security)|(java\/util\/stream)|(java\/time)| )`)
-)
+// Directories to exclude from tests.
+var javaExclDirs = regexp.MustCompile(`(^(sun\/security)|(java\/util\/stream)|(java\/time)| )`)
-type javaRunner struct {
-}
+// Location of java tests.
+const javaTestDir = "/root/test/jdk"
-func main() {
- if err := common.LaunchFunc(javaRunner{}); err != nil {
- log.Fatalf("Failed to start: %v", err)
- }
-}
+// javaRunner implements TestRunner for Java.
+type javaRunner struct{}
+
+var _ TestRunner = javaRunner{}
-func (j javaRunner) ListTests() ([]string, error) {
+// ListTests implements TestRunner.ListTests.
+func (javaRunner) ListTests() ([]string, error) {
args := []string{
- "-dir:/root/jdk11-" + hash + "/test/jdk",
+ "-dir:" + javaTestDir,
"-ignore:quiet",
"-a",
"-listtests",
@@ -54,7 +45,7 @@ func (j javaRunner) ListTests() ([]string, error) {
":jdk_sound",
":jdk_imageio",
}
- cmd := exec.Command(jtreg, args...)
+ cmd := exec.Command("jtreg", args...)
cmd.Stderr = os.Stderr
out, err := cmd.Output()
if err != nil {
@@ -62,19 +53,19 @@ func (j javaRunner) ListTests() ([]string, error) {
}
var testSlice []string
for _, test := range strings.Split(string(out), "\n") {
- if !exclDirs.MatchString(test) {
+ if !javaExclDirs.MatchString(test) {
testSlice = append(testSlice, test)
}
}
return testSlice, nil
}
-func (j javaRunner) RunTest(test string) error {
- args := []string{"-noreport", "-dir:/root/jdk11-" + hash + "/test/jdk", test}
- cmd := exec.Command(jtreg, args...)
- cmd.Stdout, cmd.Stderr = os.Stdout, os.Stderr
- if err := cmd.Run(); err != nil {
- return fmt.Errorf("failed to run: %v", err)
+// TestCmd implements TestRunner.TestCmd.
+func (javaRunner) TestCmd(test string) *exec.Cmd {
+ args := []string{
+ "-noreport",
+ "-dir:" + javaTestDir,
+ test,
}
- return nil
+ return exec.Command("jtreg", args...)
}
diff --git a/test/runtimes/images/proctor/nodejs.go b/test/runtimes/images/proctor/nodejs.go
new file mode 100644
index 000000000..bd57db444
--- /dev/null
+++ b/test/runtimes/images/proctor/nodejs.go
@@ -0,0 +1,46 @@
+// Copyright 2019 The gVisor Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+package main
+
+import (
+ "os/exec"
+ "path/filepath"
+ "regexp"
+)
+
+var nodejsTestRegEx = regexp.MustCompile(`^test-[^-].+\.js$`)
+
+// Location of nodejs tests relative to working dir.
+const nodejsTestDir = "test"
+
+// nodejsRunner implements TestRunner for NodeJS.
+type nodejsRunner struct{}
+
+var _ TestRunner = nodejsRunner{}
+
+// ListTests implements TestRunner.ListTests.
+func (nodejsRunner) ListTests() ([]string, error) {
+ testSlice, err := search(nodejsTestDir, nodejsTestRegEx)
+ if err != nil {
+ return nil, err
+ }
+ return testSlice, nil
+}
+
+// TestCmd implements TestRunner.TestCmd.
+func (nodejsRunner) TestCmd(test string) *exec.Cmd {
+ args := []string{filepath.Join("tools", "test.py"), test}
+ return exec.Command("/usr/bin/python", args...)
+}
diff --git a/test/runtimes/php/proctor-php.go b/test/runtimes/images/proctor/php.go
index e6c5fabdf..9115040e1 100644
--- a/test/runtimes/php/proctor-php.go
+++ b/test/runtimes/images/proctor/php.go
@@ -12,47 +12,31 @@
// See the License for the specific language governing permissions and
// limitations under the License.
-// Binary proctor-php is a utility that facilitates language testing for PHP.
package main
import (
- "fmt"
- "log"
- "os"
"os/exec"
"regexp"
-
- "gvisor.dev/gvisor/test/runtimes/common"
)
-var (
- dir = os.Getenv("LANG_DIR")
- testRegEx = regexp.MustCompile(`^.+\.phpt$`)
-)
+var phpTestRegEx = regexp.MustCompile(`^.+\.phpt$`)
-type phpRunner struct {
-}
+// phpRunner implements TestRunner for PHP.
+type phpRunner struct{}
-func main() {
- if err := common.LaunchFunc(phpRunner{}); err != nil {
- log.Fatalf("Failed to start: %v", err)
- }
-}
+var _ TestRunner = phpRunner{}
-func (p phpRunner) ListTests() ([]string, error) {
- testSlice, err := common.Search(dir, testRegEx)
+// ListTests implements TestRunner.ListTests.
+func (phpRunner) ListTests() ([]string, error) {
+ testSlice, err := search(".", phpTestRegEx)
if err != nil {
return nil, err
}
return testSlice, nil
}
-func (p phpRunner) RunTest(test string) error {
+// TestCmd implements TestRunner.TestCmd.
+func (phpRunner) TestCmd(test string) *exec.Cmd {
args := []string{"test", "TESTS=" + test}
- cmd := exec.Command("make", args...)
- cmd.Stdout, cmd.Stderr = os.Stdout, os.Stderr
- if err := cmd.Run(); err != nil {
- return fmt.Errorf("failed to run: %v", err)
- }
- return nil
+ return exec.Command("make", args...)
}
diff --git a/test/runtimes/images/proctor/proctor.go b/test/runtimes/images/proctor/proctor.go
new file mode 100644
index 000000000..e6178e82b
--- /dev/null
+++ b/test/runtimes/images/proctor/proctor.go
@@ -0,0 +1,154 @@
+// Copyright 2019 The gVisor Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+// Binary proctor runs the test for a particular runtime. It is meant to be
+// included in Docker images for all runtime tests.
+package main
+
+import (
+ "flag"
+ "fmt"
+ "log"
+ "os"
+ "os/exec"
+ "os/signal"
+ "path/filepath"
+ "regexp"
+ "syscall"
+)
+
+// TestRunner is an interface that must be implemented for each runtime
+// integrated with proctor.
+type TestRunner interface {
+ // ListTests returns a string slice of tests available to run.
+ ListTests() ([]string, error)
+
+ // TestCmd returns an *exec.Cmd that will run the given test.
+ TestCmd(test string) *exec.Cmd
+}
+
+var (
+ runtime = flag.String("runtime", "", "name of runtime")
+ list = flag.Bool("list", false, "list all available tests")
+ test = flag.String("test", "", "run a single test from the list of available tests")
+ pause = flag.Bool("pause", false, "cause container to pause indefinitely, reaping any zombie children")
+)
+
+func main() {
+ flag.Parse()
+
+ if *pause {
+ pauseAndReap()
+ panic("pauseAndReap should never return")
+ }
+
+ if *runtime == "" {
+ log.Fatalf("runtime flag must be provided")
+ }
+
+ tr, err := testRunnerForRuntime(*runtime)
+ if err != nil {
+ log.Fatalf("%v", err)
+ }
+
+ // List tests.
+ if *list {
+ tests, err := tr.ListTests()
+ if err != nil {
+ log.Fatalf("failed to list tests: %v", err)
+ }
+ for _, test := range tests {
+ fmt.Println(test)
+ }
+ return
+ }
+
+ // Run a single test.
+ if *test == "" {
+ log.Fatalf("test flag must be provided")
+ }
+ cmd := tr.TestCmd(*test)
+ cmd.Stdout, cmd.Stderr = os.Stdout, os.Stderr
+ if err := cmd.Run(); err != nil {
+ log.Fatalf("FAIL: %v", err)
+ }
+}
+
+// testRunnerForRuntime returns a new TestRunner for the given runtime.
+func testRunnerForRuntime(runtime string) (TestRunner, error) {
+ switch runtime {
+ case "go":
+ return goRunner{}, nil
+ case "java":
+ return javaRunner{}, nil
+ case "nodejs":
+ return nodejsRunner{}, nil
+ case "php":
+ return phpRunner{}, nil
+ case "python":
+ return pythonRunner{}, nil
+ }
+ return nil, fmt.Errorf("invalid runtime %q", runtime)
+}
+
+// pauseAndReap is like init. It runs forever and reaps any children.
+func pauseAndReap() {
+ // Get notified of any new children.
+ ch := make(chan os.Signal, 1)
+ signal.Notify(ch, syscall.SIGCHLD)
+
+ for {
+ if _, ok := <-ch; !ok {
+ // Channel closed. This should not happen.
+ panic("signal channel closed")
+ }
+
+ // Reap the child.
+ for {
+ if cpid, _ := syscall.Wait4(-1, nil, syscall.WNOHANG, nil); cpid < 1 {
+ break
+ }
+ }
+ }
+}
+
+// search is a helper function to find tests in the given directory that match
+// the regex.
+func search(root string, testFilter *regexp.Regexp) ([]string, error) {
+ var testSlice []string
+
+ err := filepath.Walk(root, func(path string, info os.FileInfo, err error) error {
+ if err != nil {
+ return err
+ }
+
+ name := filepath.Base(path)
+
+ if info.IsDir() || !testFilter.MatchString(name) {
+ return nil
+ }
+
+ relPath, err := filepath.Rel(root, path)
+ if err != nil {
+ return err
+ }
+ testSlice = append(testSlice, relPath)
+ return nil
+ })
+ if err != nil {
+ return nil, fmt.Errorf("walking %q: %v", root, err)
+ }
+
+ return testSlice, nil
+}
diff --git a/test/runtimes/common/common_test.go b/test/runtimes/images/proctor/proctor_test.go
index 65875b41b..6bb61d142 100644
--- a/test/runtimes/common/common_test.go
+++ b/test/runtimes/images/proctor/proctor_test.go
@@ -12,7 +12,7 @@
// See the License for the specific language governing permissions and
// limitations under the License.
-package common_test
+package main
import (
"io/ioutil"
@@ -24,7 +24,6 @@ import (
"testing"
"gvisor.dev/gvisor/runsc/testutil"
- "gvisor.dev/gvisor/test/runtimes/common"
)
func touch(t *testing.T, name string) {
@@ -48,9 +47,9 @@ func TestSearchEmptyDir(t *testing.T) {
var want []string
testFilter := regexp.MustCompile(`^test-[^-].+\.tc$`)
- got, err := common.Search(td, testFilter)
+ got, err := search(td, testFilter)
if err != nil {
- t.Errorf("Search error: %v", err)
+ t.Errorf("search error: %v", err)
}
if !reflect.DeepEqual(got, want) {
@@ -117,9 +116,9 @@ func TestSearch(t *testing.T) {
}
testFilter := regexp.MustCompile(`^test-[^-].+\.tc$`)
- got, err := common.Search(td, testFilter)
+ got, err := search(td, testFilter)
if err != nil {
- t.Errorf("Search error: %v", err)
+ t.Errorf("search error: %v", err)
}
if !reflect.DeepEqual(got, want) {
diff --git a/test/runtimes/python/proctor-python.go b/test/runtimes/images/proctor/python.go
index 35e28a7df..b9e0fbe6f 100644
--- a/test/runtimes/python/proctor-python.go
+++ b/test/runtimes/images/proctor/python.go
@@ -12,36 +12,24 @@
// See the License for the specific language governing permissions and
// limitations under the License.
-// Binary proctor-python is a utility that facilitates language testing for Pyhton.
package main
import (
"fmt"
- "log"
"os"
"os/exec"
- "path/filepath"
"strings"
-
- "gvisor.dev/gvisor/test/runtimes/common"
)
-var (
- dir = os.Getenv("LANG_DIR")
-)
+// pythonRunner implements TestRunner for Python.
+type pythonRunner struct{}
-type pythonRunner struct {
-}
+var _ TestRunner = pythonRunner{}
-func main() {
- if err := common.LaunchFunc(pythonRunner{}); err != nil {
- log.Fatalf("Failed to start: %v", err)
- }
-}
-
-func (p pythonRunner) ListTests() ([]string, error) {
+// ListTests implements TestRunner.ListTests.
+func (pythonRunner) ListTests() ([]string, error) {
args := []string{"-m", "test", "--list-tests"}
- cmd := exec.Command(filepath.Join(dir, "python"), args...)
+ cmd := exec.Command("./python", args...)
cmd.Stderr = os.Stderr
out, err := cmd.Output()
if err != nil {
@@ -54,12 +42,8 @@ func (p pythonRunner) ListTests() ([]string, error) {
return toolSlice, nil
}
-func (p pythonRunner) RunTest(test string) error {
+// TestCmd implements TestRunner.TestCmd.
+func (pythonRunner) TestCmd(test string) *exec.Cmd {
args := []string{"-m", "test", test}
- cmd := exec.Command(filepath.Join(dir, "python"), args...)
- cmd.Stdout, cmd.Stderr = os.Stdout, os.Stderr
- if err := cmd.Run(); err != nil {
- return fmt.Errorf("failed to run: %v", err)
- }
- return nil
+ return exec.Command("./python", args...)
}
diff --git a/test/runtimes/java/BUILD b/test/runtimes/java/BUILD
deleted file mode 100644
index 8c39d39ec..000000000
--- a/test/runtimes/java/BUILD
+++ /dev/null
@@ -1,9 +0,0 @@
-load("@io_bazel_rules_go//go:def.bzl", "go_binary")
-
-package(licenses = ["notice"])
-
-go_binary(
- name = "proctor-java",
- srcs = ["proctor-java.go"],
- deps = ["//test/runtimes/common"],
-)
diff --git a/test/runtimes/java/Dockerfile b/test/runtimes/java/Dockerfile
deleted file mode 100644
index 1a61d9d8f..000000000
--- a/test/runtimes/java/Dockerfile
+++ /dev/null
@@ -1,36 +0,0 @@
-FROM ubuntu:bionic
-# This hash is associated with a specific JDK release and needed for ensuring
-# the same version is downloaded every time.
-ENV LANG_HASH=76072a077ee1
-ENV LANG_VER=11
-ENV LANG_NAME=Java
-
-RUN apt-get update && apt-get install -y \
- autoconf \
- build-essential \
- curl \
- make \
- openjdk-${LANG_VER}-jdk \
- unzip \
- zip
-
-WORKDIR /root
-RUN curl -o go.tar.gz https://dl.google.com/go/go1.12.6.linux-amd64.tar.gz
-RUN tar -zxf go.tar.gz
-
-# Download the JDK test library.
-RUN set -ex \
- && curl -fsSL --retry 10 -o /tmp/jdktests.tar.gz http://hg.openjdk.java.net/jdk/jdk${LANG_VER}/archive/${LANG_HASH}.tar.gz/test \
- && tar -xzf /tmp/jdktests.tar.gz -C /root \
- && rm -f /tmp/jdktests.tar.gz
-
-RUN curl -o jtreg.tar.gz https://ci.adoptopenjdk.net/view/Dependencies/job/jtreg/lastSuccessfulBuild/artifact/jtreg-4.2.0-tip.tar.gz
-RUN tar -xzf jtreg.tar.gz
-
-ENV LANG_DIR=/root
-
-COPY common /root/go/src/gvisor.dev/gvisor/test/runtimes/common/common
-COPY java/proctor-java.go ${LANG_DIR}
-RUN ["/root/go/bin/go", "build", "-o", "/root/go/bin/proctor", "proctor-java.go"]
-
-ENTRYPOINT ["/root/go/bin/proctor"]
diff --git a/test/runtimes/nodejs/BUILD b/test/runtimes/nodejs/BUILD
deleted file mode 100644
index 0594c250b..000000000
--- a/test/runtimes/nodejs/BUILD
+++ /dev/null
@@ -1,9 +0,0 @@
-load("@io_bazel_rules_go//go:def.bzl", "go_binary")
-
-package(licenses = ["notice"])
-
-go_binary(
- name = "proctor-nodejs",
- srcs = ["proctor-nodejs.go"],
- deps = ["//test/runtimes/common"],
-)
diff --git a/test/runtimes/nodejs/Dockerfile b/test/runtimes/nodejs/Dockerfile
deleted file mode 100644
index ce2943af8..000000000
--- a/test/runtimes/nodejs/Dockerfile
+++ /dev/null
@@ -1,31 +0,0 @@
-FROM ubuntu:bionic
-ENV LANG_VER=12.4.0
-ENV LANG_NAME=Node
-
-RUN apt-get update && apt-get install -y \
- curl \
- dumb-init \
- g++ \
- make \
- python
-
-WORKDIR /root
-RUN curl -o go.tar.gz https://dl.google.com/go/go1.12.6.linux-amd64.tar.gz
-RUN tar -zxf go.tar.gz
-
-RUN curl -o node-v${LANG_VER}.tar.gz https://nodejs.org/dist/v${LANG_VER}/node-v${LANG_VER}.tar.gz
-RUN tar -zxf node-v${LANG_VER}.tar.gz
-ENV LANG_DIR=/root/node-v${LANG_VER}
-
-WORKDIR ${LANG_DIR}
-RUN ./configure
-RUN make
-RUN make test-build
-
-COPY common /root/go/src/gvisor.dev/gvisor/test/runtimes/common/common
-COPY nodejs/proctor-nodejs.go ${LANG_DIR}
-RUN ["/root/go/bin/go", "build", "-o", "/root/go/bin/proctor", "proctor-nodejs.go"]
-
-# Including dumb-init emulates the Linux "init" process, preventing the failure
-# of tests involving worker processes.
-ENTRYPOINT ["/usr/bin/dumb-init", "/root/go/bin/proctor"]
diff --git a/test/runtimes/nodejs/proctor-nodejs.go b/test/runtimes/nodejs/proctor-nodejs.go
deleted file mode 100644
index 0624f6a0d..000000000
--- a/test/runtimes/nodejs/proctor-nodejs.go
+++ /dev/null
@@ -1,60 +0,0 @@
-// Copyright 2019 The gVisor Authors.
-//
-// Licensed under the Apache License, Version 2.0 (the "License");
-// you may not use this file except in compliance with the License.
-// You may obtain a copy of the License at
-//
-// http://www.apache.org/licenses/LICENSE-2.0
-//
-// Unless required by applicable law or agreed to in writing, software
-// distributed under the License is distributed on an "AS IS" BASIS,
-// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-// See the License for the specific language governing permissions and
-// limitations under the License.
-
-// Binary proctor-nodejs is a utility that facilitates language testing for NodeJS.
-package main
-
-import (
- "fmt"
- "log"
- "os"
- "os/exec"
- "path/filepath"
- "regexp"
-
- "gvisor.dev/gvisor/test/runtimes/common"
-)
-
-var (
- dir = os.Getenv("LANG_DIR")
- testDir = filepath.Join(dir, "test")
- testRegEx = regexp.MustCompile(`^test-[^-].+\.js$`)
-)
-
-type nodejsRunner struct {
-}
-
-func main() {
- if err := common.LaunchFunc(nodejsRunner{}); err != nil {
- log.Fatalf("Failed to start: %v", err)
- }
-}
-
-func (n nodejsRunner) ListTests() ([]string, error) {
- testSlice, err := common.Search(testDir, testRegEx)
- if err != nil {
- return nil, err
- }
- return testSlice, nil
-}
-
-func (n nodejsRunner) RunTest(test string) error {
- args := []string{filepath.Join(dir, "tools", "test.py"), test}
- cmd := exec.Command("/usr/bin/python", args...)
- cmd.Stdout, cmd.Stderr = os.Stdout, os.Stderr
- if err := cmd.Run(); err != nil {
- return fmt.Errorf("failed to run: %v", err)
- }
- return nil
-}
diff --git a/test/runtimes/php/BUILD b/test/runtimes/php/BUILD
deleted file mode 100644
index 31799b77a..000000000
--- a/test/runtimes/php/BUILD
+++ /dev/null
@@ -1,9 +0,0 @@
-load("@io_bazel_rules_go//go:def.bzl", "go_binary")
-
-package(licenses = ["notice"])
-
-go_binary(
- name = "proctor-php",
- srcs = ["proctor-php.go"],
- deps = ["//test/runtimes/common"],
-)
diff --git a/test/runtimes/php/Dockerfile b/test/runtimes/php/Dockerfile
deleted file mode 100644
index d79babe58..000000000
--- a/test/runtimes/php/Dockerfile
+++ /dev/null
@@ -1,31 +0,0 @@
-FROM ubuntu:bionic
-ENV LANG_VER=7.3.6
-ENV LANG_NAME=PHP
-
-RUN apt-get update && apt-get install -y \
- autoconf \
- automake \
- bison \
- build-essential \
- curl \
- libtool \
- libxml2-dev \
- re2c
-
-WORKDIR /root
-RUN curl -o go.tar.gz https://dl.google.com/go/go1.12.6.linux-amd64.tar.gz
-RUN tar -zxf go.tar.gz
-
-RUN curl -o php-${LANG_VER}.tar.gz https://www.php.net/distributions/php-${LANG_VER}.tar.gz
-RUN tar -zxf php-${LANG_VER}.tar.gz
-ENV LANG_DIR=/root/php-${LANG_VER}
-
-WORKDIR ${LANG_DIR}
-RUN ./configure
-RUN make
-
-COPY common /root/go/src/gvisor.dev/gvisor/test/runtimes/common/common
-COPY php/proctor-php.go ${LANG_DIR}
-RUN ["/root/go/bin/go", "build", "-o", "/root/go/bin/proctor", "proctor-php.go"]
-
-ENTRYPOINT ["/root/go/bin/proctor"]
diff --git a/test/runtimes/python/BUILD b/test/runtimes/python/BUILD
deleted file mode 100644
index 37fd6a0f2..000000000
--- a/test/runtimes/python/BUILD
+++ /dev/null
@@ -1,9 +0,0 @@
-load("@io_bazel_rules_go//go:def.bzl", "go_binary")
-
-package(licenses = ["notice"])
-
-go_binary(
- name = "proctor-python",
- srcs = ["proctor-python.go"],
- deps = ["//test/runtimes/common"],
-)
diff --git a/test/runtimes/python/Dockerfile b/test/runtimes/python/Dockerfile
deleted file mode 100644
index 5ae328890..000000000
--- a/test/runtimes/python/Dockerfile
+++ /dev/null
@@ -1,33 +0,0 @@
-FROM ubuntu:bionic
-ENV LANG_VER=3.7.3
-ENV LANG_NAME=Python
-
-RUN apt-get update && apt-get install -y \
- curl \
- gcc \
- libbz2-dev \
- libffi-dev \
- liblzma-dev \
- libreadline-dev \
- libssl-dev \
- make \
- zlib1g-dev
-
-WORKDIR /root
-RUN curl -o go.tar.gz https://dl.google.com/go/go1.12.6.linux-amd64.tar.gz
-RUN tar -zxf go.tar.gz
-
-# Use flags -LJO to follow the html redirect and download .tar.gz.
-RUN curl -LJO https://github.com/python/cpython/archive/v${LANG_VER}.tar.gz
-RUN tar -zxf cpython-${LANG_VER}.tar.gz
-ENV LANG_DIR=/root/cpython-${LANG_VER}
-
-WORKDIR ${LANG_DIR}
-RUN ./configure --with-pydebug
-RUN make -s -j2
-
-COPY common /root/go/src/gvisor.dev/gvisor/test/runtimes/common/common
-COPY python/proctor-python.go ${LANG_DIR}
-RUN ["/root/go/bin/go", "build", "-o", "/root/go/bin/proctor", "proctor-python.go"]
-
-ENTRYPOINT ["/root/go/bin/proctor"]
diff --git a/test/runtimes/runner.go b/test/runtimes/runner.go
new file mode 100644
index 000000000..3a15f59a7
--- /dev/null
+++ b/test/runtimes/runner.go
@@ -0,0 +1,147 @@
+// Copyright 2019 The gVisor Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+// Binary runner runs the runtime tests in a Docker container.
+package main
+
+import (
+ "flag"
+ "fmt"
+ "io"
+ "log"
+ "os"
+ "sort"
+ "strings"
+ "testing"
+ "time"
+
+ "gvisor.dev/gvisor/runsc/dockerutil"
+ "gvisor.dev/gvisor/runsc/testutil"
+)
+
+var (
+ lang = flag.String("lang", "", "language runtime to test")
+ image = flag.String("image", "", "docker image with runtime tests")
+)
+
+// Wait time for each test to run.
+const timeout = 5 * time.Minute
+
+func main() {
+ flag.Parse()
+ if *lang == "" || *image == "" {
+ fmt.Fprintf(os.Stderr, "lang and image flags must not be empty\n")
+ os.Exit(1)
+ }
+
+ os.Exit(runTests())
+}
+
+// runTests is a helper that is called by main. It exists so that we can run
+// defered functions before exiting. It returns an exit code that should be
+// passed to os.Exit.
+func runTests() int {
+ // Create a single docker container that will be used for all tests.
+ d := dockerutil.MakeDocker("gvisor-" + *lang)
+ defer d.CleanUp()
+
+ // Get a slice of tests to run. This will also start a single Docker
+ // container that will be used to run each test. The final test will
+ // stop the Docker container.
+ tests, err := getTests(d)
+ if err != nil {
+ fmt.Fprintf(os.Stderr, "%s\n", err.Error())
+ return 1
+ }
+
+ m := testing.MainStart(testDeps{}, tests, nil, nil)
+ return m.Run()
+}
+
+// getTests returns a slice of tests to run, subject to the shard size and
+// index.
+func getTests(d dockerutil.Docker) ([]testing.InternalTest, error) {
+ // Pull the image.
+ if err := dockerutil.Pull(*image); err != nil {
+ return nil, fmt.Errorf("docker pull %q failed: %v", *image, err)
+ }
+
+ // Run proctor with --pause flag to keep container alive forever.
+ if err := d.Run(*image, "--pause"); err != nil {
+ return nil, fmt.Errorf("docker run failed: %v", err)
+ }
+
+ // Get a list of all tests in the image.
+ list, err := d.Exec("/proctor", "--runtime", *lang, "--list")
+ if err != nil {
+ return nil, fmt.Errorf("docker exec failed: %v", err)
+ }
+
+ // Calculate a subset of tests to run corresponding to the current
+ // shard.
+ tests := strings.Fields(list)
+ sort.Strings(tests)
+ begin, end, err := testutil.TestBoundsForShard(len(tests))
+ if err != nil {
+ return nil, fmt.Errorf("TestsForShard() failed: %v", err)
+ }
+ log.Printf("Got bounds [%d:%d) for shard out of %d total tests", begin, end, len(tests))
+ tests = tests[begin:end]
+
+ var itests []testing.InternalTest
+ for _, tc := range tests {
+ // Capture tc in this scope.
+ tc := tc
+ itests = append(itests, testing.InternalTest{
+ Name: tc,
+ F: func(t *testing.T) {
+ var (
+ now = time.Now()
+ done = make(chan struct{})
+ output string
+ err error
+ )
+ go func() {
+ fmt.Printf("RUNNING %s...\n", tc)
+ output, err = d.Exec("/proctor", "--runtime", *lang, "--test", tc)
+ close(done)
+ }()
+
+ select {
+ case <-done:
+ if err == nil {
+ fmt.Printf("PASS: %s (%v)\n\n", tc, time.Since(now))
+ return
+ }
+ t.Errorf("FAIL: %s (%v):\n%s\n", tc, time.Since(now), output)
+ case <-time.After(timeout):
+ t.Errorf("TIMEOUT: %s (%v):\n%s\n", tc, time.Since(now), output)
+ }
+ },
+ })
+ }
+ return itests, nil
+}
+
+// testDeps implements testing.testDeps (an unexported interface), and is
+// required to use testing.MainStart.
+type testDeps struct{}
+
+func (f testDeps) MatchString(a, b string) (bool, error) { return a == b, nil }
+func (f testDeps) StartCPUProfile(io.Writer) error { return nil }
+func (f testDeps) StopCPUProfile() {}
+func (f testDeps) WriteProfileTo(string, io.Writer, int) error { return nil }
+func (f testDeps) ImportPath() string { return "" }
+func (f testDeps) StartTestLog(io.Writer) {}
+func (f testDeps) StopTestLog() error { return nil }
diff --git a/kokoro/run_tests.sh b/test/runtimes/runner.sh
index 5552da11c..a8d9a3460 100644..100755
--- a/kokoro/run_tests.sh
+++ b/test/runtimes/runner.sh
@@ -1,6 +1,6 @@
#!/bin/bash
-# Copyright 2019 The gVisor Authors.
+# Copyright 2018 The gVisor Authors.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
@@ -14,16 +14,22 @@
# See the License for the specific language governing permissions and
# limitations under the License.
-set -xeo pipefail
+set -euf -x -o pipefail
-# This file is a temporary bridge. We will create multiple independent Kokoro
-# workflows that call each of the test scripts independently.
+echo -- "$@"
+
+# Create outputs dir if it does not exist.
+if [[ -n "${TEST_UNDECLARED_OUTPUTS_DIR}" ]]; then
+ mkdir -p "${TEST_UNDECLARED_OUTPUTS_DIR}"
+ chmod a+rwx "${TEST_UNDECLARED_OUTPUTS_DIR}"
+fi
+
+# Update the timestamp on the shard status file. Bazel looks for this.
+touch "${TEST_SHARD_STATUS_FILE}"
+
+# Get location of runner binary.
+readonly runner=$(find "${TEST_SRCDIR}" -name runner)
+
+# Pass the arguments of this script directly to the runner.
+exec "${runner}" "$@"
-# Run all the tests in sequence.
-$(dirname $0)/../scripts/do_tests.sh
-$(dirname $0)/../scripts/make_tests.sh
-$(dirname $0)/../scripts/root_tests.sh
-$(dirname $0)/../scripts/docker_tests.sh
-$(dirname $0)/../scripts/overlay_tests.sh
-$(dirname $0)/../scripts/hostnet_tests.sh
-$(dirname $0)/../scripts/simple_tests.sh
diff --git a/test/runtimes/runtimes_test.go b/test/runtimes/runtimes_test.go
deleted file mode 100644
index 0ff5dda02..000000000
--- a/test/runtimes/runtimes_test.go
+++ /dev/null
@@ -1,93 +0,0 @@
-// Copyright 2019 The gVisor Authors.
-//
-// Licensed under the Apache License, Version 2.0 (the "License");
-// you may not use this file except in compliance with the License.
-// You may obtain a copy of the License at
-//
-// http://www.apache.org/licenses/LICENSE-2.0
-//
-// Unless required by applicable law or agreed to in writing, software
-// distributed under the License is distributed on an "AS IS" BASIS,
-// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-// See the License for the specific language governing permissions and
-// limitations under the License.
-
-package runtimes
-
-import (
- "strings"
- "testing"
- "time"
-
- "gvisor.dev/gvisor/runsc/testutil"
-)
-
-// Wait time for each test to run.
-const timeout = 180 * time.Second
-
-// Helper function to execute the docker container associated with the
-// language passed. Captures the output of the list function and executes
-// each test individually, supplying any errors recieved.
-func testLang(t *testing.T, lang string) {
- t.Helper()
-
- img := "gcr.io/gvisor-presubmit/" + lang
- if err := testutil.Pull(img); err != nil {
- t.Fatalf("docker pull failed: %v", err)
- }
-
- c := testutil.MakeDocker("gvisor-list")
-
- list, err := c.RunFg(img, "--list")
- if err != nil {
- t.Fatalf("docker run failed: %v", err)
- }
- c.CleanUp()
-
- tests := strings.Fields(list)
-
- for _, tc := range tests {
- tc := tc
- t.Run(tc, func(t *testing.T) {
- d := testutil.MakeDocker("gvisor-test")
- if err := d.Run(img, "--test", tc); err != nil {
- t.Fatalf("docker test %q failed to run: %v", tc, err)
- }
- defer d.CleanUp()
-
- status, err := d.Wait(timeout)
- if err != nil {
- t.Fatalf("docker test %q failed to wait: %v", tc, err)
- }
- if status == 0 {
- t.Logf("test %q passed", tc)
- return
- }
- logs, err := d.Logs()
- if err != nil {
- t.Fatalf("docker test %q failed to supply logs: %v", tc, err)
- }
- t.Errorf("test %q failed: %v", tc, logs)
- })
- }
-}
-
-func TestGo(t *testing.T) {
- testLang(t, "go")
-}
-
-func TestJava(t *testing.T) {
- testLang(t, "java")
-}
-
-func TestNodejs(t *testing.T) {
- testLang(t, "nodejs")
-}
-
-func TestPhp(t *testing.T) {
- testLang(t, "php")
-}
-
-func TestPython(t *testing.T) {
- testLang(t, "python")
-}
diff --git a/test/syscalls/BUILD b/test/syscalls/BUILD
index 0135435ea..63e4c63dd 100644
--- a/test/syscalls/BUILD
+++ b/test/syscalls/BUILD
@@ -321,6 +321,10 @@ syscall_test(
)
syscall_test(
+ test = "//test/syscalls/linux:pty_root_test",
+)
+
+syscall_test(
add_overlay = True,
test = "//test/syscalls/linux:pwritev2_test",
)
diff --git a/test/syscalls/linux/BUILD b/test/syscalls/linux/BUILD
index a964bef24..a4cebf46f 100644
--- a/test/syscalls/linux/BUILD
+++ b/test/syscalls/linux/BUILD
@@ -1,3 +1,4 @@
+load("@rules_cc//cc:defs.bzl", "cc_binary", "cc_library")
load("//test/syscalls:build_defs.bzl", "select_for_linux")
package(
@@ -265,12 +266,15 @@ cc_binary(
],
linkstatic = 1,
deps = [
- # The heap check doesn't handle mremap properly.
+ # The heapchecker doesn't recognize that io_destroy munmaps.
"@com_google_googletest//:gtest",
"@com_google_absl//absl/strings",
"//test/util:cleanup",
"//test/util:file_descriptor",
+ "//test/util:fs_util",
+ "//test/util:memory_util",
"//test/util:posix_error",
+ "//test/util:proc_util",
"//test/util:temp_path",
"//test/util:test_main",
"//test/util:test_util",
@@ -390,6 +394,7 @@ cc_binary(
"//test/util:test_main",
"//test/util:test_util",
"//test/util:thread_util",
+ "@com_google_absl//absl/flags:flag",
"@com_google_absl//absl/synchronization",
"@com_google_googletest//:gtest",
],
@@ -408,6 +413,7 @@ cc_binary(
"//test/util:test_main",
"//test/util:test_util",
"//test/util:thread_util",
+ "@com_google_absl//absl/flags:flag",
"@com_google_googletest//:gtest",
],
)
@@ -724,6 +730,7 @@ cc_binary(
"//test/util:test_util",
"//test/util:timer_util",
"@com_google_absl//absl/base:core_headers",
+ "@com_google_absl//absl/flags:flag",
"@com_google_absl//absl/memory",
"@com_google_absl//absl/strings",
"@com_google_absl//absl/time",
@@ -972,6 +979,7 @@ cc_binary(
"//test/util:test_main",
"//test/util:test_util",
"//test/util:thread_util",
+ "@com_google_absl//absl/flags:flag",
"@com_google_absl//absl/synchronization",
"@com_google_absl//absl/time",
"@com_google_googletest//:gtest",
@@ -992,6 +1000,7 @@ cc_binary(
"//test/util:test_main",
"//test/util:test_util",
"//test/util:thread_util",
+ "@com_google_absl//absl/flags:flag",
"@com_google_absl//absl/strings",
"@com_google_googletest//:gtest",
],
@@ -1278,8 +1287,10 @@ cc_binary(
srcs = ["pty.cc"],
linkstatic = 1,
deps = [
+ "//test/util:capability_util",
"//test/util:file_descriptor",
"//test/util:posix_error",
+ "//test/util:pty_util",
"//test/util:test_main",
"//test/util:test_util",
"//test/util:thread_util",
@@ -1292,6 +1303,23 @@ cc_binary(
)
cc_binary(
+ name = "pty_root_test",
+ testonly = 1,
+ srcs = ["pty_root.cc"],
+ linkstatic = 1,
+ deps = [
+ "//test/util:capability_util",
+ "//test/util:file_descriptor",
+ "//test/util:posix_error",
+ "//test/util:pty_util",
+ "//test/util:test_main",
+ "//test/util:thread_util",
+ "@com_google_absl//absl/base:core_headers",
+ "@com_google_googletest//:gtest",
+ ],
+)
+
+cc_binary(
name = "partial_bad_buffer_test",
testonly = 1,
srcs = ["partial_bad_buffer.cc"],
@@ -1402,6 +1430,7 @@ cc_binary(
"//test/util:posix_error",
"//test/util:test_util",
"//test/util:thread_util",
+ "@com_google_absl//absl/flags:flag",
"@com_google_googletest//:gtest",
],
)
@@ -1418,6 +1447,7 @@ cc_binary(
"//test/util:posix_error",
"//test/util:test_util",
"//test/util:thread_util",
+ "@com_google_absl//absl/flags:flag",
"@com_google_googletest//:gtest",
],
)
@@ -1601,6 +1631,7 @@ cc_binary(
"//test/util:test_util",
"//test/util:thread_util",
"//test/util:time_util",
+ "@com_google_absl//absl/flags:flag",
"@com_google_absl//absl/time",
"@com_google_googletest//:gtest",
],
@@ -1863,7 +1894,9 @@ cc_binary(
"//test/util:temp_path",
"//test/util:test_main",
"//test/util:test_util",
+ "//test/util:thread_util",
"@com_google_absl//absl/strings",
+ "@com_google_absl//absl/time",
"@com_google_googletest//:gtest",
],
)
@@ -1897,6 +1930,7 @@ cc_binary(
"//test/util:test_util",
"//test/util:thread_util",
"@com_google_absl//absl/strings",
+ "@com_google_absl//absl/time",
"@com_google_googletest//:gtest",
],
)
@@ -1949,6 +1983,24 @@ cc_binary(
)
cc_binary(
+ name = "signalfd_test",
+ testonly = 1,
+ srcs = ["signalfd.cc"],
+ linkstatic = 1,
+ deps = [
+ "//test/util:file_descriptor",
+ "//test/util:logging",
+ "//test/util:posix_error",
+ "//test/util:signal_util",
+ "//test/util:test_main",
+ "//test/util:test_util",
+ "//test/util:thread_util",
+ "@com_google_absl//absl/synchronization",
+ "@com_google_googletest//:gtest",
+ ],
+)
+
+cc_binary(
name = "sigprocmask_test",
testonly = 1,
srcs = ["sigprocmask.cc"],
@@ -1971,6 +2023,7 @@ cc_binary(
"//test/util:posix_error",
"//test/util:test_util",
"//test/util:thread_util",
+ "@com_google_absl//absl/flags:flag",
"@com_google_absl//absl/time",
"@com_google_googletest//:gtest",
],
@@ -3106,6 +3159,7 @@ cc_binary(
"//test/util:signal_util",
"//test/util:test_util",
"//test/util:thread_util",
+ "@com_google_absl//absl/flags:flag",
"@com_google_absl//absl/time",
"@com_google_googletest//:gtest",
],
@@ -3185,6 +3239,7 @@ cc_binary(
"//test/util:test_main",
"//test/util:test_util",
"//test/util:thread_util",
+ "@com_google_absl//absl/flags:flag",
"@com_google_absl//absl/strings",
"@com_google_googletest//:gtest",
],
@@ -3276,6 +3331,7 @@ cc_binary(
"//test/util:multiprocess_util",
"//test/util:test_util",
"//test/util:time_util",
+ "@com_google_absl//absl/flags:flag",
"@com_google_absl//absl/time",
"@com_google_googletest//:gtest",
],
diff --git a/test/syscalls/linux/aio.cc b/test/syscalls/linux/aio.cc
index 68dc05417..b27d4e10a 100644
--- a/test/syscalls/linux/aio.cc
+++ b/test/syscalls/linux/aio.cc
@@ -14,31 +14,57 @@
#include <fcntl.h>
#include <linux/aio_abi.h>
-#include <string.h>
#include <sys/mman.h>
#include <sys/syscall.h>
#include <sys/types.h>
#include <unistd.h>
+#include <algorithm>
+#include <string>
+
#include "gtest/gtest.h"
#include "test/syscalls/linux/file_base.h"
#include "test/util/cleanup.h"
#include "test/util/file_descriptor.h"
+#include "test/util/fs_util.h"
+#include "test/util/memory_util.h"
+#include "test/util/posix_error.h"
+#include "test/util/proc_util.h"
#include "test/util/temp_path.h"
#include "test/util/test_util.h"
+using ::testing::_;
+
namespace gvisor {
namespace testing {
namespace {
+// Returns the size of the VMA containing the given address.
+PosixErrorOr<size_t> VmaSizeAt(uintptr_t addr) {
+ ASSIGN_OR_RETURN_ERRNO(std::string proc_self_maps,
+ GetContents("/proc/self/maps"));
+ ASSIGN_OR_RETURN_ERRNO(auto entries, ParseProcMaps(proc_self_maps));
+ // Use binary search to find the first VMA that might contain addr.
+ ProcMapsEntry target = {};
+ target.end = addr;
+ auto it =
+ std::upper_bound(entries.begin(), entries.end(), target,
+ [](const ProcMapsEntry& x, const ProcMapsEntry& y) {
+ return x.end < y.end;
+ });
+ // Check that it actually contains addr.
+ if (it == entries.end() || addr < it->start) {
+ return PosixError(ENOENT, absl::StrCat("no VMA contains address ", addr));
+ }
+ return it->end - it->start;
+}
+
constexpr char kData[] = "hello world!";
int SubmitCtx(aio_context_t ctx, long nr, struct iocb** iocbpp) {
return syscall(__NR_io_submit, ctx, nr, iocbpp);
}
-} // namespace
-
class AIOTest : public FileTest {
public:
AIOTest() : ctx_(0) {}
@@ -124,10 +150,10 @@ TEST_F(AIOTest, BasicWrite) {
EXPECT_EQ(events[0].res, strlen(kData));
// Verify that the file contains the contents.
- char verify_buf[32] = {};
- ASSERT_THAT(read(test_file_fd_.get(), &verify_buf[0], strlen(kData)),
- SyscallSucceeds());
- EXPECT_EQ(strcmp(kData, &verify_buf[0]), 0);
+ char verify_buf[sizeof(kData)] = {};
+ ASSERT_THAT(read(test_file_fd_.get(), verify_buf, sizeof(kData)),
+ SyscallSucceedsWithValue(strlen(kData)));
+ EXPECT_STREQ(verify_buf, kData);
}
TEST_F(AIOTest, BadWrite) {
@@ -220,38 +246,25 @@ TEST_F(AIOTest, CloneVm) {
TEST_F(AIOTest, Mremap) {
// Setup a context that is 128 entries deep.
ASSERT_THAT(SetupContext(128), SyscallSucceeds());
+ const size_t ctx_size =
+ ASSERT_NO_ERRNO_AND_VALUE(VmaSizeAt(reinterpret_cast<uintptr_t>(ctx_)));
struct iocb cb = CreateCallback();
struct iocb* cbs[1] = {&cb};
// Reserve address space for the mremap target so we have something safe to
// map over.
- //
- // N.B. We reserve 2 pages because we'll attempt to remap to 2 pages below.
- // That should fail with EFAULT, but will fail with EINVAL if this mmap
- // returns the page immediately below ctx_, as
- // [new_address, new_address+2*kPageSize) overlaps [ctx_, ctx_+kPageSize).
- void* new_address = mmap(nullptr, 2 * kPageSize, PROT_READ,
- MAP_PRIVATE | MAP_ANONYMOUS, -1, 0);
- ASSERT_THAT(reinterpret_cast<intptr_t>(new_address), SyscallSucceeds());
- auto mmap_cleanup = Cleanup([new_address] {
- EXPECT_THAT(munmap(new_address, 2 * kPageSize), SyscallSucceeds());
- });
-
- // Test that remapping to a larger address fails.
- void* res = mremap(reinterpret_cast<void*>(ctx_), kPageSize, 2 * kPageSize,
- MREMAP_FIXED | MREMAP_MAYMOVE, new_address);
- ASSERT_THAT(reinterpret_cast<intptr_t>(res), SyscallFailsWithErrno(EFAULT));
+ Mapping dst =
+ ASSERT_NO_ERRNO_AND_VALUE(MmapAnon(ctx_size, PROT_READ, MAP_PRIVATE));
// Remap context 'handle' to a different address.
- res = mremap(reinterpret_cast<void*>(ctx_), kPageSize, kPageSize,
- MREMAP_FIXED | MREMAP_MAYMOVE, new_address);
- ASSERT_THAT(
- reinterpret_cast<intptr_t>(res),
- SyscallSucceedsWithValue(reinterpret_cast<intptr_t>(new_address)));
- mmap_cleanup.Release();
+ ASSERT_THAT(Mremap(reinterpret_cast<void*>(ctx_), ctx_size, dst.len(),
+ MREMAP_FIXED | MREMAP_MAYMOVE, dst.ptr()),
+ IsPosixErrorOkAndHolds(dst.ptr()));
aio_context_t old_ctx = ctx_;
- ctx_ = reinterpret_cast<aio_context_t>(new_address);
+ ctx_ = reinterpret_cast<aio_context_t>(dst.addr());
+ // io_destroy() will unmap dst now.
+ dst.release();
// Check that submitting the request with the old 'ctx_' fails.
ASSERT_THAT(SubmitCtx(old_ctx, 1, cbs), SyscallFailsWithErrno(EINVAL));
@@ -260,18 +273,12 @@ TEST_F(AIOTest, Mremap) {
ASSERT_THAT(Submit(1, cbs), SyscallSucceedsWithValue(1));
// Remap again.
- new_address =
- mmap(nullptr, kPageSize, PROT_READ, MAP_PRIVATE | MAP_ANONYMOUS, -1, 0);
- ASSERT_THAT(reinterpret_cast<int64_t>(new_address), SyscallSucceeds());
- auto mmap_cleanup2 = Cleanup([new_address] {
- EXPECT_THAT(munmap(new_address, kPageSize), SyscallSucceeds());
- });
- res = mremap(reinterpret_cast<void*>(ctx_), kPageSize, kPageSize,
- MREMAP_FIXED | MREMAP_MAYMOVE, new_address);
- ASSERT_THAT(reinterpret_cast<int64_t>(res),
- SyscallSucceedsWithValue(reinterpret_cast<int64_t>(new_address)));
- mmap_cleanup2.Release();
- ctx_ = reinterpret_cast<aio_context_t>(new_address);
+ dst = ASSERT_NO_ERRNO_AND_VALUE(MmapAnon(ctx_size, PROT_READ, MAP_PRIVATE));
+ ASSERT_THAT(Mremap(reinterpret_cast<void*>(ctx_), ctx_size, dst.len(),
+ MREMAP_FIXED | MREMAP_MAYMOVE, dst.ptr()),
+ IsPosixErrorOkAndHolds(dst.ptr()));
+ ctx_ = reinterpret_cast<aio_context_t>(dst.addr());
+ dst.release();
// Get the reply with yet another 'ctx_' and verify it.
struct io_event events[1];
@@ -281,51 +288,33 @@ TEST_F(AIOTest, Mremap) {
EXPECT_EQ(events[0].res, strlen(kData));
// Verify that the file contains the contents.
- char verify_buf[32] = {};
- ASSERT_THAT(read(test_file_fd_.get(), &verify_buf[0], strlen(kData)),
- SyscallSucceeds());
- EXPECT_EQ(strcmp(kData, &verify_buf[0]), 0);
+ char verify_buf[sizeof(kData)] = {};
+ ASSERT_THAT(read(test_file_fd_.get(), verify_buf, sizeof(kData)),
+ SyscallSucceedsWithValue(strlen(kData)));
+ EXPECT_STREQ(verify_buf, kData);
}
-// Tests that AIO context can be replaced with a different mapping at the same
-// address and continue working. Don't ask why, but Linux allows it.
-TEST_F(AIOTest, MremapOver) {
+// Tests that AIO context cannot be expanded with mremap.
+TEST_F(AIOTest, MremapExpansion) {
// Setup a context that is 128 entries deep.
ASSERT_THAT(SetupContext(128), SyscallSucceeds());
+ const size_t ctx_size =
+ ASSERT_NO_ERRNO_AND_VALUE(VmaSizeAt(reinterpret_cast<uintptr_t>(ctx_)));
- struct iocb cb = CreateCallback();
- struct iocb* cbs[1] = {&cb};
-
- ASSERT_THAT(Submit(1, cbs), SyscallSucceedsWithValue(1));
-
- // Allocate a new VMA, copy 'ctx_' content over, and remap it on top
- // of 'ctx_'.
- void* new_address = mmap(nullptr, kPageSize, PROT_READ | PROT_WRITE,
- MAP_PRIVATE | MAP_ANONYMOUS, -1, 0);
- ASSERT_THAT(reinterpret_cast<int64_t>(new_address), SyscallSucceeds());
- auto mmap_cleanup = Cleanup([new_address] {
- EXPECT_THAT(munmap(new_address, kPageSize), SyscallSucceeds());
- });
-
- memcpy(new_address, reinterpret_cast<void*>(ctx_), kPageSize);
- void* res =
- mremap(new_address, kPageSize, kPageSize, MREMAP_FIXED | MREMAP_MAYMOVE,
- reinterpret_cast<void*>(ctx_));
- ASSERT_THAT(reinterpret_cast<int64_t>(res), SyscallSucceedsWithValue(ctx_));
- mmap_cleanup.Release();
-
- // Everything continues to work just fine.
- struct io_event events[1];
- ASSERT_THAT(GetEvents(1, 1, events, nullptr), SyscallSucceedsWithValue(1));
- EXPECT_EQ(events[0].data, 0x123);
- EXPECT_EQ(events[0].obj, reinterpret_cast<long>(&cb));
- EXPECT_EQ(events[0].res, strlen(kData));
-
- // Verify that the file contains the contents.
- char verify_buf[32] = {};
- ASSERT_THAT(read(test_file_fd_.get(), &verify_buf[0], strlen(kData)),
- SyscallSucceeds());
- EXPECT_EQ(strcmp(kData, &verify_buf[0]), 0);
+ // Reserve address space for the mremap target so we have something safe to
+ // map over.
+ Mapping dst = ASSERT_NO_ERRNO_AND_VALUE(
+ MmapAnon(ctx_size + kPageSize, PROT_NONE, MAP_PRIVATE));
+
+ // Test that remapping to a larger address range fails.
+ ASSERT_THAT(Mremap(reinterpret_cast<void*>(ctx_), ctx_size, dst.len(),
+ MREMAP_FIXED | MREMAP_MAYMOVE, dst.ptr()),
+ PosixErrorIs(EFAULT, _));
+
+ // mm/mremap.c:sys_mremap() => mremap_to() does do_munmap() of the destination
+ // before it hits the VM_DONTEXPAND check in vma_to_resize(), so we should no
+ // longer munmap it (another thread may have created a mapping there).
+ dst.release();
}
// Tests that AIO calls fail if context's address is inaccessible.
@@ -429,5 +418,7 @@ TEST_P(AIOVectorizedParamTest, BadIOVecs) {
INSTANTIATE_TEST_SUITE_P(BadIOVecs, AIOVectorizedParamTest,
::testing::Values(IOCB_CMD_PREADV, IOCB_CMD_PWRITEV));
+} // namespace
+
} // namespace testing
} // namespace gvisor
diff --git a/test/syscalls/linux/chown.cc b/test/syscalls/linux/chown.cc
index 2e82f0b3a..7a28b674d 100644
--- a/test/syscalls/linux/chown.cc
+++ b/test/syscalls/linux/chown.cc
@@ -16,10 +16,12 @@
#include <grp.h>
#include <sys/types.h>
#include <unistd.h>
+
#include <vector>
#include "gmock/gmock.h"
#include "gtest/gtest.h"
+#include "absl/flags/flag.h"
#include "absl/synchronization/notification.h"
#include "test/util/capability_util.h"
#include "test/util/file_descriptor.h"
@@ -29,9 +31,9 @@
#include "test/util/test_util.h"
#include "test/util/thread_util.h"
-DEFINE_int32(scratch_uid1, 65534, "first scratch UID");
-DEFINE_int32(scratch_uid2, 65533, "second scratch UID");
-DEFINE_int32(scratch_gid, 65534, "first scratch GID");
+ABSL_FLAG(int32_t, scratch_uid1, 65534, "first scratch UID");
+ABSL_FLAG(int32_t, scratch_uid2, 65533, "second scratch UID");
+ABSL_FLAG(int32_t, scratch_gid, 65534, "first scratch GID");
namespace gvisor {
namespace testing {
@@ -100,10 +102,12 @@ TEST_P(ChownParamTest, ChownFilePermissionDenied) {
// Change EUID and EGID.
//
// See note about POSIX below.
- EXPECT_THAT(syscall(SYS_setresgid, -1, FLAGS_scratch_gid, -1),
- SyscallSucceeds());
- EXPECT_THAT(syscall(SYS_setresuid, -1, FLAGS_scratch_uid1, -1),
- SyscallSucceeds());
+ EXPECT_THAT(
+ syscall(SYS_setresgid, -1, absl::GetFlag(FLAGS_scratch_gid), -1),
+ SyscallSucceeds());
+ EXPECT_THAT(
+ syscall(SYS_setresuid, -1, absl::GetFlag(FLAGS_scratch_uid1), -1),
+ SyscallSucceeds());
EXPECT_THAT(GetParam()(file.path(), geteuid(), getegid()),
PosixErrorIs(EPERM, ::testing::ContainsRegex("chown")));
@@ -125,8 +129,9 @@ TEST_P(ChownParamTest, ChownFileSucceedsAsRoot) {
// setresuid syscall. However, we want this thread to have its own set of
// credentials different from the parent process, so we use the raw
// syscall.
- EXPECT_THAT(syscall(SYS_setresuid, -1, FLAGS_scratch_uid2, -1),
- SyscallSucceeds());
+ EXPECT_THAT(
+ syscall(SYS_setresuid, -1, absl::GetFlag(FLAGS_scratch_uid2), -1),
+ SyscallSucceeds());
// Create file and immediately close it.
FileDescriptor fd =
@@ -143,12 +148,13 @@ TEST_P(ChownParamTest, ChownFileSucceedsAsRoot) {
fileCreated.WaitForNotification();
// Set file's owners to someone different.
- EXPECT_NO_ERRNO(GetParam()(filename, FLAGS_scratch_uid1, FLAGS_scratch_gid));
+ EXPECT_NO_ERRNO(GetParam()(filename, absl::GetFlag(FLAGS_scratch_uid1),
+ absl::GetFlag(FLAGS_scratch_gid)));
struct stat s;
EXPECT_THAT(stat(filename.c_str(), &s), SyscallSucceeds());
- EXPECT_EQ(s.st_uid, FLAGS_scratch_uid1);
- EXPECT_EQ(s.st_gid, FLAGS_scratch_gid);
+ EXPECT_EQ(s.st_uid, absl::GetFlag(FLAGS_scratch_uid1));
+ EXPECT_EQ(s.st_gid, absl::GetFlag(FLAGS_scratch_gid));
fileChowned.Notify();
}
diff --git a/test/syscalls/linux/fcntl.cc b/test/syscalls/linux/fcntl.cc
index 2f8e7c9dd..8a45be12a 100644
--- a/test/syscalls/linux/fcntl.cc
+++ b/test/syscalls/linux/fcntl.cc
@@ -17,9 +17,12 @@
#include <syscall.h>
#include <unistd.h>
+#include <string>
+
#include "gtest/gtest.h"
#include "absl/base/macros.h"
#include "absl/base/port.h"
+#include "absl/flags/flag.h"
#include "absl/memory/memory.h"
#include "absl/strings/str_cat.h"
#include "absl/time/clock.h"
@@ -33,18 +36,19 @@
#include "test/util/test_util.h"
#include "test/util/timer_util.h"
-DEFINE_string(child_setlock_on, "",
- "Contains the path to try to set a file lock on.");
-DEFINE_bool(child_setlock_write, false,
- "Whether to set a writable lock (otherwise readable)");
-DEFINE_bool(blocking, false,
- "Whether to set a blocking lock (otherwise non-blocking).");
-DEFINE_bool(retry_eintr, false, "Whether to retry in the subprocess on EINTR.");
-DEFINE_uint64(child_setlock_start, 0, "The value of struct flock start");
-DEFINE_uint64(child_setlock_len, 0, "The value of struct flock len");
-DEFINE_int32(socket_fd, -1,
- "A socket to use for communicating more state back "
- "to the parent.");
+ABSL_FLAG(std::string, child_setlock_on, "",
+ "Contains the path to try to set a file lock on.");
+ABSL_FLAG(bool, child_setlock_write, false,
+ "Whether to set a writable lock (otherwise readable)");
+ABSL_FLAG(bool, blocking, false,
+ "Whether to set a blocking lock (otherwise non-blocking).");
+ABSL_FLAG(bool, retry_eintr, false,
+ "Whether to retry in the subprocess on EINTR.");
+ABSL_FLAG(uint64_t, child_setlock_start, 0, "The value of struct flock start");
+ABSL_FLAG(uint64_t, child_setlock_len, 0, "The value of struct flock len");
+ABSL_FLAG(int32_t, socket_fd, -1,
+ "A socket to use for communicating more state back "
+ "to the parent.");
namespace gvisor {
namespace testing {
@@ -918,25 +922,26 @@ TEST(FcntlTest, GetOwn) {
int main(int argc, char** argv) {
gvisor::testing::TestInit(&argc, &argv);
- if (!FLAGS_child_setlock_on.empty()) {
- int socket_fd = FLAGS_socket_fd;
- int fd = open(FLAGS_child_setlock_on.c_str(), O_RDWR, 0666);
+ const std::string setlock_on = absl::GetFlag(FLAGS_child_setlock_on);
+ if (!setlock_on.empty()) {
+ int socket_fd = absl::GetFlag(FLAGS_socket_fd);
+ int fd = open(setlock_on.c_str(), O_RDWR, 0666);
if (fd == -1 && errno != 0) {
int err = errno;
- std::cerr << "CHILD open " << FLAGS_child_setlock_on << " failed " << err
+ std::cerr << "CHILD open " << setlock_on << " failed " << err
<< std::endl;
exit(err);
}
struct flock fl;
- if (FLAGS_child_setlock_write) {
+ if (absl::GetFlag(FLAGS_child_setlock_write)) {
fl.l_type = F_WRLCK;
} else {
fl.l_type = F_RDLCK;
}
fl.l_whence = SEEK_SET;
- fl.l_start = FLAGS_child_setlock_start;
- fl.l_len = FLAGS_child_setlock_len;
+ fl.l_start = absl::GetFlag(FLAGS_child_setlock_start);
+ fl.l_len = absl::GetFlag(FLAGS_child_setlock_len);
// Test the fcntl, no need to log, the error is unambiguously
// from fcntl at this point.
@@ -946,8 +951,8 @@ int main(int argc, char** argv) {
gvisor::testing::MonotonicTimer timer;
timer.Start();
do {
- ret = fcntl(fd, FLAGS_blocking ? F_SETLKW : F_SETLK, &fl);
- } while (FLAGS_retry_eintr && ret == -1 && errno == EINTR);
+ ret = fcntl(fd, absl::GetFlag(FLAGS_blocking) ? F_SETLKW : F_SETLK, &fl);
+ } while (absl::GetFlag(FLAGS_retry_eintr) && ret == -1 && errno == EINTR);
auto usec = absl::ToInt64Microseconds(timer.Duration());
if (ret == -1 && errno != 0) {
diff --git a/test/syscalls/linux/kill.cc b/test/syscalls/linux/kill.cc
index 18ad923b8..db29bd59c 100644
--- a/test/syscalls/linux/kill.cc
+++ b/test/syscalls/linux/kill.cc
@@ -21,6 +21,7 @@
#include <csignal>
#include "gtest/gtest.h"
+#include "absl/flags/flag.h"
#include "absl/synchronization/mutex.h"
#include "absl/time/clock.h"
#include "absl/time/time.h"
@@ -31,8 +32,8 @@
#include "test/util/test_util.h"
#include "test/util/thread_util.h"
-DEFINE_int32(scratch_uid, 65534, "scratch UID");
-DEFINE_int32(scratch_gid, 65534, "scratch GID");
+ABSL_FLAG(int32_t, scratch_uid, 65534, "scratch UID");
+ABSL_FLAG(int32_t, scratch_gid, 65534, "scratch GID");
using ::testing::Ge;
@@ -255,8 +256,8 @@ TEST(KillTest, ProcessGroups) {
TEST(KillTest, ChildDropsPrivsCannotKill) {
SKIP_IF(!ASSERT_NO_ERRNO_AND_VALUE(HaveCapability(CAP_SETUID)));
- int uid = FLAGS_scratch_uid;
- int gid = FLAGS_scratch_gid;
+ const int uid = absl::GetFlag(FLAGS_scratch_uid);
+ const int gid = absl::GetFlag(FLAGS_scratch_gid);
// Create the child that drops privileges and tries to kill the parent.
pid_t pid = fork();
@@ -331,8 +332,8 @@ TEST(KillTest, CanSIGCONTSameSession) {
EXPECT_TRUE(WIFSTOPPED(status) && WSTOPSIG(status) == SIGSTOP)
<< "status " << status;
- int uid = FLAGS_scratch_uid;
- int gid = FLAGS_scratch_gid;
+ const int uid = absl::GetFlag(FLAGS_scratch_uid);
+ const int gid = absl::GetFlag(FLAGS_scratch_gid);
// Drop privileges only in child process, or else this parent process won't be
// able to open some log files after the test ends.
diff --git a/test/syscalls/linux/link.cc b/test/syscalls/linux/link.cc
index a91703070..dd5352954 100644
--- a/test/syscalls/linux/link.cc
+++ b/test/syscalls/linux/link.cc
@@ -22,6 +22,7 @@
#include <string>
#include "gtest/gtest.h"
+#include "absl/flags/flag.h"
#include "absl/strings/str_cat.h"
#include "test/util/capability_util.h"
#include "test/util/file_descriptor.h"
@@ -31,7 +32,7 @@
#include "test/util/test_util.h"
#include "test/util/thread_util.h"
-DEFINE_int32(scratch_uid, 65534, "scratch UID");
+ABSL_FLAG(int32_t, scratch_uid, 65534, "scratch UID");
namespace gvisor {
namespace testing {
@@ -92,7 +93,8 @@ TEST(LinkTest, PermissionDenied) {
// threads have the same UIDs, so using the setuid wrapper sets all threads'
// real UID.
// Also drops capabilities.
- EXPECT_THAT(syscall(SYS_setuid, FLAGS_scratch_uid), SyscallSucceeds());
+ EXPECT_THAT(syscall(SYS_setuid, absl::GetFlag(FLAGS_scratch_uid)),
+ SyscallSucceeds());
EXPECT_THAT(link(oldfile.path().c_str(), newname.c_str()),
SyscallFailsWithErrno(EPERM));
diff --git a/test/syscalls/linux/mremap.cc b/test/syscalls/linux/mremap.cc
index 64e435cb7..f0e5f7d82 100644
--- a/test/syscalls/linux/mremap.cc
+++ b/test/syscalls/linux/mremap.cc
@@ -35,17 +35,6 @@ namespace testing {
namespace {
-// Wrapper for mremap that returns a PosixErrorOr<>, since the return type of
-// void* isn't directly compatible with SyscallSucceeds.
-PosixErrorOr<void*> Mremap(void* old_address, size_t old_size, size_t new_size,
- int flags, void* new_address) {
- void* rv = mremap(old_address, old_size, new_size, flags, new_address);
- if (rv == MAP_FAILED) {
- return PosixError(errno, "mremap failed");
- }
- return rv;
-}
-
// Fixture for mremap tests parameterized by mmap flags.
using MremapParamTest = ::testing::TestWithParam<int>;
diff --git a/test/syscalls/linux/pipe.cc b/test/syscalls/linux/pipe.cc
index 65afb90f3..10e2a6dfc 100644
--- a/test/syscalls/linux/pipe.cc
+++ b/test/syscalls/linux/pipe.cc
@@ -168,6 +168,20 @@ TEST_P(PipeTest, Write) {
EXPECT_EQ(wbuf, rbuf);
}
+TEST_P(PipeTest, WritePage) {
+ SKIP_IF(!CreateBlocking());
+
+ std::vector<char> wbuf(kPageSize);
+ RandomizeBuffer(wbuf.data(), wbuf.size());
+ std::vector<char> rbuf(wbuf.size());
+
+ ASSERT_THAT(write(wfd_.get(), wbuf.data(), wbuf.size()),
+ SyscallSucceedsWithValue(wbuf.size()));
+ ASSERT_THAT(read(rfd_.get(), rbuf.data(), rbuf.size()),
+ SyscallSucceedsWithValue(rbuf.size()));
+ EXPECT_EQ(memcmp(rbuf.data(), wbuf.data(), wbuf.size()), 0);
+}
+
TEST_P(PipeTest, NonBlocking) {
SKIP_IF(!CreateNonBlocking());
diff --git a/test/syscalls/linux/prctl.cc b/test/syscalls/linux/prctl.cc
index bd1779557..d07571a5f 100644
--- a/test/syscalls/linux/prctl.cc
+++ b/test/syscalls/linux/prctl.cc
@@ -21,6 +21,7 @@
#include <string>
#include "gtest/gtest.h"
+#include "absl/flags/flag.h"
#include "test/util/capability_util.h"
#include "test/util/cleanup.h"
#include "test/util/multiprocess_util.h"
@@ -28,9 +29,9 @@
#include "test/util/test_util.h"
#include "test/util/thread_util.h"
-DEFINE_bool(prctl_no_new_privs_test_child, false,
- "If true, exit with the return value of prctl(PR_GET_NO_NEW_PRIVS) "
- "plus an offset (see test source).");
+ABSL_FLAG(bool, prctl_no_new_privs_test_child, false,
+ "If true, exit with the return value of prctl(PR_GET_NO_NEW_PRIVS) "
+ "plus an offset (see test source).");
namespace gvisor {
namespace testing {
@@ -220,7 +221,7 @@ TEST(PrctlTest, RootDumpability) {
int main(int argc, char** argv) {
gvisor::testing::TestInit(&argc, &argv);
- if (FLAGS_prctl_no_new_privs_test_child) {
+ if (absl::GetFlag(FLAGS_prctl_no_new_privs_test_child)) {
exit(gvisor::testing::kPrctlNoNewPrivsTestChildExitBase +
prctl(PR_GET_NO_NEW_PRIVS, 0, 0, 0, 0));
}
diff --git a/test/syscalls/linux/prctl_setuid.cc b/test/syscalls/linux/prctl_setuid.cc
index 00dd6523e..30f0d75b3 100644
--- a/test/syscalls/linux/prctl_setuid.cc
+++ b/test/syscalls/linux/prctl_setuid.cc
@@ -14,9 +14,11 @@
#include <sched.h>
#include <sys/prctl.h>
+
#include <string>
#include "gtest/gtest.h"
+#include "absl/flags/flag.h"
#include "test/util/capability_util.h"
#include "test/util/logging.h"
#include "test/util/multiprocess_util.h"
@@ -24,12 +26,12 @@
#include "test/util/test_util.h"
#include "test/util/thread_util.h"
-DEFINE_int32(scratch_uid, 65534, "scratch UID");
+ABSL_FLAG(int32_t, scratch_uid, 65534, "scratch UID");
// This flag is used to verify that after an exec PR_GET_KEEPCAPS
// returns 0, the return code will be offset by kPrGetKeepCapsExitBase.
-DEFINE_bool(prctl_pr_get_keepcaps, false,
- "If true the test will verify that prctl with pr_get_keepcaps"
- "returns 0. The test will exit with the result of that check.");
+ABSL_FLAG(bool, prctl_pr_get_keepcaps, false,
+ "If true the test will verify that prctl with pr_get_keepcaps"
+ "returns 0. The test will exit with the result of that check.");
// These tests exist seperately from prctl because we need to start
// them as root. Setuid() has the behavior that permissions are fully
@@ -113,10 +115,12 @@ TEST_F(PrctlKeepCapsSetuidTest, SetUidNoKeepCaps) {
// call to only apply to this task. POSIX threads, however, require that
// all threads have the same UIDs, so using the setuid wrapper sets all
// threads' real UID.
- EXPECT_THAT(syscall(SYS_setuid, FLAGS_scratch_uid), SyscallSucceeds());
+ EXPECT_THAT(syscall(SYS_setuid, absl::GetFlag(FLAGS_scratch_uid)),
+ SyscallSucceeds());
// Verify that we changed uid.
- EXPECT_THAT(getuid(), SyscallSucceedsWithValue(FLAGS_scratch_uid));
+ EXPECT_THAT(getuid(),
+ SyscallSucceedsWithValue(absl::GetFlag(FLAGS_scratch_uid)));
// Verify we lost the capability in the effective set, this always happens.
TEST_CHECK(!HaveCapability(CAP_SYS_ADMIN).ValueOrDie());
@@ -157,10 +161,12 @@ TEST_F(PrctlKeepCapsSetuidTest, SetUidKeepCaps) {
// call to only apply to this task. POSIX threads, however, require that
// all threads have the same UIDs, so using the setuid wrapper sets all
// threads' real UID.
- EXPECT_THAT(syscall(SYS_setuid, FLAGS_scratch_uid), SyscallSucceeds());
+ EXPECT_THAT(syscall(SYS_setuid, absl::GetFlag(FLAGS_scratch_uid)),
+ SyscallSucceeds());
// Verify that we changed uid.
- EXPECT_THAT(getuid(), SyscallSucceedsWithValue(FLAGS_scratch_uid));
+ EXPECT_THAT(getuid(),
+ SyscallSucceedsWithValue(absl::GetFlag(FLAGS_scratch_uid)));
// Verify we lost the capability in the effective set, this always happens.
TEST_CHECK(!HaveCapability(CAP_SYS_ADMIN).ValueOrDie());
@@ -253,7 +259,7 @@ TEST_F(PrctlKeepCapsSetuidTest, PrGetKeepCaps) {
int main(int argc, char** argv) {
gvisor::testing::TestInit(&argc, &argv);
- if (FLAGS_prctl_pr_get_keepcaps) {
+ if (absl::GetFlag(FLAGS_prctl_pr_get_keepcaps)) {
return gvisor::testing::kPrGetKeepCapsExitBase +
prctl(PR_GET_KEEPCAPS, 0, 0, 0, 0);
}
diff --git a/test/syscalls/linux/proc.cc b/test/syscalls/linux/proc.cc
index 2b753b7d1..6f07803d9 100644
--- a/test/syscalls/linux/proc.cc
+++ b/test/syscalls/linux/proc.cc
@@ -1882,7 +1882,9 @@ void CheckDuplicatesRecursively(std::string path) {
errno = 0;
DIR* dir = opendir(path.c_str());
if (dir == nullptr) {
- ASSERT_THAT(errno, ::testing::AnyOf(EPERM, EACCES)) << path;
+ // Ignore any directories we can't read or missing directories as the
+ // directory could have been deleted/mutated from the time the parent
+ // directory contents were read.
return;
}
auto dir_closer = Cleanup([&dir]() { closedir(dir); });
diff --git a/test/syscalls/linux/proc_net.cc b/test/syscalls/linux/proc_net.cc
index c097af196..efdaf202b 100644
--- a/test/syscalls/linux/proc_net.cc
+++ b/test/syscalls/linux/proc_net.cc
@@ -28,7 +28,7 @@ TEST(ProcNetIfInet6, Format) {
EXPECT_THAT(ifinet6,
::testing::MatchesRegex(
// Ex: "00000000000000000000000000000001 01 80 10 80 lo\n"
- "^([a-f\\d]{32}( [a-f\\d]{2}){4} +[a-z][a-z\\d]*\\n)+$"));
+ "^([a-f0-9]{32}( [a-f0-9]{2}){4} +[a-z][a-z0-9]*\n)+$"));
}
TEST(ProcSysNetIpv4Sack, Exists) {
diff --git a/test/syscalls/linux/ptrace.cc b/test/syscalls/linux/ptrace.cc
index abf2b1a04..8f3800380 100644
--- a/test/syscalls/linux/ptrace.cc
+++ b/test/syscalls/linux/ptrace.cc
@@ -27,6 +27,7 @@
#include "gmock/gmock.h"
#include "gtest/gtest.h"
+#include "absl/flags/flag.h"
#include "absl/time/clock.h"
#include "absl/time/time.h"
#include "test/util/logging.h"
@@ -36,10 +37,10 @@
#include "test/util/thread_util.h"
#include "test/util/time_util.h"
-DEFINE_bool(ptrace_test_execve_child, false,
- "If true, run the "
- "PtraceExecveTest_Execve_GetRegs_PeekUser_SIGKILL_TraceClone_"
- "TraceExit child workload.");
+ABSL_FLAG(bool, ptrace_test_execve_child, false,
+ "If true, run the "
+ "PtraceExecveTest_Execve_GetRegs_PeekUser_SIGKILL_TraceClone_"
+ "TraceExit child workload.");
namespace gvisor {
namespace testing {
@@ -1206,7 +1207,7 @@ TEST(PtraceTest, SeizeSetOptions) {
int main(int argc, char** argv) {
gvisor::testing::TestInit(&argc, &argv);
- if (FLAGS_ptrace_test_execve_child) {
+ if (absl::GetFlag(FLAGS_ptrace_test_execve_child)) {
gvisor::testing::RunExecveChild();
}
diff --git a/test/syscalls/linux/pty.cc b/test/syscalls/linux/pty.cc
index d1ab4703f..286388316 100644
--- a/test/syscalls/linux/pty.cc
+++ b/test/syscalls/linux/pty.cc
@@ -13,13 +13,17 @@
// limitations under the License.
#include <fcntl.h>
+#include <linux/capability.h>
#include <linux/major.h>
#include <poll.h>
+#include <sched.h>
+#include <signal.h>
#include <sys/ioctl.h>
#include <sys/mman.h>
#include <sys/stat.h>
#include <sys/sysmacros.h>
#include <sys/types.h>
+#include <sys/wait.h>
#include <termios.h>
#include <unistd.h>
@@ -31,8 +35,10 @@
#include "absl/synchronization/notification.h"
#include "absl/time/clock.h"
#include "absl/time/time.h"
+#include "test/util/capability_util.h"
#include "test/util/file_descriptor.h"
#include "test/util/posix_error.h"
+#include "test/util/pty_util.h"
#include "test/util/test_util.h"
#include "test/util/thread_util.h"
@@ -370,25 +376,6 @@ PosixErrorOr<size_t> PollAndReadFd(int fd, void* buf, size_t count,
return PosixError(ETIMEDOUT, "Poll timed out");
}
-// Opens the slave end of the passed master as R/W and nonblocking.
-PosixErrorOr<FileDescriptor> OpenSlave(const FileDescriptor& master) {
- // Get pty index.
- int n;
- int ret = ioctl(master.get(), TIOCGPTN, &n);
- if (ret < 0) {
- return PosixError(errno, "ioctl(TIOCGPTN) failed");
- }
-
- // Unlock pts.
- int unlock = 0;
- ret = ioctl(master.get(), TIOCSPTLCK, &unlock);
- if (ret < 0) {
- return PosixError(errno, "ioctl(TIOSPTLCK) failed");
- }
-
- return Open(absl::StrCat("/dev/pts/", n), O_RDWR | O_NONBLOCK);
-}
-
TEST(BasicPtyTest, StatUnopenedMaster) {
struct stat s;
ASSERT_THAT(stat("/dev/ptmx", &s), SyscallSucceeds());
@@ -1233,6 +1220,374 @@ TEST_F(PtyTest, SetMasterWindowSize) {
EXPECT_EQ(retrieved_ws.ws_col, kCols);
}
+class JobControlTest : public ::testing::Test {
+ protected:
+ void SetUp() override {
+ master_ = ASSERT_NO_ERRNO_AND_VALUE(Open("/dev/ptmx", O_RDWR | O_NONBLOCK));
+ slave_ = ASSERT_NO_ERRNO_AND_VALUE(OpenSlave(master_));
+
+ // Make this a session leader, which also drops the controlling terminal.
+ // In the gVisor test environment, this test will be run as the session
+ // leader already (as the sentry init process).
+ if (!IsRunningOnGvisor()) {
+ ASSERT_THAT(setsid(), SyscallSucceeds());
+ }
+ }
+
+ // Master and slave ends of the PTY. Non-blocking.
+ FileDescriptor master_;
+ FileDescriptor slave_;
+};
+
+TEST_F(JobControlTest, SetTTYMaster) {
+ ASSERT_THAT(ioctl(master_.get(), TIOCSCTTY, 0), SyscallSucceeds());
+}
+
+TEST_F(JobControlTest, SetTTY) {
+ ASSERT_THAT(ioctl(slave_.get(), TIOCSCTTY, 0), SyscallSucceeds());
+}
+
+TEST_F(JobControlTest, SetTTYNonLeader) {
+ // Fork a process that won't be the session leader.
+ pid_t child = fork();
+ if (!child) {
+ // We shouldn't be able to set the terminal.
+ TEST_PCHECK(ioctl(slave_.get(), TIOCSCTTY, 0));
+ _exit(0);
+ }
+
+ int wstatus;
+ ASSERT_THAT(waitpid(child, &wstatus, 0), SyscallSucceedsWithValue(child));
+ ASSERT_EQ(wstatus, 0);
+}
+
+TEST_F(JobControlTest, SetTTYBadArg) {
+ // Despite the man page saying arg should be 0 here, Linux doesn't actually
+ // check.
+ ASSERT_THAT(ioctl(slave_.get(), TIOCSCTTY, 1), SyscallSucceeds());
+}
+
+TEST_F(JobControlTest, SetTTYDifferentSession) {
+ SKIP_IF(ASSERT_NO_ERRNO_AND_VALUE(HaveCapability(CAP_SYS_ADMIN)));
+
+ ASSERT_THAT(ioctl(slave_.get(), TIOCSCTTY, 0), SyscallSucceeds());
+
+ // Fork, join a new session, and try to steal the parent's controlling
+ // terminal, which should fail.
+ pid_t child = fork();
+ if (!child) {
+ TEST_PCHECK(setsid() >= 0);
+ // We shouldn't be able to steal the terminal.
+ TEST_PCHECK(ioctl(slave_.get(), TIOCSCTTY, 1));
+ _exit(0);
+ }
+
+ int wstatus;
+ ASSERT_THAT(waitpid(child, &wstatus, 0), SyscallSucceedsWithValue(child));
+ ASSERT_EQ(wstatus, 0);
+}
+
+TEST_F(JobControlTest, ReleaseTTY) {
+ ASSERT_THAT(ioctl(slave_.get(), TIOCSCTTY, 0), SyscallSucceeds());
+
+ // Make sure we're ignoring SIGHUP, which will be sent to this process once we
+ // disconnect they TTY.
+ struct sigaction sa = {
+ .sa_handler = SIG_IGN,
+ .sa_flags = 0,
+ };
+ sigemptyset(&sa.sa_mask);
+ struct sigaction old_sa;
+ EXPECT_THAT(sigaction(SIGHUP, &sa, &old_sa), SyscallSucceeds());
+ EXPECT_THAT(ioctl(slave_.get(), TIOCNOTTY), SyscallSucceeds());
+ EXPECT_THAT(sigaction(SIGHUP, &old_sa, NULL), SyscallSucceeds());
+}
+
+TEST_F(JobControlTest, ReleaseUnsetTTY) {
+ ASSERT_THAT(ioctl(slave_.get(), TIOCNOTTY), SyscallFailsWithErrno(ENOTTY));
+}
+
+TEST_F(JobControlTest, ReleaseWrongTTY) {
+ ASSERT_THAT(ioctl(slave_.get(), TIOCSCTTY, 0), SyscallSucceeds());
+
+ ASSERT_THAT(ioctl(master_.get(), TIOCNOTTY), SyscallFailsWithErrno(ENOTTY));
+}
+
+TEST_F(JobControlTest, ReleaseTTYNonLeader) {
+ ASSERT_THAT(ioctl(slave_.get(), TIOCSCTTY, 0), SyscallSucceeds());
+
+ pid_t child = fork();
+ if (!child) {
+ TEST_PCHECK(!ioctl(slave_.get(), TIOCNOTTY));
+ _exit(0);
+ }
+
+ int wstatus;
+ ASSERT_THAT(waitpid(child, &wstatus, 0), SyscallSucceedsWithValue(child));
+ ASSERT_EQ(wstatus, 0);
+}
+
+TEST_F(JobControlTest, ReleaseTTYDifferentSession) {
+ ASSERT_THAT(ioctl(slave_.get(), TIOCSCTTY, 0), SyscallSucceeds());
+
+ pid_t child = fork();
+ if (!child) {
+ // Join a new session, then try to disconnect.
+ TEST_PCHECK(setsid() >= 0);
+ TEST_PCHECK(ioctl(slave_.get(), TIOCNOTTY));
+ _exit(0);
+ }
+
+ int wstatus;
+ ASSERT_THAT(waitpid(child, &wstatus, 0), SyscallSucceedsWithValue(child));
+ ASSERT_EQ(wstatus, 0);
+}
+
+// Used by the child process spawned in ReleaseTTYSignals to track received
+// signals.
+static int received;
+
+void sig_handler(int signum) { received |= signum; }
+
+// When the session leader releases its controlling terminal, the foreground
+// process group gets SIGHUP, then SIGCONT. This test:
+// - Spawns 2 threads
+// - Has thread 1 return 0 if it gets both SIGHUP and SIGCONT
+// - Has thread 2 leave the foreground process group, and return non-zero if it
+// receives any signals.
+// - Has the parent thread release its controlling terminal
+// - Checks that thread 1 got both signals
+// - Checks that thread 2 didn't get any signals.
+TEST_F(JobControlTest, ReleaseTTYSignals) {
+ ASSERT_THAT(ioctl(slave_.get(), TIOCSCTTY, 0), SyscallSucceeds());
+
+ received = 0;
+ struct sigaction sa = {
+ .sa_handler = sig_handler,
+ .sa_flags = 0,
+ };
+ sigemptyset(&sa.sa_mask);
+ sigaddset(&sa.sa_mask, SIGHUP);
+ sigaddset(&sa.sa_mask, SIGCONT);
+ sigprocmask(SIG_BLOCK, &sa.sa_mask, NULL);
+
+ pid_t same_pgrp_child = fork();
+ if (!same_pgrp_child) {
+ // The child will wait for SIGHUP and SIGCONT, then return 0. It begins with
+ // SIGHUP and SIGCONT blocked. We install signal handlers for those signals,
+ // then use sigsuspend to wait for those specific signals.
+ TEST_PCHECK(!sigaction(SIGHUP, &sa, NULL));
+ TEST_PCHECK(!sigaction(SIGCONT, &sa, NULL));
+ sigset_t mask;
+ sigfillset(&mask);
+ sigdelset(&mask, SIGHUP);
+ sigdelset(&mask, SIGCONT);
+ while (received != (SIGHUP | SIGCONT)) {
+ sigsuspend(&mask);
+ }
+ _exit(0);
+ }
+
+ // We don't want to block these anymore.
+ sigprocmask(SIG_UNBLOCK, &sa.sa_mask, NULL);
+
+ // This child will return non-zero if either SIGHUP or SIGCONT are received.
+ pid_t diff_pgrp_child = fork();
+ if (!diff_pgrp_child) {
+ TEST_PCHECK(!setpgid(0, 0));
+ TEST_PCHECK(pause());
+ _exit(1);
+ }
+
+ EXPECT_THAT(setpgid(diff_pgrp_child, diff_pgrp_child), SyscallSucceeds());
+
+ // Make sure we're ignoring SIGHUP, which will be sent to this process once we
+ // disconnect they TTY.
+ struct sigaction sighup_sa = {
+ .sa_handler = SIG_IGN,
+ .sa_flags = 0,
+ };
+ sigemptyset(&sighup_sa.sa_mask);
+ struct sigaction old_sa;
+ EXPECT_THAT(sigaction(SIGHUP, &sighup_sa, &old_sa), SyscallSucceeds());
+
+ // Release the controlling terminal, sending SIGHUP and SIGCONT to all other
+ // processes in this process group.
+ EXPECT_THAT(ioctl(slave_.get(), TIOCNOTTY), SyscallSucceeds());
+
+ EXPECT_THAT(sigaction(SIGHUP, &old_sa, NULL), SyscallSucceeds());
+
+ // The child in the same process group will get signaled.
+ int wstatus;
+ EXPECT_THAT(waitpid(same_pgrp_child, &wstatus, 0),
+ SyscallSucceedsWithValue(same_pgrp_child));
+ EXPECT_EQ(wstatus, 0);
+
+ // The other child will not get signaled.
+ EXPECT_THAT(waitpid(diff_pgrp_child, &wstatus, WNOHANG),
+ SyscallSucceedsWithValue(0));
+ EXPECT_THAT(kill(diff_pgrp_child, SIGKILL), SyscallSucceeds());
+}
+
+TEST_F(JobControlTest, GetForegroundProcessGroup) {
+ ASSERT_THAT(ioctl(slave_.get(), TIOCSCTTY, 0), SyscallSucceeds());
+ pid_t foreground_pgid;
+ pid_t pid;
+ ASSERT_THAT(ioctl(slave_.get(), TIOCGPGRP, &foreground_pgid),
+ SyscallSucceeds());
+ ASSERT_THAT(pid = getpid(), SyscallSucceeds());
+
+ ASSERT_EQ(foreground_pgid, pid);
+}
+
+TEST_F(JobControlTest, GetForegroundProcessGroupNonControlling) {
+ // At this point there's no controlling terminal, so TIOCGPGRP should fail.
+ pid_t foreground_pgid;
+ ASSERT_THAT(ioctl(slave_.get(), TIOCGPGRP, &foreground_pgid),
+ SyscallFailsWithErrno(ENOTTY));
+}
+
+// This test:
+// - sets itself as the foreground process group
+// - creates a child process in a new process group
+// - sets that child as the foreground process group
+// - kills its child and sets itself as the foreground process group.
+TEST_F(JobControlTest, SetForegroundProcessGroup) {
+ ASSERT_THAT(ioctl(slave_.get(), TIOCSCTTY, 0), SyscallSucceeds());
+
+ // Ignore SIGTTOU so that we don't stop ourself when calling tcsetpgrp.
+ struct sigaction sa = {
+ .sa_handler = SIG_IGN,
+ .sa_flags = 0,
+ };
+ sigemptyset(&sa.sa_mask);
+ sigaction(SIGTTOU, &sa, NULL);
+
+ // Set ourself as the foreground process group.
+ ASSERT_THAT(tcsetpgrp(slave_.get(), getpgid(0)), SyscallSucceeds());
+
+ // Create a new process that just waits to be signaled.
+ pid_t child = fork();
+ if (!child) {
+ TEST_PCHECK(!pause());
+ // We should never reach this.
+ _exit(1);
+ }
+
+ // Make the child its own process group, then make it the controlling process
+ // group of the terminal.
+ ASSERT_THAT(setpgid(child, child), SyscallSucceeds());
+ ASSERT_THAT(tcsetpgrp(slave_.get(), child), SyscallSucceeds());
+
+ // Sanity check - we're still the controlling session.
+ ASSERT_EQ(getsid(0), getsid(child));
+
+ // Signal the child, wait for it to exit, then retake the terminal.
+ ASSERT_THAT(kill(child, SIGTERM), SyscallSucceeds());
+ int wstatus;
+ ASSERT_THAT(waitpid(child, &wstatus, 0), SyscallSucceedsWithValue(child));
+ ASSERT_TRUE(WIFSIGNALED(wstatus));
+ ASSERT_EQ(WTERMSIG(wstatus), SIGTERM);
+
+ // Set ourself as the foreground process.
+ pid_t pgid;
+ ASSERT_THAT(pgid = getpgid(0), SyscallSucceeds());
+ ASSERT_THAT(tcsetpgrp(slave_.get(), pgid), SyscallSucceeds());
+}
+
+TEST_F(JobControlTest, SetForegroundProcessGroupWrongTTY) {
+ pid_t pid = getpid();
+ ASSERT_THAT(ioctl(slave_.get(), TIOCSPGRP, &pid),
+ SyscallFailsWithErrno(ENOTTY));
+}
+
+TEST_F(JobControlTest, SetForegroundProcessGroupNegPgid) {
+ ASSERT_THAT(ioctl(slave_.get(), TIOCSCTTY, 0), SyscallSucceeds());
+
+ pid_t pid = -1;
+ ASSERT_THAT(ioctl(slave_.get(), TIOCSPGRP, &pid),
+ SyscallFailsWithErrno(EINVAL));
+}
+
+TEST_F(JobControlTest, SetForegroundProcessGroupEmptyProcessGroup) {
+ ASSERT_THAT(ioctl(slave_.get(), TIOCSCTTY, 0), SyscallSucceeds());
+
+ // Create a new process, put it in a new process group, make that group the
+ // foreground process group, then have the process wait.
+ pid_t child = fork();
+ if (!child) {
+ TEST_PCHECK(!setpgid(0, 0));
+ _exit(0);
+ }
+
+ // Wait for the child to exit.
+ int wstatus;
+ EXPECT_THAT(waitpid(child, &wstatus, 0), SyscallSucceedsWithValue(child));
+ // The child's process group doesn't exist anymore - this should fail.
+ ASSERT_THAT(ioctl(slave_.get(), TIOCSPGRP, &child),
+ SyscallFailsWithErrno(ESRCH));
+}
+
+TEST_F(JobControlTest, SetForegroundProcessGroupDifferentSession) {
+ ASSERT_THAT(ioctl(slave_.get(), TIOCSCTTY, 0), SyscallSucceeds());
+
+ // Create a new process and put it in a new session.
+ pid_t child = fork();
+ if (!child) {
+ TEST_PCHECK(setsid() >= 0);
+ // Tell the parent we're in a new session.
+ TEST_PCHECK(!raise(SIGSTOP));
+ TEST_PCHECK(!pause());
+ _exit(1);
+ }
+
+ // Wait for the child to tell us it's in a new session.
+ int wstatus;
+ EXPECT_THAT(waitpid(child, &wstatus, WUNTRACED),
+ SyscallSucceedsWithValue(child));
+ EXPECT_TRUE(WSTOPSIG(wstatus));
+
+ // Child is in a new session, so we can't make it the foregroup process group.
+ EXPECT_THAT(ioctl(slave_.get(), TIOCSPGRP, &child),
+ SyscallFailsWithErrno(EPERM));
+
+ EXPECT_THAT(kill(child, SIGKILL), SyscallSucceeds());
+}
+
+// Verify that we don't hang when creating a new session from an orphaned
+// process group (b/139968068). Calling setsid() creates an orphaned process
+// group, as process groups that contain the session's leading process are
+// orphans.
+//
+// We create 2 sessions in this test. The init process in gVisor is considered
+// not to be an orphan (see sessions.go), so we have to create a session from
+// which to create a session. The latter session is being created from an
+// orphaned process group.
+TEST_F(JobControlTest, OrphanRegression) {
+ pid_t session_2_leader = fork();
+ if (!session_2_leader) {
+ TEST_PCHECK(setsid() >= 0);
+
+ pid_t session_3_leader = fork();
+ if (!session_3_leader) {
+ TEST_PCHECK(setsid() >= 0);
+
+ _exit(0);
+ }
+
+ int wstatus;
+ TEST_PCHECK(waitpid(session_3_leader, &wstatus, 0) == session_3_leader);
+ TEST_PCHECK(wstatus == 0);
+
+ _exit(0);
+ }
+
+ int wstatus;
+ ASSERT_THAT(waitpid(session_2_leader, &wstatus, 0),
+ SyscallSucceedsWithValue(session_2_leader));
+ ASSERT_EQ(wstatus, 0);
+}
+
} // namespace
} // namespace testing
} // namespace gvisor
diff --git a/test/syscalls/linux/pty_root.cc b/test/syscalls/linux/pty_root.cc
new file mode 100644
index 000000000..14a4af980
--- /dev/null
+++ b/test/syscalls/linux/pty_root.cc
@@ -0,0 +1,68 @@
+// Copyright 2018 The gVisor Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+#include <sys/ioctl.h>
+#include <termios.h>
+
+#include "gtest/gtest.h"
+#include "absl/base/macros.h"
+#include "test/util/capability_util.h"
+#include "test/util/file_descriptor.h"
+#include "test/util/posix_error.h"
+#include "test/util/pty_util.h"
+
+namespace gvisor {
+namespace testing {
+
+// These tests should be run as root.
+namespace {
+
+TEST(JobControlRootTest, StealTTY) {
+ SKIP_IF(!ASSERT_NO_ERRNO_AND_VALUE(HaveCapability(CAP_SYS_ADMIN)));
+
+ // Make this a session leader, which also drops the controlling terminal.
+ // In the gVisor test environment, this test will be run as the session
+ // leader already (as the sentry init process).
+ if (!IsRunningOnGvisor()) {
+ ASSERT_THAT(setsid(), SyscallSucceeds());
+ }
+
+ FileDescriptor master =
+ ASSERT_NO_ERRNO_AND_VALUE(Open("/dev/ptmx", O_RDWR | O_NONBLOCK));
+ FileDescriptor slave = ASSERT_NO_ERRNO_AND_VALUE(OpenSlave(master));
+
+ // Make slave the controlling terminal.
+ ASSERT_THAT(ioctl(slave.get(), TIOCSCTTY, 0), SyscallSucceeds());
+
+ // Fork, join a new session, and try to steal the parent's controlling
+ // terminal, which should succeed when we have CAP_SYS_ADMIN and pass an arg
+ // of 1.
+ pid_t child = fork();
+ if (!child) {
+ ASSERT_THAT(setsid(), SyscallSucceeds());
+ // We shouldn't be able to steal the terminal with the wrong arg value.
+ TEST_PCHECK(ioctl(slave.get(), TIOCSCTTY, 0));
+ // We should be able to steal it here.
+ TEST_PCHECK(!ioctl(slave.get(), TIOCSCTTY, 1));
+ _exit(0);
+ }
+
+ int wstatus;
+ ASSERT_THAT(waitpid(child, &wstatus, 0), SyscallSucceedsWithValue(child));
+ ASSERT_EQ(wstatus, 0);
+}
+
+} // namespace
+} // namespace testing
+} // namespace gvisor
diff --git a/test/syscalls/linux/sendfile.cc b/test/syscalls/linux/sendfile.cc
index 9167ab066..4502e7fb4 100644
--- a/test/syscalls/linux/sendfile.cc
+++ b/test/syscalls/linux/sendfile.cc
@@ -19,9 +19,12 @@
#include "gmock/gmock.h"
#include "gtest/gtest.h"
#include "absl/strings/string_view.h"
+#include "absl/time/clock.h"
+#include "absl/time/time.h"
#include "test/util/file_descriptor.h"
#include "test/util/temp_path.h"
#include "test/util/test_util.h"
+#include "test/util/thread_util.h"
namespace gvisor {
namespace testing {
@@ -442,6 +445,72 @@ TEST(SendFileTest, SendToNotARegularFile) {
EXPECT_THAT(sendfile(outf.get(), inf.get(), nullptr, 0),
SyscallFailsWithErrno(EINVAL));
}
+
+TEST(SendFileTest, SendPipeWouldBlock) {
+ // Create temp file.
+ constexpr char kData[] =
+ "The fool doth think he is wise, but the wise man knows himself to be a "
+ "fool.";
+ constexpr int kDataSize = sizeof(kData) - 1;
+ const TempPath in_file = ASSERT_NO_ERRNO_AND_VALUE(TempPath::CreateFileWith(
+ GetAbsoluteTestTmpdir(), kData, TempPath::kDefaultFileMode));
+
+ // Open the input file as read only.
+ const FileDescriptor inf =
+ ASSERT_NO_ERRNO_AND_VALUE(Open(in_file.path(), O_RDONLY));
+
+ // Setup the output named pipe.
+ int fds[2];
+ ASSERT_THAT(pipe2(fds, O_NONBLOCK), SyscallSucceeds());
+ const FileDescriptor rfd(fds[0]);
+ const FileDescriptor wfd(fds[1]);
+
+ // Fill up the pipe's buffer.
+ int pipe_size = -1;
+ ASSERT_THAT(pipe_size = fcntl(wfd.get(), F_GETPIPE_SZ), SyscallSucceeds());
+ std::vector<char> buf(2 * pipe_size);
+ ASSERT_THAT(write(wfd.get(), buf.data(), buf.size()),
+ SyscallSucceedsWithValue(pipe_size));
+
+ EXPECT_THAT(sendfile(wfd.get(), inf.get(), nullptr, kDataSize),
+ SyscallFailsWithErrno(EWOULDBLOCK));
+}
+
+TEST(SendFileTest, SendPipeBlocks) {
+ // Create temp file.
+ constexpr char kData[] =
+ "The fault, dear Brutus, is not in our stars, but in ourselves.";
+ constexpr int kDataSize = sizeof(kData) - 1;
+ const TempPath in_file = ASSERT_NO_ERRNO_AND_VALUE(TempPath::CreateFileWith(
+ GetAbsoluteTestTmpdir(), kData, TempPath::kDefaultFileMode));
+
+ // Open the input file as read only.
+ const FileDescriptor inf =
+ ASSERT_NO_ERRNO_AND_VALUE(Open(in_file.path(), O_RDONLY));
+
+ // Setup the output named pipe.
+ int fds[2];
+ ASSERT_THAT(pipe(fds), SyscallSucceeds());
+ const FileDescriptor rfd(fds[0]);
+ const FileDescriptor wfd(fds[1]);
+
+ // Fill up the pipe's buffer.
+ int pipe_size = -1;
+ ASSERT_THAT(pipe_size = fcntl(wfd.get(), F_GETPIPE_SZ), SyscallSucceeds());
+ std::vector<char> buf(pipe_size);
+ ASSERT_THAT(write(wfd.get(), buf.data(), buf.size()),
+ SyscallSucceedsWithValue(pipe_size));
+
+ ScopedThread t([&]() {
+ absl::SleepFor(absl::Milliseconds(100));
+ ASSERT_THAT(read(rfd.get(), buf.data(), buf.size()),
+ SyscallSucceedsWithValue(pipe_size));
+ });
+
+ EXPECT_THAT(sendfile(wfd.get(), inf.get(), nullptr, kDataSize),
+ SyscallSucceedsWithValue(kDataSize));
+}
+
} // namespace
} // namespace testing
diff --git a/test/syscalls/linux/signalfd.cc b/test/syscalls/linux/signalfd.cc
new file mode 100644
index 000000000..54c598627
--- /dev/null
+++ b/test/syscalls/linux/signalfd.cc
@@ -0,0 +1,333 @@
+// Copyright 2019 The gVisor Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+#include <errno.h>
+#include <poll.h>
+#include <signal.h>
+#include <stdio.h>
+#include <string.h>
+#include <sys/signalfd.h>
+#include <unistd.h>
+
+#include <functional>
+#include <vector>
+
+#include "gtest/gtest.h"
+#include "gtest/gtest.h"
+#include "absl/synchronization/mutex.h"
+#include "test/util/file_descriptor.h"
+#include "test/util/posix_error.h"
+#include "test/util/signal_util.h"
+#include "test/util/test_util.h"
+#include "test/util/thread_util.h"
+
+using ::testing::KilledBySignal;
+
+namespace gvisor {
+namespace testing {
+
+namespace {
+
+constexpr int kSigno = SIGUSR1;
+constexpr int kSignoAlt = SIGUSR2;
+
+// Returns a new signalfd.
+inline PosixErrorOr<FileDescriptor> NewSignalFD(sigset_t* mask, int flags = 0) {
+ int fd = signalfd(-1, mask, flags);
+ MaybeSave();
+ if (fd < 0) {
+ return PosixError(errno, "signalfd");
+ }
+ return FileDescriptor(fd);
+}
+
+TEST(Signalfd, Basic) {
+ // Create the signalfd.
+ sigset_t mask;
+ sigemptyset(&mask);
+ sigaddset(&mask, kSigno);
+ FileDescriptor fd = ASSERT_NO_ERRNO_AND_VALUE(NewSignalFD(&mask, 0));
+
+ // Deliver the blocked signal.
+ const auto scoped_sigmask =
+ ASSERT_NO_ERRNO_AND_VALUE(ScopedSignalMask(SIG_BLOCK, kSigno));
+ ASSERT_THAT(tgkill(getpid(), gettid(), kSigno), SyscallSucceeds());
+
+ // We should now read the signal.
+ struct signalfd_siginfo rbuf;
+ ASSERT_THAT(read(fd.get(), &rbuf, sizeof(rbuf)),
+ SyscallSucceedsWithValue(sizeof(rbuf)));
+ EXPECT_EQ(rbuf.ssi_signo, kSigno);
+}
+
+TEST(Signalfd, MaskWorks) {
+ // Create two signalfds with different masks.
+ sigset_t mask1, mask2;
+ sigemptyset(&mask1);
+ sigemptyset(&mask2);
+ sigaddset(&mask1, kSigno);
+ sigaddset(&mask2, kSignoAlt);
+ FileDescriptor fd1 = ASSERT_NO_ERRNO_AND_VALUE(NewSignalFD(&mask1, 0));
+ FileDescriptor fd2 = ASSERT_NO_ERRNO_AND_VALUE(NewSignalFD(&mask2, 0));
+
+ // Deliver the two signals.
+ const auto scoped_sigmask1 =
+ ASSERT_NO_ERRNO_AND_VALUE(ScopedSignalMask(SIG_BLOCK, kSigno));
+ const auto scoped_sigmask2 =
+ ASSERT_NO_ERRNO_AND_VALUE(ScopedSignalMask(SIG_BLOCK, kSignoAlt));
+ ASSERT_THAT(tgkill(getpid(), gettid(), kSigno), SyscallSucceeds());
+ ASSERT_THAT(tgkill(getpid(), gettid(), kSignoAlt), SyscallSucceeds());
+
+ // We should see the signals on the appropriate signalfds.
+ //
+ // We read in the opposite order as the signals deliver above, to ensure that
+ // we don't happen to read the correct signal from the correct signalfd.
+ struct signalfd_siginfo rbuf1, rbuf2;
+ ASSERT_THAT(read(fd2.get(), &rbuf2, sizeof(rbuf2)),
+ SyscallSucceedsWithValue(sizeof(rbuf2)));
+ EXPECT_EQ(rbuf2.ssi_signo, kSignoAlt);
+ ASSERT_THAT(read(fd1.get(), &rbuf1, sizeof(rbuf1)),
+ SyscallSucceedsWithValue(sizeof(rbuf1)));
+ EXPECT_EQ(rbuf1.ssi_signo, kSigno);
+}
+
+TEST(Signalfd, Cloexec) {
+ // Exec tests confirm that O_CLOEXEC has the intended effect. We just create a
+ // signalfd with the appropriate flag here and assert that the FD has it set.
+ sigset_t mask;
+ sigemptyset(&mask);
+ FileDescriptor fd =
+ ASSERT_NO_ERRNO_AND_VALUE(NewSignalFD(&mask, SFD_CLOEXEC));
+ EXPECT_THAT(fcntl(fd.get(), F_GETFD), SyscallSucceedsWithValue(FD_CLOEXEC));
+}
+
+TEST(Signalfd, Blocking) {
+ // Create the signalfd in blocking mode.
+ sigset_t mask;
+ sigemptyset(&mask);
+ sigaddset(&mask, kSigno);
+ FileDescriptor fd = ASSERT_NO_ERRNO_AND_VALUE(NewSignalFD(&mask, 0));
+
+ // Shared tid variable.
+ absl::Mutex mu;
+ bool has_tid;
+ pid_t tid;
+
+ // Start a thread reading.
+ ScopedThread t([&] {
+ // Copy the tid and notify the caller.
+ {
+ absl::MutexLock ml(&mu);
+ tid = gettid();
+ has_tid = true;
+ }
+
+ // Read the signal from the signalfd.
+ struct signalfd_siginfo rbuf;
+ ASSERT_THAT(read(fd.get(), &rbuf, sizeof(rbuf)),
+ SyscallSucceedsWithValue(sizeof(rbuf)));
+ EXPECT_EQ(rbuf.ssi_signo, kSigno);
+ });
+
+ // Wait until blocked.
+ absl::MutexLock ml(&mu);
+ mu.Await(absl::Condition(&has_tid));
+
+ // Deliver the signal to either the waiting thread, or
+ // to this thread. N.B. this is a bug in the core gVisor
+ // behavior for signalfd, and needs to be fixed.
+ //
+ // See gvisor.dev/issue/139.
+ if (IsRunningOnGvisor()) {
+ ASSERT_THAT(tgkill(getpid(), gettid(), kSigno), SyscallSucceeds());
+ } else {
+ ASSERT_THAT(tgkill(getpid(), tid, kSigno), SyscallSucceeds());
+ }
+
+ // Ensure that it was received.
+ t.Join();
+}
+
+TEST(Signalfd, ThreadGroup) {
+ // Create the signalfd in blocking mode.
+ sigset_t mask;
+ sigemptyset(&mask);
+ sigaddset(&mask, kSigno);
+ FileDescriptor fd = ASSERT_NO_ERRNO_AND_VALUE(NewSignalFD(&mask, 0));
+
+ // Shared variable.
+ absl::Mutex mu;
+ bool first = false;
+ bool second = false;
+
+ // Start a thread reading.
+ ScopedThread t([&] {
+ // Read the signal from the signalfd.
+ struct signalfd_siginfo rbuf;
+ ASSERT_THAT(read(fd.get(), &rbuf, sizeof(rbuf)),
+ SyscallSucceedsWithValue(sizeof(rbuf)));
+ EXPECT_EQ(rbuf.ssi_signo, kSigno);
+
+ // Wait for the other thread.
+ absl::MutexLock ml(&mu);
+ first = true;
+ mu.Await(absl::Condition(&second));
+ });
+
+ // Deliver the signal to the threadgroup.
+ ASSERT_THAT(kill(getpid(), kSigno), SyscallSucceeds());
+
+ // Wait for the first thread to process.
+ {
+ absl::MutexLock ml(&mu);
+ mu.Await(absl::Condition(&first));
+ }
+
+ // Deliver to the thread group again (other thread still exists).
+ ASSERT_THAT(kill(getpid(), kSigno), SyscallSucceeds());
+
+ // Ensure that we can also receive it.
+ struct signalfd_siginfo rbuf;
+ ASSERT_THAT(read(fd.get(), &rbuf, sizeof(rbuf)),
+ SyscallSucceedsWithValue(sizeof(rbuf)));
+ EXPECT_EQ(rbuf.ssi_signo, kSigno);
+
+ // Mark the test as done.
+ {
+ absl::MutexLock ml(&mu);
+ second = true;
+ }
+
+ // The other thread should be joinable.
+ t.Join();
+}
+
+TEST(Signalfd, Nonblock) {
+ // Create the signalfd in non-blocking mode.
+ sigset_t mask;
+ sigemptyset(&mask);
+ sigaddset(&mask, kSigno);
+ FileDescriptor fd =
+ ASSERT_NO_ERRNO_AND_VALUE(NewSignalFD(&mask, SFD_NONBLOCK));
+
+ // We should return if we attempt to read.
+ struct signalfd_siginfo rbuf;
+ ASSERT_THAT(read(fd.get(), &rbuf, sizeof(rbuf)),
+ SyscallFailsWithErrno(EWOULDBLOCK));
+
+ // Block and deliver the signal.
+ const auto scoped_sigmask =
+ ASSERT_NO_ERRNO_AND_VALUE(ScopedSignalMask(SIG_BLOCK, kSigno));
+ ASSERT_THAT(tgkill(getpid(), gettid(), kSigno), SyscallSucceeds());
+
+ // Ensure that a read actually works.
+ ASSERT_THAT(read(fd.get(), &rbuf, sizeof(rbuf)),
+ SyscallSucceedsWithValue(sizeof(rbuf)));
+ EXPECT_EQ(rbuf.ssi_signo, kSigno);
+
+ // Should block again.
+ EXPECT_THAT(read(fd.get(), &rbuf, sizeof(rbuf)),
+ SyscallFailsWithErrno(EWOULDBLOCK));
+}
+
+TEST(Signalfd, SetMask) {
+ // Create the signalfd matching nothing.
+ sigset_t mask;
+ sigemptyset(&mask);
+ FileDescriptor fd =
+ ASSERT_NO_ERRNO_AND_VALUE(NewSignalFD(&mask, SFD_NONBLOCK));
+
+ // Block and deliver a signal.
+ const auto scoped_sigmask =
+ ASSERT_NO_ERRNO_AND_VALUE(ScopedSignalMask(SIG_BLOCK, kSigno));
+ ASSERT_THAT(tgkill(getpid(), gettid(), kSigno), SyscallSucceeds());
+
+ // We should have nothing.
+ struct signalfd_siginfo rbuf;
+ ASSERT_THAT(read(fd.get(), &rbuf, sizeof(rbuf)),
+ SyscallFailsWithErrno(EWOULDBLOCK));
+
+ // Change the signal mask.
+ sigaddset(&mask, kSigno);
+ ASSERT_THAT(signalfd(fd.get(), &mask, 0), SyscallSucceeds());
+
+ // We should now have the signal.
+ ASSERT_THAT(read(fd.get(), &rbuf, sizeof(rbuf)),
+ SyscallSucceedsWithValue(sizeof(rbuf)));
+ EXPECT_EQ(rbuf.ssi_signo, kSigno);
+}
+
+TEST(Signalfd, Poll) {
+ // Create the signalfd.
+ sigset_t mask;
+ sigemptyset(&mask);
+ sigaddset(&mask, kSigno);
+ FileDescriptor fd = ASSERT_NO_ERRNO_AND_VALUE(NewSignalFD(&mask, 0));
+
+ // Block the signal, and start a thread to deliver it.
+ const auto scoped_sigmask =
+ ASSERT_NO_ERRNO_AND_VALUE(ScopedSignalMask(SIG_BLOCK, kSigno));
+ pid_t orig_tid = gettid();
+ ScopedThread t([&] {
+ absl::SleepFor(absl::Seconds(5));
+ ASSERT_THAT(tgkill(getpid(), orig_tid, kSigno), SyscallSucceeds());
+ });
+
+ // Start polling for the signal. We expect that it is not available at the
+ // outset, but then becomes available when the signal is sent. We give a
+ // timeout of 10000ms (or the delay above + 5 seconds of additional grace
+ // time).
+ struct pollfd poll_fd = {fd.get(), POLLIN, 0};
+ EXPECT_THAT(RetryEINTR(poll)(&poll_fd, 1, 10000),
+ SyscallSucceedsWithValue(1));
+
+ // Actually read the signal to prevent delivery.
+ struct signalfd_siginfo rbuf;
+ EXPECT_THAT(read(fd.get(), &rbuf, sizeof(rbuf)),
+ SyscallSucceedsWithValue(sizeof(rbuf)));
+}
+
+TEST(Signalfd, KillStillKills) {
+ sigset_t mask;
+ sigemptyset(&mask);
+ sigaddset(&mask, SIGKILL);
+ FileDescriptor fd =
+ ASSERT_NO_ERRNO_AND_VALUE(NewSignalFD(&mask, SFD_CLOEXEC));
+
+ // Just because there is a signalfd, we shouldn't see any change in behavior
+ // for unblockable signals. It's easier to test this with SIGKILL.
+ const auto scoped_sigmask =
+ ASSERT_NO_ERRNO_AND_VALUE(ScopedSignalMask(SIG_BLOCK, SIGKILL));
+ EXPECT_EXIT(tgkill(getpid(), gettid(), SIGKILL), KilledBySignal(SIGKILL), "");
+}
+
+} // namespace
+
+} // namespace testing
+} // namespace gvisor
+
+int main(int argc, char** argv) {
+ // These tests depend on delivering signals. Block them up front so that all
+ // other threads created by TestInit will also have them blocked, and they
+ // will not interface with the rest of the test.
+ sigset_t set;
+ sigemptyset(&set);
+ sigaddset(&set, gvisor::testing::kSigno);
+ sigaddset(&set, gvisor::testing::kSignoAlt);
+ TEST_PCHECK(sigprocmask(SIG_BLOCK, &set, nullptr) == 0);
+
+ gvisor::testing::TestInit(&argc, &argv);
+
+ return RUN_ALL_TESTS();
+}
diff --git a/test/syscalls/linux/sigstop.cc b/test/syscalls/linux/sigstop.cc
index 9c7210e17..7db57d968 100644
--- a/test/syscalls/linux/sigstop.cc
+++ b/test/syscalls/linux/sigstop.cc
@@ -17,6 +17,7 @@
#include <sys/select.h>
#include "gtest/gtest.h"
+#include "absl/flags/flag.h"
#include "absl/time/clock.h"
#include "absl/time/time.h"
#include "test/util/multiprocess_util.h"
@@ -24,8 +25,8 @@
#include "test/util/test_util.h"
#include "test/util/thread_util.h"
-DEFINE_bool(sigstop_test_child, false,
- "If true, run the SigstopTest child workload.");
+ABSL_FLAG(bool, sigstop_test_child, false,
+ "If true, run the SigstopTest child workload.");
namespace gvisor {
namespace testing {
@@ -141,7 +142,7 @@ void RunChild() {
int main(int argc, char** argv) {
gvisor::testing::TestInit(&argc, &argv);
- if (FLAGS_sigstop_test_child) {
+ if (absl::GetFlag(FLAGS_sigstop_test_child)) {
gvisor::testing::RunChild();
return 1;
}
diff --git a/test/syscalls/linux/socket_ip_tcp_generic.cc b/test/syscalls/linux/socket_ip_tcp_generic.cc
index a43cf9bce..bfa7943b1 100644
--- a/test/syscalls/linux/socket_ip_tcp_generic.cc
+++ b/test/syscalls/linux/socket_ip_tcp_generic.cc
@@ -117,7 +117,7 @@ TEST_P(TCPSocketPairTest, RSTCausesPollHUP) {
struct pollfd poll_fd3 = {sockets->first_fd(), POLLHUP, 0};
ASSERT_THAT(RetryEINTR(poll)(&poll_fd3, 1, kPollTimeoutMs),
SyscallSucceedsWithValue(1));
- ASSERT_NE(poll_fd.revents & (POLLHUP | POLLIN), 0);
+ ASSERT_NE(poll_fd3.revents & POLLHUP, 0);
}
// This test validates that even if a RST is sent the other end will not
diff --git a/test/syscalls/linux/splice.cc b/test/syscalls/linux/splice.cc
index e25f264f6..85232cb1f 100644
--- a/test/syscalls/linux/splice.cc
+++ b/test/syscalls/linux/splice.cc
@@ -14,12 +14,16 @@
#include <fcntl.h>
#include <sys/eventfd.h>
+#include <sys/resource.h>
#include <sys/sendfile.h>
+#include <sys/time.h>
#include <unistd.h>
#include "gmock/gmock.h"
#include "gtest/gtest.h"
#include "absl/strings/string_view.h"
+#include "absl/time/clock.h"
+#include "absl/time/time.h"
#include "test/util/file_descriptor.h"
#include "test/util/temp_path.h"
#include "test/util/test_util.h"
@@ -36,23 +40,23 @@ TEST(SpliceTest, TwoRegularFiles) {
const TempPath out_file = ASSERT_NO_ERRNO_AND_VALUE(TempPath::CreateFile());
// Open the input file as read only.
- const FileDescriptor inf =
+ const FileDescriptor in_fd =
ASSERT_NO_ERRNO_AND_VALUE(Open(in_file.path(), O_RDONLY));
// Open the output file as write only.
- const FileDescriptor outf =
+ const FileDescriptor out_fd =
ASSERT_NO_ERRNO_AND_VALUE(Open(out_file.path(), O_WRONLY));
// Verify that it is rejected as expected; regardless of offsets.
loff_t in_offset = 0;
loff_t out_offset = 0;
- EXPECT_THAT(splice(inf.get(), &in_offset, outf.get(), &out_offset, 1, 0),
+ EXPECT_THAT(splice(in_fd.get(), &in_offset, out_fd.get(), &out_offset, 1, 0),
SyscallFailsWithErrno(EINVAL));
- EXPECT_THAT(splice(inf.get(), nullptr, outf.get(), &out_offset, 1, 0),
+ EXPECT_THAT(splice(in_fd.get(), nullptr, out_fd.get(), &out_offset, 1, 0),
SyscallFailsWithErrno(EINVAL));
- EXPECT_THAT(splice(inf.get(), &in_offset, outf.get(), nullptr, 1, 0),
+ EXPECT_THAT(splice(in_fd.get(), &in_offset, out_fd.get(), nullptr, 1, 0),
SyscallFailsWithErrno(EINVAL));
- EXPECT_THAT(splice(inf.get(), nullptr, outf.get(), nullptr, 1, 0),
+ EXPECT_THAT(splice(in_fd.get(), nullptr, out_fd.get(), nullptr, 1, 0),
SyscallFailsWithErrno(EINVAL));
}
@@ -75,8 +79,6 @@ TEST(SpliceTest, SamePipe) {
}
TEST(TeeTest, SamePipe) {
- SKIP_IF(IsRunningOnGvisor());
-
// Create a new pipe.
int fds[2];
ASSERT_THAT(pipe(fds), SyscallSucceeds());
@@ -95,11 +97,9 @@ TEST(TeeTest, SamePipe) {
}
TEST(TeeTest, RegularFile) {
- SKIP_IF(IsRunningOnGvisor());
-
// Open some file.
const TempPath in_file = ASSERT_NO_ERRNO_AND_VALUE(TempPath::CreateFile());
- const FileDescriptor inf =
+ const FileDescriptor in_fd =
ASSERT_NO_ERRNO_AND_VALUE(Open(in_file.path(), O_RDWR));
// Create a new pipe.
@@ -109,9 +109,9 @@ TEST(TeeTest, RegularFile) {
const FileDescriptor wfd(fds[1]);
// Attempt to tee from the file.
- EXPECT_THAT(tee(inf.get(), wfd.get(), kPageSize, 0),
+ EXPECT_THAT(tee(in_fd.get(), wfd.get(), kPageSize, 0),
SyscallFailsWithErrno(EINVAL));
- EXPECT_THAT(tee(rfd.get(), inf.get(), kPageSize, 0),
+ EXPECT_THAT(tee(rfd.get(), in_fd.get(), kPageSize, 0),
SyscallFailsWithErrno(EINVAL));
}
@@ -142,7 +142,7 @@ TEST(SpliceTest, FromEventFD) {
constexpr uint64_t kEventFDValue = 1;
int efd;
ASSERT_THAT(efd = eventfd(kEventFDValue, 0), SyscallSucceeds());
- const FileDescriptor inf(efd);
+ const FileDescriptor in_fd(efd);
// Create a new pipe.
int fds[2];
@@ -152,7 +152,7 @@ TEST(SpliceTest, FromEventFD) {
// Splice 8-byte eventfd value to pipe.
constexpr int kEventFDSize = 8;
- EXPECT_THAT(splice(inf.get(), nullptr, wfd.get(), nullptr, kEventFDSize, 0),
+ EXPECT_THAT(splice(in_fd.get(), nullptr, wfd.get(), nullptr, kEventFDSize, 0),
SyscallSucceedsWithValue(kEventFDSize));
// Contents should be equal.
@@ -166,7 +166,7 @@ TEST(SpliceTest, FromEventFD) {
TEST(SpliceTest, FromEventFDOffset) {
int efd;
ASSERT_THAT(efd = eventfd(0, 0), SyscallSucceeds());
- const FileDescriptor inf(efd);
+ const FileDescriptor in_fd(efd);
// Create a new pipe.
int fds[2];
@@ -179,7 +179,7 @@ TEST(SpliceTest, FromEventFDOffset) {
// This is not allowed because eventfd doesn't support pread.
constexpr int kEventFDSize = 8;
loff_t in_off = 0;
- EXPECT_THAT(splice(inf.get(), &in_off, wfd.get(), nullptr, kEventFDSize, 0),
+ EXPECT_THAT(splice(in_fd.get(), &in_off, wfd.get(), nullptr, kEventFDSize, 0),
SyscallFailsWithErrno(EINVAL));
}
@@ -200,28 +200,29 @@ TEST(SpliceTest, ToEventFDOffset) {
int efd;
ASSERT_THAT(efd = eventfd(0, 0), SyscallSucceeds());
- const FileDescriptor outf(efd);
+ const FileDescriptor out_fd(efd);
// Attempt to splice 8-byte eventfd value to pipe with offset.
//
// This is not allowed because eventfd doesn't support pwrite.
loff_t out_off = 0;
- EXPECT_THAT(splice(rfd.get(), nullptr, outf.get(), &out_off, kEventFDSize, 0),
- SyscallFailsWithErrno(EINVAL));
+ EXPECT_THAT(
+ splice(rfd.get(), nullptr, out_fd.get(), &out_off, kEventFDSize, 0),
+ SyscallFailsWithErrno(EINVAL));
}
TEST(SpliceTest, ToPipe) {
// Open the input file.
const TempPath in_file = ASSERT_NO_ERRNO_AND_VALUE(TempPath::CreateFile());
- const FileDescriptor inf =
+ const FileDescriptor in_fd =
ASSERT_NO_ERRNO_AND_VALUE(Open(in_file.path(), O_RDWR));
// Fill with some random data.
std::vector<char> buf(kPageSize);
RandomizeBuffer(buf.data(), buf.size());
- ASSERT_THAT(write(inf.get(), buf.data(), buf.size()),
+ ASSERT_THAT(write(in_fd.get(), buf.data(), buf.size()),
SyscallSucceedsWithValue(kPageSize));
- ASSERT_THAT(lseek(inf.get(), 0, SEEK_SET), SyscallSucceedsWithValue(0));
+ ASSERT_THAT(lseek(in_fd.get(), 0, SEEK_SET), SyscallSucceedsWithValue(0));
// Create a new pipe.
int fds[2];
@@ -230,7 +231,7 @@ TEST(SpliceTest, ToPipe) {
const FileDescriptor wfd(fds[1]);
// Splice to the pipe.
- EXPECT_THAT(splice(inf.get(), nullptr, wfd.get(), nullptr, kPageSize, 0),
+ EXPECT_THAT(splice(in_fd.get(), nullptr, wfd.get(), nullptr, kPageSize, 0),
SyscallSucceedsWithValue(kPageSize));
// Contents should be equal.
@@ -243,13 +244,13 @@ TEST(SpliceTest, ToPipe) {
TEST(SpliceTest, ToPipeOffset) {
// Open the input file.
const TempPath in_file = ASSERT_NO_ERRNO_AND_VALUE(TempPath::CreateFile());
- const FileDescriptor inf =
+ const FileDescriptor in_fd =
ASSERT_NO_ERRNO_AND_VALUE(Open(in_file.path(), O_RDWR));
// Fill with some random data.
std::vector<char> buf(kPageSize);
RandomizeBuffer(buf.data(), buf.size());
- ASSERT_THAT(write(inf.get(), buf.data(), buf.size()),
+ ASSERT_THAT(write(in_fd.get(), buf.data(), buf.size()),
SyscallSucceedsWithValue(kPageSize));
// Create a new pipe.
@@ -261,7 +262,7 @@ TEST(SpliceTest, ToPipeOffset) {
// Splice to the pipe.
loff_t in_offset = kPageSize / 2;
EXPECT_THAT(
- splice(inf.get(), &in_offset, wfd.get(), nullptr, kPageSize / 2, 0),
+ splice(in_fd.get(), &in_offset, wfd.get(), nullptr, kPageSize / 2, 0),
SyscallSucceedsWithValue(kPageSize / 2));
// Contents should be equal to only the second part.
@@ -286,22 +287,22 @@ TEST(SpliceTest, FromPipe) {
// Open the input file.
const TempPath out_file = ASSERT_NO_ERRNO_AND_VALUE(TempPath::CreateFile());
- const FileDescriptor outf =
+ const FileDescriptor out_fd =
ASSERT_NO_ERRNO_AND_VALUE(Open(out_file.path(), O_RDWR));
// Splice to the output file.
- EXPECT_THAT(splice(rfd.get(), nullptr, outf.get(), nullptr, kPageSize, 0),
+ EXPECT_THAT(splice(rfd.get(), nullptr, out_fd.get(), nullptr, kPageSize, 0),
SyscallSucceedsWithValue(kPageSize));
// The offset of the output should be equal to kPageSize. We assert that and
// reset to zero so that we can read the contents and ensure they match.
- EXPECT_THAT(lseek(outf.get(), 0, SEEK_CUR),
+ EXPECT_THAT(lseek(out_fd.get(), 0, SEEK_CUR),
SyscallSucceedsWithValue(kPageSize));
- ASSERT_THAT(lseek(outf.get(), 0, SEEK_SET), SyscallSucceedsWithValue(0));
+ ASSERT_THAT(lseek(out_fd.get(), 0, SEEK_SET), SyscallSucceedsWithValue(0));
// Contents should be equal.
std::vector<char> rbuf(kPageSize);
- ASSERT_THAT(read(outf.get(), rbuf.data(), rbuf.size()),
+ ASSERT_THAT(read(out_fd.get(), rbuf.data(), rbuf.size()),
SyscallSucceedsWithValue(kPageSize));
EXPECT_EQ(memcmp(rbuf.data(), buf.data(), buf.size()), 0);
}
@@ -321,18 +322,19 @@ TEST(SpliceTest, FromPipeOffset) {
// Open the input file.
const TempPath out_file = ASSERT_NO_ERRNO_AND_VALUE(TempPath::CreateFile());
- const FileDescriptor outf =
+ const FileDescriptor out_fd =
ASSERT_NO_ERRNO_AND_VALUE(Open(out_file.path(), O_RDWR));
// Splice to the output file.
loff_t out_offset = kPageSize / 2;
- EXPECT_THAT(splice(rfd.get(), nullptr, outf.get(), &out_offset, kPageSize, 0),
- SyscallSucceedsWithValue(kPageSize));
+ EXPECT_THAT(
+ splice(rfd.get(), nullptr, out_fd.get(), &out_offset, kPageSize, 0),
+ SyscallSucceedsWithValue(kPageSize));
// Content should reflect the splice. We write to a specific offset in the
// file, so the internals should now be allocated sparsely.
std::vector<char> rbuf(kPageSize);
- ASSERT_THAT(read(outf.get(), rbuf.data(), rbuf.size()),
+ ASSERT_THAT(read(out_fd.get(), rbuf.data(), rbuf.size()),
SyscallSucceedsWithValue(kPageSize));
std::vector<char> zbuf(kPageSize / 2);
memset(zbuf.data(), 0, zbuf.size());
@@ -404,8 +406,6 @@ TEST(SpliceTest, Blocking) {
}
TEST(TeeTest, Blocking) {
- SKIP_IF(IsRunningOnGvisor());
-
// Create two new pipes.
int first[2], second[2];
ASSERT_THAT(pipe(first), SyscallSucceeds());
@@ -440,6 +440,49 @@ TEST(TeeTest, Blocking) {
EXPECT_EQ(memcmp(rbuf.data(), buf.data(), kPageSize), 0);
}
+TEST(TeeTest, BlockingWrite) {
+ // Create two new pipes.
+ int first[2], second[2];
+ ASSERT_THAT(pipe(first), SyscallSucceeds());
+ const FileDescriptor rfd1(first[0]);
+ const FileDescriptor wfd1(first[1]);
+ ASSERT_THAT(pipe(second), SyscallSucceeds());
+ const FileDescriptor rfd2(second[0]);
+ const FileDescriptor wfd2(second[1]);
+
+ // Make some data available to be read.
+ std::vector<char> buf1(kPageSize);
+ RandomizeBuffer(buf1.data(), buf1.size());
+ ASSERT_THAT(write(wfd1.get(), buf1.data(), buf1.size()),
+ SyscallSucceedsWithValue(kPageSize));
+
+ // Fill up the write pipe's buffer.
+ int pipe_size = -1;
+ ASSERT_THAT(pipe_size = fcntl(wfd2.get(), F_GETPIPE_SZ), SyscallSucceeds());
+ std::vector<char> buf2(pipe_size);
+ ASSERT_THAT(write(wfd2.get(), buf2.data(), buf2.size()),
+ SyscallSucceedsWithValue(pipe_size));
+
+ ScopedThread t([&]() {
+ absl::SleepFor(absl::Milliseconds(100));
+ ASSERT_THAT(read(rfd2.get(), buf2.data(), buf2.size()),
+ SyscallSucceedsWithValue(pipe_size));
+ });
+
+ // Attempt a tee immediately; it should block.
+ EXPECT_THAT(tee(rfd1.get(), wfd2.get(), kPageSize, 0),
+ SyscallSucceedsWithValue(kPageSize));
+
+ // Thread should be joinable.
+ t.Join();
+
+ // Content should reflect the tee.
+ std::vector<char> rbuf(kPageSize);
+ ASSERT_THAT(read(rfd2.get(), rbuf.data(), rbuf.size()),
+ SyscallSucceedsWithValue(kPageSize));
+ EXPECT_EQ(memcmp(rbuf.data(), buf1.data(), kPageSize), 0);
+}
+
TEST(SpliceTest, NonBlocking) {
// Create two new pipes.
int first[2], second[2];
@@ -457,8 +500,6 @@ TEST(SpliceTest, NonBlocking) {
}
TEST(TeeTest, NonBlocking) {
- SKIP_IF(IsRunningOnGvisor());
-
// Create two new pipes.
int first[2], second[2];
ASSERT_THAT(pipe(first), SyscallSucceeds());
@@ -473,6 +514,79 @@ TEST(TeeTest, NonBlocking) {
SyscallFailsWithErrno(EAGAIN));
}
+TEST(TeeTest, MultiPage) {
+ // Create two new pipes.
+ int first[2], second[2];
+ ASSERT_THAT(pipe(first), SyscallSucceeds());
+ const FileDescriptor rfd1(first[0]);
+ const FileDescriptor wfd1(first[1]);
+ ASSERT_THAT(pipe(second), SyscallSucceeds());
+ const FileDescriptor rfd2(second[0]);
+ const FileDescriptor wfd2(second[1]);
+
+ // Make some data available to be read.
+ std::vector<char> wbuf(8 * kPageSize);
+ RandomizeBuffer(wbuf.data(), wbuf.size());
+ ASSERT_THAT(write(wfd1.get(), wbuf.data(), wbuf.size()),
+ SyscallSucceedsWithValue(wbuf.size()));
+
+ // Attempt a tee immediately; it should complete.
+ EXPECT_THAT(tee(rfd1.get(), wfd2.get(), wbuf.size(), 0),
+ SyscallSucceedsWithValue(wbuf.size()));
+
+ // Content should reflect the tee.
+ std::vector<char> rbuf(wbuf.size());
+ ASSERT_THAT(read(rfd2.get(), rbuf.data(), rbuf.size()),
+ SyscallSucceedsWithValue(rbuf.size()));
+ EXPECT_EQ(memcmp(rbuf.data(), wbuf.data(), rbuf.size()), 0);
+ ASSERT_THAT(read(rfd1.get(), rbuf.data(), rbuf.size()),
+ SyscallSucceedsWithValue(rbuf.size()));
+ EXPECT_EQ(memcmp(rbuf.data(), wbuf.data(), rbuf.size()), 0);
+}
+
+TEST(SpliceTest, FromPipeMaxFileSize) {
+ // Create a new pipe.
+ int fds[2];
+ ASSERT_THAT(pipe(fds), SyscallSucceeds());
+ const FileDescriptor rfd(fds[0]);
+ const FileDescriptor wfd(fds[1]);
+
+ // Fill with some random data.
+ std::vector<char> buf(kPageSize);
+ RandomizeBuffer(buf.data(), buf.size());
+ ASSERT_THAT(write(wfd.get(), buf.data(), buf.size()),
+ SyscallSucceedsWithValue(kPageSize));
+
+ // Open the input file.
+ const TempPath out_file = ASSERT_NO_ERRNO_AND_VALUE(TempPath::CreateFile());
+ const FileDescriptor out_fd =
+ ASSERT_NO_ERRNO_AND_VALUE(Open(out_file.path(), O_RDWR));
+
+ EXPECT_THAT(ftruncate(out_fd.get(), 13 << 20), SyscallSucceeds());
+ EXPECT_THAT(lseek(out_fd.get(), 0, SEEK_END),
+ SyscallSucceedsWithValue(13 << 20));
+
+ // Set our file size limit.
+ sigset_t set;
+ sigemptyset(&set);
+ sigaddset(&set, SIGXFSZ);
+ TEST_PCHECK(sigprocmask(SIG_BLOCK, &set, nullptr) == 0);
+ rlimit rlim = {};
+ rlim.rlim_cur = rlim.rlim_max = (13 << 20);
+ EXPECT_THAT(setrlimit(RLIMIT_FSIZE, &rlim), SyscallSucceeds());
+
+ // Splice to the output file.
+ EXPECT_THAT(
+ splice(rfd.get(), nullptr, out_fd.get(), nullptr, 3 * kPageSize, 0),
+ SyscallFailsWithErrno(EFBIG));
+
+ // Contents should be equal.
+ std::vector<char> rbuf(kPageSize);
+ ASSERT_THAT(read(rfd.get(), rbuf.data(), rbuf.size()),
+ SyscallSucceedsWithValue(kPageSize));
+ EXPECT_EQ(memcmp(rbuf.data(), buf.data(), buf.size()), 0);
+}
+
} // namespace
} // namespace testing
diff --git a/test/syscalls/linux/sticky.cc b/test/syscalls/linux/sticky.cc
index 59fb5dfe6..7e73325bf 100644
--- a/test/syscalls/linux/sticky.cc
+++ b/test/syscalls/linux/sticky.cc
@@ -17,9 +17,11 @@
#include <sys/prctl.h>
#include <sys/types.h>
#include <unistd.h>
+
#include <vector>
#include "gtest/gtest.h"
+#include "absl/flags/flag.h"
#include "test/util/capability_util.h"
#include "test/util/file_descriptor.h"
#include "test/util/fs_util.h"
@@ -27,8 +29,8 @@
#include "test/util/test_util.h"
#include "test/util/thread_util.h"
-DEFINE_int32(scratch_uid, 65534, "first scratch UID");
-DEFINE_int32(scratch_gid, 65534, "first scratch GID");
+ABSL_FLAG(int32_t, scratch_uid, 65534, "first scratch UID");
+ABSL_FLAG(int32_t, scratch_gid, 65534, "first scratch GID");
namespace gvisor {
namespace testing {
@@ -52,10 +54,12 @@ TEST(StickyTest, StickyBitPermDenied) {
}
// Change EUID and EGID.
- EXPECT_THAT(syscall(SYS_setresgid, -1, FLAGS_scratch_gid, -1),
- SyscallSucceeds());
- EXPECT_THAT(syscall(SYS_setresuid, -1, FLAGS_scratch_uid, -1),
- SyscallSucceeds());
+ EXPECT_THAT(
+ syscall(SYS_setresgid, -1, absl::GetFlag(FLAGS_scratch_gid), -1),
+ SyscallSucceeds());
+ EXPECT_THAT(
+ syscall(SYS_setresuid, -1, absl::GetFlag(FLAGS_scratch_uid), -1),
+ SyscallSucceeds());
EXPECT_THAT(rmdir(path.c_str()), SyscallFailsWithErrno(EPERM));
});
@@ -78,8 +82,9 @@ TEST(StickyTest, StickyBitSameUID) {
}
// Change EGID.
- EXPECT_THAT(syscall(SYS_setresgid, -1, FLAGS_scratch_gid, -1),
- SyscallSucceeds());
+ EXPECT_THAT(
+ syscall(SYS_setresgid, -1, absl::GetFlag(FLAGS_scratch_gid), -1),
+ SyscallSucceeds());
// We still have the same EUID.
EXPECT_THAT(rmdir(path.c_str()), SyscallSucceeds());
@@ -101,10 +106,12 @@ TEST(StickyTest, StickyBitCapFOWNER) {
EXPECT_THAT(prctl(PR_SET_KEEPCAPS, 1, 0, 0, 0), SyscallSucceeds());
// Change EUID and EGID.
- EXPECT_THAT(syscall(SYS_setresgid, -1, FLAGS_scratch_gid, -1),
- SyscallSucceeds());
- EXPECT_THAT(syscall(SYS_setresuid, -1, FLAGS_scratch_uid, -1),
- SyscallSucceeds());
+ EXPECT_THAT(
+ syscall(SYS_setresgid, -1, absl::GetFlag(FLAGS_scratch_gid), -1),
+ SyscallSucceeds());
+ EXPECT_THAT(
+ syscall(SYS_setresuid, -1, absl::GetFlag(FLAGS_scratch_uid), -1),
+ SyscallSucceeds());
EXPECT_NO_ERRNO(SetCapability(CAP_FOWNER, true));
EXPECT_THAT(rmdir(path.c_str()), SyscallSucceeds());
diff --git a/test/syscalls/linux/timers.cc b/test/syscalls/linux/timers.cc
index fd42e81e1..3db18d7ac 100644
--- a/test/syscalls/linux/timers.cc
+++ b/test/syscalls/linux/timers.cc
@@ -23,6 +23,7 @@
#include <atomic>
#include "gtest/gtest.h"
+#include "absl/flags/flag.h"
#include "absl/time/clock.h"
#include "absl/time/time.h"
#include "test/util/cleanup.h"
@@ -33,8 +34,8 @@
#include "test/util/test_util.h"
#include "test/util/thread_util.h"
-DEFINE_bool(timers_test_sleep, false,
- "If true, sleep forever instead of running tests.");
+ABSL_FLAG(bool, timers_test_sleep, false,
+ "If true, sleep forever instead of running tests.");
using ::testing::_;
using ::testing::AnyOf;
@@ -635,7 +636,7 @@ TEST(IntervalTimerTest, IgnoredSignalCountsAsOverrun) {
int main(int argc, char** argv) {
gvisor::testing::TestInit(&argc, &argv);
- if (FLAGS_timers_test_sleep) {
+ if (absl::GetFlag(FLAGS_timers_test_sleep)) {
while (true) {
absl::SleepFor(absl::Seconds(10));
}
diff --git a/test/syscalls/linux/uidgid.cc b/test/syscalls/linux/uidgid.cc
index bf1ca8679..d48453a93 100644
--- a/test/syscalls/linux/uidgid.cc
+++ b/test/syscalls/linux/uidgid.cc
@@ -18,6 +18,7 @@
#include <unistd.h>
#include "gtest/gtest.h"
+#include "absl/flags/flag.h"
#include "absl/strings/str_cat.h"
#include "absl/strings/str_join.h"
#include "test/util/capability_util.h"
@@ -25,10 +26,10 @@
#include "test/util/test_util.h"
#include "test/util/thread_util.h"
-DEFINE_int32(scratch_uid1, 65534, "first scratch UID");
-DEFINE_int32(scratch_uid2, 65533, "second scratch UID");
-DEFINE_int32(scratch_gid1, 65534, "first scratch GID");
-DEFINE_int32(scratch_gid2, 65533, "second scratch GID");
+ABSL_FLAG(int32_t, scratch_uid1, 65534, "first scratch UID");
+ABSL_FLAG(int32_t, scratch_uid2, 65533, "second scratch UID");
+ABSL_FLAG(int32_t, scratch_gid1, 65534, "first scratch GID");
+ABSL_FLAG(int32_t, scratch_gid2, 65533, "second scratch GID");
using ::testing::UnorderedElementsAreArray;
@@ -146,7 +147,7 @@ TEST(UidGidRootTest, Setuid) {
// real UID.
EXPECT_THAT(syscall(SYS_setuid, -1), SyscallFailsWithErrno(EINVAL));
- const uid_t uid = FLAGS_scratch_uid1;
+ const uid_t uid = absl::GetFlag(FLAGS_scratch_uid1);
EXPECT_THAT(syscall(SYS_setuid, uid), SyscallSucceeds());
// "If the effective UID of the caller is root (more precisely: if the
// caller has the CAP_SETUID capability), the real UID and saved set-user-ID
@@ -160,7 +161,7 @@ TEST(UidGidRootTest, Setgid) {
EXPECT_THAT(setgid(-1), SyscallFailsWithErrno(EINVAL));
- const gid_t gid = FLAGS_scratch_gid1;
+ const gid_t gid = absl::GetFlag(FLAGS_scratch_gid1);
ASSERT_THAT(setgid(gid), SyscallSucceeds());
EXPECT_NO_ERRNO(CheckGIDs(gid, gid, gid));
}
@@ -168,7 +169,7 @@ TEST(UidGidRootTest, Setgid) {
TEST(UidGidRootTest, SetgidNotFromThreadGroupLeader) {
SKIP_IF(!ASSERT_NO_ERRNO_AND_VALUE(IsRoot()));
- const gid_t gid = FLAGS_scratch_gid1;
+ const gid_t gid = absl::GetFlag(FLAGS_scratch_gid1);
// NOTE(b/64676707): Do setgid in a separate thread so that we can test if
// info.si_pid is set correctly.
ScopedThread([gid] { ASSERT_THAT(setgid(gid), SyscallSucceeds()); });
@@ -189,8 +190,8 @@ TEST(UidGidRootTest, Setreuid) {
// cannot be opened by the `uid` set below after the test. After calling
// setuid(non-zero-UID), there is no way to get root privileges back.
ScopedThread([&] {
- const uid_t ruid = FLAGS_scratch_uid1;
- const uid_t euid = FLAGS_scratch_uid2;
+ const uid_t ruid = absl::GetFlag(FLAGS_scratch_uid1);
+ const uid_t euid = absl::GetFlag(FLAGS_scratch_uid2);
// Use syscall instead of glibc setuid wrapper because we want this setuid
// call to only apply to this task. posix threads, however, require that all
@@ -211,8 +212,8 @@ TEST(UidGidRootTest, Setregid) {
EXPECT_THAT(setregid(-1, -1), SyscallSucceeds());
EXPECT_NO_ERRNO(CheckGIDs(0, 0, 0));
- const gid_t rgid = FLAGS_scratch_gid1;
- const gid_t egid = FLAGS_scratch_gid2;
+ const gid_t rgid = absl::GetFlag(FLAGS_scratch_gid1);
+ const gid_t egid = absl::GetFlag(FLAGS_scratch_gid2);
ASSERT_THAT(setregid(rgid, egid), SyscallSucceeds());
EXPECT_NO_ERRNO(CheckGIDs(rgid, egid, egid));
}
diff --git a/test/syscalls/linux/unlink.cc b/test/syscalls/linux/unlink.cc
index b6f65e027..2040375c9 100644
--- a/test/syscalls/linux/unlink.cc
+++ b/test/syscalls/linux/unlink.cc
@@ -123,6 +123,8 @@ TEST(UnlinkTest, AtBad) {
SyscallSucceeds());
EXPECT_THAT(unlinkat(dirfd, "UnlinkAtFile", AT_REMOVEDIR),
SyscallFailsWithErrno(ENOTDIR));
+ EXPECT_THAT(unlinkat(dirfd, "UnlinkAtFile/", 0),
+ SyscallFailsWithErrno(ENOTDIR));
ASSERT_THAT(close(fd), SyscallSucceeds());
EXPECT_THAT(unlinkat(dirfd, "UnlinkAtFile", 0), SyscallSucceeds());
diff --git a/test/syscalls/linux/vfork.cc b/test/syscalls/linux/vfork.cc
index f67b06f37..0aaba482d 100644
--- a/test/syscalls/linux/vfork.cc
+++ b/test/syscalls/linux/vfork.cc
@@ -22,14 +22,15 @@
#include "gmock/gmock.h"
#include "gtest/gtest.h"
+#include "absl/flags/flag.h"
#include "absl/time/time.h"
#include "test/util/logging.h"
#include "test/util/multiprocess_util.h"
#include "test/util/test_util.h"
#include "test/util/time_util.h"
-DEFINE_bool(vfork_test_child, false,
- "If true, run the VforkTest child workload.");
+ABSL_FLAG(bool, vfork_test_child, false,
+ "If true, run the VforkTest child workload.");
namespace gvisor {
namespace testing {
@@ -186,7 +187,7 @@ int RunChild() {
int main(int argc, char** argv) {
gvisor::testing::TestInit(&argc, &argv);
- if (FLAGS_vfork_test_child) {
+ if (absl::GetFlag(FLAGS_vfork_test_child)) {
return gvisor::testing::RunChild();
}
diff --git a/test/syscalls/syscall_test_runner.go b/test/syscalls/syscall_test_runner.go
index e900f8abc..c1e9ce22c 100644
--- a/test/syscalls/syscall_test_runner.go
+++ b/test/syscalls/syscall_test_runner.go
@@ -20,12 +20,10 @@ import (
"flag"
"fmt"
"io/ioutil"
- "math"
"os"
"os/exec"
"os/signal"
"path/filepath"
- "strconv"
"strings"
"syscall"
"testing"
@@ -358,32 +356,14 @@ func main() {
fatalf("ParseTestCases(%q) failed: %v", testBin, err)
}
- // If sharding, then get the subset of tests to run based on the shard index.
- if indexStr, totalStr := os.Getenv("TEST_SHARD_INDEX"), os.Getenv("TEST_TOTAL_SHARDS"); indexStr != "" && totalStr != "" {
- // Parse index and total to ints.
- index, err := strconv.Atoi(indexStr)
- if err != nil {
- fatalf("invalid TEST_SHARD_INDEX %q: %v", indexStr, err)
- }
- total, err := strconv.Atoi(totalStr)
- if err != nil {
- fatalf("invalid TEST_TOTAL_SHARDS %q: %v", totalStr, err)
- }
- // Calculate subslice of tests to run.
- shardSize := int(math.Ceil(float64(len(testCases)) / float64(total)))
- begin := index * shardSize
- // Set end as begin of next subslice.
- end := ((index + 1) * shardSize)
- if begin > len(testCases) {
- // Nothing to run.
- return
- }
- if end > len(testCases) {
- end = len(testCases)
- }
- testCases = testCases[begin:end]
+ // Get subset of tests corresponding to shard.
+ begin, end, err := testutil.TestBoundsForShard(len(testCases))
+ if err != nil {
+ fatalf("TestsForShard() failed: %v", err)
}
+ testCases = testCases[begin:end]
+ // Run the tests.
var tests []testing.InternalTest
for _, tc := range testCases {
// Capture tc.
diff --git a/test/util/BUILD b/test/util/BUILD
index 8afd89d8d..25ed9c944 100644
--- a/test/util/BUILD
+++ b/test/util/BUILD
@@ -1,3 +1,4 @@
+load("@rules_cc//cc:defs.bzl", "cc_library", "cc_test")
load("//test/syscalls:build_defs.bzl", "select_for_linux")
package(
@@ -190,6 +191,17 @@ cc_test(
)
cc_library(
+ name = "pty_util",
+ testonly = 1,
+ srcs = ["pty_util.cc"],
+ hdrs = ["pty_util.h"],
+ deps = [
+ ":file_descriptor",
+ ":posix_error",
+ ],
+)
+
+cc_library(
name = "signal_util",
testonly = 1,
srcs = ["signal_util.cc"],
@@ -227,8 +239,9 @@ cc_library(
":logging",
":posix_error",
":save_util",
- "@com_github_gflags_gflags//:gflags",
"@com_google_absl//absl/base:core_headers",
+ "@com_google_absl//absl/flags:flag",
+ "@com_google_absl//absl/flags:parse",
"@com_google_absl//absl/strings",
"@com_google_absl//absl/strings:str_format",
"@com_google_absl//absl/time",
diff --git a/test/util/memory_util.h b/test/util/memory_util.h
index 190c469b5..e189b73e8 100644
--- a/test/util/memory_util.h
+++ b/test/util/memory_util.h
@@ -118,6 +118,18 @@ inline PosixErrorOr<Mapping> MmapAnon(size_t length, int prot, int flags) {
return Mmap(nullptr, length, prot, flags | MAP_ANONYMOUS, -1, 0);
}
+// Wrapper for mremap that returns a PosixErrorOr<>, since the return type of
+// void* isn't directly compatible with SyscallSucceeds.
+inline PosixErrorOr<void*> Mremap(void* old_address, size_t old_size,
+ size_t new_size, int flags,
+ void* new_address) {
+ void* rv = mremap(old_address, old_size, new_size, flags, new_address);
+ if (rv == MAP_FAILED) {
+ return PosixError(errno, "mremap failed");
+ }
+ return rv;
+}
+
// Returns true if the page containing addr is mapped.
inline bool IsMapped(uintptr_t addr) {
int const rv = msync(reinterpret_cast<void*>(addr & ~(kPageSize - 1)),
diff --git a/test/util/pty_util.cc b/test/util/pty_util.cc
new file mode 100644
index 000000000..c0fd9a095
--- /dev/null
+++ b/test/util/pty_util.cc
@@ -0,0 +1,45 @@
+// Copyright 2019 The gVisor Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+#include "test/util/pty_util.h"
+
+#include <sys/ioctl.h>
+#include <termios.h>
+
+#include "test/util/file_descriptor.h"
+#include "test/util/posix_error.h"
+
+namespace gvisor {
+namespace testing {
+
+PosixErrorOr<FileDescriptor> OpenSlave(const FileDescriptor& master) {
+ // Get pty index.
+ int n;
+ int ret = ioctl(master.get(), TIOCGPTN, &n);
+ if (ret < 0) {
+ return PosixError(errno, "ioctl(TIOCGPTN) failed");
+ }
+
+ // Unlock pts.
+ int unlock = 0;
+ ret = ioctl(master.get(), TIOCSPTLCK, &unlock);
+ if (ret < 0) {
+ return PosixError(errno, "ioctl(TIOSPTLCK) failed");
+ }
+
+ return Open(absl::StrCat("/dev/pts/", n), O_RDWR | O_NONBLOCK);
+}
+
+} // namespace testing
+} // namespace gvisor
diff --git a/test/util/pty_util.h b/test/util/pty_util.h
new file mode 100644
index 000000000..367b14f15
--- /dev/null
+++ b/test/util/pty_util.h
@@ -0,0 +1,30 @@
+// Copyright 2019 The gVisor Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+#ifndef GVISOR_TEST_UTIL_PTY_UTIL_H_
+#define GVISOR_TEST_UTIL_PTY_UTIL_H_
+
+#include "test/util/file_descriptor.h"
+#include "test/util/posix_error.h"
+
+namespace gvisor {
+namespace testing {
+
+// Opens the slave end of the passed master as R/W and nonblocking.
+PosixErrorOr<FileDescriptor> OpenSlave(const FileDescriptor& master);
+
+} // namespace testing
+} // namespace gvisor
+
+#endif // GVISOR_TEST_UTIL_PTY_UTIL_H_
diff --git a/test/util/test_util.cc b/test/util/test_util.cc
index e42bba04a..ba0dcf7d0 100644
--- a/test/util/test_util.cc
+++ b/test/util/test_util.cc
@@ -28,6 +28,8 @@
#include <vector>
#include "absl/base/attributes.h"
+#include "absl/flags/flag.h" // IWYU pragma: keep
+#include "absl/flags/parse.h" // IWYU pragma: keep
#include "absl/strings/numbers.h"
#include "absl/strings/str_cat.h"
#include "absl/strings/str_split.h"
@@ -224,7 +226,7 @@ bool Equivalent(uint64_t current, uint64_t target, double tolerance) {
void TestInit(int* argc, char*** argv) {
::testing::InitGoogleTest(argc, *argv);
- ::gflags::ParseCommandLineFlags(argc, argv, true);
+ ::absl::ParseCommandLine(*argc, *argv);
// Always mask SIGPIPE as it's common and tests aren't expected to handle it.
struct sigaction sa = {};
diff --git a/test/util/test_util.h b/test/util/test_util.h
index cdbe8bfd1..b9d2dc2ba 100644
--- a/test/util/test_util.h
+++ b/test/util/test_util.h
@@ -185,7 +185,6 @@
#include <utility>
#include <vector>
-#include <gflags/gflags.h>
#include "gmock/gmock.h"
#include "absl/strings/str_cat.h"
#include "absl/strings/str_format.h"
diff --git a/third_party/gvsync/downgradable_rwmutex_unsafe.go b/third_party/gvsync/downgradable_rwmutex_unsafe.go
index 069939033..1f6007aa1 100644
--- a/third_party/gvsync/downgradable_rwmutex_unsafe.go
+++ b/third_party/gvsync/downgradable_rwmutex_unsafe.go
@@ -57,9 +57,6 @@ func (rw *DowngradableRWMutex) RLock() {
// RUnlock undoes a single RLock call.
func (rw *DowngradableRWMutex) RUnlock() {
if RaceEnabled {
- // TODO(jamieliu): Why does this need to be ReleaseMerge instead of
- // Release? IIUC this establishes Unlock happens-before RUnlock, which
- // seems unnecessary.
RaceReleaseMerge(unsafe.Pointer(&rw.writerSem))
RaceDisable()
}
diff --git a/tools/go_branch.sh b/tools/go_branch.sh
index d9e79401d..ddb9b6e7b 100755
--- a/tools/go_branch.sh
+++ b/tools/go_branch.sh
@@ -59,7 +59,11 @@ git checkout -b go "${go_branch}"
# Start working on a merge commit that combines the previous history with the
# current history. Note that we don't actually want any changes yet.
-git merge --allow-unrelated-histories --no-commit --strategy ours ${head}
+#
+# N.B. The git behavior changed at some point and the relevant flag was added
+# to allow for override, so try the only behavior first then pass the flag.
+git merge --no-commit --strategy ours ${head} || \
+ git merge --allow-unrelated-histories --no-commit --strategy ours ${head}
# Sync the entire gopath_dir and go.mod.
rsync --recursive --verbose --delete --exclude .git --exclude README.md -L "${gopath_dir}/" .
diff --git a/tools/go_marshal/BUILD b/tools/go_marshal/BUILD
new file mode 100644
index 000000000..c862b277c
--- /dev/null
+++ b/tools/go_marshal/BUILD
@@ -0,0 +1,14 @@
+load("@io_bazel_rules_go//go:def.bzl", "go_binary")
+
+package(licenses = ["notice"])
+
+go_binary(
+ name = "go_marshal",
+ srcs = ["main.go"],
+ visibility = [
+ "//:sandbox",
+ ],
+ deps = [
+ "//tools/go_marshal/gomarshal",
+ ],
+)
diff --git a/tools/go_marshal/README.md b/tools/go_marshal/README.md
new file mode 100644
index 000000000..481575bd3
--- /dev/null
+++ b/tools/go_marshal/README.md
@@ -0,0 +1,164 @@
+This package implements the go_marshal utility.
+
+# Overview
+
+`go_marshal` is a code generation utility similar to `go_stateify` for
+automatically generating code to marshal go data structures to memory.
+
+`go_marshal` attempts to improve on `binary.Write` and the sentry's
+`binary.Marshal` by moving the go runtime reflection necessary to marshal a
+struct to compile-time.
+
+`go_marshal` automatically generates implementations for `abi.Marshallable` and
+`safemem.{Reader,Writer}`. Call-sites for serialization (typically syscall
+implementations) can directly invoke `safemem.Reader.ReadToBlocks` and
+`safemem.Writer.WriteFromBlocks`. Data structures that require custom
+serialization will have manual implementations for these interfaces.
+
+Data structures can be flagged for code generation by adding a struct-level
+comment `// +marshal`.
+
+# Usage
+
+See `defs.bzl`: two new rules are provided, `go_marshal` and `go_library`.
+
+The recommended way to generate a go library with marshalling is to use the
+`go_library` with mostly identical configuration as the native go_library rule.
+
+```
+load("<PKGPATH>/gvisor/tools/go_marshal:defs.bzl", "go_library")
+
+go_library(
+ name = "foo",
+ srcs = ["foo.go"],
+)
+```
+
+Under the hood, the `go_marshal` rule is used to generate a file that will
+appear in a Go target; the output file should appear explicitly in a srcs list.
+For example (note that the above is the preferred method):
+
+```
+load("<PKGPATH>/gvisor/tools/go_marshal:defs.bzl", "go_marshal")
+
+go_marshal(
+ name = "foo_abi",
+ srcs = ["foo.go"],
+ out = "foo_abi.go",
+ package = "foo",
+)
+
+go_library(
+ name = "foo",
+ srcs = [
+ "foo.go",
+ "foo_abi.go",
+ ],
+ deps = [
+ "<PKGPATH>/gvisor/pkg/abi",
+ "<PKGPATH>/gvisor/pkg/sentry/safemem/safemem",
+ "<PKGPATH>/gvisor/pkg/sentry/usermem/usermem",
+ ],
+)
+```
+
+As part of the interface generation, `go_marshal` also generates some tests for
+sanity checking the struct definitions for potential alignment issues, and a
+simple round-trip test through Marshal/Unmarshal to verify the implementation.
+These tests use reflection to verify properties of the ABI struct, and should be
+considered part of the generated interfaces (but are too expensive to execute at
+runtime). Ensure these tests run at some point.
+
+```
+$ cat BUILD
+load("<PKGPATH>/gvisor/tools/go_marshal:defs.bzl", "go_library")
+
+go_library(
+ name = "foo",
+ srcs = ["foo.go"],
+)
+$ blaze build :foo
+$ blaze query ...
+<path-to-dir>:foo_abi_autogen
+<path-to-dir>:foo_abi_autogen_test
+$ blaze test :foo_abi_autogen_test
+<test-output>
+```
+
+# Restrictions
+
+Not all valid go type definitions can be used with `go_marshal`. `go_marshal` is
+intended for ABI structs, which have these additional restrictions:
+
+- At the moment, `go_marshal` only supports struct declarations.
+
+- Structs are marshalled as packed types. This means no implicit padding is
+ inserted between fields shorter than the platform register size. For
+ alignment, manually insert padding fields.
+
+- Structs used with `go_marshal` must have a compile-time static size. This
+ means no dynamically sizes fields like slices or strings. Use statically
+ sized array (byte arrays for strings) instead.
+
+- No pointers, channel, map or function pointer fields, and no fields that are
+ arrays of these types. These don't make sense in an ABI data structure.
+
+- We could support opaque pointers as `uintptr`, but this is currently not
+ implemented. Implementing this would require handling the architecture
+ dependent native pointer size.
+
+- Fields must either be a primitive integer type (`byte`,
+ `[u]int{8,16,32,64}`), or of a type that implements abi.Marshallable.
+
+- `int` and `uint` fields are not allowed. Use an explicitly-sized numeric
+ type.
+
+- `float*` fields are currently not supported, but could be if necessary.
+
+# Appendix
+
+## Working with Non-Packed Structs
+
+ABI structs must generally be packed types, meaning they should have no implicit
+padding between short fields. However, if a field is tagged
+`marshal:"unaligned"`, `go_marshal` will fall back to a safer but slower
+mechanism to deal with potentially unaligned fields.
+
+Note that the non-packed property is inheritted by any other struct that embeds
+this struct, since the `go_marshal` tool currently can't reason about alignments
+for embedded structs that are not aligned.
+
+Because of this, it's generally best to avoid using `marshal:"unaligned"` and
+insert explicit padding fields instead.
+
+## Debugging go_marshal
+
+To enable debugging output from the go marshal tool, pass the `-debug` flag to
+the tool. When using the build rules from above, add a `debug = True` field to
+the build rule like this:
+
+```
+load("<PKGPATH>/gvisor/tools/go_marshal:defs.bzl", "go_library")
+
+go_library(
+ name = "foo",
+ srcs = ["foo.go"],
+ debug = True,
+)
+```
+
+## Modifying the `go_marshal` Tool
+
+The following are some guidelines for modifying the `go_marshal` tool:
+
+- The `go_marshal` tool currently does a single pass over all types requesting
+ code generation, in arbitrary order. This means the generated code can't
+ directly obtain information about embedded marshallable types at
+ compile-time. One way to work around this restriction is to add a new
+ Marshallable interface method providing this piece of information, and
+ calling it from the generated code. Use this sparingly, as we want to rely
+ on compile-time information as much as possible for performance.
+
+- No runtime reflection in the code generated for the marshallable interface.
+ The entire point of the tool is to avoid runtime reflection. The generated
+ tests may use reflection.
diff --git a/tools/go_marshal/analysis/BUILD b/tools/go_marshal/analysis/BUILD
new file mode 100644
index 000000000..c859ced77
--- /dev/null
+++ b/tools/go_marshal/analysis/BUILD
@@ -0,0 +1,13 @@
+load("@io_bazel_rules_go//go:def.bzl", "go_library")
+
+package(licenses = ["notice"])
+
+go_library(
+ name = "analysis",
+ testonly = 1,
+ srcs = ["analysis_unsafe.go"],
+ importpath = "gvisor.dev/gvisor/tools/go_marshal/analysis",
+ visibility = [
+ "//:sandbox",
+ ],
+)
diff --git a/tools/go_marshal/analysis/analysis_unsafe.go b/tools/go_marshal/analysis/analysis_unsafe.go
new file mode 100644
index 000000000..9a9a4f298
--- /dev/null
+++ b/tools/go_marshal/analysis/analysis_unsafe.go
@@ -0,0 +1,175 @@
+// Copyright 2019 The gVisor Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+// Package analysis implements common functionality used by generated
+// go_marshal tests.
+package analysis
+
+// All functions in this package are unsafe and are not intended for general
+// consumption. They contain sharp edge cases and the caller is responsible for
+// ensuring none of them are hit. Callers must be carefully to pass in only sane
+// arguments. Failure to do so may cause panics at best and arbitrary memory
+// corruption at worst.
+//
+// Never use outside of tests.
+
+import (
+ "fmt"
+ "math/rand"
+ "reflect"
+ "testing"
+ "unsafe"
+)
+
+// RandomizeValue assigns random value(s) to an abitrary type. This is intended
+// for used with ABI structs from go_marshal, meaning the typical restrictions
+// apply (fixed-size types, no pointers, maps, channels, etc), and should only
+// be used on zeroed values to avoid overwriting pointers to active go objects.
+//
+// Internally, we populate the type with random data by doing an unsafe cast to
+// access the underlying memory of the type and filling it as if it were a byte
+// slice. This almost gets us what we want, but padding fields named "_" are
+// normally not accessible, so we walk the type and recursively zero all "_"
+// fields.
+//
+// Precondition: x must be a pointer. x must not contain any valid
+// pointers to active go objects (pointer fields aren't allowed in ABI
+// structs anyways), or we'd be violating the go runtime contract and
+// the GC may malfunction.
+func RandomizeValue(x interface{}) {
+ v := reflect.Indirect(reflect.ValueOf(x))
+ if !v.CanSet() {
+ panic("RandomizeType() called with an unaddressable value. You probably need to pass a pointer to the argument")
+ }
+
+ // Cast the underlying memory for the type into a byte slice.
+ var b []byte
+ hdr := (*reflect.SliceHeader)(unsafe.Pointer(&b))
+ // Note: v.UnsafeAddr panics if x is passed by value. x should be a pointer.
+ hdr.Data = v.UnsafeAddr()
+ hdr.Len = int(v.Type().Size())
+ hdr.Cap = hdr.Len
+
+ // Fill the byte slice with random data, which in effect fills the type with
+ // random values.
+ n, err := rand.Read(b)
+ if err != nil || n != len(b) {
+ panic("unreachable")
+ }
+
+ // Normally, padding fields are not accessible, so zero them out.
+ reflectZeroPaddingFields(v.Type(), b, false)
+}
+
+// reflectZeroPaddingFields assigns zero values to padding fields for the value
+// of type r, represented by the memory in data. Padding fields are defined as
+// fields with the name "_". If zero is true, the immediate value itself is
+// zeroed. In addition, the type is recursively scanned for padding fields in
+// inner types.
+//
+// This is used for zeroing padding fields after calling RandomizeValue.
+func reflectZeroPaddingFields(r reflect.Type, data []byte, zero bool) {
+ if zero {
+ for i, _ := range data {
+ data[i] = 0
+ }
+ }
+ switch r.Kind() {
+ case reflect.Int8, reflect.Uint8, reflect.Int16, reflect.Uint16, reflect.Int32, reflect.Uint32, reflect.Int64, reflect.Uint64:
+ // These types are explicitly allowed in an ABI type, but we don't need
+ // to recurse further as they're scalar types.
+ case reflect.Struct:
+ for i, numFields := 0, r.NumField(); i < numFields; i++ {
+ f := r.Field(i)
+ off := f.Offset
+ len := f.Type.Size()
+ window := data[off : off+len]
+ reflectZeroPaddingFields(f.Type, window, f.Name == "_")
+ }
+ case reflect.Array:
+ eLen := int(r.Elem().Size())
+ if int(r.Size()) != eLen*r.Len() {
+ panic("Array has unexpected size?")
+ }
+ for i, n := 0, r.Len(); i < n; i++ {
+ reflectZeroPaddingFields(r.Elem(), data[i*eLen:(i+1)*eLen], false)
+ }
+ default:
+ panic(fmt.Sprintf("Type %v not allowed in ABI struct", r.Kind()))
+
+ }
+}
+
+// AlignmentCheck ensures the definition of the type represented by typ doesn't
+// cause the go compiler to emit implicit padding between elements of the type
+// (i.e. fields in a struct).
+//
+// AlignmentCheck doesn't explicitly recurse for embedded structs because any
+// struct present in an ABI struct must also be Marshallable, and therefore
+// they're aligned by definition (or their alignment check would have failed).
+func AlignmentCheck(t *testing.T, typ reflect.Type) (ok bool, delta uint64) {
+ switch typ.Kind() {
+ case reflect.Int8, reflect.Uint8, reflect.Int16, reflect.Uint16, reflect.Int32, reflect.Uint32, reflect.Int64, reflect.Uint64:
+ // Primitive types are always considered well aligned. Primitive types
+ // that are fields in structs are checked independently, this branch
+ // exists to handle recursive calls to alignmentCheck.
+ case reflect.Struct:
+ xOff := 0
+ nextXOff := 0
+ skipNext := false
+ for i, numFields := 0, typ.NumField(); i < numFields; i++ {
+ xOff = nextXOff
+ f := typ.Field(i)
+ fmt.Printf("Checking alignment of %s.%s @ %d [+%d]...\n", typ.Name(), f.Name, f.Offset, f.Type.Size())
+ nextXOff = int(f.Offset + f.Type.Size())
+
+ if f.Name == "_" {
+ // Padding fields need not be aligned.
+ fmt.Printf("Padding field of type %v\n", f.Type)
+ continue
+ }
+
+ if tag, ok := f.Tag.Lookup("marshal"); ok && tag == "unaligned" {
+ skipNext = true
+ continue
+ }
+
+ if skipNext {
+ skipNext = false
+ fmt.Printf("Skipping alignment check for field %s.%s explicitly marked as unaligned.\n", typ.Name(), f.Name)
+ continue
+ }
+
+ if xOff != int(f.Offset) {
+ implicitPad := int(f.Offset) - xOff
+ t.Fatalf("Suspect offset for field %s.%s, detected an implicit %d byte padding from offset %d to %d; either add %d bytes of explicit padding before this field or tag it as `marshal:\"unaligned\"`.", typ.Name(), f.Name, implicitPad, xOff, f.Offset, implicitPad)
+ }
+ }
+
+ // Ensure structs end on a byte explicitly defined by the type.
+ if typ.NumField() > 0 && nextXOff != int(typ.Size()) {
+ implicitPad := int(typ.Size()) - nextXOff
+ f := typ.Field(typ.NumField() - 1) // Final field
+ t.Fatalf("Suspect offset for field %s.%s at the end of %s, detected an implicit %d byte padding from offset %d to %d at the end of the struct; either add %d bytes of explict padding at end of the struct or tag the final field %s as `marshal:\"unaligned\"`.",
+ typ.Name(), f.Name, typ.Name(), implicitPad, nextXOff, typ.Size(), implicitPad, f.Name)
+ }
+ case reflect.Array:
+ // Independent arrays are also always considered well aligned. We only
+ // need to worry about their alignment when they're embedded in structs,
+ // which we handle above.
+ default:
+ t.Fatalf("Unsupported type in ABI struct while checking for field alignment for type: %v", typ.Kind())
+ }
+ return true, uint64(typ.Size())
+}
diff --git a/tools/go_marshal/defs.bzl b/tools/go_marshal/defs.bzl
new file mode 100644
index 000000000..c32eb559f
--- /dev/null
+++ b/tools/go_marshal/defs.bzl
@@ -0,0 +1,152 @@
+"""Marshal is a tool for generating marshalling interfaces for Go types.
+
+The recommended way is to use the go_library rule defined below with mostly
+identical configuration as the native go_library rule.
+
+load("//tools/go_marshal:defs.bzl", "go_library")
+
+go_library(
+ name = "foo",
+ srcs = ["foo.go"],
+)
+
+Under the hood, the go_marshal rule is used to generate a file that will
+appear in a Go target; the output file should appear explicitly in a srcs list.
+For example (the above is still the preferred way):
+
+load("//tools/go_marshal:defs.bzl", "go_marshal")
+
+go_marshal(
+ name = "foo_abi",
+ srcs = ["foo.go"],
+ out = "foo_abi.go",
+ package = "foo",
+)
+
+go_library(
+ name = "foo",
+ srcs = [
+ "foo.go",
+ "foo_abi.go",
+ ],
+ deps = [
+ "//tools/go_marshal:marshal",
+ "//pkg/sentry/platform/safecopy",
+ "//pkg/sentry/usermem",
+ ],
+)
+"""
+
+load("@io_bazel_rules_go//go:def.bzl", _go_library = "go_library", _go_test = "go_test")
+
+def _go_marshal_impl(ctx):
+ """Execute the go_marshal tool."""
+ output = ctx.outputs.lib
+ output_test = ctx.outputs.test
+ (build_dir, _, _) = ctx.build_file_path.rpartition("/BUILD")
+
+ decl = "/".join(["gvisor.dev/gvisor", build_dir])
+
+ # Run the marshal command.
+ args = ["-output=%s" % output.path]
+ args += ["-pkg=%s" % ctx.attr.package]
+ args += ["-output_test=%s" % output_test.path]
+ args += ["-declarationPkg=%s" % decl]
+
+ if ctx.attr.debug:
+ args += ["-debug"]
+
+ args += ["--"]
+ for src in ctx.attr.srcs:
+ args += [f.path for f in src.files.to_list()]
+ ctx.actions.run(
+ inputs = ctx.files.srcs,
+ outputs = [output, output_test],
+ mnemonic = "GoMarshal",
+ progress_message = "go_marshal: %s" % ctx.label,
+ arguments = args,
+ executable = ctx.executable._tool,
+ )
+
+# Generates save and restore logic from a set of Go files.
+#
+# Args:
+# name: the name of the rule.
+# srcs: the input source files. These files should include all structs in the
+# package that need to be saved.
+# imports: an optional list of extra, non-aliased, Go-style absolute import
+# paths.
+# out: the name of the generated file output. This must not conflict with any
+# other files and must be added to the srcs of the relevant go_library.
+# package: the package name for the input sources.
+go_marshal = rule(
+ implementation = _go_marshal_impl,
+ attrs = {
+ "srcs": attr.label_list(mandatory = True, allow_files = True),
+ "libname": attr.string(mandatory = True),
+ "imports": attr.string_list(mandatory = False),
+ "package": attr.string(mandatory = True),
+ "debug": attr.bool(doc = "enable debugging output from the go_marshal tool"),
+ "_tool": attr.label(executable = True, cfg = "host", default = Label("//tools/go_marshal:go_marshal")),
+ },
+ outputs = {
+ "lib": "%{name}_unsafe.go",
+ "test": "%{name}_test.go",
+ },
+)
+
+def go_library(name, srcs, deps = [], imports = [], debug = False, **kwargs):
+ """wraps the standard go_library and does mashalling interface generation.
+
+ Args:
+ name: Same as native go_library.
+ srcs: Same as native go_library.
+ deps: Same as native go_library.
+ imports: Extra import paths to pass to the go_marshal tool.
+ debug: Enables debugging output from the go_marshal tool.
+ **kwargs: Remaining args to pass to the native go_library rule unmodified.
+ """
+ go_marshal(
+ name = name + "_abi_autogen",
+ libname = name,
+ srcs = [src for src in srcs if src.endswith(".go")],
+ debug = debug,
+ imports = imports,
+ package = name,
+ )
+
+ extra_deps = [
+ "//tools/go_marshal/marshal",
+ "//pkg/sentry/platform/safecopy",
+ "//pkg/sentry/usermem",
+ ]
+
+ all_srcs = srcs + [name + "_abi_autogen_unsafe.go"]
+ all_deps = deps + [] # + extra_deps
+
+ for extra in extra_deps:
+ if extra not in deps:
+ all_deps.append(extra)
+
+ _go_library(
+ name = name,
+ srcs = all_srcs,
+ deps = all_deps,
+ **kwargs
+ )
+
+ # Don't pass importpath arg to go_test.
+ kwargs.pop("importpath", "")
+
+ _go_test(
+ name = name + "_abi_autogen_test",
+ srcs = [name + "_abi_autogen_test.go"],
+ # Generated test has a fixed set of dependencies since we generate these
+ # tests. They should only depend on the library generated above, and the
+ # Marshallable interface.
+ deps = [
+ ":" + name,
+ "//tools/go_marshal/analysis",
+ ],
+ **kwargs
+ )
diff --git a/tools/go_marshal/gomarshal/BUILD b/tools/go_marshal/gomarshal/BUILD
new file mode 100644
index 000000000..a0eae6492
--- /dev/null
+++ b/tools/go_marshal/gomarshal/BUILD
@@ -0,0 +1,17 @@
+load("@io_bazel_rules_go//go:def.bzl", "go_library")
+
+package(licenses = ["notice"])
+
+go_library(
+ name = "gomarshal",
+ srcs = [
+ "generator.go",
+ "generator_interfaces.go",
+ "generator_tests.go",
+ "util.go",
+ ],
+ importpath = "gvisor.dev/gvisor/tools/go_marshal/gomarshal",
+ visibility = [
+ "//:sandbox",
+ ],
+)
diff --git a/tools/go_marshal/gomarshal/generator.go b/tools/go_marshal/gomarshal/generator.go
new file mode 100644
index 000000000..641ccd938
--- /dev/null
+++ b/tools/go_marshal/gomarshal/generator.go
@@ -0,0 +1,382 @@
+// Copyright 2019 The gVisor Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+// Package gomarshal implements the go_marshal code generator. See README.md.
+package gomarshal
+
+import (
+ "bytes"
+ "fmt"
+ "go/ast"
+ "go/parser"
+ "go/token"
+ "os"
+ "sort"
+)
+
+const (
+ marshalImport = "gvisor.dev/gvisor/tools/go_marshal/marshal"
+ usermemImport = "gvisor.dev/gvisor/pkg/sentry/usermem"
+ safecopyImport = "gvisor.dev/gvisor/pkg/sentry/platform/safecopy"
+)
+
+// List of identifiers we use in generated code, that may conflict a
+// similarly-named source identifier. Avoid problems by refusing the generate
+// code when we see these.
+//
+// This only applies to import aliases at the moment. All other identifiers
+// are qualified by a receiver argument, since they're struct fields.
+//
+// All recievers are single letters, so we don't allow import aliases to be a
+// single letter.
+var badIdents = []string{
+ "src", "srcs", "dst", "dsts", "blk", "buf", "err",
+ // All single-letter identifiers.
+}
+
+// Generator drives code generation for a single invocation of the go_marshal
+// utility.
+//
+// The Generator holds arguments passed to the tool, and drives parsing,
+// processing and code Generator for all types marked with +marshal declared in
+// the input files.
+//
+// See Generator.run() as the entry point.
+type Generator struct {
+ // Paths to input go source files.
+ inputs []string
+ // Output file to write generated go source.
+ output *os.File
+ // Output file to write generated tests.
+ outputTest *os.File
+ // Package name for the generated file.
+ pkg string
+ // Go import path for package we're processing. This package should directly
+ // declare the type we're generating code for.
+ declaration string
+ // Set of extra packages to import in the generated file.
+ imports *importTable
+}
+
+// NewGenerator creates a new code Generator.
+func NewGenerator(srcs []string, out, outTest, pkg, declaration string, imports []string) (*Generator, error) {
+ f, err := os.OpenFile(out, os.O_CREATE|os.O_WRONLY|os.O_TRUNC, 0644)
+ if err != nil {
+ return nil, fmt.Errorf("Couldn't open output file %q: %v", out, err)
+ }
+ fTest, err := os.OpenFile(outTest, os.O_CREATE|os.O_WRONLY|os.O_TRUNC, 0644)
+ if err != nil {
+ return nil, fmt.Errorf("Couldn't open test output file %q: %v", out, err)
+ }
+ g := Generator{
+ inputs: srcs,
+ output: f,
+ outputTest: fTest,
+ pkg: pkg,
+ declaration: declaration,
+ imports: newImportTable(),
+ }
+ for _, i := range imports {
+ // All imports on the extra imports list are unconditionally marked as
+ // used, so they're always added to the generated code.
+ g.imports.add(i).markUsed()
+ }
+ g.imports.add(marshalImport).markUsed()
+ // The follow imports may or may not be used by the generated
+ // code, depending what's required for the target types. Don't
+ // mark these imports as used by default.
+ g.imports.add(usermemImport)
+ g.imports.add(safecopyImport)
+ g.imports.add("unsafe")
+
+ return &g, nil
+}
+
+// writeHeader writes the header for the generated source file. The header
+// includes the package name, package level comments and import statements.
+func (g *Generator) writeHeader() error {
+ var b sourceBuffer
+ b.emit("// Automatically generated marshal implementation. See tools/go_marshal.\n\n")
+ b.emit("package %s\n\n", g.pkg)
+ if err := b.write(g.output); err != nil {
+ return err
+ }
+
+ return g.imports.write(g.output)
+}
+
+// writeTypeChecks writes a statement to force the compiler to perform a type
+// check for all Marshallable types referenced by the generated code.
+func (g *Generator) writeTypeChecks(ms map[string]struct{}) error {
+ if len(ms) == 0 {
+ return nil
+ }
+
+ msl := make([]string, 0, len(ms))
+ for m, _ := range ms {
+ msl = append(msl, m)
+ }
+ sort.Strings(msl)
+
+ var buf bytes.Buffer
+ fmt.Fprint(&buf, "// Marshallable types used by this file.\n")
+
+ for _, m := range msl {
+ fmt.Fprintf(&buf, "var _ marshal.Marshallable = (*%s)(nil)\n", m)
+ }
+ fmt.Fprint(&buf, "\n")
+
+ _, err := fmt.Fprint(g.output, buf.String())
+ return err
+}
+
+// parse processes all input files passed this generator and produces a set of
+// parsed go ASTs.
+func (g *Generator) parse() ([]*ast.File, []*token.FileSet, error) {
+ debugf("go_marshal invoked with %d input files:\n", len(g.inputs))
+ for _, path := range g.inputs {
+ debugf(" %s\n", path)
+ }
+
+ files := make([]*ast.File, 0, len(g.inputs))
+ fsets := make([]*token.FileSet, 0, len(g.inputs))
+
+ for _, path := range g.inputs {
+ fset := token.NewFileSet()
+ f, err := parser.ParseFile(fset, path, nil, parser.ParseComments)
+ if err != nil {
+ // Not a valid input file?
+ return nil, nil, fmt.Errorf("Input %q can't be parsed: %v", path, err)
+ }
+
+ if debugEnabled() {
+ debugf("AST for %q:\n", path)
+ ast.Print(fset, f)
+ }
+
+ files = append(files, f)
+ fsets = append(fsets, fset)
+ }
+
+ return files, fsets, nil
+}
+
+// collectMarshallabeTypes walks the parsed AST and collects a list of type
+// declarations for which we need to generate the Marshallable interface.
+func (g *Generator) collectMarshallabeTypes(a *ast.File, f *token.FileSet) []*ast.TypeSpec {
+ var types []*ast.TypeSpec
+ for _, decl := range a.Decls {
+ gdecl, ok := decl.(*ast.GenDecl)
+ // Type declaration?
+ if !ok || gdecl.Tok != token.TYPE {
+ debugfAt(f.Position(decl.Pos()), "Skipping declaration since it's not a type declaration.\n")
+ continue
+ }
+ // Does it have a comment?
+ if gdecl.Doc == nil {
+ debugfAt(f.Position(gdecl.Pos()), "Skipping declaration since it doesn't have a comment.\n")
+ continue
+ }
+ // Does the comment contain a "+marshal" line?
+ marked := false
+ for _, c := range gdecl.Doc.List {
+ if c.Text == "// +marshal" {
+ marked = true
+ break
+ }
+ }
+ if !marked {
+ debugfAt(f.Position(gdecl.Pos()), "Skipping declaration since it doesn't have a comment containing +marshal line.\n")
+ continue
+ }
+ for _, spec := range gdecl.Specs {
+ // We already confirmed we're in a type declaration earlier.
+ t := spec.(*ast.TypeSpec)
+ if _, ok := t.Type.(*ast.StructType); ok {
+ debugfAt(f.Position(t.Pos()), "Collected marshallable type %s.\n", t.Name.Name)
+ types = append(types, t)
+ continue
+ }
+ debugf("Skipping declaration %v since it's not a struct declaration.\n", gdecl)
+ }
+ }
+ return types
+}
+
+// collectImports collects all imports from all input source files. Some of
+// these imports are copied to the generated output, if they're referenced by
+// the generated code.
+//
+// collectImports de-duplicates imports while building the list, and ensures
+// identifiers in the generated code don't conflict with any imported package
+// names.
+func (g *Generator) collectImports(a *ast.File, f *token.FileSet) map[string]importStmt {
+ badImportNames := make(map[string]bool)
+ for _, i := range badIdents {
+ badImportNames[i] = true
+ }
+
+ is := make(map[string]importStmt)
+ for _, decl := range a.Decls {
+ gdecl, ok := decl.(*ast.GenDecl)
+ // Import statement?
+ if !ok || gdecl.Tok != token.IMPORT {
+ continue
+ }
+ for _, spec := range gdecl.Specs {
+ i := g.imports.addFromSpec(spec.(*ast.ImportSpec), f)
+ debugf("Collected import '%s' as '%s'\n", i.path, i.name)
+
+ // Make sure we have an import that doesn't use any local names that
+ // would conflict with identifiers in the generated code.
+ if len(i.name) == 1 {
+ abortAt(f.Position(spec.Pos()), fmt.Sprintf("Import has a single character local name '%s'; this may conflict with code generated by go_marshal, use a multi-character import alias", i.name))
+ }
+ if badImportNames[i.name] {
+ abortAt(f.Position(spec.Pos()), fmt.Sprintf("Import name '%s' is likely to conflict with code generated by go_marshal, use a different import alias", i.name))
+ }
+ }
+ }
+ return is
+
+}
+
+func (g *Generator) generateOne(t *ast.TypeSpec, fset *token.FileSet) *interfaceGenerator {
+ // We're guaranteed to have only struct type specs by now. See
+ // Generator.collectMarshallabeTypes.
+ i := newInterfaceGenerator(t, fset)
+ i.validate()
+ i.emitMarshallable()
+ return i
+}
+
+// generateOneTestSuite generates a test suite for the automatically generated
+// implementations type t.
+func (g *Generator) generateOneTestSuite(t *ast.TypeSpec) *testGenerator {
+ i := newTestGenerator(t, g.declaration)
+ i.emitTests()
+ return i
+}
+
+// Run is the entry point to code generation using g.
+//
+// Run parses all input source files specified in g and emits generated code.
+func (g *Generator) Run() error {
+ // Parse our input source files into ASTs and token sets.
+ asts, fsets, err := g.parse()
+ if err != nil {
+ return err
+ }
+
+ if len(asts) != len(fsets) {
+ panic("ASTs and FileSets don't match")
+ }
+
+ // Map of imports in source files; key = local package name, value = import
+ // path.
+ is := make(map[string]importStmt)
+ for i, a := range asts {
+ // Collect all imports from the source files. We may need to copy some
+ // of these to the generated code if they're referenced. This has to be
+ // done before the loop below because we need to process all ASTs before
+ // we start requesting imports to be copied one by one as we encounter
+ // them in each generated source.
+ for name, i := range g.collectImports(a, fsets[i]) {
+ is[name] = i
+ }
+ }
+
+ var impls []*interfaceGenerator
+ var ts []*testGenerator
+ // Set of Marshallable types referenced by generated code.
+ ms := make(map[string]struct{})
+ for i, a := range asts {
+ // Collect type declarations marked for code generation and generate
+ // Marshallable interfaces.
+ for _, t := range g.collectMarshallabeTypes(a, fsets[i]) {
+ impl := g.generateOne(t, fsets[i])
+ // Collect Marshallable types referenced by the generated code.
+ for ref, _ := range impl.ms {
+ ms[ref] = struct{}{}
+ }
+ impls = append(impls, impl)
+ // Collect imports referenced by the generated code and add them to
+ // the list of imports we need to copy to the generated code.
+ for name, _ := range impl.is {
+ if !g.imports.markUsed(name) {
+ panic(fmt.Sprintf("Generated code for '%s' referenced a non-existent import with local name '%s'", impl.typeName(), name))
+ }
+ }
+ ts = append(ts, g.generateOneTestSuite(t))
+ }
+ }
+
+ // Tool was invoked with input files with no data structures marked for code
+ // generation. This is probably not what the user intended.
+ if len(impls) == 0 {
+ var buf bytes.Buffer
+ fmt.Fprintf(&buf, "go_marshal invoked on these files, but they don't contain any types requiring code generation. Perhaps mark some with \"// +marshal\"?:\n")
+ for _, i := range g.inputs {
+ fmt.Fprintf(&buf, " %s\n", i)
+ }
+ abort(buf.String())
+ }
+
+ // Write output file header. These include things like package name and
+ // import statements.
+ if err := g.writeHeader(); err != nil {
+ return err
+ }
+
+ // Write type checks for referenced marshallable types to output file.
+ if err := g.writeTypeChecks(ms); err != nil {
+ return err
+ }
+
+ // Write generated interfaces to output file.
+ for _, i := range impls {
+ if err := i.write(g.output); err != nil {
+ return err
+ }
+ }
+
+ // Write generated tests to test file.
+ return g.writeTests(ts)
+}
+
+// writeTests outputs tests for the generated interface implementations to a go
+// source file.
+func (g *Generator) writeTests(ts []*testGenerator) error {
+ var b sourceBuffer
+ b.emit("package %s_test\n\n", g.pkg)
+ if err := b.write(g.outputTest); err != nil {
+ return err
+ }
+
+ imports := newImportTable()
+ for _, t := range ts {
+ imports.merge(t.imports)
+ }
+
+ if err := imports.write(g.outputTest); err != nil {
+ return err
+ }
+
+ for _, t := range ts {
+ if err := t.write(g.outputTest); err != nil {
+ return err
+ }
+ }
+ return nil
+}
diff --git a/tools/go_marshal/gomarshal/generator_interfaces.go b/tools/go_marshal/gomarshal/generator_interfaces.go
new file mode 100644
index 000000000..a712c14dc
--- /dev/null
+++ b/tools/go_marshal/gomarshal/generator_interfaces.go
@@ -0,0 +1,507 @@
+// Copyright 2019 The gVisor Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+package gomarshal
+
+import (
+ "fmt"
+ "go/ast"
+ "go/token"
+ "strings"
+)
+
+// interfaceGenerator generates marshalling interfaces for a single type.
+//
+// getState is not thread-safe.
+type interfaceGenerator struct {
+ sourceBuffer
+
+ // The type we're serializing.
+ t *ast.TypeSpec
+
+ // Receiver argument for generated methods.
+ r string
+
+ // FileSet containing the tokens for the type we're processing.
+ f *token.FileSet
+
+ // is records external packages referenced by the generated implementation.
+ is map[string]struct{}
+
+ // ms records Marshallable types referenced by the generated implementation
+ // of t's interfaces.
+ ms map[string]struct{}
+
+ // as records embedded fields in t that are potentially not packed. The key
+ // is the accessor for the field.
+ as map[string]struct{}
+}
+
+// typeName returns the name of the type this g represents.
+func (g *interfaceGenerator) typeName() string {
+ return g.t.Name.Name
+}
+
+// newinterfaceGenerator creates a new interface generator.
+func newInterfaceGenerator(t *ast.TypeSpec, fset *token.FileSet) *interfaceGenerator {
+ if _, ok := t.Type.(*ast.StructType); !ok {
+ panic(fmt.Sprintf("Attempting to generate code for a not struct type %v", t))
+ }
+ g := &interfaceGenerator{
+ t: t,
+ r: receiverName(t),
+ f: fset,
+ is: make(map[string]struct{}),
+ ms: make(map[string]struct{}),
+ as: make(map[string]struct{}),
+ }
+ g.recordUsedMarshallable(g.typeName())
+ return g
+}
+
+func (g *interfaceGenerator) recordUsedMarshallable(m string) {
+ g.ms[m] = struct{}{}
+
+}
+
+func (g *interfaceGenerator) recordUsedImport(i string) {
+ g.is[i] = struct{}{}
+
+}
+
+func (g *interfaceGenerator) recordPotentiallyNonPackedField(fieldName string) {
+ g.as[fieldName] = struct{}{}
+}
+
+func (g *interfaceGenerator) forEachField(fn func(f *ast.Field)) {
+ // This is guaranteed to succeed because g.t is always a struct.
+ st := g.t.Type.(*ast.StructType)
+ for _, field := range st.Fields.List {
+ fn(field)
+ }
+}
+
+func (g *interfaceGenerator) fieldAccessor(n *ast.Ident) string {
+ return fmt.Sprintf("%s.%s", g.r, n.Name)
+}
+
+// abortAt aborts the go_marshal tool with the given error message, with a
+// reference position to the input source. Same as abortAt, but uses g to
+// resolve p to position.
+func (g *interfaceGenerator) abortAt(p token.Pos, msg string) {
+ abortAt(g.f.Position(p), msg)
+}
+
+// validate ensures the type we're working with can be marshalled. These checks
+// are done ahead of time and in one place so we can make assumptions later.
+func (g *interfaceGenerator) validate() {
+ g.forEachField(func(f *ast.Field) {
+ if len(f.Names) == 0 {
+ g.abortAt(f.Pos(), "Cannot marshal structs with embedded fields, give the field a name; use '_' for anonymous fields such as padding fields")
+ }
+ })
+
+ g.forEachField(func(f *ast.Field) {
+ fieldDispatcher{
+ primitive: func(_, t *ast.Ident) {
+ switch t.Name {
+ case "int8", "uint8", "byte", "int16", "uint16", "int32", "uint32", "int64", "uint64":
+ // These are the only primitive types we're allow. Below, we
+ // provide suggestions for some disallowed types and reject
+ // them, then attempt to marshal any remaining types by
+ // invoking the marshal.Marshallable interface on them. If
+ // these types don't actually implement
+ // marshal.Marshallable, compilation of the generated code
+ // will fail with an appropriate error message.
+ return
+ case "int":
+ g.abortAt(f.Pos(), "Type 'int' has ambiguous width, use int32 or int64")
+ case "uint":
+ g.abortAt(f.Pos(), "Type 'uint' has ambiguous width, use uint32 or uint64")
+ case "string":
+ g.abortAt(f.Pos(), "Type 'string' is dynamically-sized and cannot be marshalled, use a fixed size byte array '[...]byte' instead")
+ default:
+ debugfAt(g.f.Position(f.Pos()), fmt.Sprintf("Found derived type '%s', will attempt dispatch via marshal.Marshallable.\n", t.Name))
+ }
+ },
+ selector: func(_, _, _ *ast.Ident) {
+ // No validation to perform on selector fields. However this
+ // callback must still be provided.
+ },
+ array: func(n, _ *ast.Ident, len int) {
+ a := f.Type.(*ast.ArrayType)
+ if a.Len == nil {
+ g.abortAt(f.Pos(), fmt.Sprintf("Dynamically sized slice '%s' cannot be marshalled, arrays must be statically sized", n.Name))
+ }
+
+ if _, ok := a.Len.(*ast.BasicLit); !ok {
+ g.abortAt(a.Len.Pos(), fmt.Sprintf("Array size must be a literal, don's use consts or expressions"))
+ }
+
+ if _, ok := a.Elt.(*ast.Ident); !ok {
+ g.abortAt(a.Elt.Pos(), fmt.Sprintf("Marshalling not supported for arrays with %s elements, array elements must be primitive types", kindString(a.Elt)))
+ }
+
+ if len <= 0 {
+ g.abortAt(a.Len.Pos(), fmt.Sprintf("Marshalling not supported for zero length arrays, why does an ABI struct have one?"))
+ }
+ },
+ unhandled: func(_ *ast.Ident) {
+ g.abortAt(f.Pos(), fmt.Sprintf("Marshalling not supported for %s fields", kindString(f.Type)))
+ },
+ }.dispatch(f)
+ })
+}
+
+// scalarSize returns the size of type identified by t. If t isn't a primitive
+// type, the size isn't known at code generation time, and must be resolved via
+// the marshal.Marshallable interface.
+func (g *interfaceGenerator) scalarSize(t *ast.Ident) (size int, unknownSize bool) {
+ switch t.Name {
+ case "int8", "uint8", "byte":
+ return 1, false
+ case "int16", "uint16":
+ return 2, false
+ case "int32", "uint32":
+ return 4, false
+ case "int64", "uint64":
+ return 8, false
+ default:
+ return 0, true
+ }
+}
+
+func (g *interfaceGenerator) shift(bufVar string, n int) {
+ g.emit("%s = %s[%d:]\n", bufVar, bufVar, n)
+}
+
+func (g *interfaceGenerator) shiftDynamic(bufVar, name string) {
+ g.emit("%s = %s[%s.SizeBytes():]\n", bufVar, bufVar, name)
+}
+
+func (g *interfaceGenerator) marshalScalar(accessor, typ string, bufVar string) {
+ switch typ {
+ case "int8", "uint8", "byte":
+ g.emit("%s[0] = byte(%s)\n", bufVar, accessor)
+ g.shift(bufVar, 1)
+ case "int16", "uint16":
+ g.recordUsedImport("usermem")
+ g.emit("usermem.ByteOrder.PutUint16(%s[:2], uint16(%s))\n", bufVar, accessor)
+ g.shift(bufVar, 2)
+ case "int32", "uint32":
+ g.recordUsedImport("usermem")
+ g.emit("usermem.ByteOrder.PutUint32(%s[:4], uint32(%s))\n", bufVar, accessor)
+ g.shift(bufVar, 4)
+ case "int64", "uint64":
+ g.recordUsedImport("usermem")
+ g.emit("usermem.ByteOrder.PutUint64(%s[:8], uint64(%s))\n", bufVar, accessor)
+ g.shift(bufVar, 8)
+ default:
+ g.emit("%s.MarshalBytes(%s[:%s.SizeBytes()])\n", accessor, bufVar, accessor)
+ g.shiftDynamic(bufVar, accessor)
+ }
+}
+
+func (g *interfaceGenerator) unmarshalScalar(accessor, typ string, bufVar string) {
+ switch typ {
+ case "int8":
+ g.emit("%s = int8(%s[0])\n", accessor, bufVar)
+ g.shift(bufVar, 1)
+ case "uint8":
+ g.emit("%s = uint8(%s[0])\n", accessor, bufVar)
+ g.shift(bufVar, 1)
+ case "byte":
+ g.emit("%s = %s[0]\n", accessor, bufVar)
+ g.shift(bufVar, 1)
+
+ case "int16":
+ g.recordUsedImport("usermem")
+ g.emit("%s = int16(usermem.ByteOrder.Uint16(%s[:2]))\n", accessor, bufVar)
+ g.shift(bufVar, 2)
+ case "uint16":
+ g.recordUsedImport("usermem")
+ g.emit("%s = usermem.ByteOrder.Uint16(%s[:2])\n", accessor, bufVar)
+ g.shift(bufVar, 2)
+
+ case "int32":
+ g.recordUsedImport("usermem")
+ g.emit("%s = int32(usermem.ByteOrder.Uint32(%s[:4]))\n", accessor, bufVar)
+ g.shift(bufVar, 4)
+ case "uint32":
+ g.recordUsedImport("usermem")
+ g.emit("%s = usermem.ByteOrder.Uint32(%s[:4])\n", accessor, bufVar)
+ g.shift(bufVar, 4)
+
+ case "int64":
+ g.recordUsedImport("usermem")
+ g.emit("%s = int64(usermem.ByteOrder.Uint64(%s[:8]))\n", accessor, bufVar)
+ g.shift(bufVar, 8)
+ case "uint64":
+ g.recordUsedImport("usermem")
+ g.emit("%s = usermem.ByteOrder.Uint64(%s[:8])\n", accessor, bufVar)
+ g.shift(bufVar, 8)
+ default:
+ g.emit("%s.UnmarshalBytes(%s[:%s.SizeBytes()])\n", accessor, bufVar, accessor)
+ g.shiftDynamic(bufVar, accessor)
+ g.recordPotentiallyNonPackedField(accessor)
+ }
+}
+
+// areFieldsPackedExpression returns a go expression checking whether g.t's fields are
+// packed. Returns "", false if g.t has no fields that may be potentially
+// packed, otherwise returns <clause>, true, where <clause> is an expression
+// like "t.a.Packed() && t.b.Packed() && t.c.Packed()".
+func (g *interfaceGenerator) areFieldsPackedExpression() (string, bool) {
+ if len(g.as) == 0 {
+ return "", false
+ }
+
+ cs := make([]string, 0, len(g.as))
+ for accessor, _ := range g.as {
+ cs = append(cs, fmt.Sprintf("%s.Packed()", accessor))
+ }
+ return strings.Join(cs, " && "), true
+}
+
+func (g *interfaceGenerator) emitMarshallable() {
+ // Is g.t a packed struct without consideing field types?
+ thisPacked := true
+ g.forEachField(func(f *ast.Field) {
+ if f.Tag != nil {
+ if f.Tag.Value == "`marshal:\"unaligned\"`" {
+ if thisPacked {
+ debugfAt(g.f.Position(g.t.Pos()),
+ fmt.Sprintf("Marking type '%s' as not packed due to tag `marshal:\"unaligned\"`.\n", g.t.Name))
+ thisPacked = false
+ }
+ }
+ }
+ })
+
+ g.emit("// SizeBytes implements marshal.Marshallable.SizeBytes.\n")
+ g.emit("func (%s *%s) SizeBytes() int {\n", g.r, g.typeName())
+ g.inIndent(func() {
+ primitiveSize := 0
+ var dynamicSizeTerms []string
+
+ g.forEachField(fieldDispatcher{
+ primitive: func(n, t *ast.Ident) {
+ if size, dynamic := g.scalarSize(t); !dynamic {
+ primitiveSize += size
+ } else {
+ g.recordUsedMarshallable(t.Name)
+ dynamicSizeTerms = append(dynamicSizeTerms, fmt.Sprintf("%s.SizeBytes()", g.fieldAccessor(n)))
+ }
+ },
+ selector: func(n, tX, tSel *ast.Ident) {
+ tName := fmt.Sprintf("%s.%s", tX.Name, tSel.Name)
+ g.recordUsedImport(tX.Name)
+ g.recordUsedMarshallable(tName)
+ dynamicSizeTerms = append(dynamicSizeTerms, fmt.Sprintf("(*%s)(nil).SizeBytes()", tName))
+ },
+ array: func(n, t *ast.Ident, len int) {
+ if len < 1 {
+ // Zero-length arrays should've been rejected by validate().
+ panic("unreachable")
+ }
+ if size, dynamic := g.scalarSize(t); !dynamic {
+ primitiveSize += size * len
+ } else {
+ g.recordUsedMarshallable(t.Name)
+ dynamicSizeTerms = append(dynamicSizeTerms, fmt.Sprintf("(*%s)(nil).SizeBytes()*%d", t.Name, len))
+ }
+ },
+ }.dispatch)
+ g.emit("return %d", primitiveSize)
+ if len(dynamicSizeTerms) > 0 {
+ g.incIndent()
+ }
+ {
+ for _, d := range dynamicSizeTerms {
+ g.emitNoIndent(" +\n")
+ g.emit(d)
+ }
+ }
+ if len(dynamicSizeTerms) > 0 {
+ g.decIndent()
+ }
+ })
+ g.emit("\n}\n\n")
+
+ g.emit("// MarshalBytes implements marshal.Marshallable.MarshalBytes.\n")
+ g.emit("func (%s *%s) MarshalBytes(dst []byte) {\n", g.r, g.typeName())
+ g.inIndent(func() {
+ g.forEachField(fieldDispatcher{
+ primitive: func(n, t *ast.Ident) {
+ if n.Name == "_" {
+ g.emit("// Padding: dst[:sizeof(%s)] ~= %s(0)\n", t.Name, t.Name)
+ if len, dynamic := g.scalarSize(t); !dynamic {
+ g.shift("dst", len)
+ } else {
+ // We can't use shiftDynamic here because we don't have
+ // an instance of the dynamic type we can referece here
+ // (since the version in this struct is anonymous). Use
+ // a typed nil pointer to call SizeBytes() instead.
+ g.emit("dst = dst[(*%s)(nil).SizeBytes():]\n", t.Name)
+ }
+ return
+ }
+ g.marshalScalar(g.fieldAccessor(n), t.Name, "dst")
+ },
+ selector: func(n, tX, tSel *ast.Ident) {
+ g.marshalScalar(g.fieldAccessor(n), fmt.Sprintf("%s.%s", tX.Name, tSel.Name), "dst")
+ },
+ array: func(n, t *ast.Ident, size int) {
+ if n.Name == "_" {
+ g.emit("// Padding: dst[:sizeof(%s)*%d] ~= [%d]%s{0}\n", t.Name, size, size, t.Name)
+ if len, dynamic := g.scalarSize(t); !dynamic {
+ g.shift("dst", len*size)
+ } else {
+ // We can't use shiftDynamic here because we don't have
+ // an instance of the dynamic type we can reference here
+ // (since the version in this struct is anonymous). Use
+ // a typed nil pointer to call SizeBytes() instead.
+ g.emit("dst = dst[(*%s)(nil).SizeBytes()*%d:]\n", t.Name, size)
+ }
+ return
+ }
+
+ g.emit("for i := 0; i < %d; i++ {\n", size)
+ g.inIndent(func() {
+ g.marshalScalar(fmt.Sprintf("%s[i]", g.fieldAccessor(n)), t.Name, "dst")
+ })
+ g.emit("}\n")
+ },
+ }.dispatch)
+ })
+ g.emit("}\n\n")
+
+ g.emit("// UnmarshalBytes implements marshal.Marshallable.UnmarshalBytes.\n")
+ g.emit("func (%s *%s) UnmarshalBytes(src []byte) {\n", g.r, g.typeName())
+ g.inIndent(func() {
+ g.forEachField(fieldDispatcher{
+ primitive: func(n, t *ast.Ident) {
+ if n.Name == "_" {
+ g.emit("// Padding: var _ %s ~= src[:sizeof(%s)]\n", t.Name, t.Name)
+ if len, dynamic := g.scalarSize(t); !dynamic {
+ g.shift("src", len)
+ } else {
+ // We can't use shiftDynamic here because we don't have
+ // an instance of the dynamic type we can reference here
+ // (since the version in this struct is anonymous). Use
+ // a typed nil pointer to call SizeBytes() instead.
+ g.emit("src = src[(*%s)(nil).SizeBytes():]\n", t.Name)
+ g.recordPotentiallyNonPackedField(fmt.Sprintf("(*%s)(nil)", t.Name))
+ }
+ return
+ }
+ g.unmarshalScalar(g.fieldAccessor(n), t.Name, "src")
+ },
+ selector: func(n, tX, tSel *ast.Ident) {
+ g.unmarshalScalar(g.fieldAccessor(n), fmt.Sprintf("%s.%s", tX.Name, tSel.Name), "src")
+ },
+ array: func(n, t *ast.Ident, size int) {
+ if n.Name == "_" {
+ g.emit("// Padding: ~ copy([%d]%s(%s), src[:sizeof(%s)*%d])\n", size, t.Name, g.fieldAccessor(n), t.Name, size)
+ if len, dynamic := g.scalarSize(t); !dynamic {
+ g.shift("src", len*size)
+ } else {
+ // We can't use shiftDynamic here because we don't have
+ // an instance of the dynamic type we can referece here
+ // (since the version in this struct is anonymous). Use
+ // a typed nil pointer to call SizeBytes() instead.
+ g.emit("src = src[(*%s)(nil).SizeBytes()*%d:]\n", t.Name, size)
+ }
+ return
+ }
+
+ g.emit("for i := 0; i < %d; i++ {\n", size)
+ g.inIndent(func() {
+ g.unmarshalScalar(fmt.Sprintf("%s[i]", g.fieldAccessor(n)), t.Name, "src")
+ })
+ g.emit("}\n")
+ },
+ }.dispatch)
+ })
+ g.emit("}\n\n")
+
+ g.emit("// Packed implements marshal.Marshallable.Packed.\n")
+ g.emit("func (%s *%s) Packed() bool {\n", g.r, g.typeName())
+ g.inIndent(func() {
+ expr, fieldsMaybePacked := g.areFieldsPackedExpression()
+ switch {
+ case !thisPacked:
+ g.emit("return false\n")
+ case fieldsMaybePacked:
+ g.emit("return %s\n", expr)
+ default:
+ g.emit("return true\n")
+
+ }
+ })
+ g.emit("}\n\n")
+
+ g.emit("// MarshalUnsafe implements marshal.Marshallable.MarshalUnsafe.\n")
+ g.emit("func (%s *%s) MarshalUnsafe(dst []byte) {\n", g.r, g.typeName())
+ g.inIndent(func() {
+ if thisPacked {
+ g.recordUsedImport("safecopy")
+ g.recordUsedImport("unsafe")
+ if cond, ok := g.areFieldsPackedExpression(); ok {
+ g.emit("if %s {\n", cond)
+ g.inIndent(func() {
+ g.emit("safecopy.CopyIn(dst, unsafe.Pointer(%s))\n", g.r)
+ })
+ g.emit("} else {\n")
+ g.inIndent(func() {
+ g.emit("%s.MarshalBytes(dst)\n", g.r)
+ })
+ g.emit("}\n")
+ } else {
+ g.emit("safecopy.CopyIn(dst, unsafe.Pointer(%s))\n", g.r)
+ }
+ } else {
+ g.emit("// Type %s doesn't have a packed layout in memory, fallback to MarshalBytes.\n", g.typeName())
+ g.emit("%s.MarshalBytes(dst)\n", g.r)
+ }
+ })
+ g.emit("}\n\n")
+
+ g.emit("// UnmarshalUnsafe implements marshal.Marshallable.UnmarshalUnsafe.\n")
+ g.emit("func (%s *%s) UnmarshalUnsafe(src []byte) {\n", g.r, g.typeName())
+ g.inIndent(func() {
+ if thisPacked {
+ g.recordUsedImport("safecopy")
+ g.recordUsedImport("unsafe")
+ if cond, ok := g.areFieldsPackedExpression(); ok {
+ g.emit("if %s {\n", cond)
+ g.inIndent(func() {
+ g.emit("safecopy.CopyOut(unsafe.Pointer(%s), src)\n", g.r)
+ })
+ g.emit("} else {\n")
+ g.inIndent(func() {
+ g.emit("%s.UnmarshalBytes(src)\n", g.r)
+ })
+ g.emit("}\n")
+ } else {
+ g.emit("safecopy.CopyOut(unsafe.Pointer(%s), src)\n", g.r)
+ }
+ } else {
+ g.emit("// Type %s doesn't have a packed layout in memory, fall back to UnmarshalBytes.\n", g.typeName())
+ g.emit("%s.UnmarshalBytes(src)\n", g.r)
+ }
+ })
+ g.emit("}\n\n")
+
+}
diff --git a/tools/go_marshal/gomarshal/generator_tests.go b/tools/go_marshal/gomarshal/generator_tests.go
new file mode 100644
index 000000000..df25cb5b2
--- /dev/null
+++ b/tools/go_marshal/gomarshal/generator_tests.go
@@ -0,0 +1,154 @@
+// Copyright 2019 The gVisor Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+package gomarshal
+
+import (
+ "fmt"
+ "go/ast"
+ "io"
+ "strings"
+)
+
+var standardImports = []string{
+ "fmt",
+ "reflect",
+ "testing",
+ "gvisor.dev/gvisor/tools/go_marshal/analysis",
+}
+
+type testGenerator struct {
+ sourceBuffer
+
+ // The type we're serializing.
+ t *ast.TypeSpec
+
+ // Receiver argument for generated methods.
+ r string
+
+ // Imports used by generated code.
+ imports *importTable
+
+ // Import statement for the package declaring the type we generated code
+ // for. We need this to construct test instances for the type, since the
+ // tests aren't written in the same package.
+ decl *importStmt
+}
+
+func newTestGenerator(t *ast.TypeSpec, declaration string) *testGenerator {
+ if _, ok := t.Type.(*ast.StructType); !ok {
+ panic(fmt.Sprintf("Attempting to generate code for a not struct type %v", t))
+ }
+ g := &testGenerator{
+ t: t,
+ r: receiverName(t),
+ imports: newImportTable(),
+ }
+
+ for _, i := range standardImports {
+ g.imports.add(i).markUsed()
+ }
+ g.decl = g.imports.add(declaration)
+ g.decl.markUsed()
+
+ return g
+}
+
+func (g *testGenerator) typeName() string {
+ return fmt.Sprintf("%s.%s", g.decl.name, g.t.Name.Name)
+}
+
+func (g *testGenerator) forEachField(fn func(f *ast.Field)) {
+ // This is guaranteed to succeed because g.t is always a struct.
+ st := g.t.Type.(*ast.StructType)
+ for _, field := range st.Fields.List {
+ fn(field)
+ }
+}
+
+func (g *testGenerator) testFuncName(base string) string {
+ return fmt.Sprintf("%s%s", base, strings.Title(g.t.Name.Name))
+}
+
+func (g *testGenerator) inTestFunction(name string, body func()) {
+ g.emit("func %s(t *testing.T) {\n", g.testFuncName(name))
+ g.inIndent(body)
+ g.emit("}\n\n")
+}
+
+func (g *testGenerator) emitTestNonZeroSize() {
+ g.inTestFunction("TestSizeNonZero", func() {
+ g.emit("x := &%s{}\n", g.typeName())
+ g.emit("if x.SizeBytes() == 0 {\n")
+ g.inIndent(func() {
+ g.emit("t.Fatal(\"Marshallable.Size() should not return zero\")\n")
+ })
+ g.emit("}\n")
+ })
+}
+
+func (g *testGenerator) emitTestSuspectAlignment() {
+ g.inTestFunction("TestSuspectAlignment", func() {
+ g.emit("x := %s{}\n", g.typeName())
+ g.emit("analysis.AlignmentCheck(t, reflect.TypeOf(x))\n")
+ })
+}
+
+func (g *testGenerator) emitTestMarshalUnmarshalPreservesData() {
+ g.inTestFunction("TestSafeMarshalUnmarshalPreservesData", func() {
+ g.emit("var x, y, z, yUnsafe, zUnsafe %s\n", g.typeName())
+ g.emit("analysis.RandomizeValue(&x)\n\n")
+
+ g.emit("buf := make([]byte, x.SizeBytes())\n")
+ g.emit("x.MarshalBytes(buf)\n")
+ g.emit("bufUnsafe := make([]byte, x.SizeBytes())\n")
+ g.emit("x.MarshalUnsafe(bufUnsafe)\n\n")
+
+ g.emit("y.UnmarshalBytes(buf)\n")
+ g.emit("if !reflect.DeepEqual(x, y) {\n")
+ g.inIndent(func() {
+ g.emit("t.Fatal(fmt.Sprintf(\"Data corrupted across Marshal/Unmarshal cycle:\\nBefore: %%+v\\nAfter: %%+v\\n\", x, y))\n")
+ })
+ g.emit("}\n")
+ g.emit("yUnsafe.UnmarshalBytes(bufUnsafe)\n")
+ g.emit("if !reflect.DeepEqual(x, yUnsafe) {\n")
+ g.inIndent(func() {
+ g.emit("t.Fatal(fmt.Sprintf(\"Data corrupted across MarshalUnsafe/Unmarshal cycle:\\nBefore: %%+v\\nAfter: %%+v\\n\", x, yUnsafe))\n")
+ })
+ g.emit("}\n\n")
+
+ g.emit("z.UnmarshalUnsafe(buf)\n")
+ g.emit("if !reflect.DeepEqual(x, z) {\n")
+ g.inIndent(func() {
+ g.emit("t.Fatal(fmt.Sprintf(\"Data corrupted across Marshal/UnmarshalUnsafe cycle:\\nBefore: %%+v\\nAfter: %%+v\\n\", x, z))\n")
+ })
+ g.emit("}\n")
+ g.emit("zUnsafe.UnmarshalUnsafe(bufUnsafe)\n")
+ g.emit("if !reflect.DeepEqual(x, zUnsafe) {\n")
+ g.inIndent(func() {
+ g.emit("t.Fatal(fmt.Sprintf(\"Data corrupted across MarshalUnsafe/UnmarshalUnsafe cycle:\\nBefore: %%+v\\nAfter: %%+v\\n\", x, zUnsafe))\n")
+ })
+ g.emit("}\n")
+ })
+}
+
+func (g *testGenerator) emitTests() {
+ g.emitTestNonZeroSize()
+ g.emitTestSuspectAlignment()
+ g.emitTestMarshalUnmarshalPreservesData()
+}
+
+func (g *testGenerator) write(out io.Writer) error {
+ return g.sourceBuffer.write(out)
+}
diff --git a/tools/go_marshal/gomarshal/util.go b/tools/go_marshal/gomarshal/util.go
new file mode 100644
index 000000000..967537abf
--- /dev/null
+++ b/tools/go_marshal/gomarshal/util.go
@@ -0,0 +1,387 @@
+// Copyright 2019 The gVisor Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+package gomarshal
+
+import (
+ "bytes"
+ "flag"
+ "fmt"
+ "go/ast"
+ "go/token"
+ "io"
+ "os"
+ "path"
+ "reflect"
+ "sort"
+ "strconv"
+ "strings"
+)
+
+var debug = flag.Bool("debug", false, "enables debugging output")
+
+// receiverName returns an appropriate receiver name given a type spec.
+func receiverName(t *ast.TypeSpec) string {
+ if len(t.Name.Name) < 1 {
+ // Zero length type name?
+ panic("unreachable")
+ }
+ return strings.ToLower(t.Name.Name[:1])
+}
+
+// kindString returns a user-friendly representation of an AST expr type.
+func kindString(e ast.Expr) string {
+ switch e.(type) {
+ case *ast.Ident:
+ return "scalar"
+ case *ast.ArrayType:
+ return "array"
+ case *ast.StructType:
+ return "struct"
+ case *ast.StarExpr:
+ return "pointer"
+ case *ast.FuncType:
+ return "function"
+ case *ast.InterfaceType:
+ return "interface"
+ case *ast.MapType:
+ return "map"
+ case *ast.ChanType:
+ return "channel"
+ default:
+ return reflect.TypeOf(e).String()
+ }
+}
+
+// fieldDispatcher is a collection of callbacks for handling different types of
+// fields in a struct declaration.
+type fieldDispatcher struct {
+ primitive func(n, t *ast.Ident)
+ selector func(n, tX, tSel *ast.Ident)
+ array func(n, t *ast.Ident, size int)
+ unhandled func(n *ast.Ident)
+}
+
+// Precondition: All dispatch callbacks that will be invoked must be
+// provided. Embedded fields are not allowed, len(f.Names) >= 1.
+func (fd fieldDispatcher) dispatch(f *ast.Field) {
+ // Each field declaration may actually be multiple declarations of the same
+ // type. For example, consider:
+ //
+ // type Point struct {
+ // x, y, z int
+ // }
+ //
+ // We invoke the call-backs once per such instance. Embedded fields are not
+ // allowed, and results in a panic.
+ if len(f.Names) < 1 {
+ panic("Precondition not met: attempted to dispatch on embedded field")
+ }
+
+ for _, name := range f.Names {
+ switch v := f.Type.(type) {
+ case *ast.Ident:
+ fd.primitive(name, v)
+ case *ast.SelectorExpr:
+ fd.selector(name, v.X.(*ast.Ident), v.Sel)
+ case *ast.ArrayType:
+ len := 0
+ if v.Len != nil {
+ // Non-literal array length is handled by generatorInterfaces.validate().
+ if lenLit, ok := v.Len.(*ast.BasicLit); ok {
+ var err error
+ len, err = strconv.Atoi(lenLit.Value)
+ if err != nil {
+ panic(err)
+ }
+ }
+ }
+ switch t := v.Elt.(type) {
+ case *ast.Ident:
+ fd.array(name, t, len)
+ default:
+ fd.array(name, nil, len)
+ }
+ default:
+ fd.unhandled(name)
+ }
+ }
+}
+
+// debugEnabled indicates whether debugging is enabled for gomarshal.
+func debugEnabled() bool {
+ return *debug
+}
+
+// abort aborts the go_marshal tool with the given error message.
+func abort(msg string) {
+ if !strings.HasSuffix(msg, "\n") {
+ msg += "\n"
+ }
+ fmt.Print(msg)
+ os.Exit(1)
+}
+
+// abortAt aborts the go_marshal tool with the given error message, with
+// a reference position to the input source.
+func abortAt(p token.Position, msg string) {
+ abort(fmt.Sprintf("%v:\n %s\n", p, msg))
+}
+
+// debugf conditionally prints a debug message.
+func debugf(f string, a ...interface{}) {
+ if debugEnabled() {
+ fmt.Printf(f, a...)
+ }
+}
+
+// debugfAt conditionally prints a debug message with a reference to a position
+// in the input source.
+func debugfAt(p token.Position, f string, a ...interface{}) {
+ if debugEnabled() {
+ fmt.Printf("%s:\n %s", p, fmt.Sprintf(f, a...))
+ }
+}
+
+// emit generates a line of code in the output file.
+//
+// emit is a wrapper around writing a formatted string to the output
+// buffer. emit can be invoked in one of two ways:
+//
+// (1) emit("some string")
+// When emit is called with a single string argument, it is simply copied to
+// the output buffer without any further formatting.
+// (2) emit(fmtString, args...)
+// emit can also be invoked in a similar fashion to *Printf() functions,
+// where the first argument is a format string.
+//
+// Calling emit with a single argument that is not a string will result in a
+// panic, as the caller's intent is ambiguous.
+func emit(out io.Writer, indent int, a ...interface{}) {
+ const spacesPerIndentLevel = 4
+
+ if len(a) < 1 {
+ panic("emit() called with no arguments")
+ }
+
+ if indent > 0 {
+ if _, err := fmt.Fprint(out, strings.Repeat(" ", indent*spacesPerIndentLevel)); err != nil {
+ // Writing to the emit output should not fail. Typically the output
+ // is a byte.Buffer; writes to these never fail.
+ panic(err)
+ }
+ }
+
+ first, ok := a[0].(string)
+ if !ok {
+ // First argument must be either the string to emit (case 1 from
+ // function-level comment), or a format string (case 2).
+ panic(fmt.Sprintf("First argument to emit() is not a string: %+v", a[0]))
+ }
+
+ if len(a) == 1 {
+ // Single string argument. Assume no formatting requested.
+ if _, err := fmt.Fprint(out, first); err != nil {
+ // Writing to out should not fail.
+ panic(err)
+ }
+ return
+
+ }
+
+ // Formatting requested.
+ if _, err := fmt.Fprintf(out, first, a[1:]...); err != nil {
+ // Writing to out should not fail.
+ panic(err)
+ }
+}
+
+// sourceBuffer represents fragments of generated go source code.
+//
+// sourceBuffer provides a convenient way to build up go souce fragments in
+// memory. May be safely zero-value initialized. Not thread-safe.
+type sourceBuffer struct {
+ // Current indentation level.
+ indent int
+
+ // Memory buffer containing contents while they're being generated.
+ b bytes.Buffer
+}
+
+func (b *sourceBuffer) incIndent() {
+ b.indent++
+}
+
+func (b *sourceBuffer) decIndent() {
+ if b.indent <= 0 {
+ panic("decIndent() without matching incIndent()")
+ }
+ b.indent--
+}
+
+func (b *sourceBuffer) emit(a ...interface{}) {
+ emit(&b.b, b.indent, a...)
+}
+
+func (b *sourceBuffer) emitNoIndent(a ...interface{}) {
+ emit(&b.b, 0 /*indent*/, a...)
+}
+
+func (b *sourceBuffer) inIndent(body func()) {
+ b.incIndent()
+ body()
+ b.decIndent()
+}
+
+func (b *sourceBuffer) write(out io.Writer) error {
+ _, err := fmt.Fprint(out, b.b.String())
+ return err
+}
+
+// Write implements io.Writer.Write.
+func (b *sourceBuffer) Write(buf []byte) (int, error) {
+ return (b.b.Write(buf))
+}
+
+// importStmt represents a single import statement.
+type importStmt struct {
+ // Local name of the imported package.
+ name string
+ // Import path.
+ path string
+ // Indicates whether the local name is an alias, or simply the final
+ // component of the path.
+ aliased bool
+ // Indicates whether this import was referenced by generated code.
+ used bool
+}
+
+func newImport(p string) *importStmt {
+ name := path.Base(p)
+ return &importStmt{
+ name: name,
+ path: p,
+ aliased: false,
+ }
+}
+
+func newImportFromSpec(spec *ast.ImportSpec, f *token.FileSet) *importStmt {
+ p := spec.Path.Value[1 : len(spec.Path.Value)-1] // Strip the " quotes around path.
+ name := path.Base(p)
+ if name == "" || name == "/" || name == "." {
+ panic(fmt.Sprintf("Couldn't process local package name for import at %s, (processed as %s)",
+ f.Position(spec.Path.Pos()), name))
+ }
+ if spec.Name != nil {
+ name = spec.Name.Name
+ }
+ return &importStmt{
+ name: name,
+ path: p,
+ aliased: spec.Name != nil,
+ }
+}
+
+func (i *importStmt) String() string {
+ if i.aliased {
+ return fmt.Sprintf("%s \"%s\"", i.name, i.path)
+ }
+ return fmt.Sprintf("\"%s\"", i.path)
+}
+
+func (i *importStmt) markUsed() {
+ i.used = true
+}
+
+func (i *importStmt) equivalent(other *importStmt) bool {
+ return i == other
+}
+
+// importTable represents a collection of importStmts.
+type importTable struct {
+ // Map of imports and whether they should be copied to the output.
+ is map[string]*importStmt
+}
+
+func newImportTable() *importTable {
+ return &importTable{
+ is: make(map[string]*importStmt),
+ }
+}
+
+// Merges import statements from other into i. Collisions in import statements
+// result in a panic.
+func (i *importTable) merge(other *importTable) {
+ for name, im := range other.is {
+ if dup, ok := i.is[name]; ok && dup.equivalent(im) {
+ panic(fmt.Sprintf("Found colliding import statements: ours: %+v, other's: %+v", dup, im))
+ }
+
+ i.is[name] = im
+ }
+}
+
+func (i *importTable) add(s string) *importStmt {
+ n := newImport(s)
+ i.is[n.name] = n
+ return n
+}
+
+func (i *importTable) addFromSpec(spec *ast.ImportSpec, f *token.FileSet) *importStmt {
+ n := newImportFromSpec(spec, f)
+ i.is[n.name] = n
+ return n
+}
+
+// Marks the import named n as used. If no such import is in the table, returns
+// false.
+func (i *importTable) markUsed(n string) bool {
+ if n, ok := i.is[n]; ok {
+ n.markUsed()
+ return true
+ }
+ return false
+}
+
+func (i *importTable) clear() {
+ for _, i := range i.is {
+ i.used = false
+ }
+}
+
+func (i *importTable) write(out io.Writer) error {
+ if len(i.is) == 0 {
+ // Nothing to import, we're done.
+ return nil
+ }
+
+ imports := make([]string, 0, len(i.is))
+ for _, i := range i.is {
+ if i.used {
+ imports = append(imports, i.String())
+ }
+ }
+ sort.Strings(imports)
+
+ var b sourceBuffer
+ b.emit("import (\n")
+ b.incIndent()
+ for _, i := range imports {
+ b.emit("%s\n", i)
+ }
+ b.decIndent()
+ b.emit(")\n\n")
+
+ return b.write(out)
+}
diff --git a/tools/go_marshal/main.go b/tools/go_marshal/main.go
new file mode 100644
index 000000000..3d12eb93c
--- /dev/null
+++ b/tools/go_marshal/main.go
@@ -0,0 +1,73 @@
+// Copyright 2019 Google LLC
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+// go_marshal is a code generation utility for automatically generating code to
+// marshal go data structures to memory.
+//
+// This binary is typically run as part of the build process, and is invoked by
+// the go_marshal bazel rule defined in defs.bzl.
+//
+// See README.md.
+package main
+
+import (
+ "flag"
+ "fmt"
+ "os"
+ "strings"
+
+ "gvisor.dev/gvisor/tools/go_marshal/gomarshal"
+)
+
+var (
+ pkg = flag.String("pkg", "", "output package")
+ output = flag.String("output", "", "output file")
+ outputTest = flag.String("output_test", "", "output file for tests")
+ imports = flag.String("imports", "", "comma-separated list of extra packages to import in generated code")
+ declarationPkg = flag.String("declarationPkg", "", "import path of target declaring the types we're generating on")
+)
+
+func main() {
+ flag.Usage = func() {
+ fmt.Fprintf(os.Stderr, "Usage: %s <input go src files>\n", os.Args[0])
+ flag.PrintDefaults()
+ }
+ flag.Parse()
+ if len(flag.Args()) == 0 {
+ flag.Usage()
+ os.Exit(1)
+ }
+
+ if *pkg == "" {
+ flag.Usage()
+ fmt.Fprint(os.Stderr, "Flag -pkg must be provided.\n")
+ os.Exit(1)
+ }
+
+ var extraImports []string
+ if len(*imports) > 0 {
+ // Note: strings.Split(s, sep) returns s if sep doesn't exist in s. Thus
+ // we check for an empty imports list to avoid emitting an empty string
+ // as an import.
+ extraImports = strings.Split(*imports, ",")
+ }
+ g, err := gomarshal.NewGenerator(flag.Args(), *output, *outputTest, *pkg, *declarationPkg, extraImports)
+ if err != nil {
+ panic(err)
+ }
+
+ if err := g.Run(); err != nil {
+ panic(err)
+ }
+}
diff --git a/tools/go_marshal/marshal/BUILD b/tools/go_marshal/marshal/BUILD
new file mode 100644
index 000000000..47dda97a1
--- /dev/null
+++ b/tools/go_marshal/marshal/BUILD
@@ -0,0 +1,14 @@
+load("@io_bazel_rules_go//go:def.bzl", "go_library")
+
+package(licenses = ["notice"])
+
+go_library(
+ name = "marshal",
+ srcs = [
+ "marshal.go",
+ ],
+ importpath = "gvisor.dev/gvisor/tools/go_marshal/marshal",
+ visibility = [
+ "//:sandbox",
+ ],
+)
diff --git a/tools/go_marshal/marshal/marshal.go b/tools/go_marshal/marshal/marshal.go
new file mode 100644
index 000000000..a313a27ed
--- /dev/null
+++ b/tools/go_marshal/marshal/marshal.go
@@ -0,0 +1,60 @@
+// Copyright 2019 The gVisor Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+// Package marshal defines the Marshallable interface for
+// serialize/deserializing go data structures to/from memory, according to the
+// Linux ABI.
+//
+// Implementations of this interface are typically automatically generated by
+// tools/go_marshal. See the go_marshal README for details.
+package marshal
+
+// Marshallable represents a type that can be marshalled to and from memory.
+type Marshallable interface {
+ // SizeBytes is the size of the memory representation of a type in
+ // marshalled form.
+ SizeBytes() int
+
+ // MarshalBytes serializes a copy of a type to dst. dst must be at least
+ // SizeBytes() long.
+ MarshalBytes(dst []byte)
+
+ // UnmarshalBytes deserializes a type from src. src must be at least
+ // SizeBytes() long.
+ UnmarshalBytes(src []byte)
+
+ // Packed returns true if the marshalled size of the type is the same as the
+ // size it occupies in memory. This happens when the type has no fields
+ // starting at unaligned addresses (should always be true by default for ABI
+ // structs, verified by automatically generated tests when using
+ // go_marshal), and has no fields marked `marshal:"unaligned"`.
+ Packed() bool
+
+ // MarshalUnsafe serializes a type by bulk copying its in-memory
+ // representation to the dst buffer. This is only safe to do when the type
+ // has no implicit padding, see Marshallable.Packed. When Packed would
+ // return false, MarshalUnsafe should fall back to the safer but slower
+ // MarshalBytes.
+ MarshalUnsafe(dst []byte)
+
+ // UnmarshalUnsafe deserializes a type directly to the underlying memory
+ // allocated for the object by the runtime.
+ //
+ // This allows much faster unmarshalling of types which have no implicit
+ // padding, see Marshallable.Packed. When Packed would return false,
+ // UnmarshalUnsafe should fall back to the safer but slower unmarshal
+ // mechanism implemented in UnmarshalBytes (usually by calling
+ // UnmarshalBytes directly).
+ UnmarshalUnsafe(src []byte)
+}
diff --git a/tools/go_marshal/test/BUILD b/tools/go_marshal/test/BUILD
new file mode 100644
index 000000000..fa82f8e9b
--- /dev/null
+++ b/tools/go_marshal/test/BUILD
@@ -0,0 +1,31 @@
+load("@io_bazel_rules_go//go:def.bzl", "go_test")
+
+package(licenses = ["notice"])
+
+load("//tools/go_marshal:defs.bzl", "go_library")
+
+package_group(
+ name = "gomarshal_test",
+ packages = [
+ "//tools/go_marshal/test/...",
+ ],
+)
+
+go_test(
+ name = "benchmark_test",
+ srcs = ["benchmark_test.go"],
+ deps = [
+ ":test",
+ "//pkg/binary",
+ "//pkg/sentry/usermem",
+ "//tools/go_marshal/analysis",
+ ],
+)
+
+go_library(
+ name = "test",
+ testonly = 1,
+ srcs = ["test.go"],
+ importpath = "gvisor.dev/gvisor/tools/go_marshal/test",
+ deps = ["//tools/go_marshal/test/external"],
+)
diff --git a/tools/go_marshal/test/benchmark_test.go b/tools/go_marshal/test/benchmark_test.go
new file mode 100644
index 000000000..e70db06d8
--- /dev/null
+++ b/tools/go_marshal/test/benchmark_test.go
@@ -0,0 +1,178 @@
+// Copyright 2019 The gVisor Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+package benchmark_test
+
+import (
+ "bytes"
+ encbin "encoding/binary"
+ "fmt"
+ "reflect"
+ "testing"
+
+ "gvisor.dev/gvisor/pkg/binary"
+ "gvisor.dev/gvisor/pkg/sentry/usermem"
+ "gvisor.dev/gvisor/tools/go_marshal/analysis"
+ test "gvisor.dev/gvisor/tools/go_marshal/test"
+)
+
+// Marshalling using the standard encoding/binary package.
+func BenchmarkEncodingBinary(b *testing.B) {
+ var s1, s2 test.Stat
+ analysis.RandomizeValue(&s1)
+
+ size := encbin.Size(&s1)
+
+ b.ResetTimer()
+
+ for n := 0; n < b.N; n++ {
+ buf := bytes.NewBuffer(make([]byte, size))
+ buf.Reset()
+ if err := encbin.Write(buf, usermem.ByteOrder, &s1); err != nil {
+ b.Error("Write:", err)
+ }
+ if err := encbin.Read(buf, usermem.ByteOrder, &s2); err != nil {
+ b.Error("Read:", err)
+ }
+ }
+
+ b.StopTimer()
+
+ // Sanity check, make sure the values were preserved.
+ if !reflect.DeepEqual(s1, s2) {
+ panic(fmt.Sprintf("Data corruption across marshal/unmarshal cycle:\nBefore: %+v\nAfter: %+v\n", s1, s2))
+ }
+}
+
+// Marshalling using the sentry's binary.Marshal.
+func BenchmarkBinary(b *testing.B) {
+ var s1, s2 test.Stat
+ analysis.RandomizeValue(&s1)
+
+ size := binary.Size(s1)
+
+ b.ResetTimer()
+
+ for n := 0; n < b.N; n++ {
+ buf := make([]byte, 0, size)
+ buf = binary.Marshal(buf, usermem.ByteOrder, &s1)
+ binary.Unmarshal(buf, usermem.ByteOrder, &s2)
+ }
+
+ b.StopTimer()
+
+ // Sanity check, make sure the values were preserved.
+ if !reflect.DeepEqual(s1, s2) {
+ panic(fmt.Sprintf("Data corruption across marshal/unmarshal cycle:\nBefore: %+v\nAfter: %+v\n", s1, s2))
+ }
+}
+
+// Marshalling field-by-field with manually-written code.
+func BenchmarkMarshalManual(b *testing.B) {
+ var s1, s2 test.Stat
+ analysis.RandomizeValue(&s1)
+
+ b.ResetTimer()
+
+ for n := 0; n < b.N; n++ {
+ buf := make([]byte, 0, s1.SizeBytes())
+
+ // Marshal
+ buf = binary.AppendUint64(buf, usermem.ByteOrder, s1.Dev)
+ buf = binary.AppendUint64(buf, usermem.ByteOrder, s1.Ino)
+ buf = binary.AppendUint64(buf, usermem.ByteOrder, s1.Nlink)
+ buf = binary.AppendUint32(buf, usermem.ByteOrder, s1.Mode)
+ buf = binary.AppendUint32(buf, usermem.ByteOrder, s1.UID)
+ buf = binary.AppendUint32(buf, usermem.ByteOrder, s1.GID)
+ buf = binary.AppendUint32(buf, usermem.ByteOrder, 0)
+ buf = binary.AppendUint64(buf, usermem.ByteOrder, s1.Rdev)
+ buf = binary.AppendUint64(buf, usermem.ByteOrder, uint64(s1.Size))
+ buf = binary.AppendUint64(buf, usermem.ByteOrder, uint64(s1.Blksize))
+ buf = binary.AppendUint64(buf, usermem.ByteOrder, uint64(s1.Blocks))
+ buf = binary.AppendUint64(buf, usermem.ByteOrder, uint64(s1.ATime.Sec))
+ buf = binary.AppendUint64(buf, usermem.ByteOrder, uint64(s1.ATime.Nsec))
+ buf = binary.AppendUint64(buf, usermem.ByteOrder, uint64(s1.MTime.Sec))
+ buf = binary.AppendUint64(buf, usermem.ByteOrder, uint64(s1.MTime.Nsec))
+ buf = binary.AppendUint64(buf, usermem.ByteOrder, uint64(s1.CTime.Sec))
+ buf = binary.AppendUint64(buf, usermem.ByteOrder, uint64(s1.CTime.Nsec))
+
+ // Unmarshal
+ s2.Dev = usermem.ByteOrder.Uint64(buf[0:8])
+ s2.Ino = usermem.ByteOrder.Uint64(buf[8:16])
+ s2.Nlink = usermem.ByteOrder.Uint64(buf[16:24])
+ s2.Mode = usermem.ByteOrder.Uint32(buf[24:28])
+ s2.UID = usermem.ByteOrder.Uint32(buf[28:32])
+ s2.GID = usermem.ByteOrder.Uint32(buf[32:36])
+ // Padding: buf[36:40]
+ s2.Rdev = usermem.ByteOrder.Uint64(buf[40:48])
+ s2.Size = int64(usermem.ByteOrder.Uint64(buf[48:56]))
+ s2.Blksize = int64(usermem.ByteOrder.Uint64(buf[56:64]))
+ s2.Blocks = int64(usermem.ByteOrder.Uint64(buf[64:72]))
+ s2.ATime.Sec = int64(usermem.ByteOrder.Uint64(buf[72:80]))
+ s2.ATime.Nsec = int64(usermem.ByteOrder.Uint64(buf[80:88]))
+ s2.MTime.Sec = int64(usermem.ByteOrder.Uint64(buf[88:96]))
+ s2.MTime.Nsec = int64(usermem.ByteOrder.Uint64(buf[96:104]))
+ s2.CTime.Sec = int64(usermem.ByteOrder.Uint64(buf[104:112]))
+ s2.CTime.Nsec = int64(usermem.ByteOrder.Uint64(buf[112:120]))
+ }
+
+ b.StopTimer()
+
+ // Sanity check, make sure the values were preserved.
+ if !reflect.DeepEqual(s1, s2) {
+ panic(fmt.Sprintf("Data corruption across marshal/unmarshal cycle:\nBefore: %+v\nAfter: %+v\n", s1, s2))
+ }
+}
+
+// Marshalling with the go_marshal safe API.
+func BenchmarkGoMarshalSafe(b *testing.B) {
+ var s1, s2 test.Stat
+ analysis.RandomizeValue(&s1)
+
+ b.ResetTimer()
+
+ for n := 0; n < b.N; n++ {
+ buf := make([]byte, s1.SizeBytes())
+ s1.MarshalBytes(buf)
+ s2.UnmarshalBytes(buf)
+ }
+
+ b.StopTimer()
+
+ // Sanity check, make sure the values were preserved.
+ if !reflect.DeepEqual(s1, s2) {
+ panic(fmt.Sprintf("Data corruption across marshal/unmarshal cycle:\nBefore: %+v\nAfter: %+v\n", s1, s2))
+ }
+}
+
+// Marshalling with the go_marshal unsafe API.
+func BenchmarkGoMarshalUnsafe(b *testing.B) {
+ var s1, s2 test.Stat
+ analysis.RandomizeValue(&s1)
+
+ b.ResetTimer()
+
+ for n := 0; n < b.N; n++ {
+ buf := make([]byte, s1.SizeBytes())
+ s1.MarshalUnsafe(buf)
+ s2.UnmarshalUnsafe(buf)
+ }
+
+ b.StopTimer()
+
+ // Sanity check, make sure the values were preserved.
+ if !reflect.DeepEqual(s1, s2) {
+ panic(fmt.Sprintf("Data corruption across marshal/unmarshal cycle:\nBefore: %+v\nAfter: %+v\n", s1, s2))
+ }
+}
diff --git a/tools/go_marshal/test/external/BUILD b/tools/go_marshal/test/external/BUILD
new file mode 100644
index 000000000..8fb43179b
--- /dev/null
+++ b/tools/go_marshal/test/external/BUILD
@@ -0,0 +1,11 @@
+package(licenses = ["notice"])
+
+load("//tools/go_marshal:defs.bzl", "go_library")
+
+go_library(
+ name = "external",
+ testonly = 1,
+ srcs = ["external.go"],
+ importpath = "gvisor.dev/gvisor/tools/go_marshal/test/external",
+ visibility = ["//tools/go_marshal/test:gomarshal_test"],
+)
diff --git a/test/runtimes/runtimes.go b/tools/go_marshal/test/external/external.go
index 2568e07fe..4be3722f3 100644
--- a/test/runtimes/runtimes.go
+++ b/tools/go_marshal/test/external/external.go
@@ -12,9 +12,12 @@
// See the License for the specific language governing permissions and
// limitations under the License.
-// Package runtimes provides language tests for runsc runtimes.
-// Each test calls docker commands to start up a container for each supported runtime,
-// and tests that its respective language tests are behaving as expected, like
-// connecting to a port or looking at the output. The container is killed and deleted
-// at the end.
-package runtimes
+// Package external defines types we can import for testing.
+package external
+
+// External is a public Marshallable type for use in testing.
+//
+// +marshal
+type External struct {
+ j int64
+}
diff --git a/tools/go_marshal/test/test.go b/tools/go_marshal/test/test.go
new file mode 100644
index 000000000..8de02d707
--- /dev/null
+++ b/tools/go_marshal/test/test.go
@@ -0,0 +1,105 @@
+// Copyright 2019 The gVisor Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+// Package test contains data structures for testing the go_marshal tool.
+package test
+
+import (
+ // We're intentionally using a package name alias here even though it's not
+ // necessary to test the code generator's ability to handle package aliases.
+ ex "gvisor.dev/gvisor/tools/go_marshal/test/external"
+)
+
+// Type1 is a test data type.
+//
+// +marshal
+type Type1 struct {
+ a Type2
+ x, y int64 // Multiple field names.
+ b byte `marshal:"unaligned"` // Short field.
+ c uint64
+ _ uint32 // Unnamed scalar field.
+ _ [6]byte // Unnamed vector field, typical padding.
+ _ [2]byte
+ xs [8]int32
+ as [10]Type2 `marshal:"unaligned"` // Array of Marshallable objects.
+ ss Type3
+}
+
+// Type2 is a test data type.
+//
+// +marshal
+type Type2 struct {
+ n int64
+ c byte
+ _ [7]byte
+ m int64
+ a int64
+}
+
+// Type3 is a test data type.
+//
+// +marshal
+type Type3 struct {
+ s int64
+ x ex.External // Type defined in another package.
+}
+
+// Type4 is a test data type.
+//
+// +marshal
+type Type4 struct {
+ c byte
+ x int64 `marshal:"unaligned"`
+ d byte
+ _ [7]byte
+}
+
+// Type5 is a test data type.
+//
+// +marshal
+type Type5 struct {
+ n int64
+ t Type4
+ m int64
+}
+
+// Timespec represents struct timespec in <time.h>.
+//
+// +marshal
+type Timespec struct {
+ Sec int64
+ Nsec int64
+}
+
+// Stat represents struct stat.
+//
+// +marshal
+type Stat struct {
+ Dev uint64
+ Ino uint64
+ Nlink uint64
+ Mode uint32
+ UID uint32
+ GID uint32
+ _ int32
+ Rdev uint64
+ Size int64
+ Blksize int64
+ Blocks int64
+ ATime Timespec
+ MTime Timespec
+ CTime Timespec
+ _ [3]int64
+}
diff --git a/tools/go_stateify/defs.bzl b/tools/go_stateify/defs.bzl
index aeba197e2..3ce36c1c8 100644
--- a/tools/go_stateify/defs.bzl
+++ b/tools/go_stateify/defs.bzl
@@ -35,7 +35,7 @@ go_library(
)
"""
-load("@io_bazel_rules_go//go:def.bzl", _go_library = "go_library", _go_test = "go_test")
+load("@io_bazel_rules_go//go:def.bzl", _go_library = "go_library")
def _go_stateify_impl(ctx):
"""Implementation for the stateify tool."""
@@ -60,28 +60,57 @@ def _go_stateify_impl(ctx):
executable = ctx.executable._tool,
)
-# Generates save and restore logic from a set of Go files.
-#
-# Args:
-# name: the name of the rule.
-# srcs: the input source files. These files should include all structs in the package that need to be saved.
-# imports: an optional list of extra non-aliased, Go-style absolute import paths.
-# out: the name of the generated file output. This must not conflict with any other files and must be added to the srcs of the relevant go_library.
-# package: the package name for the input sources.
go_stateify = rule(
implementation = _go_stateify_impl,
+ doc = "Generates save and restore logic from a set of Go files.",
attrs = {
- "srcs": attr.label_list(mandatory = True, allow_files = True),
- "imports": attr.string_list(mandatory = False),
- "package": attr.string(mandatory = True),
- "out": attr.output(mandatory = True),
- "_tool": attr.label(executable = True, cfg = "host", default = Label("//tools/go_stateify:stateify")),
+ "srcs": attr.label_list(
+ doc = """
+The input source files. These files should include all structs in the package
+that need to be saved.
+""",
+ mandatory = True,
+ allow_files = True,
+ ),
+ "imports": attr.string_list(
+ doc = """
+An optional list of extra non-aliased, Go-style absolute import paths required
+for statified types.
+""",
+ mandatory = False,
+ ),
+ "package": attr.string(
+ doc = "The package name for the input sources.",
+ mandatory = True,
+ ),
+ "out": attr.output(
+ doc = """
+The name of the generated file output. This must not conflict with any other
+files and must be added to the srcs of the relevant go_library.
+""",
+ mandatory = True,
+ ),
+ "_tool": attr.label(
+ executable = True,
+ cfg = "host",
+ default = Label("//tools/go_stateify:stateify"),
+ ),
"_statepkg": attr.string(default = "gvisor.dev/gvisor/pkg/state"),
},
)
def go_library(name, srcs, deps = [], imports = [], **kwargs):
- """wraps the standard go_library and does stateification."""
+ """Standard go_library wrapped which generates state source files.
+
+ Args:
+ name: the name of the go_library rule.
+ srcs: sources of the go_library. Each will be processed for stateify
+ annotations.
+ deps: dependencies for the go_library.
+ imports: an optional list of extra non-aliased, Go-style absolute import
+ paths required for stateified types.
+ **kwargs: passed to go_library.
+ """
if "encode_unsafe.go" not in srcs and (name + "_state_autogen.go") not in srcs:
# Only do stateification for non-state packages without manual autogen.
go_stateify(
@@ -105,9 +134,3 @@ def go_library(name, srcs, deps = [], imports = [], **kwargs):
deps = all_deps,
**kwargs
)
-
-def go_test(**kwargs):
- """Wraps the standard go_test."""
- _go_test(
- **kwargs
- )
diff --git a/tools/make_repository.sh b/tools/make_repository.sh
index bf9c50d74..071f72b74 100755
--- a/tools/make_repository.sh
+++ b/tools/make_repository.sh
@@ -16,13 +16,14 @@
# Parse arguments. We require more than two arguments, which are the private
# keyring, the e-mail associated with the signer, and the list of packages.
-if [ "$#" -le 2 ]; then
- echo "usage: $0 <private-key> <signer-email> <packages...>"
+if [ "$#" -le 3 ]; then
+ echo "usage: $0 <private-key> <signer-email> <component> <packages...>"
exit 1
fi
declare -r private_key=$(readlink -e "$1")
declare -r signer="$2"
-shift; shift
+declare -r component="$3"
+shift; shift; shift
# Verbose from this point.
set -xeo pipefail
@@ -37,13 +38,20 @@ cleanup() {
rm -f "${keyring}"
}
trap cleanup EXIT
-gpg --no-default-keyring --keyring "${keyring}" --import "${private_key}"
-
-# Export the public key from the keyring.
-gpg --no-default-keyring --keyring "${keyring}" --armor --export "${signer}" > "${tmpdir}"/keyFile
+gpg --no-default-keyring --keyring "${keyring}" --import "${private_key}" >&2
# Copy the packages, and ensure permissions are correct.
-cp -a "$@" "${tmpdir}" && chmod 0644 "${tmpdir}"/*
+for pkg in "$@"; do
+ name=$(basename "${pkg}" .deb)
+ name=$(basename "${name}" .changes)
+ arch=${name##*_}
+ if [[ "${name}" == "${arch}" ]]; then
+ continue # Not a regular package.
+ fi
+ mkdir -p "${tmpdir}"/"${component}"/binary-"${arch}"
+ cp -a "${pkg}" "${tmpdir}"/"${component}"/binary-"${arch}"
+done
+find "${tmpdir}" -type f -exec chmod 0644 {} \;
# Ensure there are no symlinks hanging around; these may be remnants of the
# build process. They may be useful for other things, but we are going to build
@@ -51,19 +59,21 @@ cp -a "$@" "${tmpdir}" && chmod 0644 "${tmpdir}"/*
find "${tmpdir}" -type l -exec rm -f {} \;
# Sign all packages.
-for file in "${tmpdir}"/*.deb; do
- dpkg-sig -g "--no-default-keyring --keyring ${keyring}" --sign builder "${file}"
+for file in "${tmpdir}"/"${component}"/binary-*/*.deb; do
+ dpkg-sig -g "--no-default-keyring --keyring ${keyring}" --sign builder "${file}" >&2
done
# Build the package list.
-(cd "${tmpdir}" && apt-ftparchive packages . | gzip > Packages.gz)
+for dir in "${tmpdir}"/"${component}"/binary-*; do
+ (cd "${dir}" && apt-ftparchive packages . | gzip > Packages.gz)
+done
# Build the release list.
(cd "${tmpdir}" && apt-ftparchive release . > Release)
# Sign the release.
-(cd "${tmpdir}" && gpg --no-default-keyring --keyring "${keyring}" --clearsign -o InRelease Release)
-(cd "${tmpdir}" && gpg --no-default-keyring --keyring "${keyring}" -abs -o Release.gpg Release)
+(cd "${tmpdir}" && gpg --no-default-keyring --keyring "${keyring}" --clearsign -o InRelease Release >&2)
+(cd "${tmpdir}" && gpg --no-default-keyring --keyring "${keyring}" -abs -o Release.gpg Release >&2)
# Show the results.
echo "${tmpdir}"
diff --git a/tools/workspace_status.sh b/tools/workspace_status.sh
index 64a905fc9..fb09ff331 100755
--- a/tools/workspace_status.sh
+++ b/tools/workspace_status.sh
@@ -14,4 +14,5 @@
# See the License for the specific language governing permissions and
# limitations under the License.
-echo VERSION $(git describe --always --tags --abbrev=12 --dirty)
+# The STABLE_ prefix will trigger a re-link if it changes.
+echo STABLE_VERSION $(git describe --always --tags --abbrev=12 --dirty)