summaryrefslogtreecommitdiffhomepage
diff options
context:
space:
mode:
-rw-r--r--.travis.yml19
-rw-r--r--Dockerfile11
-rw-r--r--Makefile5
-rw-r--r--WORKSPACE31
-rw-r--r--benchmarks/harness/machine.py9
-rw-r--r--benchmarks/harness/ssh_connection.py9
-rw-r--r--kokoro/benchmark_tests.cfg26
-rw-r--r--kokoro/runtime_tests/go1.12.cfg10
-rw-r--r--kokoro/runtime_tests/java11.cfg10
-rw-r--r--kokoro/runtime_tests/nodejs12.4.0.cfg10
-rw-r--r--kokoro/runtime_tests/php7.3.6.cfg10
-rw-r--r--kokoro/runtime_tests/python3.7.3.cfg10
-rw-r--r--pkg/abi/linux/BUILD1
-rw-r--r--pkg/abi/linux/epoll.go7
-rw-r--r--pkg/abi/linux/epoll_amd64.go2
-rw-r--r--pkg/abi/linux/epoll_arm64.go2
-rw-r--r--pkg/abi/linux/file.go2
-rw-r--r--pkg/abi/linux/fs.go2
-rw-r--r--pkg/abi/linux/ioctl.go26
-rw-r--r--pkg/abi/linux/ioctl_tun.go (renamed from pkg/usermem/usermem_unsafe.go)24
-rw-r--r--pkg/abi/linux/netfilter.go9
-rw-r--r--pkg/abi/linux/signal.go2
-rw-r--r--pkg/abi/linux/time.go6
-rw-r--r--pkg/abi/linux/xattr.go1
-rw-r--r--pkg/atomicbitops/BUILD10
-rw-r--r--pkg/atomicbitops/atomicbitops.go (renamed from pkg/atomicbitops/atomic_bitops.go)31
-rw-r--r--pkg/atomicbitops/atomicbitops_amd64.s (renamed from pkg/atomicbitops/atomic_bitops_amd64.s)38
-rw-r--r--pkg/atomicbitops/atomicbitops_arm64.s (renamed from pkg/atomicbitops/atomic_bitops_arm64.s)34
-rw-r--r--pkg/atomicbitops/atomicbitops_noasm.go (renamed from pkg/atomicbitops/atomic_bitops_common.go)42
-rw-r--r--pkg/atomicbitops/atomicbitops_test.go (renamed from pkg/atomicbitops/atomic_bitops_test.go)64
-rw-r--r--pkg/cpuid/cpuid_x86.go19
-rw-r--r--pkg/fspath/BUILD4
-rw-r--r--pkg/fspath/builder.go8
-rw-r--r--pkg/fspath/fspath.go3
-rw-r--r--pkg/gohacks/BUILD11
-rw-r--r--pkg/gohacks/gohacks_unsafe.go57
-rw-r--r--pkg/log/glog.go6
-rw-r--r--pkg/log/json.go2
-rw-r--r--pkg/log/json_k8s.go2
-rw-r--r--pkg/log/log.go60
-rw-r--r--pkg/sentry/arch/arch_aarch64.go3
-rw-r--r--pkg/sentry/fs/dev/BUILD5
-rw-r--r--pkg/sentry/fs/dev/dev.go10
-rw-r--r--pkg/sentry/fs/dev/net_tun.go170
-rw-r--r--pkg/sentry/fs/mount_test.go11
-rw-r--r--pkg/sentry/fs/mounts.go10
-rw-r--r--pkg/sentry/fs/proc/mounts.go16
-rw-r--r--pkg/sentry/fs/proc/net.go5
-rw-r--r--pkg/sentry/fs/proc/sys_net.go4
-rw-r--r--pkg/sentry/fsbridge/vfs.go10
-rw-r--r--pkg/sentry/fsimpl/proc/tasks.go4
-rw-r--r--pkg/sentry/fsimpl/proc/tasks_net.go5
-rw-r--r--pkg/sentry/fsimpl/proc/tasks_sys.go4
-rw-r--r--pkg/sentry/fsimpl/testutil/kernel.go1
-rw-r--r--pkg/sentry/fsimpl/tmpfs/filesystem.go8
-rw-r--r--pkg/sentry/fsimpl/tmpfs/regular_file.go233
-rw-r--r--pkg/sentry/fsimpl/tmpfs/tmpfs.go16
-rw-r--r--pkg/sentry/inet/BUILD1
-rw-r--r--pkg/sentry/inet/namespace.go99
-rw-r--r--pkg/sentry/kernel/fd_table.go49
-rw-r--r--pkg/sentry/kernel/fs_context.go22
-rw-r--r--pkg/sentry/kernel/kernel.go30
-rw-r--r--pkg/sentry/kernel/rseq.go2
-rw-r--r--pkg/sentry/kernel/task.go27
-rw-r--r--pkg/sentry/kernel/task_clone.go16
-rw-r--r--pkg/sentry/kernel/task_context.go2
-rw-r--r--pkg/sentry/kernel/task_exec.go2
-rw-r--r--pkg/sentry/kernel/task_log.go6
-rw-r--r--pkg/sentry/kernel/task_net.go19
-rw-r--r--pkg/sentry/kernel/task_start.go8
-rw-r--r--pkg/sentry/kernel/task_usermem.go2
-rw-r--r--pkg/sentry/mm/address_space.go46
-rw-r--r--pkg/sentry/mm/lifecycle.go37
-rw-r--r--pkg/sentry/mm/mm.go5
-rw-r--r--pkg/sentry/mm/mm_test.go2
-rw-r--r--pkg/sentry/platform/kvm/machine.go18
-rw-r--r--pkg/sentry/socket/control/control.go2
-rw-r--r--pkg/sentry/socket/netfilter/BUILD1
-rw-r--r--pkg/sentry/socket/netfilter/netfilter.go63
-rw-r--r--pkg/sentry/socket/netfilter/targets.go35
-rw-r--r--pkg/sentry/socket/netstack/netstack.go27
-rw-r--r--pkg/sentry/socket/netstack/provider.go2
-rw-r--r--pkg/sentry/strace/BUILD1
-rw-r--r--pkg/sentry/strace/epoll.go89
-rw-r--r--pkg/sentry/strace/linux64_amd64.go6
-rw-r--r--pkg/sentry/strace/linux64_arm64.go4
-rw-r--r--pkg/sentry/strace/strace.go8
-rw-r--r--pkg/sentry/strace/syscalls.go10
-rw-r--r--pkg/sentry/syscalls/linux/sys_epoll.go4
-rw-r--r--pkg/sentry/syscalls/linux/sys_file.go40
-rw-r--r--pkg/sentry/syscalls/linux/sys_getdents.go4
-rw-r--r--pkg/sentry/syscalls/linux/sys_lseek.go4
-rw-r--r--pkg/sentry/syscalls/linux/sys_mmap.go4
-rw-r--r--pkg/sentry/syscalls/linux/sys_read.go4
-rw-r--r--pkg/sentry/syscalls/linux/sys_stat.go10
-rw-r--r--pkg/sentry/syscalls/linux/sys_sync.go4
-rw-r--r--pkg/sentry/syscalls/linux/sys_write.go4
-rw-r--r--pkg/sentry/syscalls/linux/sys_xattr.go4
-rw-r--r--pkg/sentry/syscalls/linux/vfs2/BUILD28
-rw-r--r--pkg/sentry/syscalls/linux/vfs2/epoll.go225
-rw-r--r--pkg/sentry/syscalls/linux/vfs2/epoll_unsafe.go44
-rw-r--r--pkg/sentry/syscalls/linux/vfs2/execve.go137
-rw-r--r--pkg/sentry/syscalls/linux/vfs2/fd.go147
-rw-r--r--pkg/sentry/syscalls/linux/vfs2/filesystem.go326
-rw-r--r--pkg/sentry/syscalls/linux/vfs2/fscontext.go131
-rw-r--r--pkg/sentry/syscalls/linux/vfs2/getdents.go149
-rw-r--r--pkg/sentry/syscalls/linux/vfs2/ioctl.go35
-rw-r--r--pkg/sentry/syscalls/linux/vfs2/linux64_override_amd64.go216
-rw-r--r--pkg/sentry/syscalls/linux/vfs2/linux64_override_arm64.go2
-rw-r--r--pkg/sentry/syscalls/linux/vfs2/mmap.go92
-rw-r--r--pkg/sentry/syscalls/linux/vfs2/path.go94
-rw-r--r--pkg/sentry/syscalls/linux/vfs2/poll.go584
-rw-r--r--pkg/sentry/syscalls/linux/vfs2/read_write.go511
-rw-r--r--pkg/sentry/syscalls/linux/vfs2/setstat.go380
-rw-r--r--pkg/sentry/syscalls/linux/vfs2/stat.go346
-rw-r--r--pkg/sentry/syscalls/linux/vfs2/sync.go87
-rw-r--r--pkg/sentry/syscalls/linux/vfs2/sys_read.go95
-rw-r--r--pkg/sentry/syscalls/linux/vfs2/xattr.go353
-rw-r--r--pkg/sentry/vfs/BUILD1
-rw-r--r--pkg/sentry/vfs/epoll.go3
-rw-r--r--pkg/sentry/vfs/mount_unsafe.go12
-rw-r--r--pkg/sentry/vfs/resolving_path.go2
-rw-r--r--pkg/sentry/vfs/vfs.go10
-rw-r--r--pkg/syncevent/BUILD39
-rw-r--r--pkg/syncevent/broadcaster.go218
-rw-r--r--pkg/syncevent/broadcaster_test.go376
-rw-r--r--pkg/syncevent/receiver.go103
-rw-r--r--pkg/syncevent/source.go59
-rw-r--r--pkg/syncevent/syncevent.go32
-rw-r--r--pkg/syncevent/syncevent_example_test.go108
-rw-r--r--pkg/syncevent/waiter_amd64.s32
-rw-r--r--pkg/syncevent/waiter_arm64.s34
-rw-r--r--pkg/syncevent/waiter_asm_unsafe.go (renamed from pkg/fspath/builder_unsafe.go)15
-rw-r--r--pkg/syncevent/waiter_noasm_unsafe.go39
-rw-r--r--pkg/syncevent/waiter_test.go414
-rw-r--r--pkg/syncevent/waiter_unsafe.go206
-rw-r--r--pkg/syserror/syserror.go1
-rw-r--r--pkg/tcpip/adapters/gonet/gonet_test.go63
-rw-r--r--pkg/tcpip/buffer/view.go6
-rw-r--r--pkg/tcpip/checker/checker.go14
-rw-r--r--pkg/tcpip/iptables/iptables.go104
-rw-r--r--pkg/tcpip/iptables/targets.go28
-rw-r--r--pkg/tcpip/iptables/types.go19
-rw-r--r--pkg/tcpip/link/channel/BUILD1
-rw-r--r--pkg/tcpip/link/channel/channel.go180
-rw-r--r--pkg/tcpip/link/tun/BUILD18
-rw-r--r--pkg/tcpip/link/tun/device.go352
-rw-r--r--pkg/tcpip/link/tun/protocol.go56
-rw-r--r--pkg/tcpip/network/arp/arp.go20
-rw-r--r--pkg/tcpip/network/ipv4/ipv4.go6
-rw-r--r--pkg/tcpip/network/ipv6/ipv6.go6
-rw-r--r--pkg/tcpip/stack/ndp.go25
-rw-r--r--pkg/tcpip/stack/ndp_test.go900
-rw-r--r--pkg/tcpip/stack/nic.go149
-rw-r--r--pkg/tcpip/stack/registration.go25
-rw-r--r--pkg/tcpip/stack/stack.go103
-rw-r--r--pkg/tcpip/stack/stack_test.go439
-rw-r--r--pkg/tcpip/stack/transport_demuxer.go20
-rw-r--r--pkg/tcpip/stack/transport_test.go15
-rw-r--r--pkg/tcpip/tcpip.go23
-rw-r--r--pkg/tcpip/time_unsafe.go2
-rw-r--r--pkg/tcpip/transport/icmp/endpoint.go5
-rw-r--r--pkg/tcpip/transport/icmp/protocol.go16
-rw-r--r--pkg/tcpip/transport/packet/endpoint.go5
-rw-r--r--pkg/tcpip/transport/raw/endpoint.go5
-rw-r--r--pkg/tcpip/transport/tcp/accept.go9
-rw-r--r--pkg/tcpip/transport/tcp/connect.go4
-rw-r--r--pkg/tcpip/transport/tcp/dispatcher.go31
-rw-r--r--pkg/tcpip/transport/tcp/endpoint.go33
-rw-r--r--pkg/tcpip/transport/tcp/protocol.go14
-rw-r--r--pkg/tcpip/transport/tcp/tcp_test.go5
-rw-r--r--pkg/tcpip/transport/udp/endpoint.go43
-rw-r--r--pkg/tcpip/transport/udp/protocol.go14
-rw-r--r--pkg/tcpip/transport/udp/udp_test.go120
-rw-r--r--pkg/usermem/BUILD2
-rw-r--r--pkg/usermem/usermem.go9
-rw-r--r--runsc/boot/BUILD2
-rw-r--r--runsc/boot/controller.go11
-rw-r--r--runsc/boot/filter/config.go2
-rw-r--r--runsc/boot/loader.go121
-rw-r--r--runsc/boot/network.go27
-rw-r--r--runsc/boot/pprof/BUILD11
-rw-r--r--runsc/boot/pprof/pprof.go (renamed from runsc/boot/pprof.go)6
-rw-r--r--runsc/cmd/BUILD3
-rw-r--r--runsc/cmd/statefile.go143
-rw-r--r--runsc/container/container_test.go1
-rw-r--r--runsc/main.go3
-rw-r--r--runsc/sandbox/network.go25
-rw-r--r--scripts/benchmark.sh25
-rwxr-xr-xscripts/common_build.sh12
-rw-r--r--test/iptables/filter_input.go218
-rw-r--r--test/iptables/iptables_test.go36
-rw-r--r--test/iptables/iptables_util.go23
-rw-r--r--test/perf/BUILD114
-rw-r--r--test/perf/linux/BUILD356
-rw-r--r--test/perf/linux/clock_getres_benchmark.cc39
-rw-r--r--test/perf/linux/clock_gettime_benchmark.cc60
-rw-r--r--test/perf/linux/death_benchmark.cc36
-rw-r--r--test/perf/linux/epoll_benchmark.cc99
-rw-r--r--test/perf/linux/fork_benchmark.cc350
-rw-r--r--test/perf/linux/futex_benchmark.cc248
-rw-r--r--test/perf/linux/getdents_benchmark.cc149
-rw-r--r--test/perf/linux/getpid_benchmark.cc37
-rw-r--r--test/perf/linux/gettid_benchmark.cc38
-rw-r--r--test/perf/linux/mapping_benchmark.cc163
-rw-r--r--test/perf/linux/open_benchmark.cc56
-rw-r--r--test/perf/linux/pipe_benchmark.cc66
-rw-r--r--test/perf/linux/randread_benchmark.cc100
-rw-r--r--test/perf/linux/read_benchmark.cc53
-rw-r--r--test/perf/linux/sched_yield_benchmark.cc37
-rw-r--r--test/perf/linux/send_recv_benchmark.cc372
-rw-r--r--test/perf/linux/seqwrite_benchmark.cc66
-rw-r--r--test/perf/linux/signal_benchmark.cc59
-rw-r--r--test/perf/linux/sleep_benchmark.cc60
-rw-r--r--test/perf/linux/stat_benchmark.cc62
-rw-r--r--test/perf/linux/unlink_benchmark.cc66
-rw-r--r--test/perf/linux/write_benchmark.cc52
-rw-r--r--test/runner/BUILD22
-rw-r--r--test/runner/defs.bzl (renamed from test/syscalls/build_defs.bzl)246
-rw-r--r--test/runner/gtest/BUILD (renamed from test/syscalls/gtest/BUILD)0
-rw-r--r--test/runner/gtest/gtest.go167
-rw-r--r--test/runner/runner.go (renamed from test/syscalls/syscall_test_runner.go)33
-rw-r--r--test/syscalls/BUILD25
-rw-r--r--test/syscalls/gtest/gtest.go93
-rw-r--r--test/syscalls/linux/32bit.cc2
-rw-r--r--test/syscalls/linux/BUILD51
-rw-r--r--test/syscalls/linux/alarm.cc3
-rw-r--r--test/syscalls/linux/dev.cc7
-rw-r--r--test/syscalls/linux/exec.cc3
-rw-r--r--test/syscalls/linux/fallocate.cc12
-rw-r--r--test/syscalls/linux/fcntl.cc2
-rw-r--r--test/syscalls/linux/ip_socket_test_util.h16
-rw-r--r--test/syscalls/linux/itimer.cc3
-rw-r--r--test/syscalls/linux/network_namespace.cc121
-rw-r--r--test/syscalls/linux/prctl.cc2
-rw-r--r--test/syscalls/linux/prctl_setuid.cc2
-rw-r--r--test/syscalls/linux/proc.cc2
-rw-r--r--test/syscalls/linux/ptrace.cc2
-rw-r--r--test/syscalls/linux/rseq/uapi.h29
-rw-r--r--test/syscalls/linux/rtsignal.cc3
-rw-r--r--test/syscalls/linux/seccomp.cc2
-rw-r--r--test/syscalls/linux/sigiret.cc3
-rw-r--r--test/syscalls/linux/signalfd.cc2
-rw-r--r--test/syscalls/linux/sigstop.cc2
-rw-r--r--test/syscalls/linux/sigtimedwait.cc3
-rw-r--r--test/syscalls/linux/socket_ip_udp_generic.cc133
-rw-r--r--test/syscalls/linux/socket_netlink_route_util.cc163
-rw-r--r--test/syscalls/linux/socket_netlink_route_util.h55
-rw-r--r--test/syscalls/linux/timers.cc2
-rw-r--r--test/syscalls/linux/tuntap.cc346
-rw-r--r--test/syscalls/linux/udp_socket_test_cases.cc8
-rw-r--r--test/syscalls/linux/vfork.cc2
-rwxr-xr-xtest/syscalls/syscall_test_runner.sh34
-rw-r--r--test/util/BUILD3
-rw-r--r--test/util/test_main.cc2
-rw-r--r--test/util/test_util.h1
-rw-r--r--test/util/test_util_impl.cc14
-rw-r--r--tools/bazeldefs/defs.bzl3
-rw-r--r--tools/bazeldefs/platforms.bzl17
-rw-r--r--tools/defs.bzl15
-rw-r--r--tools/go_marshal/BUILD5
-rw-r--r--tools/go_marshal/gomarshal/generator.go74
-rw-r--r--tools/go_marshal/gomarshal/generator_interfaces.go317
-rw-r--r--tools/go_marshal/gomarshal/generator_tests.go17
-rw-r--r--tools/go_marshal/marshal/marshal.go4
-rw-r--r--tools/go_marshal/test/test.go10
-rwxr-xr-xtools/installers/master.sh17
267 files changed, 14084 insertions, 1871 deletions
diff --git a/.travis.yml b/.travis.yml
new file mode 100644
index 000000000..a2a260538
--- /dev/null
+++ b/.travis.yml
@@ -0,0 +1,19 @@
+language: minimal
+sudo: required
+dist: xenial
+cache:
+ directories:
+ - /home/travis/.cache/bazel/
+services:
+ - docker
+matrix:
+ include:
+ - os: linux
+ arch: amd64
+ env: RUNSC_PATH=./bazel-bin/runsc/linux_amd64_pure_stripped/runsc
+ - os: linux
+ arch: arm64
+ env: RUNSC_PATH=./bazel-bin/runsc/linux_arm64_pure_stripped/runsc
+script:
+ - uname -a
+ - make DOCKER_RUN_OPTIONS="" BAZEL_OPTIONS="build runsc:runsc" bazel && $RUNSC_PATH --alsologtostderr --network none --debug --TESTONLY-unsafe-nonroot=true --rootless do ls
diff --git a/Dockerfile b/Dockerfile
index 738623023..2bfdfec6c 100644
--- a/Dockerfile
+++ b/Dockerfile
@@ -1,8 +1,9 @@
-FROM ubuntu:bionic
+FROM fedora:31
-RUN apt-get update && apt-get install -y curl gnupg2 git python python3 python3-distutils python3-pip
-RUN echo "deb [arch=amd64] http://storage.googleapis.com/bazel-apt stable jdk1.8" | tee /etc/apt/sources.list.d/bazel.list && \
- curl https://bazel.build/bazel-release.pub.gpg | apt-key add -
-RUN apt-get update && apt-get install -y bazel && apt-get clean
+RUN dnf install -y dnf-plugins-core && dnf copr enable -y vbatts/bazel
+
+RUN dnf install -y bazel2 git gcc make golang gcc-c++ glibc-devel python3 which python3-pip python3-devel libffi-devel openssl-devel pkg-config glibc-static
+
+RUN pip install pycparser
WORKDIR /gvisor
diff --git a/Makefile b/Makefile
index a73bc0c36..d9531fbd5 100644
--- a/Makefile
+++ b/Makefile
@@ -2,6 +2,9 @@ UID := $(shell id -u ${USER})
GID := $(shell id -g ${USER})
GVISOR_BAZEL_CACHE := $(shell readlink -f ~/.cache/bazel/)
+# The --privileged is required to run tests.
+DOCKER_RUN_OPTIONS ?= --privileged
+
all: runsc
docker-build:
@@ -19,7 +22,7 @@ bazel-server-start: docker-build
-v "$(CURDIR):$(CURDIR)" \
--workdir "$(CURDIR)" \
--tmpfs /tmp:rw,exec \
- --privileged \
+ $(DOCKER_RUN_OPTIONS) \
gvisor-bazel \
sh -c "while :; do sleep 100; done" && \
docker exec --user 0:0 -i gvisor-bazel sh -c "groupadd --gid $(GID) --non-unique gvisor && useradd --uid $(UID) --non-unique --gid $(GID) -d $(HOME) gvisor"
diff --git a/WORKSPACE b/WORKSPACE
index 2827c3a26..a15238a2e 100644
--- a/WORKSPACE
+++ b/WORKSPACE
@@ -33,6 +33,20 @@ load("@bazel_gazelle//:deps.bzl", "gazelle_dependencies", "go_repository")
gazelle_dependencies()
+# TODO(gvisor.dev/issue/1876): Move the statement to "External repositories"
+# block below once 1876 is fixed.
+#
+# The com_google_protobuf repository below would trigger downloading a older
+# version of org_golang_x_sys. If putting this repository statment in a place
+# after that of the com_google_protobuf, this statement will not work as
+# expectd to download a new version of org_golang_x_sys.
+go_repository(
+ name = "org_golang_x_sys",
+ importpath = "golang.org/x/sys",
+ sum = "h1:72l8qCJ1nGxMGH26QVBVIxKd/D34cfGt0OvrPtpemyY=",
+ version = "v0.0.0-20191220220014-0732a990476f",
+)
+
# Load C++ rules.
http_archive(
name = "rules_cc",
@@ -257,13 +271,6 @@ go_repository(
)
go_repository(
- name = "org_golang_x_sys",
- importpath = "golang.org/x/sys",
- sum = "h1:1BGLXjeY4akVXGgbC9HugT3Jv3hCI0z56oJR5vAMgBU=",
- version = "v0.0.0-20190215142949-d0b11bdaac8a",
-)
-
-go_repository(
name = "org_golang_x_time",
commit = "c4c64cad1fd0a1a8dab2523e04e61d35308e131e",
importpath = "golang.org/x/time",
@@ -330,3 +337,13 @@ http_archive(
"https://github.com/google/googletest/archive/565f1b848215b77c3732bca345fe76a0431d8b34.tar.gz",
],
)
+
+http_archive(
+ name = "com_google_benchmark",
+ sha256 = "3c6a165b6ecc948967a1ead710d4a181d7b0fbcaa183ef7ea84604994966221a",
+ strip_prefix = "benchmark-1.5.0",
+ urls = [
+ "https://mirror.bazel.build/github.com/google/benchmark/archive/v1.5.0.tar.gz",
+ "https://github.com/google/benchmark/archive/v1.5.0.tar.gz",
+ ],
+)
diff --git a/benchmarks/harness/machine.py b/benchmarks/harness/machine.py
index 3d32d3dda..5bdc4aa85 100644
--- a/benchmarks/harness/machine.py
+++ b/benchmarks/harness/machine.py
@@ -43,6 +43,8 @@ from benchmarks.harness import machine_mocks
from benchmarks.harness import ssh_connection
from benchmarks.harness import tunnel_dispatcher
+log = logging.getLogger(__name__)
+
class Machine(object):
"""The machine object is the primary object for benchmarks.
@@ -236,9 +238,10 @@ class RemoteMachine(Machine):
archive=archive, dir=harness.REMOTE_INSTALLERS_PATH))
self._has_installers = True
- # Execute the remote installer.
- self.run("sudo {dir}/{file}".format(
- dir=harness.REMOTE_INSTALLERS_PATH, file=installer))
+ # Execute the remote installer.
+ self.run("sudo {dir}/{file}".format(
+ dir=harness.REMOTE_INSTALLERS_PATH, file=installer))
+
if results:
results[index] = True
diff --git a/benchmarks/harness/ssh_connection.py b/benchmarks/harness/ssh_connection.py
index a50e34293..b8c8e42d4 100644
--- a/benchmarks/harness/ssh_connection.py
+++ b/benchmarks/harness/ssh_connection.py
@@ -13,7 +13,7 @@
# limitations under the License.
"""SSHConnection handles the details of SSH connections."""
-
+import logging
import os
import warnings
@@ -24,6 +24,8 @@ from benchmarks import harness
# Get rid of paramiko Cryptography Warnings.
warnings.filterwarnings(action="ignore", module=".*paramiko.*")
+log = logging.getLogger(__name__)
+
def send_one_file(client: paramiko.SSHClient, path: str,
remote_dir: str) -> str:
@@ -94,10 +96,13 @@ class SSHConnection:
The contents of stdout and stderr.
"""
with self._client() as client:
+ log.info("running command: %s", cmd)
_, stdout, stderr = client.exec_command(command=cmd)
- stdout.channel.recv_exit_status()
+ log.info("returned status: %d", stdout.channel.recv_exit_status())
stdout = stdout.read().decode("utf-8")
stderr = stderr.read().decode("utf-8")
+ log.info("stdout: %s", stdout)
+ log.info("stderr: %s", stderr)
return stdout, stderr
def send_workload(self, name: str) -> str:
diff --git a/kokoro/benchmark_tests.cfg b/kokoro/benchmark_tests.cfg
new file mode 100644
index 000000000..c48518a05
--- /dev/null
+++ b/kokoro/benchmark_tests.cfg
@@ -0,0 +1,26 @@
+build_file : 'repo/scripts/benchmark.sh'
+
+
+before_action {
+ fetch_keystore {
+ keystore_resource {
+ keystore_config_id : 73898
+ keyname : 'kokoro-rbe-service-account'
+ },
+ }
+}
+
+env_vars {
+ key : 'PROJECT'
+ value : 'gvisor-kokoro-testing'
+}
+
+env_vars {
+ key : 'ZONE'
+ value : 'us-central1-b'
+}
+
+env_vars {
+ key : 'KOKORO_SERVICE_ACCOUNT'
+ value : '73898_kokoro-rbe-service-account'
+}
diff --git a/kokoro/runtime_tests/go1.12.cfg b/kokoro/runtime_tests/go1.12.cfg
index 164ddc18f..fd4911e88 100644
--- a/kokoro/runtime_tests/go1.12.cfg
+++ b/kokoro/runtime_tests/go1.12.cfg
@@ -4,3 +4,13 @@ env_vars {
key: "RUNTIME_TEST_NAME"
value: "go1.12"
}
+
+action {
+ define_artifacts {
+ regex: "**/sponge_log.xml"
+ regex: "**/sponge_log.log"
+ regex: "**/outputs.zip"
+ regex: "**/runsc"
+ regex: "**/runsc.*"
+ }
+} \ No newline at end of file
diff --git a/kokoro/runtime_tests/java11.cfg b/kokoro/runtime_tests/java11.cfg
index 4957d4794..7f8611a08 100644
--- a/kokoro/runtime_tests/java11.cfg
+++ b/kokoro/runtime_tests/java11.cfg
@@ -4,3 +4,13 @@ env_vars {
key: "RUNTIME_TEST_NAME"
value: "java11"
}
+
+action {
+ define_artifacts {
+ regex: "**/sponge_log.xml"
+ regex: "**/sponge_log.log"
+ regex: "**/outputs.zip"
+ regex: "**/runsc"
+ regex: "**/runsc.*"
+ }
+} \ No newline at end of file
diff --git a/kokoro/runtime_tests/nodejs12.4.0.cfg b/kokoro/runtime_tests/nodejs12.4.0.cfg
index 1df343f95..c67ad5567 100644
--- a/kokoro/runtime_tests/nodejs12.4.0.cfg
+++ b/kokoro/runtime_tests/nodejs12.4.0.cfg
@@ -4,3 +4,13 @@ env_vars {
key: "RUNTIME_TEST_NAME"
value: "nodejs12.4.0"
}
+
+action {
+ define_artifacts {
+ regex: "**/sponge_log.xml"
+ regex: "**/sponge_log.log"
+ regex: "**/outputs.zip"
+ regex: "**/runsc"
+ regex: "**/runsc.*"
+ }
+} \ No newline at end of file
diff --git a/kokoro/runtime_tests/php7.3.6.cfg b/kokoro/runtime_tests/php7.3.6.cfg
index 8e3667125..f266c5e26 100644
--- a/kokoro/runtime_tests/php7.3.6.cfg
+++ b/kokoro/runtime_tests/php7.3.6.cfg
@@ -4,3 +4,13 @@ env_vars {
key: "RUNTIME_TEST_NAME"
value: "php7.3.6"
}
+
+action {
+ define_artifacts {
+ regex: "**/sponge_log.xml"
+ regex: "**/sponge_log.log"
+ regex: "**/outputs.zip"
+ regex: "**/runsc"
+ regex: "**/runsc.*"
+ }
+} \ No newline at end of file
diff --git a/kokoro/runtime_tests/python3.7.3.cfg b/kokoro/runtime_tests/python3.7.3.cfg
index 0ca70d5bb..574add152 100644
--- a/kokoro/runtime_tests/python3.7.3.cfg
+++ b/kokoro/runtime_tests/python3.7.3.cfg
@@ -4,3 +4,13 @@ env_vars {
key: "RUNTIME_TEST_NAME"
value: "python3.7.3"
}
+
+action {
+ define_artifacts {
+ regex: "**/sponge_log.xml"
+ regex: "**/sponge_log.log"
+ regex: "**/outputs.zip"
+ regex: "**/runsc"
+ regex: "**/runsc.*"
+ }
+} \ No newline at end of file
diff --git a/pkg/abi/linux/BUILD b/pkg/abi/linux/BUILD
index a89f34d4b..322d1ccc4 100644
--- a/pkg/abi/linux/BUILD
+++ b/pkg/abi/linux/BUILD
@@ -30,6 +30,7 @@ go_library(
"futex.go",
"inotify.go",
"ioctl.go",
+ "ioctl_tun.go",
"ip.go",
"ipc.go",
"limits.go",
diff --git a/pkg/abi/linux/epoll.go b/pkg/abi/linux/epoll.go
index 6e4de69da..1121a1a92 100644
--- a/pkg/abi/linux/epoll.go
+++ b/pkg/abi/linux/epoll.go
@@ -14,6 +14,10 @@
package linux
+import (
+ "gvisor.dev/gvisor/pkg/binary"
+)
+
// Event masks.
const (
EPOLLIN = 0x1
@@ -53,3 +57,6 @@ const (
EPOLL_CTL_DEL = 0x2
EPOLL_CTL_MOD = 0x3
)
+
+// SizeOfEpollEvent is the size of EpollEvent struct.
+var SizeOfEpollEvent = int(binary.Size(EpollEvent{}))
diff --git a/pkg/abi/linux/epoll_amd64.go b/pkg/abi/linux/epoll_amd64.go
index 57041491c..34ff18009 100644
--- a/pkg/abi/linux/epoll_amd64.go
+++ b/pkg/abi/linux/epoll_amd64.go
@@ -15,6 +15,8 @@
package linux
// EpollEvent is equivalent to struct epoll_event from epoll(2).
+//
+// +marshal
type EpollEvent struct {
Events uint32
// Linux makes struct epoll_event::data a __u64. We represent it as
diff --git a/pkg/abi/linux/epoll_arm64.go b/pkg/abi/linux/epoll_arm64.go
index 62ef5821e..f86c35329 100644
--- a/pkg/abi/linux/epoll_arm64.go
+++ b/pkg/abi/linux/epoll_arm64.go
@@ -15,6 +15,8 @@
package linux
// EpollEvent is equivalent to struct epoll_event from epoll(2).
+//
+// +marshal
type EpollEvent struct {
Events uint32
// Linux makes struct epoll_event a __u64, necessitating 4 bytes of padding
diff --git a/pkg/abi/linux/file.go b/pkg/abi/linux/file.go
index c3ab15a4f..e229ac21c 100644
--- a/pkg/abi/linux/file.go
+++ b/pkg/abi/linux/file.go
@@ -241,6 +241,8 @@ const (
)
// Statx represents struct statx.
+//
+// +marshal
type Statx struct {
Mask uint32
Blksize uint32
diff --git a/pkg/abi/linux/fs.go b/pkg/abi/linux/fs.go
index 2c652baa2..158d2db5b 100644
--- a/pkg/abi/linux/fs.go
+++ b/pkg/abi/linux/fs.go
@@ -38,6 +38,8 @@ const (
)
// Statfs is struct statfs, from uapi/asm-generic/statfs.h.
+//
+// +marshal
type Statfs struct {
// Type is one of the filesystem magic values, defined above.
Type uint64
diff --git a/pkg/abi/linux/ioctl.go b/pkg/abi/linux/ioctl.go
index 0e18db9ef..2062e6a4b 100644
--- a/pkg/abi/linux/ioctl.go
+++ b/pkg/abi/linux/ioctl.go
@@ -72,3 +72,29 @@ const (
SIOCGMIIPHY = 0x8947
SIOCGMIIREG = 0x8948
)
+
+// ioctl(2) directions. Used to calculate requests number.
+// Constants from asm-generic/ioctl.h.
+const (
+ _IOC_NONE = 0
+ _IOC_WRITE = 1
+ _IOC_READ = 2
+)
+
+// Constants from asm-generic/ioctl.h.
+const (
+ _IOC_NRBITS = 8
+ _IOC_TYPEBITS = 8
+ _IOC_SIZEBITS = 14
+ _IOC_DIRBITS = 2
+
+ _IOC_NRSHIFT = 0
+ _IOC_TYPESHIFT = _IOC_NRSHIFT + _IOC_NRBITS
+ _IOC_SIZESHIFT = _IOC_TYPESHIFT + _IOC_TYPEBITS
+ _IOC_DIRSHIFT = _IOC_SIZESHIFT + _IOC_SIZEBITS
+)
+
+// IOC outputs the result of _IOC macro in asm-generic/ioctl.h.
+func IOC(dir, typ, nr, size uint32) uint32 {
+ return uint32(dir)<<_IOC_DIRSHIFT | typ<<_IOC_TYPESHIFT | nr<<_IOC_NRSHIFT | size<<_IOC_SIZESHIFT
+}
diff --git a/pkg/usermem/usermem_unsafe.go b/pkg/abi/linux/ioctl_tun.go
index 876783e78..c59c9c136 100644
--- a/pkg/usermem/usermem_unsafe.go
+++ b/pkg/abi/linux/ioctl_tun.go
@@ -1,4 +1,4 @@
-// Copyright 2019 The gVisor Authors.
+// Copyright 2020 The gVisor Authors.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
@@ -12,16 +12,18 @@
// See the License for the specific language governing permissions and
// limitations under the License.
-package usermem
+package linux
-import (
- "unsafe"
+// ioctl(2) request numbers from linux/if_tun.h
+var (
+ TUNSETIFF = IOC(_IOC_WRITE, 'T', 202, 4)
+ TUNGETIFF = IOC(_IOC_READ, 'T', 210, 4)
)
-// stringFromImmutableBytes is equivalent to string(bs), except that it never
-// copies even if escape analysis can't prove that bs does not escape. This is
-// only valid if bs is never mutated after stringFromImmutableBytes returns.
-func stringFromImmutableBytes(bs []byte) string {
- // Compare strings.Builder.String().
- return *(*string)(unsafe.Pointer(&bs))
-}
+// Flags from net/if_tun.h
+const (
+ IFF_TUN = 0x0001
+ IFF_TAP = 0x0002
+ IFF_NO_PI = 0x1000
+ IFF_NOFILTER = 0x1000
+)
diff --git a/pkg/abi/linux/netfilter.go b/pkg/abi/linux/netfilter.go
index 2179ac995..314c318b6 100644
--- a/pkg/abi/linux/netfilter.go
+++ b/pkg/abi/linux/netfilter.go
@@ -225,11 +225,14 @@ type XTEntryTarget struct {
// SizeOfXTEntryTarget is the size of an XTEntryTarget.
const SizeOfXTEntryTarget = 32
-// XTStandardTarget is a builtin target, one of ACCEPT, DROP, JUMP, QUEUE, or
-// RETURN. It corresponds to struct xt_standard_target in
+// XTStandardTarget is a built-in target, one of ACCEPT, DROP, JUMP, QUEUE,
+// RETURN, or jump. It corresponds to struct xt_standard_target in
// include/uapi/linux/netfilter/x_tables.h.
type XTStandardTarget struct {
- Target XTEntryTarget
+ Target XTEntryTarget
+ // A positive verdict indicates a jump, and is the offset from the
+ // start of the table to jump to. A negative value means one of the
+ // other built-in targets.
Verdict int32
_ [4]byte
}
diff --git a/pkg/abi/linux/signal.go b/pkg/abi/linux/signal.go
index c69b04ea9..1c330e763 100644
--- a/pkg/abi/linux/signal.go
+++ b/pkg/abi/linux/signal.go
@@ -115,6 +115,8 @@ const (
)
// SignalSet is a signal mask with a bit corresponding to each signal.
+//
+// +marshal
type SignalSet uint64
// SignalSetSize is the size in bytes of a SignalSet.
diff --git a/pkg/abi/linux/time.go b/pkg/abi/linux/time.go
index e562b46d9..e6860ed49 100644
--- a/pkg/abi/linux/time.go
+++ b/pkg/abi/linux/time.go
@@ -157,6 +157,8 @@ func DurationToTimespec(dur time.Duration) Timespec {
const SizeOfTimeval = 16
// Timeval represents struct timeval in <time.h>.
+//
+// +marshal
type Timeval struct {
Sec int64
Usec int64
@@ -230,6 +232,8 @@ type Tms struct {
type TimerID int32
// StatxTimestamp represents struct statx_timestamp.
+//
+// +marshal
type StatxTimestamp struct {
Sec int64
Nsec uint32
@@ -258,6 +262,8 @@ func NsecToStatxTimestamp(nsec int64) (ts StatxTimestamp) {
}
// Utime represents struct utimbuf used by utimes(2).
+//
+// +marshal
type Utime struct {
Actime int64
Modtime int64
diff --git a/pkg/abi/linux/xattr.go b/pkg/abi/linux/xattr.go
index a3b6406fa..99180b208 100644
--- a/pkg/abi/linux/xattr.go
+++ b/pkg/abi/linux/xattr.go
@@ -18,6 +18,7 @@ package linux
const (
XATTR_NAME_MAX = 255
XATTR_SIZE_MAX = 65536
+ XATTR_LIST_MAX = 65536
XATTR_CREATE = 1
XATTR_REPLACE = 2
diff --git a/pkg/atomicbitops/BUILD b/pkg/atomicbitops/BUILD
index 3948074ba..1a30f6967 100644
--- a/pkg/atomicbitops/BUILD
+++ b/pkg/atomicbitops/BUILD
@@ -5,10 +5,10 @@ package(licenses = ["notice"])
go_library(
name = "atomicbitops",
srcs = [
- "atomic_bitops.go",
- "atomic_bitops_amd64.s",
- "atomic_bitops_arm64.s",
- "atomic_bitops_common.go",
+ "atomicbitops.go",
+ "atomicbitops_amd64.s",
+ "atomicbitops_arm64.s",
+ "atomicbitops_noasm.go",
],
visibility = ["//:sandbox"],
)
@@ -16,7 +16,7 @@ go_library(
go_test(
name = "atomicbitops_test",
size = "small",
- srcs = ["atomic_bitops_test.go"],
+ srcs = ["atomicbitops_test.go"],
library = ":atomicbitops",
deps = ["//pkg/sync"],
)
diff --git a/pkg/atomicbitops/atomic_bitops.go b/pkg/atomicbitops/atomicbitops.go
index fcc41a9ea..1be081719 100644
--- a/pkg/atomicbitops/atomic_bitops.go
+++ b/pkg/atomicbitops/atomicbitops.go
@@ -14,47 +14,34 @@
// +build amd64 arm64
-// Package atomicbitops provides basic bitwise operations in an atomic way.
-// The implementation on amd64 leverages the LOCK prefix directly instead of
-// relying on the generic cas primitives, and the arm64 leverages the LDAXR
-// and STLXR pair primitives.
+// Package atomicbitops provides extensions to the sync/atomic package.
//
-// WARNING: the bitwise ops provided in this package doesn't imply any memory
-// ordering. Using them to construct locks must employ proper memory barriers.
+// All read-modify-write operations implemented by this package have
+// acquire-release memory ordering (like sync/atomic).
package atomicbitops
-// AndUint32 atomically applies bitwise and operation to *addr with val.
+// AndUint32 atomically applies bitwise AND operation to *addr with val.
func AndUint32(addr *uint32, val uint32)
-// OrUint32 atomically applies bitwise or operation to *addr with val.
+// OrUint32 atomically applies bitwise OR operation to *addr with val.
func OrUint32(addr *uint32, val uint32)
-// XorUint32 atomically applies bitwise xor operation to *addr with val.
+// XorUint32 atomically applies bitwise XOR operation to *addr with val.
func XorUint32(addr *uint32, val uint32)
// CompareAndSwapUint32 is like sync/atomic.CompareAndSwapUint32, but returns
// the value previously stored at addr.
func CompareAndSwapUint32(addr *uint32, old, new uint32) uint32
-// AndUint64 atomically applies bitwise and operation to *addr with val.
+// AndUint64 atomically applies bitwise AND operation to *addr with val.
func AndUint64(addr *uint64, val uint64)
-// OrUint64 atomically applies bitwise or operation to *addr with val.
+// OrUint64 atomically applies bitwise OR operation to *addr with val.
func OrUint64(addr *uint64, val uint64)
-// XorUint64 atomically applies bitwise xor operation to *addr with val.
+// XorUint64 atomically applies bitwise XOR operation to *addr with val.
func XorUint64(addr *uint64, val uint64)
// CompareAndSwapUint64 is like sync/atomic.CompareAndSwapUint64, but returns
// the value previously stored at addr.
func CompareAndSwapUint64(addr *uint64, old, new uint64) uint64
-
-// IncUnlessZeroInt32 increments the value stored at the given address and
-// returns true; unless the value stored in the pointer is zero, in which case
-// it is left unmodified and false is returned.
-func IncUnlessZeroInt32(addr *int32) bool
-
-// DecUnlessOneInt32 decrements the value stored at the given address and
-// returns true; unless the value stored in the pointer is 1, in which case it
-// is left unmodified and false is returned.
-func DecUnlessOneInt32(addr *int32) bool
diff --git a/pkg/atomicbitops/atomic_bitops_amd64.s b/pkg/atomicbitops/atomicbitops_amd64.s
index db0972001..54c887ee5 100644
--- a/pkg/atomicbitops/atomic_bitops_amd64.s
+++ b/pkg/atomicbitops/atomicbitops_amd64.s
@@ -75,41 +75,3 @@ TEXT ·CompareAndSwapUint64(SB),$0-32
CMPXCHGQ DX, 0(DI)
MOVQ AX, ret+24(FP)
RET
-
-TEXT ·IncUnlessZeroInt32(SB),NOSPLIT,$0-9
- MOVQ addr+0(FP), DI
- MOVL 0(DI), AX
-
-retry:
- TESTL AX, AX
- JZ fail
- LEAL 1(AX), DX
- LOCK
- CMPXCHGL DX, 0(DI)
- JNZ retry
-
- SETEQ ret+8(FP)
- RET
-
-fail:
- MOVB AX, ret+8(FP)
- RET
-
-TEXT ·DecUnlessOneInt32(SB),NOSPLIT,$0-9
- MOVQ addr+0(FP), DI
- MOVL 0(DI), AX
-
-retry:
- LEAL -1(AX), DX
- TESTL DX, DX
- JZ fail
- LOCK
- CMPXCHGL DX, 0(DI)
- JNZ retry
-
- SETEQ ret+8(FP)
- RET
-
-fail:
- MOVB DX, ret+8(FP)
- RET
diff --git a/pkg/atomicbitops/atomic_bitops_arm64.s b/pkg/atomicbitops/atomicbitops_arm64.s
index 97f8808c1..5c780851b 100644
--- a/pkg/atomicbitops/atomic_bitops_arm64.s
+++ b/pkg/atomicbitops/atomicbitops_arm64.s
@@ -50,7 +50,6 @@ TEXT ·CompareAndSwapUint32(SB),$0-20
MOVD addr+0(FP), R0
MOVW old+8(FP), R1
MOVW new+12(FP), R2
-
again:
LDAXRW (R0), R3
CMPW R1, R3
@@ -95,7 +94,6 @@ TEXT ·CompareAndSwapUint64(SB),$0-32
MOVD addr+0(FP), R0
MOVD old+8(FP), R1
MOVD new+16(FP), R2
-
again:
LDAXR (R0), R3
CMP R1, R3
@@ -105,35 +103,3 @@ again:
done:
MOVD R3, prev+24(FP)
RET
-
-TEXT ·IncUnlessZeroInt32(SB),NOSPLIT,$0-9
- MOVD addr+0(FP), R0
-
-again:
- LDAXRW (R0), R1
- CBZ R1, fail
- ADDW $1, R1
- STLXRW R1, (R0), R2
- CBNZ R2, again
- MOVW $1, R2
- MOVB R2, ret+8(FP)
- RET
-fail:
- MOVB ZR, ret+8(FP)
- RET
-
-TEXT ·DecUnlessOneInt32(SB),NOSPLIT,$0-9
- MOVD addr+0(FP), R0
-
-again:
- LDAXRW (R0), R1
- SUBSW $1, R1, R1
- BEQ fail
- STLXRW R1, (R0), R2
- CBNZ R2, again
- MOVW $1, R2
- MOVB R2, ret+8(FP)
- RET
-fail:
- MOVB ZR, ret+8(FP)
- RET
diff --git a/pkg/atomicbitops/atomic_bitops_common.go b/pkg/atomicbitops/atomicbitops_noasm.go
index 85163ad62..3b2898256 100644
--- a/pkg/atomicbitops/atomic_bitops_common.go
+++ b/pkg/atomicbitops/atomicbitops_noasm.go
@@ -20,7 +20,6 @@ import (
"sync/atomic"
)
-// AndUint32 atomically applies bitwise and operation to *addr with val.
func AndUint32(addr *uint32, val uint32) {
for {
o := atomic.LoadUint32(addr)
@@ -31,7 +30,6 @@ func AndUint32(addr *uint32, val uint32) {
}
}
-// OrUint32 atomically applies bitwise or operation to *addr with val.
func OrUint32(addr *uint32, val uint32) {
for {
o := atomic.LoadUint32(addr)
@@ -42,7 +40,6 @@ func OrUint32(addr *uint32, val uint32) {
}
}
-// XorUint32 atomically applies bitwise xor operation to *addr with val.
func XorUint32(addr *uint32, val uint32) {
for {
o := atomic.LoadUint32(addr)
@@ -53,8 +50,6 @@ func XorUint32(addr *uint32, val uint32) {
}
}
-// CompareAndSwapUint32 is like sync/atomic.CompareAndSwapUint32, but returns
-// the value previously stored at addr.
func CompareAndSwapUint32(addr *uint32, old, new uint32) (prev uint32) {
for {
prev = atomic.LoadUint32(addr)
@@ -67,7 +62,6 @@ func CompareAndSwapUint32(addr *uint32, old, new uint32) (prev uint32) {
}
}
-// AndUint64 atomically applies bitwise and operation to *addr with val.
func AndUint64(addr *uint64, val uint64) {
for {
o := atomic.LoadUint64(addr)
@@ -78,7 +72,6 @@ func AndUint64(addr *uint64, val uint64) {
}
}
-// OrUint64 atomically applies bitwise or operation to *addr with val.
func OrUint64(addr *uint64, val uint64) {
for {
o := atomic.LoadUint64(addr)
@@ -89,7 +82,6 @@ func OrUint64(addr *uint64, val uint64) {
}
}
-// XorUint64 atomically applies bitwise xor operation to *addr with val.
func XorUint64(addr *uint64, val uint64) {
for {
o := atomic.LoadUint64(addr)
@@ -100,8 +92,6 @@ func XorUint64(addr *uint64, val uint64) {
}
}
-// CompareAndSwapUint64 is like sync/atomic.CompareAndSwapUint64, but returns
-// the value previously stored at addr.
func CompareAndSwapUint64(addr *uint64, old, new uint64) (prev uint64) {
for {
prev = atomic.LoadUint64(addr)
@@ -113,35 +103,3 @@ func CompareAndSwapUint64(addr *uint64, old, new uint64) (prev uint64) {
}
}
}
-
-// IncUnlessZeroInt32 increments the value stored at the given address and
-// returns true; unless the value stored in the pointer is zero, in which case
-// it is left unmodified and false is returned.
-func IncUnlessZeroInt32(addr *int32) bool {
- for {
- v := atomic.LoadInt32(addr)
- if v == 0 {
- return false
- }
-
- if atomic.CompareAndSwapInt32(addr, v, v+1) {
- return true
- }
- }
-}
-
-// DecUnlessOneInt32 decrements the value stored at the given address and
-// returns true; unless the value stored in the pointer is 1, in which case it
-// is left unmodified and false is returned.
-func DecUnlessOneInt32(addr *int32) bool {
- for {
- v := atomic.LoadInt32(addr)
- if v == 1 {
- return false
- }
-
- if atomic.CompareAndSwapInt32(addr, v, v-1) {
- return true
- }
- }
-}
diff --git a/pkg/atomicbitops/atomic_bitops_test.go b/pkg/atomicbitops/atomicbitops_test.go
index 9466d3e23..73af71bb4 100644
--- a/pkg/atomicbitops/atomic_bitops_test.go
+++ b/pkg/atomicbitops/atomicbitops_test.go
@@ -196,67 +196,3 @@ func TestCompareAndSwapUint64(t *testing.T) {
}
}
}
-
-func TestIncUnlessZeroInt32(t *testing.T) {
- for _, test := range []struct {
- initial int32
- final int32
- ret bool
- }{
- {
- initial: 0,
- final: 0,
- ret: false,
- },
- {
- initial: 1,
- final: 2,
- ret: true,
- },
- {
- initial: 2,
- final: 3,
- ret: true,
- },
- } {
- val := test.initial
- if got, want := IncUnlessZeroInt32(&val), test.ret; got != want {
- t.Errorf("For initial value of %d: incorrect return value: got %v, wanted %v", test.initial, got, want)
- }
- if got, want := val, test.final; got != want {
- t.Errorf("For initial value of %d: incorrect final value: got %d, wanted %d", test.initial, got, want)
- }
- }
-}
-
-func TestDecUnlessOneInt32(t *testing.T) {
- for _, test := range []struct {
- initial int32
- final int32
- ret bool
- }{
- {
- initial: 0,
- final: -1,
- ret: true,
- },
- {
- initial: 1,
- final: 1,
- ret: false,
- },
- {
- initial: 2,
- final: 1,
- ret: true,
- },
- } {
- val := test.initial
- if got, want := DecUnlessOneInt32(&val), test.ret; got != want {
- t.Errorf("For initial value of %d: incorrect return value: got %v, wanted %v", test.initial, got, want)
- }
- if got, want := val, test.final; got != want {
- t.Errorf("For initial value of %d: incorrect final value: got %d, wanted %d", test.initial, got, want)
- }
- }
-}
diff --git a/pkg/cpuid/cpuid_x86.go b/pkg/cpuid/cpuid_x86.go
index 333ca0a04..a0bc55ea1 100644
--- a/pkg/cpuid/cpuid_x86.go
+++ b/pkg/cpuid/cpuid_x86.go
@@ -725,6 +725,18 @@ func vendorIDFromRegs(bx, cx, dx uint32) string {
return string(bytes)
}
+var maxXsaveSize = func() uint32 {
+ // Leaf 0 of xsaveinfo function returns the size for currently
+ // enabled xsave features in ebx, the maximum size if all valid
+ // features are saved with xsave in ecx, and valid XCR0 bits in
+ // edx:eax.
+ //
+ // If xSaveInfo isn't supported, cpuid will not fault but will
+ // return bogus values.
+ _, _, maxXsaveSize, _ := HostID(uint32(xSaveInfo), 0)
+ return maxXsaveSize
+}()
+
// ExtendedStateSize returns the number of bytes needed to save the "extended
// state" for this processor and the boundary it must be aligned to. Extended
// state includes floating point registers, and other cpu state that's not
@@ -736,12 +748,7 @@ func vendorIDFromRegs(bx, cx, dx uint32) string {
// about 2.5K worst case, with avx512).
func (fs *FeatureSet) ExtendedStateSize() (size, align uint) {
if fs.UseXsave() {
- // Leaf 0 of xsaveinfo function returns the size for currently
- // enabled xsave features in ebx, the maximum size if all valid
- // features are saved with xsave in ecx, and valid XCR0 bits in
- // edx:eax.
- _, _, maxSize, _ := HostID(uint32(xSaveInfo), 0)
- return uint(maxSize), 64
+ return uint(maxXsaveSize), 64
}
// If we don't support xsave, we fall back to fxsave, which requires
diff --git a/pkg/fspath/BUILD b/pkg/fspath/BUILD
index ee84471b2..67dd1e225 100644
--- a/pkg/fspath/BUILD
+++ b/pkg/fspath/BUILD
@@ -8,9 +8,11 @@ go_library(
name = "fspath",
srcs = [
"builder.go",
- "builder_unsafe.go",
"fspath.go",
],
+ deps = [
+ "//pkg/gohacks",
+ ],
)
go_test(
diff --git a/pkg/fspath/builder.go b/pkg/fspath/builder.go
index 7ddb36826..6318d3874 100644
--- a/pkg/fspath/builder.go
+++ b/pkg/fspath/builder.go
@@ -16,6 +16,8 @@ package fspath
import (
"fmt"
+
+ "gvisor.dev/gvisor/pkg/gohacks"
)
// Builder is similar to strings.Builder, but is used to produce pathnames
@@ -102,3 +104,9 @@ func (b *Builder) AppendString(str string) {
copy(b.buf[b.start:], b.buf[oldStart:])
copy(b.buf[len(b.buf)-len(str):], str)
}
+
+// String returns the accumulated string. No other methods should be called
+// after String.
+func (b *Builder) String() string {
+ return gohacks.StringFromImmutableBytes(b.buf[b.start:])
+}
diff --git a/pkg/fspath/fspath.go b/pkg/fspath/fspath.go
index 9fb3fee24..4c983d5fd 100644
--- a/pkg/fspath/fspath.go
+++ b/pkg/fspath/fspath.go
@@ -67,7 +67,8 @@ func Parse(pathname string) Path {
// Path contains the information contained in a pathname string.
//
-// Path is copyable by value.
+// Path is copyable by value. The zero value for Path is equivalent to
+// fspath.Parse(""), i.e. the empty path.
type Path struct {
// Begin is an iterator to the first path component in the relative part of
// the path.
diff --git a/pkg/gohacks/BUILD b/pkg/gohacks/BUILD
new file mode 100644
index 000000000..798a65eca
--- /dev/null
+++ b/pkg/gohacks/BUILD
@@ -0,0 +1,11 @@
+load("//tools:defs.bzl", "go_library")
+
+package(licenses = ["notice"])
+
+go_library(
+ name = "gohacks",
+ srcs = [
+ "gohacks_unsafe.go",
+ ],
+ visibility = ["//:sandbox"],
+)
diff --git a/pkg/gohacks/gohacks_unsafe.go b/pkg/gohacks/gohacks_unsafe.go
new file mode 100644
index 000000000..aad675172
--- /dev/null
+++ b/pkg/gohacks/gohacks_unsafe.go
@@ -0,0 +1,57 @@
+// Copyright 2020 The gVisor Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+// Package gohacks contains utilities for subverting the Go compiler.
+package gohacks
+
+import (
+ "reflect"
+ "unsafe"
+)
+
+// Noescape hides a pointer from escape analysis. Noescape is the identity
+// function but escape analysis doesn't think the output depends on the input.
+// Noescape is inlined and currently compiles down to zero instructions.
+// USE CAREFULLY!
+//
+// (Noescape is copy/pasted from Go's runtime/stubs.go:noescape().)
+//
+//go:nosplit
+func Noescape(p unsafe.Pointer) unsafe.Pointer {
+ x := uintptr(p)
+ return unsafe.Pointer(x ^ 0)
+}
+
+// ImmutableBytesFromString is equivalent to []byte(s), except that it uses the
+// same memory backing s instead of making a heap-allocated copy. This is only
+// valid if the returned slice is never mutated.
+func ImmutableBytesFromString(s string) []byte {
+ shdr := (*reflect.StringHeader)(unsafe.Pointer(&s))
+ var bs []byte
+ bshdr := (*reflect.SliceHeader)(unsafe.Pointer(&bs))
+ bshdr.Data = shdr.Data
+ bshdr.Len = shdr.Len
+ bshdr.Cap = shdr.Len
+ return bs
+}
+
+// StringFromImmutableBytes is equivalent to string(bs), except that it uses
+// the same memory backing bs instead of making a heap-allocated copy. This is
+// only valid if bs is never mutated after StringFromImmutableBytes returns.
+func StringFromImmutableBytes(bs []byte) string {
+ // This is cheaper than messing with reflect.StringHeader and
+ // reflect.SliceHeader, which as of this writing produces many dead stores
+ // of zeroes. Compare strings.Builder.String().
+ return *(*string)(unsafe.Pointer(&bs))
+}
diff --git a/pkg/log/glog.go b/pkg/log/glog.go
index cab5fae55..b4f7bb5a4 100644
--- a/pkg/log/glog.go
+++ b/pkg/log/glog.go
@@ -46,7 +46,7 @@ var pid = os.Getpid()
// line The line number
// msg The user-supplied message
//
-func (g *GoogleEmitter) Emit(level Level, timestamp time.Time, format string, args ...interface{}) {
+func (g *GoogleEmitter) Emit(depth int, level Level, timestamp time.Time, format string, args ...interface{}) {
// Log level.
prefix := byte('?')
switch level {
@@ -64,9 +64,7 @@ func (g *GoogleEmitter) Emit(level Level, timestamp time.Time, format string, ar
microsecond := int(timestamp.Nanosecond() / 1000)
// 0 = this frame.
- // 1 = Debugf, etc.
- // 2 = Caller.
- _, file, line, ok := runtime.Caller(2)
+ _, file, line, ok := runtime.Caller(depth + 1)
if ok {
// Trim any directory path from the file.
slash := strings.LastIndexByte(file, byte('/'))
diff --git a/pkg/log/json.go b/pkg/log/json.go
index a278c8fc8..0943db1cc 100644
--- a/pkg/log/json.go
+++ b/pkg/log/json.go
@@ -62,7 +62,7 @@ type JSONEmitter struct {
}
// Emit implements Emitter.Emit.
-func (e JSONEmitter) Emit(level Level, timestamp time.Time, format string, v ...interface{}) {
+func (e JSONEmitter) Emit(_ int, level Level, timestamp time.Time, format string, v ...interface{}) {
j := jsonLog{
Msg: fmt.Sprintf(format, v...),
Level: level,
diff --git a/pkg/log/json_k8s.go b/pkg/log/json_k8s.go
index cee6eb514..6c6fc8b6f 100644
--- a/pkg/log/json_k8s.go
+++ b/pkg/log/json_k8s.go
@@ -33,7 +33,7 @@ type K8sJSONEmitter struct {
}
// Emit implements Emitter.Emit.
-func (e *K8sJSONEmitter) Emit(level Level, timestamp time.Time, format string, v ...interface{}) {
+func (e *K8sJSONEmitter) Emit(_ int, level Level, timestamp time.Time, format string, v ...interface{}) {
j := k8sJSONLog{
Log: fmt.Sprintf(format, v...),
Level: level,
diff --git a/pkg/log/log.go b/pkg/log/log.go
index 5056f17e6..a794da1aa 100644
--- a/pkg/log/log.go
+++ b/pkg/log/log.go
@@ -79,7 +79,7 @@ func (l Level) String() string {
type Emitter interface {
// Emit emits the given log statement. This allows for control over the
// timestamp used for logging.
- Emit(level Level, timestamp time.Time, format string, v ...interface{})
+ Emit(depth int, level Level, timestamp time.Time, format string, v ...interface{})
}
// Writer writes the output to the given writer.
@@ -142,7 +142,7 @@ func (l *Writer) Write(data []byte) (int, error) {
}
// Emit emits the message.
-func (l *Writer) Emit(level Level, timestamp time.Time, format string, args ...interface{}) {
+func (l *Writer) Emit(_ int, _ Level, _ time.Time, format string, args ...interface{}) {
fmt.Fprintf(l, format, args...)
}
@@ -150,9 +150,9 @@ func (l *Writer) Emit(level Level, timestamp time.Time, format string, args ...i
type MultiEmitter []Emitter
// Emit emits to all emitters.
-func (m *MultiEmitter) Emit(level Level, timestamp time.Time, format string, v ...interface{}) {
+func (m *MultiEmitter) Emit(depth int, level Level, timestamp time.Time, format string, v ...interface{}) {
for _, e := range *m {
- e.Emit(level, timestamp, format, v...)
+ e.Emit(1+depth, level, timestamp, format, v...)
}
}
@@ -167,7 +167,7 @@ type TestEmitter struct {
}
// Emit emits to the TestLogger.
-func (t *TestEmitter) Emit(level Level, timestamp time.Time, format string, v ...interface{}) {
+func (t *TestEmitter) Emit(_ int, level Level, timestamp time.Time, format string, v ...interface{}) {
t.Logf(format, v...)
}
@@ -198,22 +198,37 @@ type BasicLogger struct {
// Debugf implements logger.Debugf.
func (l *BasicLogger) Debugf(format string, v ...interface{}) {
- if l.IsLogging(Debug) {
- l.Emit(Debug, time.Now(), format, v...)
- }
+ l.DebugfAtDepth(1, format, v...)
}
// Infof implements logger.Infof.
func (l *BasicLogger) Infof(format string, v ...interface{}) {
- if l.IsLogging(Info) {
- l.Emit(Info, time.Now(), format, v...)
- }
+ l.InfofAtDepth(1, format, v...)
}
// Warningf implements logger.Warningf.
func (l *BasicLogger) Warningf(format string, v ...interface{}) {
+ l.WarningfAtDepth(1, format, v...)
+}
+
+// DebugfAtDepth logs at a specific depth.
+func (l *BasicLogger) DebugfAtDepth(depth int, format string, v ...interface{}) {
+ if l.IsLogging(Debug) {
+ l.Emit(1+depth, Debug, time.Now(), format, v...)
+ }
+}
+
+// InfofAtDepth logs at a specific depth.
+func (l *BasicLogger) InfofAtDepth(depth int, format string, v ...interface{}) {
+ if l.IsLogging(Info) {
+ l.Emit(1+depth, Info, time.Now(), format, v...)
+ }
+}
+
+// WarningfAtDepth logs at a specific depth.
+func (l *BasicLogger) WarningfAtDepth(depth int, format string, v ...interface{}) {
if l.IsLogging(Warning) {
- l.Emit(Warning, time.Now(), format, v...)
+ l.Emit(1+depth, Warning, time.Now(), format, v...)
}
}
@@ -257,17 +272,32 @@ func SetLevel(newLevel Level) {
// Debugf logs to the global logger.
func Debugf(format string, v ...interface{}) {
- Log().Debugf(format, v...)
+ Log().DebugfAtDepth(1, format, v...)
}
// Infof logs to the global logger.
func Infof(format string, v ...interface{}) {
- Log().Infof(format, v...)
+ Log().InfofAtDepth(1, format, v...)
}
// Warningf logs to the global logger.
func Warningf(format string, v ...interface{}) {
- Log().Warningf(format, v...)
+ Log().WarningfAtDepth(1, format, v...)
+}
+
+// DebugfAtDepth logs to the global logger.
+func DebugfAtDepth(depth int, format string, v ...interface{}) {
+ Log().DebugfAtDepth(1+depth, format, v...)
+}
+
+// InfofAtDepth logs to the global logger.
+func InfofAtDepth(depth int, format string, v ...interface{}) {
+ Log().InfofAtDepth(1+depth, format, v...)
+}
+
+// WarningfAtDepth logs to the global logger.
+func WarningfAtDepth(depth int, format string, v ...interface{}) {
+ Log().WarningfAtDepth(1+depth, format, v...)
}
// defaultStackSize is the default buffer size to allocate for stack traces.
diff --git a/pkg/sentry/arch/arch_aarch64.go b/pkg/sentry/arch/arch_aarch64.go
index 3b6987665..d7794db63 100644
--- a/pkg/sentry/arch/arch_aarch64.go
+++ b/pkg/sentry/arch/arch_aarch64.go
@@ -34,6 +34,9 @@ const (
SyscallWidth = 4
)
+// ARMTrapFlag is the mask for the trap flag.
+const ARMTrapFlag = uint64(1) << 21
+
// aarch64FPState is aarch64 floating point state.
type aarch64FPState []byte
diff --git a/pkg/sentry/fs/dev/BUILD b/pkg/sentry/fs/dev/BUILD
index 4c4b7d5cc..9b6bb26d0 100644
--- a/pkg/sentry/fs/dev/BUILD
+++ b/pkg/sentry/fs/dev/BUILD
@@ -9,6 +9,7 @@ go_library(
"device.go",
"fs.go",
"full.go",
+ "net_tun.go",
"null.go",
"random.go",
"tty.go",
@@ -19,15 +20,19 @@ go_library(
"//pkg/context",
"//pkg/rand",
"//pkg/safemem",
+ "//pkg/sentry/arch",
"//pkg/sentry/device",
"//pkg/sentry/fs",
"//pkg/sentry/fs/fsutil",
"//pkg/sentry/fs/ramfs",
"//pkg/sentry/fs/tmpfs",
+ "//pkg/sentry/kernel",
"//pkg/sentry/memmap",
"//pkg/sentry/mm",
"//pkg/sentry/pgalloc",
+ "//pkg/sentry/socket/netstack",
"//pkg/syserror",
+ "//pkg/tcpip/link/tun",
"//pkg/usermem",
"//pkg/waiter",
],
diff --git a/pkg/sentry/fs/dev/dev.go b/pkg/sentry/fs/dev/dev.go
index 35bd23991..7e66c29b0 100644
--- a/pkg/sentry/fs/dev/dev.go
+++ b/pkg/sentry/fs/dev/dev.go
@@ -66,8 +66,8 @@ func newMemDevice(ctx context.Context, iops fs.InodeOperations, msrc *fs.MountSo
})
}
-func newDirectory(ctx context.Context, msrc *fs.MountSource) *fs.Inode {
- iops := ramfs.NewDir(ctx, nil, fs.RootOwner, fs.FilePermsFromMode(0555))
+func newDirectory(ctx context.Context, contents map[string]*fs.Inode, msrc *fs.MountSource) *fs.Inode {
+ iops := ramfs.NewDir(ctx, contents, fs.RootOwner, fs.FilePermsFromMode(0555))
return fs.NewInode(ctx, iops, msrc, fs.StableAttr{
DeviceID: devDevice.DeviceID(),
InodeID: devDevice.NextIno(),
@@ -111,7 +111,7 @@ func New(ctx context.Context, msrc *fs.MountSource) *fs.Inode {
// A devpts is typically mounted at /dev/pts to provide
// pseudoterminal support. Place an empty directory there for
// the devpts to be mounted over.
- "pts": newDirectory(ctx, msrc),
+ "pts": newDirectory(ctx, nil, msrc),
// Similarly, applications expect a ptmx device at /dev/ptmx
// connected to the terminals provided by /dev/pts/. Rather
// than creating a device directly (which requires a hairy
@@ -124,6 +124,10 @@ func New(ctx context.Context, msrc *fs.MountSource) *fs.Inode {
"ptmx": newSymlink(ctx, "pts/ptmx", msrc),
"tty": newCharacterDevice(ctx, newTTYDevice(ctx, fs.RootOwner, 0666), msrc, ttyDevMajor, ttyDevMinor),
+
+ "net": newDirectory(ctx, map[string]*fs.Inode{
+ "tun": newCharacterDevice(ctx, newNetTunDevice(ctx, fs.RootOwner, 0666), msrc, netTunDevMajor, netTunDevMinor),
+ }, msrc),
}
iops := ramfs.NewDir(ctx, contents, fs.RootOwner, fs.FilePermsFromMode(0555))
diff --git a/pkg/sentry/fs/dev/net_tun.go b/pkg/sentry/fs/dev/net_tun.go
new file mode 100644
index 000000000..755644488
--- /dev/null
+++ b/pkg/sentry/fs/dev/net_tun.go
@@ -0,0 +1,170 @@
+// Copyright 2020 The gVisor Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+package dev
+
+import (
+ "gvisor.dev/gvisor/pkg/abi/linux"
+ "gvisor.dev/gvisor/pkg/context"
+ "gvisor.dev/gvisor/pkg/sentry/arch"
+ "gvisor.dev/gvisor/pkg/sentry/fs"
+ "gvisor.dev/gvisor/pkg/sentry/fs/fsutil"
+ "gvisor.dev/gvisor/pkg/sentry/kernel"
+ "gvisor.dev/gvisor/pkg/sentry/socket/netstack"
+ "gvisor.dev/gvisor/pkg/syserror"
+ "gvisor.dev/gvisor/pkg/tcpip/link/tun"
+ "gvisor.dev/gvisor/pkg/usermem"
+ "gvisor.dev/gvisor/pkg/waiter"
+)
+
+const (
+ netTunDevMajor = 10
+ netTunDevMinor = 200
+)
+
+// +stateify savable
+type netTunInodeOperations struct {
+ fsutil.InodeGenericChecker `state:"nosave"`
+ fsutil.InodeNoExtendedAttributes `state:"nosave"`
+ fsutil.InodeNoopAllocate `state:"nosave"`
+ fsutil.InodeNoopRelease `state:"nosave"`
+ fsutil.InodeNoopTruncate `state:"nosave"`
+ fsutil.InodeNoopWriteOut `state:"nosave"`
+ fsutil.InodeNotDirectory `state:"nosave"`
+ fsutil.InodeNotMappable `state:"nosave"`
+ fsutil.InodeNotSocket `state:"nosave"`
+ fsutil.InodeNotSymlink `state:"nosave"`
+ fsutil.InodeVirtual `state:"nosave"`
+
+ fsutil.InodeSimpleAttributes
+}
+
+var _ fs.InodeOperations = (*netTunInodeOperations)(nil)
+
+func newNetTunDevice(ctx context.Context, owner fs.FileOwner, mode linux.FileMode) *netTunInodeOperations {
+ return &netTunInodeOperations{
+ InodeSimpleAttributes: fsutil.NewInodeSimpleAttributes(ctx, owner, fs.FilePermsFromMode(mode), linux.TMPFS_MAGIC),
+ }
+}
+
+// GetFile implements fs.InodeOperations.GetFile.
+func (iops *netTunInodeOperations) GetFile(ctx context.Context, d *fs.Dirent, flags fs.FileFlags) (*fs.File, error) {
+ return fs.NewFile(ctx, d, flags, &netTunFileOperations{}), nil
+}
+
+// +stateify savable
+type netTunFileOperations struct {
+ fsutil.FileNoSeek `state:"nosave"`
+ fsutil.FileNoMMap `state:"nosave"`
+ fsutil.FileNoSplice `state:"nosave"`
+ fsutil.FileNoopFlush `state:"nosave"`
+ fsutil.FileNoopFsync `state:"nosave"`
+ fsutil.FileNotDirReaddir `state:"nosave"`
+ fsutil.FileUseInodeUnstableAttr `state:"nosave"`
+
+ device tun.Device
+}
+
+var _ fs.FileOperations = (*netTunFileOperations)(nil)
+
+// Release implements fs.FileOperations.Release.
+func (fops *netTunFileOperations) Release() {
+ fops.device.Release()
+}
+
+// Ioctl implements fs.FileOperations.Ioctl.
+func (fops *netTunFileOperations) Ioctl(ctx context.Context, file *fs.File, io usermem.IO, args arch.SyscallArguments) (uintptr, error) {
+ request := args[1].Uint()
+ data := args[2].Pointer()
+
+ switch request {
+ case linux.TUNSETIFF:
+ t := kernel.TaskFromContext(ctx)
+ if t == nil {
+ panic("Ioctl should be called from a task context")
+ }
+ if !t.HasCapability(linux.CAP_NET_ADMIN) {
+ return 0, syserror.EPERM
+ }
+ stack, ok := t.NetworkContext().(*netstack.Stack)
+ if !ok {
+ return 0, syserror.EINVAL
+ }
+
+ var req linux.IFReq
+ if _, err := usermem.CopyObjectIn(ctx, io, data, &req, usermem.IOOpts{
+ AddressSpaceActive: true,
+ }); err != nil {
+ return 0, err
+ }
+ flags := usermem.ByteOrder.Uint16(req.Data[:])
+ return 0, fops.device.SetIff(stack.Stack, req.Name(), flags)
+
+ case linux.TUNGETIFF:
+ var req linux.IFReq
+
+ copy(req.IFName[:], fops.device.Name())
+
+ // Linux adds IFF_NOFILTER (the same value as IFF_NO_PI unfortunately) when
+ // there is no sk_filter. See __tun_chr_ioctl() in net/drivers/tun.c.
+ flags := fops.device.Flags() | linux.IFF_NOFILTER
+ usermem.ByteOrder.PutUint16(req.Data[:], flags)
+
+ _, err := usermem.CopyObjectOut(ctx, io, data, &req, usermem.IOOpts{
+ AddressSpaceActive: true,
+ })
+ return 0, err
+
+ default:
+ return 0, syserror.ENOTTY
+ }
+}
+
+// Write implements fs.FileOperations.Write.
+func (fops *netTunFileOperations) Write(ctx context.Context, file *fs.File, src usermem.IOSequence, offset int64) (int64, error) {
+ data := make([]byte, src.NumBytes())
+ if _, err := src.CopyIn(ctx, data); err != nil {
+ return 0, err
+ }
+ return fops.device.Write(data)
+}
+
+// Read implements fs.FileOperations.Read.
+func (fops *netTunFileOperations) Read(ctx context.Context, file *fs.File, dst usermem.IOSequence, offset int64) (int64, error) {
+ data, err := fops.device.Read()
+ if err != nil {
+ return 0, err
+ }
+ n, err := dst.CopyOut(ctx, data)
+ if n > 0 && n < len(data) {
+ // Not an error for partial copying. Packet truncated.
+ err = nil
+ }
+ return int64(n), err
+}
+
+// Readiness implements watier.Waitable.Readiness.
+func (fops *netTunFileOperations) Readiness(mask waiter.EventMask) waiter.EventMask {
+ return fops.device.Readiness(mask)
+}
+
+// EventRegister implements watier.Waitable.EventRegister.
+func (fops *netTunFileOperations) EventRegister(e *waiter.Entry, mask waiter.EventMask) {
+ fops.device.EventRegister(e, mask)
+}
+
+// EventUnregister implements watier.Waitable.EventUnregister.
+func (fops *netTunFileOperations) EventUnregister(e *waiter.Entry) {
+ fops.device.EventUnregister(e)
+}
diff --git a/pkg/sentry/fs/mount_test.go b/pkg/sentry/fs/mount_test.go
index e672a438c..a3d10770b 100644
--- a/pkg/sentry/fs/mount_test.go
+++ b/pkg/sentry/fs/mount_test.go
@@ -36,11 +36,12 @@ func mountPathsAre(root *Dirent, got []*Mount, want ...string) error {
gotPaths := make(map[string]struct{}, len(got))
gotStr := make([]string, len(got))
for i, g := range got {
- groot := g.Root()
- name, _ := groot.FullName(root)
- groot.DecRef()
- gotStr[i] = name
- gotPaths[name] = struct{}{}
+ if groot := g.Root(); groot != nil {
+ name, _ := groot.FullName(root)
+ groot.DecRef()
+ gotStr[i] = name
+ gotPaths[name] = struct{}{}
+ }
}
if len(got) != len(want) {
return fmt.Errorf("mount paths are different, got: %q, want: %q", gotStr, want)
diff --git a/pkg/sentry/fs/mounts.go b/pkg/sentry/fs/mounts.go
index 574a2cc91..c7981f66e 100644
--- a/pkg/sentry/fs/mounts.go
+++ b/pkg/sentry/fs/mounts.go
@@ -100,10 +100,14 @@ func newUndoMount(d *Dirent) *Mount {
}
}
-// Root returns the root dirent of this mount. Callers must call DecRef on the
-// returned dirent.
+// Root returns the root dirent of this mount.
+//
+// This may return nil if the mount has already been free. Callers must handle this
+// case appropriately. If non-nil, callers must call DecRef on the returned *Dirent.
func (m *Mount) Root() *Dirent {
- m.root.IncRef()
+ if !m.root.TryIncRef() {
+ return nil
+ }
return m.root
}
diff --git a/pkg/sentry/fs/proc/mounts.go b/pkg/sentry/fs/proc/mounts.go
index c10888100..94deb553b 100644
--- a/pkg/sentry/fs/proc/mounts.go
+++ b/pkg/sentry/fs/proc/mounts.go
@@ -60,13 +60,15 @@ func forEachMount(t *kernel.Task, fn func(string, *fs.Mount)) {
})
for _, m := range ms {
mroot := m.Root()
+ if mroot == nil {
+ continue // No longer valid.
+ }
mountPath, desc := mroot.FullName(rootDir)
mroot.DecRef()
if !desc {
// MountSources that are not descendants of the chroot jail are ignored.
continue
}
-
fn(mountPath, m)
}
}
@@ -91,6 +93,12 @@ func (mif *mountInfoFile) ReadSeqFileData(ctx context.Context, handle seqfile.Se
var buf bytes.Buffer
forEachMount(mif.t, func(mountPath string, m *fs.Mount) {
+ mroot := m.Root()
+ if mroot == nil {
+ return // No longer valid.
+ }
+ defer mroot.DecRef()
+
// Format:
// 36 35 98:0 /mnt1 /mnt2 rw,noatime master:1 - ext3 /dev/root rw,errors=continue
// (1)(2)(3) (4) (5) (6) (7) (8) (9) (10) (11)
@@ -107,9 +115,6 @@ func (mif *mountInfoFile) ReadSeqFileData(ctx context.Context, handle seqfile.Se
// (3) Major:Minor device ID. We don't have a superblock, so we
// just use the root inode device number.
- mroot := m.Root()
- defer mroot.DecRef()
-
sa := mroot.Inode.StableAttr
fmt.Fprintf(&buf, "%d:%d ", sa.DeviceFileMajor, sa.DeviceFileMinor)
@@ -207,6 +212,9 @@ func (mf *mountsFile) ReadSeqFileData(ctx context.Context, handle seqfile.SeqHan
//
// The "needs dump"and fsck flags are always 0, which is allowed.
root := m.Root()
+ if root == nil {
+ return // No longer valid.
+ }
defer root.DecRef()
flags := root.Inode.MountSource.Flags
diff --git a/pkg/sentry/fs/proc/net.go b/pkg/sentry/fs/proc/net.go
index 6f2775344..95d5817ff 100644
--- a/pkg/sentry/fs/proc/net.go
+++ b/pkg/sentry/fs/proc/net.go
@@ -43,7 +43,10 @@ import (
// newNet creates a new proc net entry.
func (p *proc) newNetDir(ctx context.Context, k *kernel.Kernel, msrc *fs.MountSource) *fs.Inode {
var contents map[string]*fs.Inode
- if s := p.k.NetworkStack(); s != nil {
+ // TODO(gvisor.dev/issue/1833): Support for using the network stack in the
+ // network namespace of the calling process. We should make this per-process,
+ // a.k.a. /proc/PID/net, and make /proc/net a symlink to /proc/self/net.
+ if s := p.k.RootNetworkNamespace().Stack(); s != nil {
contents = map[string]*fs.Inode{
"dev": seqfile.NewSeqFileInode(ctx, &netDev{s: s}, msrc),
"snmp": seqfile.NewSeqFileInode(ctx, &netSnmp{s: s}, msrc),
diff --git a/pkg/sentry/fs/proc/sys_net.go b/pkg/sentry/fs/proc/sys_net.go
index 0772d4ae4..d4c4b533d 100644
--- a/pkg/sentry/fs/proc/sys_net.go
+++ b/pkg/sentry/fs/proc/sys_net.go
@@ -357,7 +357,9 @@ func (p *proc) newSysNetIPv4Dir(ctx context.Context, msrc *fs.MountSource, s ine
func (p *proc) newSysNetDir(ctx context.Context, msrc *fs.MountSource) *fs.Inode {
var contents map[string]*fs.Inode
- if s := p.k.NetworkStack(); s != nil {
+ // TODO(gvisor.dev/issue/1833): Support for using the network stack in the
+ // network namespace of the calling process.
+ if s := p.k.RootNetworkNamespace().Stack(); s != nil {
contents = map[string]*fs.Inode{
"ipv4": p.newSysNetIPv4Dir(ctx, msrc, s),
"core": p.newSysNetCore(ctx, msrc, s),
diff --git a/pkg/sentry/fsbridge/vfs.go b/pkg/sentry/fsbridge/vfs.go
index e657c39bc..6aa17bfc1 100644
--- a/pkg/sentry/fsbridge/vfs.go
+++ b/pkg/sentry/fsbridge/vfs.go
@@ -117,15 +117,19 @@ func NewVFSLookup(mntns *vfs.MountNamespace, root, workingDir vfs.VirtualDentry)
// default anyways.
//
// TODO(gvisor.dev/issue/1623): Check mount has read and exec permission.
-func (l *vfsLookup) OpenPath(ctx context.Context, path string, opts vfs.OpenOptions, _ *uint, resolveFinal bool) (File, error) {
+func (l *vfsLookup) OpenPath(ctx context.Context, pathname string, opts vfs.OpenOptions, _ *uint, resolveFinal bool) (File, error) {
vfsObj := l.mntns.Root().Mount().Filesystem().VirtualFilesystem()
creds := auth.CredentialsFromContext(ctx)
+ path := fspath.Parse(pathname)
pop := &vfs.PathOperation{
Root: l.root,
- Start: l.root,
- Path: fspath.Parse(path),
+ Start: l.workingDir,
+ Path: path,
FollowFinalSymlink: resolveFinal,
}
+ if path.Absolute {
+ pop.Start = l.root
+ }
fd, err := vfsObj.OpenAt(ctx, creds, pop, &opts)
if err != nil {
return nil, err
diff --git a/pkg/sentry/fsimpl/proc/tasks.go b/pkg/sentry/fsimpl/proc/tasks.go
index ce08a7d53..10c08fa90 100644
--- a/pkg/sentry/fsimpl/proc/tasks.go
+++ b/pkg/sentry/fsimpl/proc/tasks.go
@@ -73,9 +73,9 @@ func newTasksInode(inoGen InoGenerator, k *kernel.Kernel, pidns *kernel.PIDNames
"meminfo": newDentry(root, inoGen.NextIno(), 0444, &meminfoData{}),
"mounts": kernfs.NewStaticSymlink(root, inoGen.NextIno(), "self/mounts"),
"net": newNetDir(root, inoGen, k),
- "stat": newDentry(root, inoGen.NextIno(), 0444, &statData{}),
+ "stat": newDentry(root, inoGen.NextIno(), 0444, &statData{k: k}),
"uptime": newDentry(root, inoGen.NextIno(), 0444, &uptimeData{}),
- "version": newDentry(root, inoGen.NextIno(), 0444, &versionData{}),
+ "version": newDentry(root, inoGen.NextIno(), 0444, &versionData{k: k}),
}
inode := &tasksInode{
diff --git a/pkg/sentry/fsimpl/proc/tasks_net.go b/pkg/sentry/fsimpl/proc/tasks_net.go
index 608fec017..d4e1812d8 100644
--- a/pkg/sentry/fsimpl/proc/tasks_net.go
+++ b/pkg/sentry/fsimpl/proc/tasks_net.go
@@ -39,7 +39,10 @@ import (
func newNetDir(root *auth.Credentials, inoGen InoGenerator, k *kernel.Kernel) *kernfs.Dentry {
var contents map[string]*kernfs.Dentry
- if stack := k.NetworkStack(); stack != nil {
+ // TODO(gvisor.dev/issue/1833): Support for using the network stack in the
+ // network namespace of the calling process. We should make this per-process,
+ // a.k.a. /proc/PID/net, and make /proc/net a symlink to /proc/self/net.
+ if stack := k.RootNetworkNamespace().Stack(); stack != nil {
const (
arp = "IP address HW type Flags HW address Mask Device\n"
netlink = "sk Eth Pid Groups Rmem Wmem Dump Locks Drops Inode\n"
diff --git a/pkg/sentry/fsimpl/proc/tasks_sys.go b/pkg/sentry/fsimpl/proc/tasks_sys.go
index c7ce74883..3d5dc463c 100644
--- a/pkg/sentry/fsimpl/proc/tasks_sys.go
+++ b/pkg/sentry/fsimpl/proc/tasks_sys.go
@@ -50,7 +50,9 @@ func newSysDir(root *auth.Credentials, inoGen InoGenerator, k *kernel.Kernel) *k
func newSysNetDir(root *auth.Credentials, inoGen InoGenerator, k *kernel.Kernel) *kernfs.Dentry {
var contents map[string]*kernfs.Dentry
- if stack := k.NetworkStack(); stack != nil {
+ // TODO(gvisor.dev/issue/1833): Support for using the network stack in the
+ // network namespace of the calling process.
+ if stack := k.RootNetworkNamespace().Stack(); stack != nil {
contents = map[string]*kernfs.Dentry{
"ipv4": kernfs.NewStaticDir(root, inoGen.NextIno(), 0555, map[string]*kernfs.Dentry{
"tcp_sack": newDentry(root, inoGen.NextIno(), 0644, &tcpSackData{stack: stack}),
diff --git a/pkg/sentry/fsimpl/testutil/kernel.go b/pkg/sentry/fsimpl/testutil/kernel.go
index d0be32e72..488478e29 100644
--- a/pkg/sentry/fsimpl/testutil/kernel.go
+++ b/pkg/sentry/fsimpl/testutil/kernel.go
@@ -128,6 +128,7 @@ func CreateTask(ctx context.Context, name string, tc *kernel.ThreadGroup, mntns
ThreadGroup: tc,
TaskContext: &kernel.TaskContext{Name: name},
Credentials: auth.CredentialsFromContext(ctx),
+ NetworkNamespace: k.RootNetworkNamespace(),
AllowedCPUMask: sched.NewFullCPUSet(k.ApplicationCores()),
UTSNamespace: kernel.UTSNamespaceFromContext(ctx),
IPCNamespace: kernel.IPCNamespaceFromContext(ctx),
diff --git a/pkg/sentry/fsimpl/tmpfs/filesystem.go b/pkg/sentry/fsimpl/tmpfs/filesystem.go
index 7f7b791c4..e1b551422 100644
--- a/pkg/sentry/fsimpl/tmpfs/filesystem.go
+++ b/pkg/sentry/fsimpl/tmpfs/filesystem.go
@@ -16,7 +16,6 @@ package tmpfs
import (
"fmt"
- "sync/atomic"
"gvisor.dev/gvisor/pkg/abi/linux"
"gvisor.dev/gvisor/pkg/context"
@@ -347,10 +346,9 @@ func (d *dentry) open(ctx context.Context, rp *vfs.ResolvingPath, opts *vfs.Open
return nil, err
}
if opts.Flags&linux.O_TRUNC != 0 {
- impl.mu.Lock()
- impl.data.Truncate(0, impl.memFile)
- atomic.StoreUint64(&impl.size, 0)
- impl.mu.Unlock()
+ if _, err := impl.truncate(0); err != nil {
+ return nil, err
+ }
}
return &fd.vfsfd, nil
case *directory:
diff --git a/pkg/sentry/fsimpl/tmpfs/regular_file.go b/pkg/sentry/fsimpl/tmpfs/regular_file.go
index dab346a41..711442424 100644
--- a/pkg/sentry/fsimpl/tmpfs/regular_file.go
+++ b/pkg/sentry/fsimpl/tmpfs/regular_file.go
@@ -15,6 +15,7 @@
package tmpfs
import (
+ "fmt"
"io"
"math"
"sync/atomic"
@@ -22,6 +23,7 @@ import (
"gvisor.dev/gvisor/pkg/abi/linux"
"gvisor.dev/gvisor/pkg/context"
"gvisor.dev/gvisor/pkg/safemem"
+ "gvisor.dev/gvisor/pkg/sentry/fs"
"gvisor.dev/gvisor/pkg/sentry/fs/fsutil"
"gvisor.dev/gvisor/pkg/sentry/fs/lock"
"gvisor.dev/gvisor/pkg/sentry/kernel/auth"
@@ -34,25 +36,53 @@ import (
"gvisor.dev/gvisor/pkg/usermem"
)
+// regularFile is a regular (=S_IFREG) tmpfs file.
type regularFile struct {
inode inode
// memFile is a platform.File used to allocate pages to this regularFile.
memFile *pgalloc.MemoryFile
- // mu protects the fields below.
- mu sync.RWMutex
+ // mapsMu protects mappings.
+ mapsMu sync.Mutex `state:"nosave"`
+
+ // mappings tracks mappings of the file into memmap.MappingSpaces.
+ //
+ // Protected by mapsMu.
+ mappings memmap.MappingSet
+
+ // writableMappingPages tracks how many pages of virtual memory are mapped
+ // as potentially writable from this file. If a page has multiple mappings,
+ // each mapping is counted separately.
+ //
+ // This counter is susceptible to overflow as we can potentially count
+ // mappings from many VMAs. We count pages rather than bytes to slightly
+ // mitigate this.
+ //
+ // Protected by mapsMu.
+ writableMappingPages uint64
+
+ // dataMu protects the fields below.
+ dataMu sync.RWMutex
// data maps offsets into the file to offsets into memFile that store
// the file's data.
+ //
+ // Protected by dataMu.
data fsutil.FileRangeSet
- // size is the size of data, but accessed using atomic memory
- // operations to avoid locking in inode.stat().
- size uint64
-
// seals represents file seals on this inode.
+ //
+ // Protected by dataMu.
seals uint32
+
+ // size is the size of data.
+ //
+ // Protected by both dataMu and inode.mu; reading it requires holding
+ // either mutex, while writing requires holding both AND using atomics.
+ // Readers that do not require consistency (like Stat) may read the
+ // value atomically without holding either lock.
+ size uint64
}
func (fs *filesystem) newRegularFile(creds *auth.Credentials, mode linux.FileMode) *inode {
@@ -66,39 +96,170 @@ func (fs *filesystem) newRegularFile(creds *auth.Credentials, mode linux.FileMod
// truncate grows or shrinks the file to the given size. It returns true if the
// file size was updated.
-func (rf *regularFile) truncate(size uint64) (bool, error) {
- rf.mu.Lock()
- defer rf.mu.Unlock()
+func (rf *regularFile) truncate(newSize uint64) (bool, error) {
+ rf.inode.mu.Lock()
+ defer rf.inode.mu.Unlock()
+ return rf.truncateLocked(newSize)
+}
- if size == rf.size {
+// Preconditions: rf.inode.mu must be held.
+func (rf *regularFile) truncateLocked(newSize uint64) (bool, error) {
+ oldSize := rf.size
+ if newSize == oldSize {
// Nothing to do.
return false, nil
}
- if size > rf.size {
- // Growing the file.
+ // Need to hold inode.mu and dataMu while modifying size.
+ rf.dataMu.Lock()
+ if newSize > oldSize {
+ // Can we grow the file?
if rf.seals&linux.F_SEAL_GROW != 0 {
- // Seal does not allow growth.
+ rf.dataMu.Unlock()
return false, syserror.EPERM
}
- rf.size = size
+ // We only need to update the file size.
+ atomic.StoreUint64(&rf.size, newSize)
+ rf.dataMu.Unlock()
return true, nil
}
- // Shrinking the file
+ // We are shrinking the file. First check if this is allowed.
if rf.seals&linux.F_SEAL_SHRINK != 0 {
- // Seal does not allow shrink.
+ rf.dataMu.Unlock()
return false, syserror.EPERM
}
- // TODO(gvisor.dev/issues/1197): Invalidate mappings once we have
- // mappings.
+ // Update the file size.
+ atomic.StoreUint64(&rf.size, newSize)
+ rf.dataMu.Unlock()
+
+ // Invalidate past translations of truncated pages.
+ oldpgend := fs.OffsetPageEnd(int64(oldSize))
+ newpgend := fs.OffsetPageEnd(int64(newSize))
+ if newpgend < oldpgend {
+ rf.mapsMu.Lock()
+ rf.mappings.Invalidate(memmap.MappableRange{newpgend, oldpgend}, memmap.InvalidateOpts{
+ // Compare Linux's mm/shmem.c:shmem_setattr() =>
+ // mm/memory.c:unmap_mapping_range(evencows=1).
+ InvalidatePrivate: true,
+ })
+ rf.mapsMu.Unlock()
+ }
- rf.data.Truncate(size, rf.memFile)
- rf.size = size
+ // We are now guaranteed that there are no translations of truncated pages,
+ // and can remove them.
+ rf.dataMu.Lock()
+ rf.data.Truncate(newSize, rf.memFile)
+ rf.dataMu.Unlock()
return true, nil
}
+// AddMapping implements memmap.Mappable.AddMapping.
+func (rf *regularFile) AddMapping(ctx context.Context, ms memmap.MappingSpace, ar usermem.AddrRange, offset uint64, writable bool) error {
+ rf.mapsMu.Lock()
+ defer rf.mapsMu.Unlock()
+ rf.dataMu.RLock()
+ defer rf.dataMu.RUnlock()
+
+ // Reject writable mapping if F_SEAL_WRITE is set.
+ if rf.seals&linux.F_SEAL_WRITE != 0 && writable {
+ return syserror.EPERM
+ }
+
+ rf.mappings.AddMapping(ms, ar, offset, writable)
+ if writable {
+ pagesBefore := rf.writableMappingPages
+
+ // ar is guaranteed to be page aligned per memmap.Mappable.
+ rf.writableMappingPages += uint64(ar.Length() / usermem.PageSize)
+
+ if rf.writableMappingPages < pagesBefore {
+ panic(fmt.Sprintf("Overflow while mapping potentially writable pages pointing to a tmpfs file. Before %v, after %v", pagesBefore, rf.writableMappingPages))
+ }
+ }
+
+ return nil
+}
+
+// RemoveMapping implements memmap.Mappable.RemoveMapping.
+func (rf *regularFile) RemoveMapping(ctx context.Context, ms memmap.MappingSpace, ar usermem.AddrRange, offset uint64, writable bool) {
+ rf.mapsMu.Lock()
+ defer rf.mapsMu.Unlock()
+
+ rf.mappings.RemoveMapping(ms, ar, offset, writable)
+
+ if writable {
+ pagesBefore := rf.writableMappingPages
+
+ // ar is guaranteed to be page aligned per memmap.Mappable.
+ rf.writableMappingPages -= uint64(ar.Length() / usermem.PageSize)
+
+ if rf.writableMappingPages > pagesBefore {
+ panic(fmt.Sprintf("Underflow while unmapping potentially writable pages pointing to a tmpfs file. Before %v, after %v", pagesBefore, rf.writableMappingPages))
+ }
+ }
+}
+
+// CopyMapping implements memmap.Mappable.CopyMapping.
+func (rf *regularFile) CopyMapping(ctx context.Context, ms memmap.MappingSpace, srcAR, dstAR usermem.AddrRange, offset uint64, writable bool) error {
+ return rf.AddMapping(ctx, ms, dstAR, offset, writable)
+}
+
+// Translate implements memmap.Mappable.Translate.
+func (rf *regularFile) Translate(ctx context.Context, required, optional memmap.MappableRange, at usermem.AccessType) ([]memmap.Translation, error) {
+ rf.dataMu.Lock()
+ defer rf.dataMu.Unlock()
+
+ // Constrain translations to f.attr.Size (rounded up) to prevent
+ // translation to pages that may be concurrently truncated.
+ pgend := fs.OffsetPageEnd(int64(rf.size))
+ var beyondEOF bool
+ if required.End > pgend {
+ if required.Start >= pgend {
+ return nil, &memmap.BusError{io.EOF}
+ }
+ beyondEOF = true
+ required.End = pgend
+ }
+ if optional.End > pgend {
+ optional.End = pgend
+ }
+
+ cerr := rf.data.Fill(ctx, required, optional, rf.memFile, usage.Tmpfs, func(_ context.Context, dsts safemem.BlockSeq, _ uint64) (uint64, error) {
+ // Newly-allocated pages are zeroed, so we don't need to do anything.
+ return dsts.NumBytes(), nil
+ })
+
+ var ts []memmap.Translation
+ var translatedEnd uint64
+ for seg := rf.data.FindSegment(required.Start); seg.Ok() && seg.Start() < required.End; seg, _ = seg.NextNonEmpty() {
+ segMR := seg.Range().Intersect(optional)
+ ts = append(ts, memmap.Translation{
+ Source: segMR,
+ File: rf.memFile,
+ Offset: seg.FileRangeOf(segMR).Start,
+ Perms: usermem.AnyAccess,
+ })
+ translatedEnd = segMR.End
+ }
+
+ // Don't return the error returned by f.data.Fill if it occurred outside of
+ // required.
+ if translatedEnd < required.End && cerr != nil {
+ return ts, &memmap.BusError{cerr}
+ }
+ if beyondEOF {
+ return ts, &memmap.BusError{io.EOF}
+ }
+ return ts, nil
+}
+
+// InvalidateUnsavable implements memmap.Mappable.InvalidateUnsavable.
+func (*regularFile) InvalidateUnsavable(context.Context) error {
+ return nil
+}
+
type regularFileFD struct {
fileDescription
@@ -152,8 +313,10 @@ func (fd *regularFileFD) PWrite(ctx context.Context, src usermem.IOSequence, off
// Overflow.
return 0, syserror.EFBIG
}
+ f.inode.mu.Lock()
rw := getRegularFileReadWriter(f, offset)
n, err := src.CopyInTo(ctx, rw)
+ f.inode.mu.Unlock()
putRegularFileReadWriter(rw)
return n, err
}
@@ -215,6 +378,12 @@ func (fd *regularFileFD) UnlockPOSIX(ctx context.Context, uid lock.UniqueID, rng
return nil
}
+// ConfigureMMap implements vfs.FileDescriptionImpl.ConfigureMMap.
+func (fd *regularFileFD) ConfigureMMap(ctx context.Context, opts *memmap.MMapOpts) error {
+ file := fd.inode().impl.(*regularFile)
+ return vfs.GenericConfigureMMap(&fd.vfsfd, file, opts)
+}
+
// regularFileReadWriter implements safemem.Reader and Safemem.Writer.
type regularFileReadWriter struct {
file *regularFile
@@ -244,14 +413,15 @@ func putRegularFileReadWriter(rw *regularFileReadWriter) {
// ReadToBlocks implements safemem.Reader.ReadToBlocks.
func (rw *regularFileReadWriter) ReadToBlocks(dsts safemem.BlockSeq) (uint64, error) {
- rw.file.mu.RLock()
+ rw.file.dataMu.RLock()
+ defer rw.file.dataMu.RUnlock()
+ size := rw.file.size
// Compute the range to read (limited by file size and overflow-checked).
- if rw.off >= rw.file.size {
- rw.file.mu.RUnlock()
+ if rw.off >= size {
return 0, io.EOF
}
- end := rw.file.size
+ end := size
if rend := rw.off + dsts.NumBytes(); rend > rw.off && rend < end {
end = rend
}
@@ -265,7 +435,6 @@ func (rw *regularFileReadWriter) ReadToBlocks(dsts safemem.BlockSeq) (uint64, er
// Get internal mappings.
ims, err := rw.file.memFile.MapInternal(seg.FileRangeOf(seg.Range().Intersect(mr)), usermem.Read)
if err != nil {
- rw.file.mu.RUnlock()
return done, err
}
@@ -275,7 +444,6 @@ func (rw *regularFileReadWriter) ReadToBlocks(dsts safemem.BlockSeq) (uint64, er
rw.off += uint64(n)
dsts = dsts.DropFirst64(n)
if err != nil {
- rw.file.mu.RUnlock()
return done, err
}
@@ -291,7 +459,6 @@ func (rw *regularFileReadWriter) ReadToBlocks(dsts safemem.BlockSeq) (uint64, er
rw.off += uint64(n)
dsts = dsts.DropFirst64(n)
if err != nil {
- rw.file.mu.RUnlock()
return done, err
}
@@ -299,13 +466,16 @@ func (rw *regularFileReadWriter) ReadToBlocks(dsts safemem.BlockSeq) (uint64, er
seg, gap = gap.NextSegment(), fsutil.FileRangeGapIterator{}
}
}
- rw.file.mu.RUnlock()
return done, nil
}
// WriteFromBlocks implements safemem.Writer.WriteFromBlocks.
+//
+// Preconditions: inode.mu must be held.
func (rw *regularFileReadWriter) WriteFromBlocks(srcs safemem.BlockSeq) (uint64, error) {
- rw.file.mu.Lock()
+ // Hold dataMu so we can modify size.
+ rw.file.dataMu.Lock()
+ defer rw.file.dataMu.Unlock()
// Compute the range to write (overflow-checked).
end := rw.off + srcs.NumBytes()
@@ -316,7 +486,6 @@ func (rw *regularFileReadWriter) WriteFromBlocks(srcs safemem.BlockSeq) (uint64,
// Check if seals prevent either file growth or all writes.
switch {
case rw.file.seals&linux.F_SEAL_WRITE != 0: // Write sealed
- rw.file.mu.Unlock()
return 0, syserror.EPERM
case end > rw.file.size && rw.file.seals&linux.F_SEAL_GROW != 0: // Grow sealed
// When growth is sealed, Linux effectively allows writes which would
@@ -338,7 +507,6 @@ func (rw *regularFileReadWriter) WriteFromBlocks(srcs safemem.BlockSeq) (uint64,
}
if end <= rw.off {
// Truncation would result in no data being written.
- rw.file.mu.Unlock()
return 0, syserror.EPERM
}
}
@@ -395,9 +563,8 @@ exitLoop:
// If the write ends beyond the file's previous size, it causes the
// file to grow.
if rw.off > rw.file.size {
- atomic.StoreUint64(&rw.file.size, rw.off)
+ rw.file.size = rw.off
}
- rw.file.mu.Unlock()
return done, retErr
}
diff --git a/pkg/sentry/fsimpl/tmpfs/tmpfs.go b/pkg/sentry/fsimpl/tmpfs/tmpfs.go
index c5bb17562..521206305 100644
--- a/pkg/sentry/fsimpl/tmpfs/tmpfs.go
+++ b/pkg/sentry/fsimpl/tmpfs/tmpfs.go
@@ -18,9 +18,10 @@
// Lock order:
//
// filesystem.mu
-// regularFileFD.offMu
-// regularFile.mu
// inode.mu
+// regularFileFD.offMu
+// regularFile.mapsMu
+// regularFile.dataMu
package tmpfs
import (
@@ -226,12 +227,15 @@ func (i *inode) tryIncRef() bool {
func (i *inode) decRef() {
if refs := atomic.AddInt64(&i.refs, -1); refs == 0 {
- // This is unnecessary; it's mostly to simulate what tmpfs would do.
if regFile, ok := i.impl.(*regularFile); ok {
- regFile.mu.Lock()
+ // Hold inode.mu and regFile.dataMu while mutating
+ // size.
+ i.mu.Lock()
+ regFile.dataMu.Lock()
regFile.data.DropAll(regFile.memFile)
atomic.StoreUint64(&regFile.size, 0)
- regFile.mu.Unlock()
+ regFile.dataMu.Unlock()
+ i.mu.Unlock()
}
} else if refs < 0 {
panic("tmpfs.inode.decRef() called without holding a reference")
@@ -320,7 +324,7 @@ func (i *inode) setStat(stat linux.Statx) error {
if mask&linux.STATX_SIZE != 0 {
switch impl := i.impl.(type) {
case *regularFile:
- updated, err := impl.truncate(stat.Size)
+ updated, err := impl.truncateLocked(stat.Size)
if err != nil {
return err
}
diff --git a/pkg/sentry/inet/BUILD b/pkg/sentry/inet/BUILD
index 334432abf..07bf39fed 100644
--- a/pkg/sentry/inet/BUILD
+++ b/pkg/sentry/inet/BUILD
@@ -10,6 +10,7 @@ go_library(
srcs = [
"context.go",
"inet.go",
+ "namespace.go",
"test_stack.go",
],
deps = [
diff --git a/pkg/sentry/inet/namespace.go b/pkg/sentry/inet/namespace.go
new file mode 100644
index 000000000..c16667e7f
--- /dev/null
+++ b/pkg/sentry/inet/namespace.go
@@ -0,0 +1,99 @@
+// Copyright 2020 The gVisor Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+package inet
+
+// Namespace represents a network namespace. See network_namespaces(7).
+//
+// +stateify savable
+type Namespace struct {
+ // stack is the network stack implementation of this network namespace.
+ stack Stack `state:"nosave"`
+
+ // creator allows kernel to create new network stack for network namespaces.
+ // If nil, no networking will function if network is namespaced.
+ creator NetworkStackCreator
+
+ // isRoot indicates whether this is the root network namespace.
+ isRoot bool
+}
+
+// NewRootNamespace creates the root network namespace, with creator
+// allowing new network namespaces to be created. If creator is nil, no
+// networking will function if the network is namespaced.
+func NewRootNamespace(stack Stack, creator NetworkStackCreator) *Namespace {
+ return &Namespace{
+ stack: stack,
+ creator: creator,
+ isRoot: true,
+ }
+}
+
+// NewNamespace creates a new network namespace from the root.
+func NewNamespace(root *Namespace) *Namespace {
+ n := &Namespace{
+ creator: root.creator,
+ }
+ n.init()
+ return n
+}
+
+// Stack returns the network stack of n. Stack may return nil if no network
+// stack is configured.
+func (n *Namespace) Stack() Stack {
+ return n.stack
+}
+
+// IsRoot returns whether n is the root network namespace.
+func (n *Namespace) IsRoot() bool {
+ return n.isRoot
+}
+
+// RestoreRootStack restores the root network namespace with stack. This should
+// only be called when restoring kernel.
+func (n *Namespace) RestoreRootStack(stack Stack) {
+ if !n.isRoot {
+ panic("RestoreRootStack can only be called on root network namespace")
+ }
+ if n.stack != nil {
+ panic("RestoreRootStack called after a stack has already been set")
+ }
+ n.stack = stack
+}
+
+func (n *Namespace) init() {
+ // Root network namespace will have stack assigned later.
+ if n.isRoot {
+ return
+ }
+ if n.creator != nil {
+ var err error
+ n.stack, err = n.creator.CreateStack()
+ if err != nil {
+ panic(err)
+ }
+ }
+}
+
+// afterLoad is invoked by stateify.
+func (n *Namespace) afterLoad() {
+ n.init()
+}
+
+// NetworkStackCreator allows new instances of a network stack to be created. It
+// is used by the kernel to create new network namespaces when requested.
+type NetworkStackCreator interface {
+ // CreateStack creates a new network stack for a network namespace.
+ CreateStack() (Stack, error)
+}
diff --git a/pkg/sentry/kernel/fd_table.go b/pkg/sentry/kernel/fd_table.go
index 23b88f7a6..58001d56c 100644
--- a/pkg/sentry/kernel/fd_table.go
+++ b/pkg/sentry/kernel/fd_table.go
@@ -296,6 +296,50 @@ func (f *FDTable) NewFDs(ctx context.Context, fd int32, files []*fs.File, flags
return fds, nil
}
+// NewFDVFS2 allocates a file descriptor greater than or equal to minfd for
+// the given file description. If it succeeds, it takes a reference on file.
+func (f *FDTable) NewFDVFS2(ctx context.Context, minfd int32, file *vfs.FileDescription, flags FDFlags) (int32, error) {
+ if minfd < 0 {
+ // Don't accept negative FDs.
+ return -1, syscall.EINVAL
+ }
+
+ // Default limit.
+ end := int32(math.MaxInt32)
+
+ // Ensure we don't get past the provided limit.
+ if limitSet := limits.FromContext(ctx); limitSet != nil {
+ lim := limitSet.Get(limits.NumberOfFiles)
+ if lim.Cur != limits.Infinity {
+ end = int32(lim.Cur)
+ }
+ if minfd >= end {
+ return -1, syscall.EMFILE
+ }
+ }
+
+ f.mu.Lock()
+ defer f.mu.Unlock()
+
+ // From f.next to find available fd.
+ fd := minfd
+ if fd < f.next {
+ fd = f.next
+ }
+ for fd < end {
+ if d, _, _ := f.get(fd); d == nil {
+ f.setVFS2(fd, file, flags)
+ if fd == f.next {
+ // Update next search start position.
+ f.next = fd + 1
+ }
+ return fd, nil
+ }
+ fd++
+ }
+ return -1, syscall.EMFILE
+}
+
// NewFDAt sets the file reference for the given FD. If there is an active
// reference for that FD, the ref count for that existing reference is
// decremented.
@@ -316,9 +360,6 @@ func (f *FDTable) newFDAt(ctx context.Context, fd int32, file *fs.File, fileVFS2
return syscall.EBADF
}
- f.mu.Lock()
- defer f.mu.Unlock()
-
// Check the limit for the provided file.
if limitSet := limits.FromContext(ctx); limitSet != nil {
if lim := limitSet.Get(limits.NumberOfFiles); lim.Cur != limits.Infinity && uint64(fd) >= lim.Cur {
@@ -327,6 +368,8 @@ func (f *FDTable) newFDAt(ctx context.Context, fd int32, file *fs.File, fileVFS2
}
// Install the entry.
+ f.mu.Lock()
+ defer f.mu.Unlock()
f.setAll(fd, file, fileVFS2, flags)
return nil
}
diff --git a/pkg/sentry/kernel/fs_context.go b/pkg/sentry/kernel/fs_context.go
index 7218aa24e..47f78df9a 100644
--- a/pkg/sentry/kernel/fs_context.go
+++ b/pkg/sentry/kernel/fs_context.go
@@ -244,6 +244,28 @@ func (f *FSContext) SetRootDirectory(d *fs.Dirent) {
old.DecRef()
}
+// SetRootDirectoryVFS2 sets the root directory. It takes a reference on vd.
+//
+// This is not a valid call after free.
+func (f *FSContext) SetRootDirectoryVFS2(vd vfs.VirtualDentry) {
+ if !vd.Ok() {
+ panic("FSContext.SetRootDirectoryVFS2 called with zero-value VirtualDentry")
+ }
+
+ f.mu.Lock()
+
+ if !f.rootVFS2.Ok() {
+ f.mu.Unlock()
+ panic(fmt.Sprintf("FSContext.SetRootDirectoryVFS2(%v)) called after destroy", vd))
+ }
+
+ old := f.rootVFS2
+ vd.IncRef()
+ f.rootVFS2 = vd
+ f.mu.Unlock()
+ old.DecRef()
+}
+
// Umask returns the current umask.
func (f *FSContext) Umask() uint {
f.mu.Lock()
diff --git a/pkg/sentry/kernel/kernel.go b/pkg/sentry/kernel/kernel.go
index 7da0368f1..8b76750e9 100644
--- a/pkg/sentry/kernel/kernel.go
+++ b/pkg/sentry/kernel/kernel.go
@@ -111,7 +111,7 @@ type Kernel struct {
timekeeper *Timekeeper
tasks *TaskSet
rootUserNamespace *auth.UserNamespace
- networkStack inet.Stack `state:"nosave"`
+ rootNetworkNamespace *inet.Namespace
applicationCores uint
useHostCores bool
extraAuxv []arch.AuxEntry
@@ -247,6 +247,10 @@ type Kernel struct {
// VFS keeps the filesystem state used across the kernel.
vfs vfs.VirtualFilesystem
+
+ // If set to true, report address space activation waits as if the task is in
+ // external wait so that the watchdog doesn't report the task stuck.
+ SleepForAddressSpaceActivation bool
}
// InitKernelArgs holds arguments to Init.
@@ -260,8 +264,9 @@ type InitKernelArgs struct {
// RootUserNamespace is the root user namespace.
RootUserNamespace *auth.UserNamespace
- // NetworkStack is the TCP/IP network stack. NetworkStack may be nil.
- NetworkStack inet.Stack
+ // RootNetworkNamespace is the root network namespace. If nil, no networking
+ // will be available.
+ RootNetworkNamespace *inet.Namespace
// ApplicationCores is the number of logical CPUs visible to sandboxed
// applications. The set of logical CPU IDs is [0, ApplicationCores); thus
@@ -320,7 +325,10 @@ func (k *Kernel) Init(args InitKernelArgs) error {
k.rootUTSNamespace = args.RootUTSNamespace
k.rootIPCNamespace = args.RootIPCNamespace
k.rootAbstractSocketNamespace = args.RootAbstractSocketNamespace
- k.networkStack = args.NetworkStack
+ k.rootNetworkNamespace = args.RootNetworkNamespace
+ if k.rootNetworkNamespace == nil {
+ k.rootNetworkNamespace = inet.NewRootNamespace(nil, nil)
+ }
k.applicationCores = args.ApplicationCores
if args.UseHostCores {
k.useHostCores = true
@@ -543,8 +551,6 @@ func (ts *TaskSet) unregisterEpollWaiters() {
func (k *Kernel) LoadFrom(r io.Reader, net inet.Stack, clocks sentrytime.Clocks) error {
loadStart := time.Now()
- k.networkStack = net
-
initAppCores := k.applicationCores
// Load the pre-saved CPUID FeatureSet.
@@ -575,6 +581,10 @@ func (k *Kernel) LoadFrom(r io.Reader, net inet.Stack, clocks sentrytime.Clocks)
log.Infof("Kernel load stats: %s", &stats)
log.Infof("Kernel load took [%s].", time.Since(kernelStart))
+ // rootNetworkNamespace should be populated after loading the state file.
+ // Restore the root network stack.
+ k.rootNetworkNamespace.RestoreRootStack(net)
+
// Load the memory file's state.
memoryStart := time.Now()
if err := k.mf.LoadFrom(k.SupervisorContext(), r); err != nil {
@@ -905,6 +915,7 @@ func (k *Kernel) CreateProcess(args CreateProcessArgs) (*ThreadGroup, ThreadID,
FSContext: fsContext,
FDTable: args.FDTable,
Credentials: args.Credentials,
+ NetworkNamespace: k.RootNetworkNamespace(),
AllowedCPUMask: sched.NewFullCPUSet(k.applicationCores),
UTSNamespace: args.UTSNamespace,
IPCNamespace: args.IPCNamespace,
@@ -1255,10 +1266,9 @@ func (k *Kernel) RootAbstractSocketNamespace() *AbstractSocketNamespace {
return k.rootAbstractSocketNamespace
}
-// NetworkStack returns the network stack. NetworkStack may return nil if no
-// network stack is available.
-func (k *Kernel) NetworkStack() inet.Stack {
- return k.networkStack
+// RootNetworkNamespace returns the root network namespace, always non-nil.
+func (k *Kernel) RootNetworkNamespace() *inet.Namespace {
+ return k.rootNetworkNamespace
}
// GlobalInit returns the thread group with ID 1 in the root PID namespace, or
diff --git a/pkg/sentry/kernel/rseq.go b/pkg/sentry/kernel/rseq.go
index 18416643b..ded95f532 100644
--- a/pkg/sentry/kernel/rseq.go
+++ b/pkg/sentry/kernel/rseq.go
@@ -304,7 +304,7 @@ func (t *Task) rseqAddrInterrupt() {
}
var cs linux.RSeqCriticalSection
- if _, err := cs.CopyIn(t, critAddr); err != nil {
+ if err := cs.CopyIn(t, critAddr); err != nil {
t.Debugf("Failed to copy critical section from %#x for rseq: %v", critAddr, err)
t.forceSignal(linux.SIGSEGV, false /* unconditional */)
t.SendSignal(SignalInfoPriv(linux.SIGSEGV))
diff --git a/pkg/sentry/kernel/task.go b/pkg/sentry/kernel/task.go
index a3443ff21..2cee2e6ed 100644
--- a/pkg/sentry/kernel/task.go
+++ b/pkg/sentry/kernel/task.go
@@ -486,13 +486,10 @@ type Task struct {
numaPolicy int32
numaNodeMask uint64
- // If netns is true, the task is in a non-root network namespace. Network
- // namespaces aren't currently implemented in full; being in a network
- // namespace simply prevents the task from observing any network devices
- // (including loopback) or using abstract socket addresses (see unix(7)).
+ // netns is the task's network namespace. netns is never nil.
//
- // netns is protected by mu. netns is owned by the task goroutine.
- netns bool
+ // netns is protected by mu.
+ netns *inet.Namespace
// If rseqPreempted is true, before the next call to p.Switch(),
// interrupt rseq critical regions as defined by rseqAddr and
@@ -792,6 +789,15 @@ func (t *Task) NewFDFrom(fd int32, file *fs.File, flags FDFlags) (int32, error)
return fds[0], nil
}
+// NewFDFromVFS2 is a convenience wrapper for t.FDTable().NewFDVFS2.
+//
+// This automatically passes the task as the context.
+//
+// Precondition: same as FDTable.Get.
+func (t *Task) NewFDFromVFS2(fd int32, file *vfs.FileDescription, flags FDFlags) (int32, error) {
+ return t.fdTable.NewFDVFS2(t, fd, file, flags)
+}
+
// NewFDAt is a convenience wrapper for t.FDTable().NewFDAt.
//
// This automatically passes the task as the context.
@@ -801,6 +807,15 @@ func (t *Task) NewFDAt(fd int32, file *fs.File, flags FDFlags) error {
return t.fdTable.NewFDAt(t, fd, file, flags)
}
+// NewFDAtVFS2 is a convenience wrapper for t.FDTable().NewFDAtVFS2.
+//
+// This automatically passes the task as the context.
+//
+// Precondition: same as FDTable.
+func (t *Task) NewFDAtVFS2(fd int32, file *vfs.FileDescription, flags FDFlags) error {
+ return t.fdTable.NewFDAtVFS2(t, fd, file, flags)
+}
+
// WithMuLocked executes f with t.mu locked.
func (t *Task) WithMuLocked(f func(*Task)) {
t.mu.Lock()
diff --git a/pkg/sentry/kernel/task_clone.go b/pkg/sentry/kernel/task_clone.go
index ba74b4c1c..78866f280 100644
--- a/pkg/sentry/kernel/task_clone.go
+++ b/pkg/sentry/kernel/task_clone.go
@@ -17,6 +17,7 @@ package kernel
import (
"gvisor.dev/gvisor/pkg/abi/linux"
"gvisor.dev/gvisor/pkg/bpf"
+ "gvisor.dev/gvisor/pkg/sentry/inet"
"gvisor.dev/gvisor/pkg/syserror"
"gvisor.dev/gvisor/pkg/usermem"
)
@@ -54,8 +55,7 @@ type SharingOptions struct {
NewUserNamespace bool
// If NewNetworkNamespace is true, the task should have an independent
- // network namespace. (Note that network namespaces are not really
- // implemented; see comment on Task.netns for details.)
+ // network namespace.
NewNetworkNamespace bool
// If NewFiles is true, the task should use an independent file descriptor
@@ -199,6 +199,11 @@ func (t *Task) Clone(opts *CloneOptions) (ThreadID, *SyscallControl, error) {
ipcns = NewIPCNamespace(userns)
}
+ netns := t.NetworkNamespace()
+ if opts.NewNetworkNamespace {
+ netns = inet.NewNamespace(netns)
+ }
+
// TODO(b/63601033): Implement CLONE_NEWNS.
mntnsVFS2 := t.mountNamespaceVFS2
if mntnsVFS2 != nil {
@@ -268,7 +273,7 @@ func (t *Task) Clone(opts *CloneOptions) (ThreadID, *SyscallControl, error) {
FDTable: fdTable,
Credentials: creds,
Niceness: t.Niceness(),
- NetworkNamespaced: t.netns,
+ NetworkNamespace: netns,
AllowedCPUMask: t.CPUMask(),
UTSNamespace: utsns,
IPCNamespace: ipcns,
@@ -283,9 +288,6 @@ func (t *Task) Clone(opts *CloneOptions) (ThreadID, *SyscallControl, error) {
} else {
cfg.InheritParent = t
}
- if opts.NewNetworkNamespace {
- cfg.NetworkNamespaced = true
- }
nt, err := t.tg.pidns.owner.NewTask(cfg)
if err != nil {
if opts.NewThreadGroup {
@@ -482,7 +484,7 @@ func (t *Task) Unshare(opts *SharingOptions) error {
t.mu.Unlock()
return syserror.EPERM
}
- t.netns = true
+ t.netns = inet.NewNamespace(t.netns)
}
if opts.NewUTSNamespace {
if !haveCapSysAdmin {
diff --git a/pkg/sentry/kernel/task_context.go b/pkg/sentry/kernel/task_context.go
index 2be982684..0158b1788 100644
--- a/pkg/sentry/kernel/task_context.go
+++ b/pkg/sentry/kernel/task_context.go
@@ -140,7 +140,7 @@ func (k *Kernel) LoadTaskImage(ctx context.Context, args loader.LoadArgs) (*Task
}
// Prepare a new user address space to load into.
- m := mm.NewMemoryManager(k, k)
+ m := mm.NewMemoryManager(k, k, k.SleepForAddressSpaceActivation)
defer m.DecUsers(ctx)
args.MemoryManager = m
diff --git a/pkg/sentry/kernel/task_exec.go b/pkg/sentry/kernel/task_exec.go
index 8f57a34a6..00c425cca 100644
--- a/pkg/sentry/kernel/task_exec.go
+++ b/pkg/sentry/kernel/task_exec.go
@@ -220,7 +220,7 @@ func (r *runSyscallAfterExecStop) execute(t *Task) taskRunState {
t.mu.Unlock()
t.unstopVforkParent()
// NOTE(b/30316266): All locks must be dropped prior to calling Activate.
- t.MemoryManager().Activate()
+ t.MemoryManager().Activate(t)
t.ptraceExec(oldTID)
return (*runSyscallExit)(nil)
diff --git a/pkg/sentry/kernel/task_log.go b/pkg/sentry/kernel/task_log.go
index 6d737d3e5..eeccaa197 100644
--- a/pkg/sentry/kernel/task_log.go
+++ b/pkg/sentry/kernel/task_log.go
@@ -32,21 +32,21 @@ const (
// Infof logs an formatted info message by calling log.Infof.
func (t *Task) Infof(fmt string, v ...interface{}) {
if log.IsLogging(log.Info) {
- log.Infof(t.logPrefix.Load().(string)+fmt, v...)
+ log.InfofAtDepth(1, t.logPrefix.Load().(string)+fmt, v...)
}
}
// Warningf logs a warning string by calling log.Warningf.
func (t *Task) Warningf(fmt string, v ...interface{}) {
if log.IsLogging(log.Warning) {
- log.Warningf(t.logPrefix.Load().(string)+fmt, v...)
+ log.WarningfAtDepth(1, t.logPrefix.Load().(string)+fmt, v...)
}
}
// Debugf creates a debug string that includes the task ID.
func (t *Task) Debugf(fmt string, v ...interface{}) {
if log.IsLogging(log.Debug) {
- log.Debugf(t.logPrefix.Load().(string)+fmt, v...)
+ log.DebugfAtDepth(1, t.logPrefix.Load().(string)+fmt, v...)
}
}
diff --git a/pkg/sentry/kernel/task_net.go b/pkg/sentry/kernel/task_net.go
index 172a31e1d..f7711232c 100644
--- a/pkg/sentry/kernel/task_net.go
+++ b/pkg/sentry/kernel/task_net.go
@@ -22,14 +22,23 @@ import (
func (t *Task) IsNetworkNamespaced() bool {
t.mu.Lock()
defer t.mu.Unlock()
- return t.netns
+ return !t.netns.IsRoot()
}
// NetworkContext returns the network stack used by the task. NetworkContext
// may return nil if no network stack is available.
+//
+// TODO(gvisor.dev/issue/1833): Migrate callers of this method to
+// NetworkNamespace().
func (t *Task) NetworkContext() inet.Stack {
- if t.IsNetworkNamespaced() {
- return nil
- }
- return t.k.networkStack
+ t.mu.Lock()
+ defer t.mu.Unlock()
+ return t.netns.Stack()
+}
+
+// NetworkNamespace returns the network namespace observed by the task.
+func (t *Task) NetworkNamespace() *inet.Namespace {
+ t.mu.Lock()
+ defer t.mu.Unlock()
+ return t.netns
}
diff --git a/pkg/sentry/kernel/task_start.go b/pkg/sentry/kernel/task_start.go
index f9236a842..a5035bb7f 100644
--- a/pkg/sentry/kernel/task_start.go
+++ b/pkg/sentry/kernel/task_start.go
@@ -17,6 +17,7 @@ package kernel
import (
"gvisor.dev/gvisor/pkg/abi/linux"
"gvisor.dev/gvisor/pkg/sentry/arch"
+ "gvisor.dev/gvisor/pkg/sentry/inet"
"gvisor.dev/gvisor/pkg/sentry/kernel/auth"
"gvisor.dev/gvisor/pkg/sentry/kernel/futex"
"gvisor.dev/gvisor/pkg/sentry/kernel/sched"
@@ -65,9 +66,8 @@ type TaskConfig struct {
// Niceness is the niceness of the new task.
Niceness int
- // If NetworkNamespaced is true, the new task should observe a non-root
- // network namespace.
- NetworkNamespaced bool
+ // NetworkNamespace is the network namespace to be used for the new task.
+ NetworkNamespace *inet.Namespace
// AllowedCPUMask contains the cpus that this task can run on.
AllowedCPUMask sched.CPUSet
@@ -133,7 +133,7 @@ func (ts *TaskSet) newTask(cfg *TaskConfig) (*Task, error) {
allowedCPUMask: cfg.AllowedCPUMask.Copy(),
ioUsage: &usage.IO{},
niceness: cfg.Niceness,
- netns: cfg.NetworkNamespaced,
+ netns: cfg.NetworkNamespace,
utsns: cfg.UTSNamespace,
ipcns: cfg.IPCNamespace,
abstractSockets: cfg.AbstractSocketNamespace,
diff --git a/pkg/sentry/kernel/task_usermem.go b/pkg/sentry/kernel/task_usermem.go
index 2bf3ce8a8..b02044ad2 100644
--- a/pkg/sentry/kernel/task_usermem.go
+++ b/pkg/sentry/kernel/task_usermem.go
@@ -30,7 +30,7 @@ var MAX_RW_COUNT = int(usermem.Addr(math.MaxInt32).RoundDown())
// Activate ensures that the task has an active address space.
func (t *Task) Activate() {
if mm := t.MemoryManager(); mm != nil {
- if err := mm.Activate(); err != nil {
+ if err := mm.Activate(t); err != nil {
panic("unable to activate mm: " + err.Error())
}
}
diff --git a/pkg/sentry/mm/address_space.go b/pkg/sentry/mm/address_space.go
index e58a63deb..0332fc71c 100644
--- a/pkg/sentry/mm/address_space.go
+++ b/pkg/sentry/mm/address_space.go
@@ -18,7 +18,7 @@ import (
"fmt"
"sync/atomic"
- "gvisor.dev/gvisor/pkg/atomicbitops"
+ "gvisor.dev/gvisor/pkg/context"
"gvisor.dev/gvisor/pkg/sentry/platform"
"gvisor.dev/gvisor/pkg/usermem"
)
@@ -39,11 +39,18 @@ func (mm *MemoryManager) AddressSpace() platform.AddressSpace {
//
// When this MemoryManager is no longer needed by a task, it should call
// Deactivate to release the reference.
-func (mm *MemoryManager) Activate() error {
+func (mm *MemoryManager) Activate(ctx context.Context) error {
// Fast path: the MemoryManager already has an active
// platform.AddressSpace, and we just need to indicate that we need it too.
- if atomicbitops.IncUnlessZeroInt32(&mm.active) {
- return nil
+ for {
+ active := atomic.LoadInt32(&mm.active)
+ if active == 0 {
+ // Fall back to the slow path.
+ break
+ }
+ if atomic.CompareAndSwapInt32(&mm.active, active, active+1) {
+ return nil
+ }
}
for {
@@ -85,16 +92,20 @@ func (mm *MemoryManager) Activate() error {
if as == nil {
// AddressSpace is unavailable, we must wait.
//
- // activeMu must not be held while waiting, as the user
- // of the address space we are waiting on may attempt
- // to take activeMu.
- //
- // Don't call UninterruptibleSleepStart to register the
- // wait to allow the watchdog stuck task to trigger in
- // case a process is starved waiting for the address
- // space.
+ // activeMu must not be held while waiting, as the user of the address
+ // space we are waiting on may attempt to take activeMu.
mm.activeMu.Unlock()
+
+ sleep := mm.p.CooperativelySchedulesAddressSpace() && mm.sleepForActivation
+ if sleep {
+ // Mark this task sleeping while waiting for the address space to
+ // prevent the watchdog from reporting it as a stuck task.
+ ctx.UninterruptibleSleepStart(false)
+ }
<-c
+ if sleep {
+ ctx.UninterruptibleSleepFinish(false)
+ }
continue
}
@@ -118,8 +129,15 @@ func (mm *MemoryManager) Activate() error {
func (mm *MemoryManager) Deactivate() {
// Fast path: this is not the last goroutine to deactivate the
// MemoryManager.
- if atomicbitops.DecUnlessOneInt32(&mm.active) {
- return
+ for {
+ active := atomic.LoadInt32(&mm.active)
+ if active == 1 {
+ // Fall back to the slow path.
+ break
+ }
+ if atomic.CompareAndSwapInt32(&mm.active, active, active-1) {
+ return
+ }
}
mm.activeMu.Lock()
diff --git a/pkg/sentry/mm/lifecycle.go b/pkg/sentry/mm/lifecycle.go
index 47b8fbf43..d8a5b9d29 100644
--- a/pkg/sentry/mm/lifecycle.go
+++ b/pkg/sentry/mm/lifecycle.go
@@ -18,7 +18,6 @@ import (
"fmt"
"sync/atomic"
- "gvisor.dev/gvisor/pkg/atomicbitops"
"gvisor.dev/gvisor/pkg/context"
"gvisor.dev/gvisor/pkg/sentry/arch"
"gvisor.dev/gvisor/pkg/sentry/limits"
@@ -29,16 +28,17 @@ import (
)
// NewMemoryManager returns a new MemoryManager with no mappings and 1 user.
-func NewMemoryManager(p platform.Platform, mfp pgalloc.MemoryFileProvider) *MemoryManager {
+func NewMemoryManager(p platform.Platform, mfp pgalloc.MemoryFileProvider, sleepForActivation bool) *MemoryManager {
return &MemoryManager{
- p: p,
- mfp: mfp,
- haveASIO: p.SupportsAddressSpaceIO(),
- privateRefs: &privateRefs{},
- users: 1,
- auxv: arch.Auxv{},
- dumpability: UserDumpable,
- aioManager: aioManager{contexts: make(map[uint64]*AIOContext)},
+ p: p,
+ mfp: mfp,
+ haveASIO: p.SupportsAddressSpaceIO(),
+ privateRefs: &privateRefs{},
+ users: 1,
+ auxv: arch.Auxv{},
+ dumpability: UserDumpable,
+ aioManager: aioManager{contexts: make(map[uint64]*AIOContext)},
+ sleepForActivation: sleepForActivation,
}
}
@@ -80,9 +80,10 @@ func (mm *MemoryManager) Fork(ctx context.Context) (*MemoryManager, error) {
envv: mm.envv,
auxv: append(arch.Auxv(nil), mm.auxv...),
// IncRef'd below, once we know that there isn't an error.
- executable: mm.executable,
- dumpability: mm.dumpability,
- aioManager: aioManager{contexts: make(map[uint64]*AIOContext)},
+ executable: mm.executable,
+ dumpability: mm.dumpability,
+ aioManager: aioManager{contexts: make(map[uint64]*AIOContext)},
+ sleepForActivation: mm.sleepForActivation,
}
// Copy vmas.
@@ -229,7 +230,15 @@ func (mm *MemoryManager) Fork(ctx context.Context) (*MemoryManager, error) {
// IncUsers increments mm's user count and returns true. If the user count is
// already 0, IncUsers does nothing and returns false.
func (mm *MemoryManager) IncUsers() bool {
- return atomicbitops.IncUnlessZeroInt32(&mm.users)
+ for {
+ users := atomic.LoadInt32(&mm.users)
+ if users == 0 {
+ return false
+ }
+ if atomic.CompareAndSwapInt32(&mm.users, users, users+1) {
+ return true
+ }
+ }
}
// DecUsers decrements mm's user count. If the user count reaches 0, all
diff --git a/pkg/sentry/mm/mm.go b/pkg/sentry/mm/mm.go
index 637383c7a..c2195ae11 100644
--- a/pkg/sentry/mm/mm.go
+++ b/pkg/sentry/mm/mm.go
@@ -226,6 +226,11 @@ type MemoryManager struct {
// aioManager keeps track of AIOContexts used for async IOs. AIOManager
// must be cloned when CLONE_VM is used.
aioManager aioManager
+
+ // sleepForActivation indicates whether the task should report to be sleeping
+ // before trying to activate the address space. When set to true, delays in
+ // activation are not reported as stuck tasks by the watchdog.
+ sleepForActivation bool
}
// vma represents a virtual memory area.
diff --git a/pkg/sentry/mm/mm_test.go b/pkg/sentry/mm/mm_test.go
index edacca741..fdc308542 100644
--- a/pkg/sentry/mm/mm_test.go
+++ b/pkg/sentry/mm/mm_test.go
@@ -31,7 +31,7 @@ import (
func testMemoryManager(ctx context.Context) *MemoryManager {
p := platform.FromContext(ctx)
mfp := pgalloc.MemoryFileProviderFromContext(ctx)
- mm := NewMemoryManager(p, mfp)
+ mm := NewMemoryManager(p, mfp, false)
mm.layout = arch.MmapLayout{
MinAddr: p.MinUserAddress(),
MaxAddr: p.MaxUserAddress(),
diff --git a/pkg/sentry/platform/kvm/machine.go b/pkg/sentry/platform/kvm/machine.go
index 8076c7529..f1afc74dc 100644
--- a/pkg/sentry/platform/kvm/machine.go
+++ b/pkg/sentry/platform/kvm/machine.go
@@ -329,10 +329,12 @@ func (m *machine) Destroy() {
}
// Get gets an available vCPU.
+//
+// This will return with the OS thread locked.
func (m *machine) Get() *vCPU {
+ m.mu.RLock()
runtime.LockOSThread()
tid := procid.Current()
- m.mu.RLock()
// Check for an exact match.
if c := m.vCPUs[tid]; c != nil {
@@ -343,8 +345,22 @@ func (m *machine) Get() *vCPU {
// The happy path failed. We now proceed to acquire an exclusive lock
// (because the vCPU map may change), and scan all available vCPUs.
+ // In this case, we first unlock the OS thread. Otherwise, if mu is
+ // not available, the current system thread will be parked and a new
+ // system thread spawned. We avoid this situation by simply refreshing
+ // tid after relocking the system thread.
m.mu.RUnlock()
+ runtime.UnlockOSThread()
m.mu.Lock()
+ runtime.LockOSThread()
+ tid = procid.Current()
+
+ // Recheck for an exact match.
+ if c := m.vCPUs[tid]; c != nil {
+ c.lock()
+ m.mu.Unlock()
+ return c
+ }
for {
// Scan for an available vCPU.
diff --git a/pkg/sentry/socket/control/control.go b/pkg/sentry/socket/control/control.go
index 4667373d2..8834a1e1a 100644
--- a/pkg/sentry/socket/control/control.go
+++ b/pkg/sentry/socket/control/control.go
@@ -329,7 +329,7 @@ func PackTOS(t *kernel.Task, tos uint8, buf []byte) []byte {
}
// PackTClass packs an IPV6_TCLASS socket control message.
-func PackTClass(t *kernel.Task, tClass int32, buf []byte) []byte {
+func PackTClass(t *kernel.Task, tClass uint32, buf []byte) []byte {
return putCmsgStruct(
buf,
linux.SOL_IPV6,
diff --git a/pkg/sentry/socket/netfilter/BUILD b/pkg/sentry/socket/netfilter/BUILD
index c91ec7494..7cd2ce55b 100644
--- a/pkg/sentry/socket/netfilter/BUILD
+++ b/pkg/sentry/socket/netfilter/BUILD
@@ -7,6 +7,7 @@ go_library(
srcs = [
"extensions.go",
"netfilter.go",
+ "targets.go",
"tcp_matcher.go",
"udp_matcher.go",
],
diff --git a/pkg/sentry/socket/netfilter/netfilter.go b/pkg/sentry/socket/netfilter/netfilter.go
index 257cb485b..f68a2260d 100644
--- a/pkg/sentry/socket/netfilter/netfilter.go
+++ b/pkg/sentry/socket/netfilter/netfilter.go
@@ -248,13 +248,15 @@ func marshalTarget(target iptables.Target) []byte {
return marshalStandardTarget(iptables.RuleReturn)
case iptables.RedirectTarget:
return marshalRedirectTarget()
+ case JumpTarget:
+ return marshalJumpTarget(tg)
default:
panic(fmt.Errorf("unknown target of type %T", target))
}
}
func marshalStandardTarget(verdict iptables.RuleVerdict) []byte {
- nflog("convert to binary: marshalling standard target with size %d", linux.SizeOfXTStandardTarget)
+ nflog("convert to binary: marshalling standard target")
// The target's name will be the empty string.
target := linux.XTStandardTarget{
@@ -290,8 +292,25 @@ func marshalRedirectTarget() []byte {
},
}
copy(target.Target.Name[:], redirectTargetName)
-
+
ret := make([]byte, 0, linux.SizeOfXTRedirectTarget)
+ return binary.Marshal(ret, usermem.ByteOrder, target)
+}
+
+func marshalJumpTarget(jt JumpTarget) []byte {
+ nflog("convert to binary: marshalling jump target")
+
+ // The target's name will be the empty string.
+ target := linux.XTStandardTarget{
+ Target: linux.XTEntryTarget{
+ TargetSize: linux.SizeOfXTStandardTarget,
+ },
+ // Verdict is overloaded by the ABI. When positive, it holds
+ // the jump offset from the start of the table.
+ Verdict: int32(jt.Offset),
+ }
+
+ ret := make([]byte, 0, linux.SizeOfXTStandardTarget)
return binary.Marshal(ret, usermem.ByteOrder, target)
}
@@ -358,7 +377,8 @@ func SetEntries(stack *stack.Stack, optVal []byte) *syserr.Error {
// Convert input into a list of rules and their offsets.
var offset uint32
- var offsets []uint32
+ // offsets maps rule byte offsets to their position in table.Rules.
+ offsets := map[uint32]int{}
for entryIdx := uint32(0); entryIdx < replace.NumEntries; entryIdx++ {
nflog("set entries: processing entry at offset %d", offset)
@@ -419,11 +439,12 @@ func SetEntries(stack *stack.Stack, optVal []byte) *syserr.Error {
Target: target,
Matchers: matchers,
})
- offsets = append(offsets, offset)
+ offsets[offset] = int(entryIdx)
offset += uint32(entry.NextOffset)
if initialOptValLen-len(optVal) != int(entry.NextOffset) {
nflog("entry NextOffset is %d, but entry took up %d bytes", entry.NextOffset, initialOptValLen-len(optVal))
+ return syserr.ErrInvalidArgument
}
}
@@ -432,13 +453,13 @@ func SetEntries(stack *stack.Stack, optVal []byte) *syserr.Error {
for hook, _ := range replace.HookEntry {
if table.ValidHooks()&(1<<hook) != 0 {
hk := hookFromLinux(hook)
- for ruleIdx, offset := range offsets {
+ for offset, ruleIdx := range offsets {
if offset == replace.HookEntry[hook] {
table.BuiltinChains[hk] = ruleIdx
}
if offset == replace.Underflow[hook] {
if !validUnderflow(table.Rules[ruleIdx]) {
- nflog("underflow for hook %d isn't an unconditional ACCEPT or DROP.")
+ nflog("underflow for hook %d isn't an unconditional ACCEPT or DROP")
return syserr.ErrInvalidArgument
}
table.Underflows[hk] = ruleIdx
@@ -467,16 +488,35 @@ func SetEntries(stack *stack.Stack, optVal []byte) *syserr.Error {
// - There's some other rule after it.
// - There are no matchers.
if ruleIdx == len(table.Rules)-1 {
- nflog("user chain must have a rule or default policy.")
+ nflog("user chain must have a rule or default policy")
return syserr.ErrInvalidArgument
}
if len(table.Rules[ruleIdx].Matchers) != 0 {
- nflog("user chain's first node must have no matcheres.")
+ nflog("user chain's first node must have no matchers")
return syserr.ErrInvalidArgument
}
table.UserChains[target.Name] = ruleIdx + 1
}
+ // Set each jump to point to the appropriate rule. Right now they hold byte
+ // offsets.
+ for ruleIdx, rule := range table.Rules {
+ jump, ok := rule.Target.(JumpTarget)
+ if !ok {
+ continue
+ }
+
+ // Find the rule corresponding to the jump rule offset.
+ jumpTo, ok := offsets[jump.Offset]
+ if !ok {
+ nflog("failed to find a rule to jump to")
+ return syserr.ErrInvalidArgument
+ }
+ jump.RuleNum = jumpTo
+ rule.Target = jump
+ table.Rules[ruleIdx] = rule
+ }
+
// TODO(gvisor.dev/issue/170): Support other chains.
// Since we only support modifying the INPUT chain and redirect for
// PREROUTING chain right now, make sure all other chains point to
@@ -572,7 +612,12 @@ func parseTarget(filter iptables.IPHeaderFilter, optVal []byte) (iptables.Target
buf = optVal[:linux.SizeOfXTStandardTarget]
binary.Unmarshal(buf, usermem.ByteOrder, &standardTarget)
- return translateToStandardTarget(standardTarget.Verdict)
+ if standardTarget.Verdict < 0 {
+ // A Verdict < 0 indicates a non-jump verdict.
+ return translateToStandardTarget(standardTarget.Verdict)
+ }
+ // A verdict >= 0 indicates a jump.
+ return JumpTarget{Offset: uint32(standardTarget.Verdict)}, nil
case errorTargetName:
// Error target.
diff --git a/pkg/sentry/socket/netfilter/targets.go b/pkg/sentry/socket/netfilter/targets.go
new file mode 100644
index 000000000..c421b87cf
--- /dev/null
+++ b/pkg/sentry/socket/netfilter/targets.go
@@ -0,0 +1,35 @@
+// Copyright 2020 The gVisor Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+package netfilter
+
+import (
+ "gvisor.dev/gvisor/pkg/tcpip"
+ "gvisor.dev/gvisor/pkg/tcpip/iptables"
+)
+
+// JumpTarget implements iptables.Target.
+type JumpTarget struct {
+ // Offset is the byte offset of the rule to jump to. It is used for
+ // marshaling and unmarshaling.
+ Offset uint32
+
+ // RuleNum is the rule to jump to.
+ RuleNum int
+}
+
+// Action implements iptables.Target.Action.
+func (jt JumpTarget) Action(tcpip.PacketBuffer) (iptables.RuleVerdict, int) {
+ return iptables.RuleJump, jt.RuleNum
+}
diff --git a/pkg/sentry/socket/netstack/netstack.go b/pkg/sentry/socket/netstack/netstack.go
index 9757fbfba..e187276c5 100644
--- a/pkg/sentry/socket/netstack/netstack.go
+++ b/pkg/sentry/socket/netstack/netstack.go
@@ -1318,6 +1318,22 @@ func getSockOptIPv6(t *kernel.Task, ep commonEndpoint, name, outLen int) (interf
}
return ib, nil
+ case linux.IPV6_RECVTCLASS:
+ if outLen < sizeOfInt32 {
+ return nil, syserr.ErrInvalidArgument
+ }
+
+ v, err := ep.GetSockOptBool(tcpip.ReceiveTClassOption)
+ if err != nil {
+ return nil, syserr.TranslateNetstackError(err)
+ }
+
+ var o int32
+ if v {
+ o = 1
+ }
+ return o, nil
+
default:
emitUnimplementedEventIPv6(t, name)
}
@@ -1803,6 +1819,14 @@ func setSockOptIPv6(t *kernel.Task, ep commonEndpoint, name int, optVal []byte)
}
return syserr.TranslateNetstackError(ep.SetSockOpt(tcpip.IPv6TrafficClassOption(v)))
+ case linux.IPV6_RECVTCLASS:
+ v, err := parseIntOrChar(optVal)
+ if err != nil {
+ return err
+ }
+
+ return syserr.TranslateNetstackError(ep.SetSockOptBool(tcpip.ReceiveTClassOption, v != 0))
+
default:
emitUnimplementedEventIPv6(t, name)
}
@@ -2086,7 +2110,6 @@ func emitUnimplementedEventIPv6(t *kernel.Task, name int) {
linux.IPV6_RECVPATHMTU,
linux.IPV6_RECVPKTINFO,
linux.IPV6_RECVRTHDR,
- linux.IPV6_RECVTCLASS,
linux.IPV6_RTHDR,
linux.IPV6_RTHDRDSTOPTS,
linux.IPV6_TCLASS,
@@ -2424,6 +2447,8 @@ func (s *SocketOperations) controlMessages() socket.ControlMessages {
Timestamp: s.readCM.Timestamp,
HasTOS: s.readCM.HasTOS,
TOS: s.readCM.TOS,
+ HasTClass: s.readCM.HasTClass,
+ TClass: s.readCM.TClass,
HasIPPacketInfo: s.readCM.HasIPPacketInfo,
PacketInfo: s.readCM.PacketInfo,
},
diff --git a/pkg/sentry/socket/netstack/provider.go b/pkg/sentry/socket/netstack/provider.go
index 5afff2564..5f181f017 100644
--- a/pkg/sentry/socket/netstack/provider.go
+++ b/pkg/sentry/socket/netstack/provider.go
@@ -75,6 +75,8 @@ func getTransportProtocol(ctx context.Context, stype linux.SockType, protocol in
switch protocol {
case syscall.IPPROTO_ICMP:
return header.ICMPv4ProtocolNumber, true, nil
+ case syscall.IPPROTO_ICMPV6:
+ return header.ICMPv6ProtocolNumber, true, nil
case syscall.IPPROTO_UDP:
return header.UDPProtocolNumber, true, nil
case syscall.IPPROTO_TCP:
diff --git a/pkg/sentry/strace/BUILD b/pkg/sentry/strace/BUILD
index 2f39a6f2b..88d5db9fc 100644
--- a/pkg/sentry/strace/BUILD
+++ b/pkg/sentry/strace/BUILD
@@ -7,6 +7,7 @@ go_library(
srcs = [
"capability.go",
"clone.go",
+ "epoll.go",
"futex.go",
"linux64_amd64.go",
"linux64_arm64.go",
diff --git a/pkg/sentry/strace/epoll.go b/pkg/sentry/strace/epoll.go
new file mode 100644
index 000000000..a6e48b836
--- /dev/null
+++ b/pkg/sentry/strace/epoll.go
@@ -0,0 +1,89 @@
+// Copyright 2020 The gVisor Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+package strace
+
+import (
+ "fmt"
+ "strings"
+
+ "gvisor.dev/gvisor/pkg/abi"
+ "gvisor.dev/gvisor/pkg/abi/linux"
+ "gvisor.dev/gvisor/pkg/sentry/kernel"
+ "gvisor.dev/gvisor/pkg/usermem"
+)
+
+func epollEvent(t *kernel.Task, eventAddr usermem.Addr) string {
+ var e linux.EpollEvent
+ if _, err := t.CopyIn(eventAddr, &e); err != nil {
+ return fmt.Sprintf("%#x {error reading event: %v}", eventAddr, err)
+ }
+ var sb strings.Builder
+ fmt.Fprintf(&sb, "%#x ", eventAddr)
+ writeEpollEvent(&sb, e)
+ return sb.String()
+}
+
+func epollEvents(t *kernel.Task, eventsAddr usermem.Addr, numEvents, maxBytes uint64) string {
+ var sb strings.Builder
+ fmt.Fprintf(&sb, "%#x {", eventsAddr)
+ addr := eventsAddr
+ for i := uint64(0); i < numEvents; i++ {
+ var e linux.EpollEvent
+ if _, err := t.CopyIn(addr, &e); err != nil {
+ fmt.Fprintf(&sb, "{error reading event at %#x: %v}", addr, err)
+ continue
+ }
+ writeEpollEvent(&sb, e)
+ if uint64(sb.Len()) >= maxBytes {
+ sb.WriteString("...")
+ break
+ }
+ if _, ok := addr.AddLength(uint64(linux.SizeOfEpollEvent)); !ok {
+ fmt.Fprintf(&sb, "{error reading event at %#x: EFAULT}", addr)
+ continue
+ }
+ }
+ sb.WriteString("}")
+ return sb.String()
+}
+
+func writeEpollEvent(sb *strings.Builder, e linux.EpollEvent) {
+ events := epollEventEvents.Parse(uint64(e.Events))
+ fmt.Fprintf(sb, "{events=%s data=[%#x, %#x]}", events, e.Data[0], e.Data[1])
+}
+
+var epollCtlOps = abi.ValueSet{
+ linux.EPOLL_CTL_ADD: "EPOLL_CTL_ADD",
+ linux.EPOLL_CTL_DEL: "EPOLL_CTL_DEL",
+ linux.EPOLL_CTL_MOD: "EPOLL_CTL_MOD",
+}
+
+var epollEventEvents = abi.FlagSet{
+ {Flag: linux.EPOLLIN, Name: "EPOLLIN"},
+ {Flag: linux.EPOLLPRI, Name: "EPOLLPRI"},
+ {Flag: linux.EPOLLOUT, Name: "EPOLLOUT"},
+ {Flag: linux.EPOLLERR, Name: "EPOLLERR"},
+ {Flag: linux.EPOLLHUP, Name: "EPULLHUP"},
+ {Flag: linux.EPOLLRDNORM, Name: "EPOLLRDNORM"},
+ {Flag: linux.EPOLLRDBAND, Name: "EPOLLRDBAND"},
+ {Flag: linux.EPOLLWRNORM, Name: "EPOLLWRNORM"},
+ {Flag: linux.EPOLLWRBAND, Name: "EPOLLWRBAND"},
+ {Flag: linux.EPOLLMSG, Name: "EPOLLMSG"},
+ {Flag: linux.EPOLLRDHUP, Name: "EPOLLRDHUP"},
+ {Flag: linux.EPOLLEXCLUSIVE, Name: "EPOLLEXCLUSIVE"},
+ {Flag: linux.EPOLLWAKEUP, Name: "EPOLLWAKEUP"},
+ {Flag: linux.EPOLLONESHOT, Name: "EPOLLONESHOT"},
+ {Flag: linux.EPOLLET, Name: "EPOLLET"},
+}
diff --git a/pkg/sentry/strace/linux64_amd64.go b/pkg/sentry/strace/linux64_amd64.go
index a4de545e9..71b92eaee 100644
--- a/pkg/sentry/strace/linux64_amd64.go
+++ b/pkg/sentry/strace/linux64_amd64.go
@@ -256,8 +256,8 @@ var linuxAMD64 = SyscallMap{
229: makeSyscallInfo("clock_getres", Hex, PostTimespec),
230: makeSyscallInfo("clock_nanosleep", Hex, Hex, Timespec, PostTimespec),
231: makeSyscallInfo("exit_group", Hex),
- 232: makeSyscallInfo("epoll_wait", Hex, Hex, Hex, Hex),
- 233: makeSyscallInfo("epoll_ctl", Hex, Hex, FD, Hex),
+ 232: makeSyscallInfo("epoll_wait", FD, EpollEvents, Hex, Hex),
+ 233: makeSyscallInfo("epoll_ctl", FD, EpollCtlOp, FD, EpollEvent),
234: makeSyscallInfo("tgkill", Hex, Hex, Signal),
235: makeSyscallInfo("utimes", Path, Timeval),
// 236: vserver (not implemented in the Linux kernel)
@@ -305,7 +305,7 @@ var linuxAMD64 = SyscallMap{
278: makeSyscallInfo("vmsplice", FD, Hex, Hex, Hex),
279: makeSyscallInfo("move_pages", Hex, Hex, Hex, Hex, Hex, Hex),
280: makeSyscallInfo("utimensat", FD, Path, UTimeTimespec, Hex),
- 281: makeSyscallInfo("epoll_pwait", Hex, Hex, Hex, Hex, SigSet, Hex),
+ 281: makeSyscallInfo("epoll_pwait", FD, EpollEvents, Hex, Hex, SigSet, Hex),
282: makeSyscallInfo("signalfd", Hex, Hex, Hex),
283: makeSyscallInfo("timerfd_create", Hex, Hex),
284: makeSyscallInfo("eventfd", Hex),
diff --git a/pkg/sentry/strace/linux64_arm64.go b/pkg/sentry/strace/linux64_arm64.go
index 8bc38545f..bd7361a52 100644
--- a/pkg/sentry/strace/linux64_arm64.go
+++ b/pkg/sentry/strace/linux64_arm64.go
@@ -45,8 +45,8 @@ var linuxARM64 = SyscallMap{
18: makeSyscallInfo("lookup_dcookie", Hex, Hex, Hex),
19: makeSyscallInfo("eventfd2", Hex, Hex),
20: makeSyscallInfo("epoll_create1", Hex),
- 21: makeSyscallInfo("epoll_ctl", Hex, Hex, FD, Hex),
- 22: makeSyscallInfo("epoll_pwait", Hex, Hex, Hex, Hex, SigSet, Hex),
+ 21: makeSyscallInfo("epoll_ctl", FD, EpollCtlOp, FD, EpollEvent),
+ 22: makeSyscallInfo("epoll_pwait", FD, EpollEvents, Hex, Hex, SigSet, Hex),
23: makeSyscallInfo("dup", FD),
24: makeSyscallInfo("dup3", FD, FD, Hex),
25: makeSyscallInfo("fcntl", FD, Hex, Hex),
diff --git a/pkg/sentry/strace/strace.go b/pkg/sentry/strace/strace.go
index 46cb2a1cc..77655558e 100644
--- a/pkg/sentry/strace/strace.go
+++ b/pkg/sentry/strace/strace.go
@@ -481,6 +481,12 @@ func (i *SyscallInfo) pre(t *kernel.Task, args arch.SyscallArguments, maximumBlo
output = append(output, capData(t, args[arg-1].Pointer(), args[arg].Pointer()))
case PollFDs:
output = append(output, pollFDs(t, args[arg].Pointer(), uint(args[arg+1].Uint()), false))
+ case EpollCtlOp:
+ output = append(output, epollCtlOps.Parse(uint64(args[arg].Int())))
+ case EpollEvent:
+ output = append(output, epollEvent(t, args[arg].Pointer()))
+ case EpollEvents:
+ output = append(output, epollEvents(t, args[arg].Pointer(), 0 /* numEvents */, uint64(maximumBlobSize)))
case SelectFDSet:
output = append(output, fdSet(t, int(args[0].Int()), args[arg].Pointer()))
case Oct:
@@ -549,6 +555,8 @@ func (i *SyscallInfo) post(t *kernel.Task, args arch.SyscallArguments, rval uint
output[arg] = capData(t, args[arg-1].Pointer(), args[arg].Pointer())
case PollFDs:
output[arg] = pollFDs(t, args[arg].Pointer(), uint(args[arg+1].Uint()), true)
+ case EpollEvents:
+ output[arg] = epollEvents(t, args[arg].Pointer(), uint64(rval), uint64(maximumBlobSize))
case GetSockOptVal:
output[arg] = getSockOptVal(t, args[arg-2].Uint64() /* level */, args[arg-1].Uint64() /* optName */, args[arg].Pointer() /* optVal */, args[arg+1].Pointer() /* optLen */, maximumBlobSize, rval)
case SetSockOptVal:
diff --git a/pkg/sentry/strace/syscalls.go b/pkg/sentry/strace/syscalls.go
index 446d1e0f6..7e69b9279 100644
--- a/pkg/sentry/strace/syscalls.go
+++ b/pkg/sentry/strace/syscalls.go
@@ -228,6 +228,16 @@ const (
// SockOptLevel is the optname argument in getsockopt(2) and
// setsockopt(2).
SockOptName
+
+ // EpollCtlOp is the op argument to epoll_ctl(2).
+ EpollCtlOp
+
+ // EpollEvent is the event argument in epoll_ctl(2).
+ EpollEvent
+
+ // EpollEvents is an array of struct epoll_event. It is the events
+ // argument in epoll_wait(2)/epoll_pwait(2).
+ EpollEvents
)
// defaultFormat is the syscall argument format to use if the actual format is
diff --git a/pkg/sentry/syscalls/linux/sys_epoll.go b/pkg/sentry/syscalls/linux/sys_epoll.go
index fbef5b376..3ab93fbde 100644
--- a/pkg/sentry/syscalls/linux/sys_epoll.go
+++ b/pkg/sentry/syscalls/linux/sys_epoll.go
@@ -25,6 +25,8 @@ import (
"gvisor.dev/gvisor/pkg/waiter"
)
+// LINT.IfChange
+
// EpollCreate1 implements the epoll_create1(2) linux syscall.
func EpollCreate1(t *kernel.Task, args arch.SyscallArguments) (uintptr, *kernel.SyscallControl, error) {
flags := args[0].Int()
@@ -164,3 +166,5 @@ func EpollPwait(t *kernel.Task, args arch.SyscallArguments) (uintptr, *kernel.Sy
return EpollWait(t, args)
}
+
+// LINT.ThenChange(vfs2/epoll.go)
diff --git a/pkg/sentry/syscalls/linux/sys_file.go b/pkg/sentry/syscalls/linux/sys_file.go
index 421845ebb..c21f14dc0 100644
--- a/pkg/sentry/syscalls/linux/sys_file.go
+++ b/pkg/sentry/syscalls/linux/sys_file.go
@@ -130,6 +130,8 @@ func copyInPath(t *kernel.Task, addr usermem.Addr, allowEmpty bool) (path string
return path, dirPath, nil
}
+// LINT.IfChange
+
func openAt(t *kernel.Task, dirFD int32, addr usermem.Addr, flags uint) (fd uintptr, err error) {
path, dirPath, err := copyInPath(t, addr, false /* allowEmpty */)
if err != nil {
@@ -575,6 +577,10 @@ func Faccessat(t *kernel.Task, args arch.SyscallArguments) (uintptr, *kernel.Sys
return 0, nil, accessAt(t, dirFD, addr, flags&linux.AT_SYMLINK_NOFOLLOW == 0, mode)
}
+// LINT.ThenChange(vfs2/filesystem.go)
+
+// LINT.IfChange
+
// Ioctl implements linux syscall ioctl(2).
func Ioctl(t *kernel.Task, args arch.SyscallArguments) (uintptr, *kernel.SyscallControl, error) {
fd := args[0].Int()
@@ -650,6 +656,10 @@ func Ioctl(t *kernel.Task, args arch.SyscallArguments) (uintptr, *kernel.Syscall
}
}
+// LINT.ThenChange(vfs2/ioctl.go)
+
+// LINT.IfChange
+
// Getcwd implements the linux syscall getcwd(2).
func Getcwd(t *kernel.Task, args arch.SyscallArguments) (uintptr, *kernel.SyscallControl, error) {
addr := args[0].Pointer()
@@ -760,6 +770,10 @@ func Fchdir(t *kernel.Task, args arch.SyscallArguments) (uintptr, *kernel.Syscal
return 0, nil, nil
}
+// LINT.ThenChange(vfs2/fscontext.go)
+
+// LINT.IfChange
+
// Close implements linux syscall close(2).
func Close(t *kernel.Task, args arch.SyscallArguments) (uintptr, *kernel.SyscallControl, error) {
fd := args[0].Int()
@@ -1094,6 +1108,8 @@ func Fcntl(t *kernel.Task, args arch.SyscallArguments) (uintptr, *kernel.Syscall
}
}
+// LINT.ThenChange(vfs2/fd.go)
+
const (
_FADV_NORMAL = 0
_FADV_RANDOM = 1
@@ -1141,6 +1157,8 @@ func Fadvise64(t *kernel.Task, args arch.SyscallArguments) (uintptr, *kernel.Sys
return 0, nil, nil
}
+// LINT.IfChange
+
func mkdirAt(t *kernel.Task, dirFD int32, addr usermem.Addr, mode linux.FileMode) error {
path, _, err := copyInPath(t, addr, false /* allowEmpty */)
if err != nil {
@@ -1421,6 +1439,10 @@ func Linkat(t *kernel.Task, args arch.SyscallArguments) (uintptr, *kernel.Syscal
return 0, nil, linkAt(t, oldDirFD, oldAddr, newDirFD, newAddr, resolve, allowEmpty)
}
+// LINT.ThenChange(vfs2/filesystem.go)
+
+// LINT.IfChange
+
func readlinkAt(t *kernel.Task, dirFD int32, addr usermem.Addr, bufAddr usermem.Addr, size uint) (copied uintptr, err error) {
path, dirPath, err := copyInPath(t, addr, false /* allowEmpty */)
if err != nil {
@@ -1480,6 +1502,10 @@ func Readlinkat(t *kernel.Task, args arch.SyscallArguments) (uintptr, *kernel.Sy
return n, nil, err
}
+// LINT.ThenChange(vfs2/stat.go)
+
+// LINT.IfChange
+
func unlinkAt(t *kernel.Task, dirFD int32, addr usermem.Addr) error {
path, dirPath, err := copyInPath(t, addr, false /* allowEmpty */)
if err != nil {
@@ -1516,6 +1542,10 @@ func Unlinkat(t *kernel.Task, args arch.SyscallArguments) (uintptr, *kernel.Sysc
return 0, nil, unlinkAt(t, dirFD, addr)
}
+// LINT.ThenChange(vfs2/filesystem.go)
+
+// LINT.IfChange
+
// Truncate implements linux syscall truncate(2).
func Truncate(t *kernel.Task, args arch.SyscallArguments) (uintptr, *kernel.SyscallControl, error) {
addr := args[0].Pointer()
@@ -1614,6 +1644,8 @@ func Ftruncate(t *kernel.Task, args arch.SyscallArguments) (uintptr, *kernel.Sys
return 0, nil, nil
}
+// LINT.ThenChange(vfs2/setstat.go)
+
// Umask implements linux syscall umask(2).
func Umask(t *kernel.Task, args arch.SyscallArguments) (uintptr, *kernel.SyscallControl, error) {
mask := args[0].ModeT()
@@ -1621,6 +1653,8 @@ func Umask(t *kernel.Task, args arch.SyscallArguments) (uintptr, *kernel.Syscall
return uintptr(mask), nil, nil
}
+// LINT.IfChange
+
// Change ownership of a file.
//
// uid and gid may be -1, in which case they will not be changed.
@@ -1987,6 +2021,10 @@ func Futimesat(t *kernel.Task, args arch.SyscallArguments) (uintptr, *kernel.Sys
return 0, nil, utimes(t, dirFD, pathnameAddr, ts, true)
}
+// LINT.ThenChange(vfs2/setstat.go)
+
+// LINT.IfChange
+
func renameAt(t *kernel.Task, oldDirFD int32, oldAddr usermem.Addr, newDirFD int32, newAddr usermem.Addr) error {
newPath, _, err := copyInPath(t, newAddr, false /* allowEmpty */)
if err != nil {
@@ -2042,6 +2080,8 @@ func Renameat(t *kernel.Task, args arch.SyscallArguments) (uintptr, *kernel.Sysc
return 0, nil, renameAt(t, oldDirFD, oldPathAddr, newDirFD, newPathAddr)
}
+// LINT.ThenChange(vfs2/filesystem.go)
+
// Fallocate implements linux system call fallocate(2).
func Fallocate(t *kernel.Task, args arch.SyscallArguments) (uintptr, *kernel.SyscallControl, error) {
fd := args[0].Int()
diff --git a/pkg/sentry/syscalls/linux/sys_getdents.go b/pkg/sentry/syscalls/linux/sys_getdents.go
index f66f4ffde..b126fecc0 100644
--- a/pkg/sentry/syscalls/linux/sys_getdents.go
+++ b/pkg/sentry/syscalls/linux/sys_getdents.go
@@ -27,6 +27,8 @@ import (
"gvisor.dev/gvisor/pkg/usermem"
)
+// LINT.IfChange
+
// Getdents implements linux syscall getdents(2) for 64bit systems.
func Getdents(t *kernel.Task, args arch.SyscallArguments) (uintptr, *kernel.SyscallControl, error) {
fd := args[0].Int()
@@ -244,3 +246,5 @@ func (ds *direntSerializer) CopyOut(name string, attr fs.DentAttr) error {
func (ds *direntSerializer) Written() int {
return ds.written
}
+
+// LINT.ThenChange(vfs2/getdents.go)
diff --git a/pkg/sentry/syscalls/linux/sys_lseek.go b/pkg/sentry/syscalls/linux/sys_lseek.go
index 297e920c4..3f7691eae 100644
--- a/pkg/sentry/syscalls/linux/sys_lseek.go
+++ b/pkg/sentry/syscalls/linux/sys_lseek.go
@@ -21,6 +21,8 @@ import (
"gvisor.dev/gvisor/pkg/syserror"
)
+// LINT.IfChange
+
// Lseek implements linux syscall lseek(2).
func Lseek(t *kernel.Task, args arch.SyscallArguments) (uintptr, *kernel.SyscallControl, error) {
fd := args[0].Int()
@@ -52,3 +54,5 @@ func Lseek(t *kernel.Task, args arch.SyscallArguments) (uintptr, *kernel.Syscall
}
return uintptr(offset), nil, err
}
+
+// LINT.ThenChange(vfs2/read_write.go)
diff --git a/pkg/sentry/syscalls/linux/sys_mmap.go b/pkg/sentry/syscalls/linux/sys_mmap.go
index 9959f6e61..91694d374 100644
--- a/pkg/sentry/syscalls/linux/sys_mmap.go
+++ b/pkg/sentry/syscalls/linux/sys_mmap.go
@@ -35,6 +35,8 @@ func Brk(t *kernel.Task, args arch.SyscallArguments) (uintptr, *kernel.SyscallCo
return uintptr(addr), nil, nil
}
+// LINT.IfChange
+
// Mmap implements linux syscall mmap(2).
func Mmap(t *kernel.Task, args arch.SyscallArguments) (uintptr, *kernel.SyscallControl, error) {
prot := args[2].Int()
@@ -104,6 +106,8 @@ func Mmap(t *kernel.Task, args arch.SyscallArguments) (uintptr, *kernel.SyscallC
return uintptr(rv), nil, err
}
+// LINT.ThenChange(vfs2/mmap.go)
+
// Munmap implements linux syscall munmap(2).
func Munmap(t *kernel.Task, args arch.SyscallArguments) (uintptr, *kernel.SyscallControl, error) {
return 0, nil, t.MemoryManager().MUnmap(t, args[0].Pointer(), args[1].Uint64())
diff --git a/pkg/sentry/syscalls/linux/sys_read.go b/pkg/sentry/syscalls/linux/sys_read.go
index 227692f06..78a2cb750 100644
--- a/pkg/sentry/syscalls/linux/sys_read.go
+++ b/pkg/sentry/syscalls/linux/sys_read.go
@@ -28,6 +28,8 @@ import (
"gvisor.dev/gvisor/pkg/waiter"
)
+// LINT.IfChange
+
const (
// EventMaskRead contains events that can be triggered on reads.
EventMaskRead = waiter.EventIn | waiter.EventHUp | waiter.EventErr
@@ -388,3 +390,5 @@ func preadv(t *kernel.Task, f *fs.File, dst usermem.IOSequence, offset int64) (i
return total, err
}
+
+// LINT.ThenChange(vfs2/read_write.go)
diff --git a/pkg/sentry/syscalls/linux/sys_stat.go b/pkg/sentry/syscalls/linux/sys_stat.go
index 8b66a9006..701b27b4a 100644
--- a/pkg/sentry/syscalls/linux/sys_stat.go
+++ b/pkg/sentry/syscalls/linux/sys_stat.go
@@ -23,6 +23,8 @@ import (
"gvisor.dev/gvisor/pkg/usermem"
)
+// LINT.IfChange
+
func statFromAttrs(t *kernel.Task, sattr fs.StableAttr, uattr fs.UnstableAttr) linux.Stat {
return linux.Stat{
Dev: sattr.DeviceID,
@@ -131,8 +133,7 @@ func stat(t *kernel.Task, d *fs.Dirent, dirPath bool, statAddr usermem.Addr) err
return err
}
s := statFromAttrs(t, d.Inode.StableAttr, uattr)
- _, err = s.CopyOut(t, statAddr)
- return err
+ return s.CopyOut(t, statAddr)
}
// fstat implements fstat for the given *fs.File.
@@ -142,8 +143,7 @@ func fstat(t *kernel.Task, f *fs.File, statAddr usermem.Addr) error {
return err
}
s := statFromAttrs(t, f.Dirent.Inode.StableAttr, uattr)
- _, err = s.CopyOut(t, statAddr)
- return err
+ return s.CopyOut(t, statAddr)
}
// Statx implements linux syscall statx(2).
@@ -299,3 +299,5 @@ func statfsImpl(t *kernel.Task, d *fs.Dirent, addr usermem.Addr) error {
_, err = t.CopyOut(addr, &statfs)
return err
}
+
+// LINT.ThenChange(vfs2/stat.go)
diff --git a/pkg/sentry/syscalls/linux/sys_sync.go b/pkg/sentry/syscalls/linux/sys_sync.go
index 3e55235bd..5ad465ae3 100644
--- a/pkg/sentry/syscalls/linux/sys_sync.go
+++ b/pkg/sentry/syscalls/linux/sys_sync.go
@@ -22,6 +22,8 @@ import (
"gvisor.dev/gvisor/pkg/syserror"
)
+// LINT.IfChange
+
// Sync implements linux system call sync(2).
func Sync(t *kernel.Task, args arch.SyscallArguments) (uintptr, *kernel.SyscallControl, error) {
t.MountNamespace().SyncAll(t)
@@ -135,3 +137,5 @@ func SyncFileRange(t *kernel.Task, args arch.SyscallArguments) (uintptr, *kernel
return 0, nil, syserror.ConvertIntr(err, kernel.ERESTARTSYS)
}
+
+// LINT.ThenChange(vfs2/sync.go)
diff --git a/pkg/sentry/syscalls/linux/sys_write.go b/pkg/sentry/syscalls/linux/sys_write.go
index aba892939..506ee54ce 100644
--- a/pkg/sentry/syscalls/linux/sys_write.go
+++ b/pkg/sentry/syscalls/linux/sys_write.go
@@ -28,6 +28,8 @@ import (
"gvisor.dev/gvisor/pkg/waiter"
)
+// LINT.IfChange
+
const (
// EventMaskWrite contains events that can be triggered on writes.
//
@@ -358,3 +360,5 @@ func pwritev(t *kernel.Task, f *fs.File, src usermem.IOSequence, offset int64) (
return total, err
}
+
+// LINT.ThenChange(vfs2/read_write.go)
diff --git a/pkg/sentry/syscalls/linux/sys_xattr.go b/pkg/sentry/syscalls/linux/sys_xattr.go
index 9d8140b8a..2de5e3422 100644
--- a/pkg/sentry/syscalls/linux/sys_xattr.go
+++ b/pkg/sentry/syscalls/linux/sys_xattr.go
@@ -25,6 +25,8 @@ import (
"gvisor.dev/gvisor/pkg/usermem"
)
+// LINT.IfChange
+
// GetXattr implements linux syscall getxattr(2).
func GetXattr(t *kernel.Task, args arch.SyscallArguments) (uintptr, *kernel.SyscallControl, error) {
return getXattrFromPath(t, args, true)
@@ -418,3 +420,5 @@ func removeXattr(t *kernel.Task, d *fs.Dirent, nameAddr usermem.Addr) error {
return d.Inode.RemoveXattr(t, d, name)
}
+
+// LINT.ThenChange(vfs2/xattr.go)
diff --git a/pkg/sentry/syscalls/linux/vfs2/BUILD b/pkg/sentry/syscalls/linux/vfs2/BUILD
index 6b8a00b6e..f51761e81 100644
--- a/pkg/sentry/syscalls/linux/vfs2/BUILD
+++ b/pkg/sentry/syscalls/linux/vfs2/BUILD
@@ -5,18 +5,44 @@ package(licenses = ["notice"])
go_library(
name = "vfs2",
srcs = [
+ "epoll.go",
+ "epoll_unsafe.go",
+ "execve.go",
+ "fd.go",
+ "filesystem.go",
+ "fscontext.go",
+ "getdents.go",
+ "ioctl.go",
"linux64.go",
"linux64_override_amd64.go",
"linux64_override_arm64.go",
- "sys_read.go",
+ "mmap.go",
+ "path.go",
+ "poll.go",
+ "read_write.go",
+ "setstat.go",
+ "stat.go",
+ "sync.go",
+ "xattr.go",
],
+ marshal = True,
visibility = ["//:sandbox"],
deps = [
+ "//pkg/abi/linux",
+ "//pkg/fspath",
+ "//pkg/gohacks",
"//pkg/sentry/arch",
+ "//pkg/sentry/fsbridge",
"//pkg/sentry/kernel",
+ "//pkg/sentry/kernel/auth",
+ "//pkg/sentry/kernel/time",
+ "//pkg/sentry/limits",
+ "//pkg/sentry/loader",
+ "//pkg/sentry/memmap",
"//pkg/sentry/syscalls",
"//pkg/sentry/syscalls/linux",
"//pkg/sentry/vfs",
+ "//pkg/sync",
"//pkg/syserror",
"//pkg/usermem",
"//pkg/waiter",
diff --git a/pkg/sentry/syscalls/linux/vfs2/epoll.go b/pkg/sentry/syscalls/linux/vfs2/epoll.go
new file mode 100644
index 000000000..d6cb0e79a
--- /dev/null
+++ b/pkg/sentry/syscalls/linux/vfs2/epoll.go
@@ -0,0 +1,225 @@
+// Copyright 2020 The gVisor Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+package vfs2
+
+import (
+ "math"
+ "time"
+
+ "gvisor.dev/gvisor/pkg/abi/linux"
+ "gvisor.dev/gvisor/pkg/sentry/arch"
+ "gvisor.dev/gvisor/pkg/sentry/kernel"
+ ktime "gvisor.dev/gvisor/pkg/sentry/kernel/time"
+ "gvisor.dev/gvisor/pkg/sentry/vfs"
+ "gvisor.dev/gvisor/pkg/syserror"
+ "gvisor.dev/gvisor/pkg/usermem"
+ "gvisor.dev/gvisor/pkg/waiter"
+)
+
+// EpollCreate1 implements Linux syscall epoll_create1(2).
+func EpollCreate1(t *kernel.Task, args arch.SyscallArguments) (uintptr, *kernel.SyscallControl, error) {
+ flags := args[0].Int()
+ if flags&^linux.EPOLL_CLOEXEC != 0 {
+ return 0, nil, syserror.EINVAL
+ }
+
+ file, err := t.Kernel().VFS().NewEpollInstanceFD()
+ if err != nil {
+ return 0, nil, err
+ }
+ defer file.DecRef()
+
+ fd, err := t.NewFDFromVFS2(0, file, kernel.FDFlags{
+ CloseOnExec: flags&linux.EPOLL_CLOEXEC != 0,
+ })
+ if err != nil {
+ return 0, nil, err
+ }
+ return uintptr(fd), nil, nil
+}
+
+// EpollCreate implements Linux syscall epoll_create(2).
+func EpollCreate(t *kernel.Task, args arch.SyscallArguments) (uintptr, *kernel.SyscallControl, error) {
+ size := args[0].Int()
+
+ // "Since Linux 2.6.8, the size argument is ignored, but must be greater
+ // than zero" - epoll_create(2)
+ if size <= 0 {
+ return 0, nil, syserror.EINVAL
+ }
+
+ file, err := t.Kernel().VFS().NewEpollInstanceFD()
+ if err != nil {
+ return 0, nil, err
+ }
+ defer file.DecRef()
+
+ fd, err := t.NewFDFromVFS2(0, file, kernel.FDFlags{})
+ if err != nil {
+ return 0, nil, err
+ }
+ return uintptr(fd), nil, nil
+}
+
+// EpollCtl implements Linux syscall epoll_ctl(2).
+func EpollCtl(t *kernel.Task, args arch.SyscallArguments) (uintptr, *kernel.SyscallControl, error) {
+ epfd := args[0].Int()
+ op := args[1].Int()
+ fd := args[2].Int()
+ eventAddr := args[3].Pointer()
+
+ epfile := t.GetFileVFS2(epfd)
+ if epfile == nil {
+ return 0, nil, syserror.EBADF
+ }
+ defer epfile.DecRef()
+ ep, ok := epfile.Impl().(*vfs.EpollInstance)
+ if !ok {
+ return 0, nil, syserror.EINVAL
+ }
+ file := t.GetFileVFS2(fd)
+ if file == nil {
+ return 0, nil, syserror.EBADF
+ }
+ defer file.DecRef()
+ if epfile == file {
+ return 0, nil, syserror.EINVAL
+ }
+
+ var event linux.EpollEvent
+ switch op {
+ case linux.EPOLL_CTL_ADD:
+ if err := event.CopyIn(t, eventAddr); err != nil {
+ return 0, nil, err
+ }
+ return 0, nil, ep.AddInterest(file, fd, event)
+ case linux.EPOLL_CTL_DEL:
+ return 0, nil, ep.DeleteInterest(file, fd)
+ case linux.EPOLL_CTL_MOD:
+ if err := event.CopyIn(t, eventAddr); err != nil {
+ return 0, nil, err
+ }
+ return 0, nil, ep.ModifyInterest(file, fd, event)
+ default:
+ return 0, nil, syserror.EINVAL
+ }
+}
+
+// EpollWait implements Linux syscall epoll_wait(2).
+func EpollWait(t *kernel.Task, args arch.SyscallArguments) (uintptr, *kernel.SyscallControl, error) {
+ epfd := args[0].Int()
+ eventsAddr := args[1].Pointer()
+ maxEvents := int(args[2].Int())
+ timeout := int(args[3].Int())
+
+ const _EP_MAX_EVENTS = math.MaxInt32 / sizeofEpollEvent // Linux: fs/eventpoll.c:EP_MAX_EVENTS
+ if maxEvents <= 0 || maxEvents > _EP_MAX_EVENTS {
+ return 0, nil, syserror.EINVAL
+ }
+
+ epfile := t.GetFileVFS2(epfd)
+ if epfile == nil {
+ return 0, nil, syserror.EBADF
+ }
+ defer epfile.DecRef()
+ ep, ok := epfile.Impl().(*vfs.EpollInstance)
+ if !ok {
+ return 0, nil, syserror.EINVAL
+ }
+
+ // Use a fixed-size buffer in a loop, instead of make([]linux.EpollEvent,
+ // maxEvents), so that the buffer can be allocated on the stack.
+ var (
+ events [16]linux.EpollEvent
+ total int
+ ch chan struct{}
+ haveDeadline bool
+ deadline ktime.Time
+ )
+ for {
+ batchEvents := len(events)
+ if batchEvents > maxEvents {
+ batchEvents = maxEvents
+ }
+ n := ep.ReadEvents(events[:batchEvents])
+ maxEvents -= n
+ if n != 0 {
+ // Copy what we read out.
+ copiedEvents, err := copyOutEvents(t, eventsAddr, events[:n])
+ eventsAddr += usermem.Addr(copiedEvents * sizeofEpollEvent)
+ total += copiedEvents
+ if err != nil {
+ if total != 0 {
+ return uintptr(total), nil, nil
+ }
+ return 0, nil, err
+ }
+ // If we've filled the application's event buffer, we're done.
+ if maxEvents == 0 {
+ return uintptr(total), nil, nil
+ }
+ // Loop if we read a full batch, under the expectation that there
+ // may be more events to read.
+ if n == batchEvents {
+ continue
+ }
+ }
+ // We get here if n != batchEvents. If we read any number of events
+ // (just now, or in a previous iteration of this loop), or if timeout
+ // is 0 (such that epoll_wait should be non-blocking), return the
+ // events we've read so far to the application.
+ if total != 0 || timeout == 0 {
+ return uintptr(total), nil, nil
+ }
+ // In the first iteration of this loop, register with the epoll
+ // instance for readability events, but then immediately continue the
+ // loop since we need to retry ReadEvents() before blocking. In all
+ // subsequent iterations, block until events are available, the timeout
+ // expires, or an interrupt arrives.
+ if ch == nil {
+ var w waiter.Entry
+ w, ch = waiter.NewChannelEntry(nil)
+ epfile.EventRegister(&w, waiter.EventIn)
+ defer epfile.EventUnregister(&w)
+ } else {
+ // Set up the timer if a timeout was specified.
+ if timeout > 0 && !haveDeadline {
+ timeoutDur := time.Duration(timeout) * time.Millisecond
+ deadline = t.Kernel().MonotonicClock().Now().Add(timeoutDur)
+ haveDeadline = true
+ }
+ if err := t.BlockWithDeadline(ch, haveDeadline, deadline); err != nil {
+ if err == syserror.ETIMEDOUT {
+ err = nil
+ }
+ // total must be 0 since otherwise we would have returned
+ // above.
+ return 0, nil, err
+ }
+ }
+ }
+}
+
+// EpollPwait implements Linux syscall epoll_pwait(2).
+func EpollPwait(t *kernel.Task, args arch.SyscallArguments) (uintptr, *kernel.SyscallControl, error) {
+ maskAddr := args[4].Pointer()
+ maskSize := uint(args[5].Uint())
+
+ if err := setTempSignalSet(t, maskAddr, maskSize); err != nil {
+ return 0, nil, err
+ }
+
+ return EpollWait(t, args)
+}
diff --git a/pkg/sentry/syscalls/linux/vfs2/epoll_unsafe.go b/pkg/sentry/syscalls/linux/vfs2/epoll_unsafe.go
new file mode 100644
index 000000000..825f325bf
--- /dev/null
+++ b/pkg/sentry/syscalls/linux/vfs2/epoll_unsafe.go
@@ -0,0 +1,44 @@
+// Copyright 2020 The gVisor Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+package vfs2
+
+import (
+ "reflect"
+ "runtime"
+ "unsafe"
+
+ "gvisor.dev/gvisor/pkg/abi/linux"
+ "gvisor.dev/gvisor/pkg/gohacks"
+ "gvisor.dev/gvisor/pkg/sentry/kernel"
+ "gvisor.dev/gvisor/pkg/usermem"
+)
+
+const sizeofEpollEvent = int(unsafe.Sizeof(linux.EpollEvent{}))
+
+func copyOutEvents(t *kernel.Task, addr usermem.Addr, events []linux.EpollEvent) (int, error) {
+ if len(events) == 0 {
+ return 0, nil
+ }
+ // Cast events to a byte slice for copying.
+ var eventBytes []byte
+ eventBytesHdr := (*reflect.SliceHeader)(unsafe.Pointer(&eventBytes))
+ eventBytesHdr.Data = uintptr(gohacks.Noescape(unsafe.Pointer(&events[0])))
+ eventBytesHdr.Len = len(events) * sizeofEpollEvent
+ eventBytesHdr.Cap = len(events) * sizeofEpollEvent
+ copiedBytes, err := t.CopyOutBytes(addr, eventBytes)
+ runtime.KeepAlive(events)
+ copiedEvents := copiedBytes / sizeofEpollEvent // rounded down
+ return copiedEvents, err
+}
diff --git a/pkg/sentry/syscalls/linux/vfs2/execve.go b/pkg/sentry/syscalls/linux/vfs2/execve.go
new file mode 100644
index 000000000..aef0078a8
--- /dev/null
+++ b/pkg/sentry/syscalls/linux/vfs2/execve.go
@@ -0,0 +1,137 @@
+// Copyright 2020 The gVisor Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+package vfs2
+
+import (
+ "gvisor.dev/gvisor/pkg/abi/linux"
+ "gvisor.dev/gvisor/pkg/fspath"
+ "gvisor.dev/gvisor/pkg/sentry/arch"
+ "gvisor.dev/gvisor/pkg/sentry/fsbridge"
+ "gvisor.dev/gvisor/pkg/sentry/kernel"
+ "gvisor.dev/gvisor/pkg/sentry/loader"
+ slinux "gvisor.dev/gvisor/pkg/sentry/syscalls/linux"
+ "gvisor.dev/gvisor/pkg/sentry/vfs"
+ "gvisor.dev/gvisor/pkg/syserror"
+ "gvisor.dev/gvisor/pkg/usermem"
+)
+
+// Execve implements linux syscall execve(2).
+func Execve(t *kernel.Task, args arch.SyscallArguments) (uintptr, *kernel.SyscallControl, error) {
+ pathnameAddr := args[0].Pointer()
+ argvAddr := args[1].Pointer()
+ envvAddr := args[2].Pointer()
+ return execveat(t, linux.AT_FDCWD, pathnameAddr, argvAddr, envvAddr, 0 /* flags */)
+}
+
+// Execveat implements linux syscall execveat(2).
+func Execveat(t *kernel.Task, args arch.SyscallArguments) (uintptr, *kernel.SyscallControl, error) {
+ dirfd := args[0].Int()
+ pathnameAddr := args[1].Pointer()
+ argvAddr := args[2].Pointer()
+ envvAddr := args[3].Pointer()
+ flags := args[4].Int()
+ return execveat(t, dirfd, pathnameAddr, argvAddr, envvAddr, flags)
+}
+
+func execveat(t *kernel.Task, dirfd int32, pathnameAddr, argvAddr, envvAddr usermem.Addr, flags int32) (uintptr, *kernel.SyscallControl, error) {
+ if flags&^(linux.AT_EMPTY_PATH|linux.AT_SYMLINK_NOFOLLOW) != 0 {
+ return 0, nil, syserror.EINVAL
+ }
+
+ pathname, err := t.CopyInString(pathnameAddr, linux.PATH_MAX)
+ if err != nil {
+ return 0, nil, err
+ }
+ var argv, envv []string
+ if argvAddr != 0 {
+ var err error
+ argv, err = t.CopyInVector(argvAddr, slinux.ExecMaxElemSize, slinux.ExecMaxTotalSize)
+ if err != nil {
+ return 0, nil, err
+ }
+ }
+ if envvAddr != 0 {
+ var err error
+ envv, err = t.CopyInVector(envvAddr, slinux.ExecMaxElemSize, slinux.ExecMaxTotalSize)
+ if err != nil {
+ return 0, nil, err
+ }
+ }
+
+ root := t.FSContext().RootDirectoryVFS2()
+ defer root.DecRef()
+ var executable fsbridge.File
+ closeOnExec := false
+ if path := fspath.Parse(pathname); dirfd != linux.AT_FDCWD && !path.Absolute {
+ // We must open the executable ourselves since dirfd is used as the
+ // starting point while resolving path, but the task working directory
+ // is used as the starting point while resolving interpreters (Linux:
+ // fs/binfmt_script.c:load_script() => fs/exec.c:open_exec() =>
+ // do_open_execat(fd=AT_FDCWD)), and the loader package is currently
+ // incapable of handling this correctly.
+ if !path.HasComponents() && flags&linux.AT_EMPTY_PATH == 0 {
+ return 0, nil, syserror.ENOENT
+ }
+ dirfile, dirfileFlags := t.FDTable().GetVFS2(dirfd)
+ if dirfile == nil {
+ return 0, nil, syserror.EBADF
+ }
+ start := dirfile.VirtualDentry()
+ start.IncRef()
+ dirfile.DecRef()
+ closeOnExec = dirfileFlags.CloseOnExec
+ file, err := t.Kernel().VFS().OpenAt(t, t.Credentials(), &vfs.PathOperation{
+ Root: root,
+ Start: start,
+ Path: path,
+ FollowFinalSymlink: flags&linux.AT_SYMLINK_NOFOLLOW == 0,
+ }, &vfs.OpenOptions{
+ Flags: linux.O_RDONLY,
+ FileExec: true,
+ })
+ start.DecRef()
+ if err != nil {
+ return 0, nil, err
+ }
+ defer file.DecRef()
+ executable = fsbridge.NewVFSFile(file)
+ }
+
+ // Load the new TaskContext.
+ mntns := t.MountNamespaceVFS2() // FIXME(jamieliu): useless refcount change
+ defer mntns.DecRef()
+ wd := t.FSContext().WorkingDirectoryVFS2()
+ defer wd.DecRef()
+ remainingTraversals := uint(linux.MaxSymlinkTraversals)
+ loadArgs := loader.LoadArgs{
+ Opener: fsbridge.NewVFSLookup(mntns, root, wd),
+ RemainingTraversals: &remainingTraversals,
+ ResolveFinal: flags&linux.AT_SYMLINK_NOFOLLOW == 0,
+ Filename: pathname,
+ File: executable,
+ CloseOnExec: closeOnExec,
+ Argv: argv,
+ Envv: envv,
+ Features: t.Arch().FeatureSet(),
+ }
+
+ tc, se := t.Kernel().LoadTaskImage(t, loadArgs)
+ if se != nil {
+ return 0, nil, se.ToError()
+ }
+
+ ctrl, err := t.Execve(tc)
+ return 0, ctrl, err
+}
diff --git a/pkg/sentry/syscalls/linux/vfs2/fd.go b/pkg/sentry/syscalls/linux/vfs2/fd.go
new file mode 100644
index 000000000..3afcea665
--- /dev/null
+++ b/pkg/sentry/syscalls/linux/vfs2/fd.go
@@ -0,0 +1,147 @@
+// Copyright 2020 The gVisor Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+package vfs2
+
+import (
+ "gvisor.dev/gvisor/pkg/abi/linux"
+ "gvisor.dev/gvisor/pkg/sentry/arch"
+ "gvisor.dev/gvisor/pkg/sentry/kernel"
+ slinux "gvisor.dev/gvisor/pkg/sentry/syscalls/linux"
+ "gvisor.dev/gvisor/pkg/syserror"
+)
+
+// Close implements Linux syscall close(2).
+func Close(t *kernel.Task, args arch.SyscallArguments) (uintptr, *kernel.SyscallControl, error) {
+ fd := args[0].Int()
+
+ // Note that Remove provides a reference on the file that we may use to
+ // flush. It is still active until we drop the final reference below
+ // (and other reference-holding operations complete).
+ _, file := t.FDTable().Remove(fd)
+ if file == nil {
+ return 0, nil, syserror.EBADF
+ }
+ defer file.DecRef()
+
+ err := file.OnClose(t)
+ return 0, nil, slinux.HandleIOErrorVFS2(t, false /* partial */, err, syserror.EINTR, "close", file)
+}
+
+// Dup implements Linux syscall dup(2).
+func Dup(t *kernel.Task, args arch.SyscallArguments) (uintptr, *kernel.SyscallControl, error) {
+ fd := args[0].Int()
+
+ file := t.GetFileVFS2(fd)
+ if file == nil {
+ return 0, nil, syserror.EBADF
+ }
+ defer file.DecRef()
+
+ newFD, err := t.NewFDFromVFS2(0, file, kernel.FDFlags{})
+ if err != nil {
+ return 0, nil, syserror.EMFILE
+ }
+ return uintptr(newFD), nil, nil
+}
+
+// Dup2 implements Linux syscall dup2(2).
+func Dup2(t *kernel.Task, args arch.SyscallArguments) (uintptr, *kernel.SyscallControl, error) {
+ oldfd := args[0].Int()
+ newfd := args[1].Int()
+
+ if oldfd == newfd {
+ // As long as oldfd is valid, dup2() does nothing and returns newfd.
+ file := t.GetFileVFS2(oldfd)
+ if file == nil {
+ return 0, nil, syserror.EBADF
+ }
+ file.DecRef()
+ return uintptr(newfd), nil, nil
+ }
+
+ return dup3(t, oldfd, newfd, 0)
+}
+
+// Dup3 implements Linux syscall dup3(2).
+func Dup3(t *kernel.Task, args arch.SyscallArguments) (uintptr, *kernel.SyscallControl, error) {
+ oldfd := args[0].Int()
+ newfd := args[1].Int()
+ flags := args[2].Uint()
+
+ if oldfd == newfd {
+ return 0, nil, syserror.EINVAL
+ }
+
+ return dup3(t, oldfd, newfd, flags)
+}
+
+func dup3(t *kernel.Task, oldfd, newfd int32, flags uint32) (uintptr, *kernel.SyscallControl, error) {
+ if flags&^linux.O_CLOEXEC != 0 {
+ return 0, nil, syserror.EINVAL
+ }
+
+ file := t.GetFileVFS2(oldfd)
+ if file == nil {
+ return 0, nil, syserror.EBADF
+ }
+ defer file.DecRef()
+
+ err := t.NewFDAtVFS2(newfd, file, kernel.FDFlags{
+ CloseOnExec: flags&linux.O_CLOEXEC != 0,
+ })
+ if err != nil {
+ return 0, nil, err
+ }
+ return uintptr(newfd), nil, nil
+}
+
+// Fcntl implements linux syscall fcntl(2).
+func Fcntl(t *kernel.Task, args arch.SyscallArguments) (uintptr, *kernel.SyscallControl, error) {
+ fd := args[0].Int()
+ cmd := args[1].Int()
+
+ file, flags := t.FDTable().GetVFS2(fd)
+ if file == nil {
+ return 0, nil, syserror.EBADF
+ }
+ defer file.DecRef()
+
+ switch cmd {
+ case linux.F_DUPFD, linux.F_DUPFD_CLOEXEC:
+ minfd := args[2].Int()
+ fd, err := t.NewFDFromVFS2(minfd, file, kernel.FDFlags{
+ CloseOnExec: cmd == linux.F_DUPFD_CLOEXEC,
+ })
+ if err != nil {
+ return 0, nil, err
+ }
+ return uintptr(fd), nil, nil
+ case linux.F_GETFD:
+ return uintptr(flags.ToLinuxFDFlags()), nil, nil
+ case linux.F_SETFD:
+ flags := args[2].Uint()
+ t.FDTable().SetFlags(fd, kernel.FDFlags{
+ CloseOnExec: flags&linux.FD_CLOEXEC != 0,
+ })
+ return 0, nil, nil
+ case linux.F_GETFL:
+ return uintptr(file.StatusFlags()), nil, nil
+ case linux.F_SETFL:
+ return 0, nil, file.SetStatusFlags(t, t.Credentials(), args[2].Uint())
+ default:
+ // TODO(gvisor.dev/issue/1623): Everything else is not yet supported.
+ return 0, nil, syserror.EINVAL
+ }
+}
diff --git a/pkg/sentry/syscalls/linux/vfs2/filesystem.go b/pkg/sentry/syscalls/linux/vfs2/filesystem.go
new file mode 100644
index 000000000..fc5ceea4c
--- /dev/null
+++ b/pkg/sentry/syscalls/linux/vfs2/filesystem.go
@@ -0,0 +1,326 @@
+// Copyright 2020 The gVisor Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+package vfs2
+
+import (
+ "gvisor.dev/gvisor/pkg/abi/linux"
+ "gvisor.dev/gvisor/pkg/sentry/arch"
+ "gvisor.dev/gvisor/pkg/sentry/kernel"
+ "gvisor.dev/gvisor/pkg/sentry/vfs"
+ "gvisor.dev/gvisor/pkg/syserror"
+ "gvisor.dev/gvisor/pkg/usermem"
+)
+
+// Link implements Linux syscall link(2).
+func Link(t *kernel.Task, args arch.SyscallArguments) (uintptr, *kernel.SyscallControl, error) {
+ oldpathAddr := args[0].Pointer()
+ newpathAddr := args[1].Pointer()
+ return 0, nil, linkat(t, linux.AT_FDCWD, oldpathAddr, linux.AT_FDCWD, newpathAddr, 0 /* flags */)
+}
+
+// Linkat implements Linux syscall linkat(2).
+func Linkat(t *kernel.Task, args arch.SyscallArguments) (uintptr, *kernel.SyscallControl, error) {
+ olddirfd := args[0].Int()
+ oldpathAddr := args[1].Pointer()
+ newdirfd := args[2].Int()
+ newpathAddr := args[3].Pointer()
+ flags := args[4].Int()
+ return 0, nil, linkat(t, olddirfd, oldpathAddr, newdirfd, newpathAddr, flags)
+}
+
+func linkat(t *kernel.Task, olddirfd int32, oldpathAddr usermem.Addr, newdirfd int32, newpathAddr usermem.Addr, flags int32) error {
+ if flags&^(linux.AT_EMPTY_PATH|linux.AT_SYMLINK_FOLLOW) != 0 {
+ return syserror.EINVAL
+ }
+ if flags&linux.AT_EMPTY_PATH != 0 && !t.HasCapability(linux.CAP_DAC_READ_SEARCH) {
+ return syserror.ENOENT
+ }
+
+ oldpath, err := copyInPath(t, oldpathAddr)
+ if err != nil {
+ return err
+ }
+ oldtpop, err := getTaskPathOperation(t, olddirfd, oldpath, shouldAllowEmptyPath(flags&linux.AT_EMPTY_PATH != 0), shouldFollowFinalSymlink(flags&linux.AT_SYMLINK_FOLLOW != 0))
+ if err != nil {
+ return err
+ }
+ defer oldtpop.Release()
+
+ newpath, err := copyInPath(t, newpathAddr)
+ if err != nil {
+ return err
+ }
+ newtpop, err := getTaskPathOperation(t, newdirfd, newpath, disallowEmptyPath, nofollowFinalSymlink)
+ if err != nil {
+ return err
+ }
+ defer newtpop.Release()
+
+ return t.Kernel().VFS().LinkAt(t, t.Credentials(), &oldtpop.pop, &newtpop.pop)
+}
+
+// Mkdir implements Linux syscall mkdir(2).
+func Mkdir(t *kernel.Task, args arch.SyscallArguments) (uintptr, *kernel.SyscallControl, error) {
+ addr := args[0].Pointer()
+ mode := args[1].ModeT()
+ return 0, nil, mkdirat(t, linux.AT_FDCWD, addr, mode)
+}
+
+// Mkdirat implements Linux syscall mkdirat(2).
+func Mkdirat(t *kernel.Task, args arch.SyscallArguments) (uintptr, *kernel.SyscallControl, error) {
+ dirfd := args[0].Int()
+ addr := args[1].Pointer()
+ mode := args[2].ModeT()
+ return 0, nil, mkdirat(t, dirfd, addr, mode)
+}
+
+func mkdirat(t *kernel.Task, dirfd int32, addr usermem.Addr, mode uint) error {
+ path, err := copyInPath(t, addr)
+ if err != nil {
+ return err
+ }
+ tpop, err := getTaskPathOperation(t, dirfd, path, disallowEmptyPath, nofollowFinalSymlink)
+ if err != nil {
+ return err
+ }
+ defer tpop.Release()
+ return t.Kernel().VFS().MkdirAt(t, t.Credentials(), &tpop.pop, &vfs.MkdirOptions{
+ Mode: linux.FileMode(mode & (0777 | linux.S_ISVTX) &^ t.FSContext().Umask()),
+ })
+}
+
+// Mknod implements Linux syscall mknod(2).
+func Mknod(t *kernel.Task, args arch.SyscallArguments) (uintptr, *kernel.SyscallControl, error) {
+ addr := args[0].Pointer()
+ mode := args[1].ModeT()
+ dev := args[2].Uint()
+ return 0, nil, mknodat(t, linux.AT_FDCWD, addr, mode, dev)
+}
+
+// Mknodat implements Linux syscall mknodat(2).
+func Mknodat(t *kernel.Task, args arch.SyscallArguments) (uintptr, *kernel.SyscallControl, error) {
+ dirfd := args[0].Int()
+ addr := args[1].Pointer()
+ mode := args[2].ModeT()
+ dev := args[3].Uint()
+ return 0, nil, mknodat(t, dirfd, addr, mode, dev)
+}
+
+func mknodat(t *kernel.Task, dirfd int32, addr usermem.Addr, mode uint, dev uint32) error {
+ path, err := copyInPath(t, addr)
+ if err != nil {
+ return err
+ }
+ tpop, err := getTaskPathOperation(t, dirfd, path, disallowEmptyPath, nofollowFinalSymlink)
+ if err != nil {
+ return err
+ }
+ defer tpop.Release()
+ major, minor := linux.DecodeDeviceID(dev)
+ return t.Kernel().VFS().MknodAt(t, t.Credentials(), &tpop.pop, &vfs.MknodOptions{
+ Mode: linux.FileMode(mode &^ t.FSContext().Umask()),
+ DevMajor: uint32(major),
+ DevMinor: minor,
+ })
+}
+
+// Open implements Linux syscall open(2).
+func Open(t *kernel.Task, args arch.SyscallArguments) (uintptr, *kernel.SyscallControl, error) {
+ addr := args[0].Pointer()
+ flags := args[1].Uint()
+ mode := args[2].ModeT()
+ return openat(t, linux.AT_FDCWD, addr, flags, mode)
+}
+
+// Openat implements Linux syscall openat(2).
+func Openat(t *kernel.Task, args arch.SyscallArguments) (uintptr, *kernel.SyscallControl, error) {
+ dirfd := args[0].Int()
+ addr := args[1].Pointer()
+ flags := args[2].Uint()
+ mode := args[3].ModeT()
+ return openat(t, dirfd, addr, flags, mode)
+}
+
+// Creat implements Linux syscall creat(2).
+func Creat(t *kernel.Task, args arch.SyscallArguments) (uintptr, *kernel.SyscallControl, error) {
+ addr := args[0].Pointer()
+ mode := args[1].ModeT()
+ return openat(t, linux.AT_FDCWD, addr, linux.O_WRONLY|linux.O_CREAT|linux.O_TRUNC, mode)
+}
+
+func openat(t *kernel.Task, dirfd int32, pathAddr usermem.Addr, flags uint32, mode uint) (uintptr, *kernel.SyscallControl, error) {
+ path, err := copyInPath(t, pathAddr)
+ if err != nil {
+ return 0, nil, err
+ }
+ tpop, err := getTaskPathOperation(t, dirfd, path, disallowEmptyPath, shouldFollowFinalSymlink(flags&linux.O_NOFOLLOW == 0))
+ if err != nil {
+ return 0, nil, err
+ }
+ defer tpop.Release()
+
+ file, err := t.Kernel().VFS().OpenAt(t, t.Credentials(), &tpop.pop, &vfs.OpenOptions{
+ Flags: flags,
+ Mode: linux.FileMode(mode & (0777 | linux.S_ISUID | linux.S_ISGID | linux.S_ISVTX) &^ t.FSContext().Umask()),
+ })
+ if err != nil {
+ return 0, nil, err
+ }
+ defer file.DecRef()
+
+ fd, err := t.NewFDFromVFS2(0, file, kernel.FDFlags{
+ CloseOnExec: flags&linux.O_CLOEXEC != 0,
+ })
+ return uintptr(fd), nil, err
+}
+
+// Rename implements Linux syscall rename(2).
+func Rename(t *kernel.Task, args arch.SyscallArguments) (uintptr, *kernel.SyscallControl, error) {
+ oldpathAddr := args[0].Pointer()
+ newpathAddr := args[1].Pointer()
+ return 0, nil, renameat(t, linux.AT_FDCWD, oldpathAddr, linux.AT_FDCWD, newpathAddr, 0 /* flags */)
+}
+
+// Renameat implements Linux syscall renameat(2).
+func Renameat(t *kernel.Task, args arch.SyscallArguments) (uintptr, *kernel.SyscallControl, error) {
+ olddirfd := args[0].Int()
+ oldpathAddr := args[1].Pointer()
+ newdirfd := args[2].Int()
+ newpathAddr := args[3].Pointer()
+ return 0, nil, renameat(t, olddirfd, oldpathAddr, newdirfd, newpathAddr, 0 /* flags */)
+}
+
+// Renameat2 implements Linux syscall renameat2(2).
+func Renameat2(t *kernel.Task, args arch.SyscallArguments) (uintptr, *kernel.SyscallControl, error) {
+ olddirfd := args[0].Int()
+ oldpathAddr := args[1].Pointer()
+ newdirfd := args[2].Int()
+ newpathAddr := args[3].Pointer()
+ flags := args[4].Uint()
+ return 0, nil, renameat(t, olddirfd, oldpathAddr, newdirfd, newpathAddr, flags)
+}
+
+func renameat(t *kernel.Task, olddirfd int32, oldpathAddr usermem.Addr, newdirfd int32, newpathAddr usermem.Addr, flags uint32) error {
+ oldpath, err := copyInPath(t, oldpathAddr)
+ if err != nil {
+ return err
+ }
+ // "If oldpath refers to a symbolic link, the link is renamed" - rename(2)
+ oldtpop, err := getTaskPathOperation(t, olddirfd, oldpath, disallowEmptyPath, nofollowFinalSymlink)
+ if err != nil {
+ return err
+ }
+ defer oldtpop.Release()
+
+ newpath, err := copyInPath(t, newpathAddr)
+ if err != nil {
+ return err
+ }
+ newtpop, err := getTaskPathOperation(t, newdirfd, newpath, disallowEmptyPath, nofollowFinalSymlink)
+ if err != nil {
+ return err
+ }
+ defer newtpop.Release()
+
+ return t.Kernel().VFS().RenameAt(t, t.Credentials(), &oldtpop.pop, &newtpop.pop, &vfs.RenameOptions{
+ Flags: flags,
+ })
+}
+
+// Rmdir implements Linux syscall rmdir(2).
+func Rmdir(t *kernel.Task, args arch.SyscallArguments) (uintptr, *kernel.SyscallControl, error) {
+ pathAddr := args[0].Pointer()
+ return 0, nil, rmdirat(t, linux.AT_FDCWD, pathAddr)
+}
+
+func rmdirat(t *kernel.Task, dirfd int32, pathAddr usermem.Addr) error {
+ path, err := copyInPath(t, pathAddr)
+ if err != nil {
+ return err
+ }
+ tpop, err := getTaskPathOperation(t, dirfd, path, disallowEmptyPath, followFinalSymlink)
+ if err != nil {
+ return err
+ }
+ defer tpop.Release()
+ return t.Kernel().VFS().RmdirAt(t, t.Credentials(), &tpop.pop)
+}
+
+// Unlink implements Linux syscall unlink(2).
+func Unlink(t *kernel.Task, args arch.SyscallArguments) (uintptr, *kernel.SyscallControl, error) {
+ pathAddr := args[0].Pointer()
+ return 0, nil, unlinkat(t, linux.AT_FDCWD, pathAddr)
+}
+
+func unlinkat(t *kernel.Task, dirfd int32, pathAddr usermem.Addr) error {
+ path, err := copyInPath(t, pathAddr)
+ if err != nil {
+ return err
+ }
+ tpop, err := getTaskPathOperation(t, dirfd, path, disallowEmptyPath, nofollowFinalSymlink)
+ if err != nil {
+ return err
+ }
+ defer tpop.Release()
+ return t.Kernel().VFS().UnlinkAt(t, t.Credentials(), &tpop.pop)
+}
+
+// Unlinkat implements Linux syscall unlinkat(2).
+func Unlinkat(t *kernel.Task, args arch.SyscallArguments) (uintptr, *kernel.SyscallControl, error) {
+ dirfd := args[0].Int()
+ pathAddr := args[1].Pointer()
+ flags := args[2].Int()
+
+ if flags&^linux.AT_REMOVEDIR != 0 {
+ return 0, nil, syserror.EINVAL
+ }
+
+ if flags&linux.AT_REMOVEDIR != 0 {
+ return 0, nil, rmdirat(t, dirfd, pathAddr)
+ }
+ return 0, nil, unlinkat(t, dirfd, pathAddr)
+}
+
+// Symlink implements Linux syscall symlink(2).
+func Symlink(t *kernel.Task, args arch.SyscallArguments) (uintptr, *kernel.SyscallControl, error) {
+ targetAddr := args[0].Pointer()
+ linkpathAddr := args[1].Pointer()
+ return 0, nil, symlinkat(t, targetAddr, linux.AT_FDCWD, linkpathAddr)
+}
+
+// Symlinkat implements Linux syscall symlinkat(2).
+func Symlinkat(t *kernel.Task, args arch.SyscallArguments) (uintptr, *kernel.SyscallControl, error) {
+ targetAddr := args[0].Pointer()
+ newdirfd := args[1].Int()
+ linkpathAddr := args[2].Pointer()
+ return 0, nil, symlinkat(t, targetAddr, newdirfd, linkpathAddr)
+}
+
+func symlinkat(t *kernel.Task, targetAddr usermem.Addr, newdirfd int32, linkpathAddr usermem.Addr) error {
+ target, err := t.CopyInString(targetAddr, linux.PATH_MAX)
+ if err != nil {
+ return err
+ }
+ linkpath, err := copyInPath(t, linkpathAddr)
+ if err != nil {
+ return err
+ }
+ tpop, err := getTaskPathOperation(t, newdirfd, linkpath, disallowEmptyPath, nofollowFinalSymlink)
+ if err != nil {
+ return err
+ }
+ defer tpop.Release()
+ return t.Kernel().VFS().SymlinkAt(t, t.Credentials(), &tpop.pop, target)
+}
diff --git a/pkg/sentry/syscalls/linux/vfs2/fscontext.go b/pkg/sentry/syscalls/linux/vfs2/fscontext.go
new file mode 100644
index 000000000..317409a18
--- /dev/null
+++ b/pkg/sentry/syscalls/linux/vfs2/fscontext.go
@@ -0,0 +1,131 @@
+// Copyright 2020 The gVisor Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+package vfs2
+
+import (
+ "gvisor.dev/gvisor/pkg/abi/linux"
+ "gvisor.dev/gvisor/pkg/fspath"
+ "gvisor.dev/gvisor/pkg/sentry/arch"
+ "gvisor.dev/gvisor/pkg/sentry/kernel"
+ "gvisor.dev/gvisor/pkg/sentry/vfs"
+ "gvisor.dev/gvisor/pkg/syserror"
+)
+
+// Getcwd implements Linux syscall getcwd(2).
+func Getcwd(t *kernel.Task, args arch.SyscallArguments) (uintptr, *kernel.SyscallControl, error) {
+ addr := args[0].Pointer()
+ size := args[1].SizeT()
+
+ root := t.FSContext().RootDirectoryVFS2()
+ wd := t.FSContext().WorkingDirectoryVFS2()
+ s, err := t.Kernel().VFS().PathnameForGetcwd(t, root, wd)
+ root.DecRef()
+ wd.DecRef()
+ if err != nil {
+ return 0, nil, err
+ }
+
+ // Note this is >= because we need a terminator.
+ if uint(len(s)) >= size {
+ return 0, nil, syserror.ERANGE
+ }
+
+ // Construct a byte slice containing a NUL terminator.
+ buf := t.CopyScratchBuffer(len(s) + 1)
+ copy(buf, s)
+ buf[len(buf)-1] = 0
+
+ // Write the pathname slice.
+ n, err := t.CopyOutBytes(addr, buf)
+ if err != nil {
+ return 0, nil, err
+ }
+ return uintptr(n), nil, nil
+}
+
+// Chdir implements Linux syscall chdir(2).
+func Chdir(t *kernel.Task, args arch.SyscallArguments) (uintptr, *kernel.SyscallControl, error) {
+ addr := args[0].Pointer()
+
+ path, err := copyInPath(t, addr)
+ if err != nil {
+ return 0, nil, err
+ }
+ tpop, err := getTaskPathOperation(t, linux.AT_FDCWD, path, disallowEmptyPath, followFinalSymlink)
+ if err != nil {
+ return 0, nil, err
+ }
+ defer tpop.Release()
+
+ vd, err := t.Kernel().VFS().GetDentryAt(t, t.Credentials(), &tpop.pop, &vfs.GetDentryOptions{
+ CheckSearchable: true,
+ })
+ if err != nil {
+ return 0, nil, err
+ }
+ t.FSContext().SetWorkingDirectoryVFS2(vd)
+ vd.DecRef()
+ return 0, nil, nil
+}
+
+// Fchdir implements Linux syscall fchdir(2).
+func Fchdir(t *kernel.Task, args arch.SyscallArguments) (uintptr, *kernel.SyscallControl, error) {
+ fd := args[0].Int()
+
+ tpop, err := getTaskPathOperation(t, fd, fspath.Path{}, allowEmptyPath, nofollowFinalSymlink)
+ if err != nil {
+ return 0, nil, err
+ }
+ defer tpop.Release()
+
+ vd, err := t.Kernel().VFS().GetDentryAt(t, t.Credentials(), &tpop.pop, &vfs.GetDentryOptions{
+ CheckSearchable: true,
+ })
+ if err != nil {
+ return 0, nil, err
+ }
+ t.FSContext().SetWorkingDirectoryVFS2(vd)
+ vd.DecRef()
+ return 0, nil, nil
+}
+
+// Chroot implements Linux syscall chroot(2).
+func Chroot(t *kernel.Task, args arch.SyscallArguments) (uintptr, *kernel.SyscallControl, error) {
+ addr := args[0].Pointer()
+
+ if !t.HasCapability(linux.CAP_SYS_CHROOT) {
+ return 0, nil, syserror.EPERM
+ }
+
+ path, err := copyInPath(t, addr)
+ if err != nil {
+ return 0, nil, err
+ }
+ tpop, err := getTaskPathOperation(t, linux.AT_FDCWD, path, disallowEmptyPath, followFinalSymlink)
+ if err != nil {
+ return 0, nil, err
+ }
+ defer tpop.Release()
+
+ vd, err := t.Kernel().VFS().GetDentryAt(t, t.Credentials(), &tpop.pop, &vfs.GetDentryOptions{
+ CheckSearchable: true,
+ })
+ if err != nil {
+ return 0, nil, err
+ }
+ t.FSContext().SetRootDirectoryVFS2(vd)
+ vd.DecRef()
+ return 0, nil, nil
+}
diff --git a/pkg/sentry/syscalls/linux/vfs2/getdents.go b/pkg/sentry/syscalls/linux/vfs2/getdents.go
new file mode 100644
index 000000000..ddc140b65
--- /dev/null
+++ b/pkg/sentry/syscalls/linux/vfs2/getdents.go
@@ -0,0 +1,149 @@
+// Copyright 2020 The gVisor Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+package vfs2
+
+import (
+ "fmt"
+
+ "gvisor.dev/gvisor/pkg/sentry/arch"
+ "gvisor.dev/gvisor/pkg/sentry/kernel"
+ "gvisor.dev/gvisor/pkg/sentry/vfs"
+ "gvisor.dev/gvisor/pkg/sync"
+ "gvisor.dev/gvisor/pkg/syserror"
+ "gvisor.dev/gvisor/pkg/usermem"
+)
+
+// Getdents implements Linux syscall getdents(2).
+func Getdents(t *kernel.Task, args arch.SyscallArguments) (uintptr, *kernel.SyscallControl, error) {
+ return getdents(t, args, false /* isGetdents64 */)
+}
+
+// Getdents64 implements Linux syscall getdents64(2).
+func Getdents64(t *kernel.Task, args arch.SyscallArguments) (uintptr, *kernel.SyscallControl, error) {
+ return getdents(t, args, true /* isGetdents64 */)
+}
+
+func getdents(t *kernel.Task, args arch.SyscallArguments, isGetdents64 bool) (uintptr, *kernel.SyscallControl, error) {
+ fd := args[0].Int()
+ addr := args[1].Pointer()
+ size := int(args[2].Uint())
+
+ file := t.GetFileVFS2(fd)
+ if file == nil {
+ return 0, nil, syserror.EBADF
+ }
+ defer file.DecRef()
+
+ cb := getGetdentsCallback(t, addr, size, isGetdents64)
+ err := file.IterDirents(t, cb)
+ n := size - cb.remaining
+ putGetdentsCallback(cb)
+ if n == 0 {
+ return 0, nil, err
+ }
+ return uintptr(n), nil, nil
+}
+
+type getdentsCallback struct {
+ t *kernel.Task
+ addr usermem.Addr
+ remaining int
+ isGetdents64 bool
+}
+
+var getdentsCallbackPool = sync.Pool{
+ New: func() interface{} {
+ return &getdentsCallback{}
+ },
+}
+
+func getGetdentsCallback(t *kernel.Task, addr usermem.Addr, size int, isGetdents64 bool) *getdentsCallback {
+ cb := getdentsCallbackPool.Get().(*getdentsCallback)
+ *cb = getdentsCallback{
+ t: t,
+ addr: addr,
+ remaining: size,
+ isGetdents64: isGetdents64,
+ }
+ return cb
+}
+
+func putGetdentsCallback(cb *getdentsCallback) {
+ cb.t = nil
+ getdentsCallbackPool.Put(cb)
+}
+
+// Handle implements vfs.IterDirentsCallback.Handle.
+func (cb *getdentsCallback) Handle(dirent vfs.Dirent) error {
+ var buf []byte
+ if cb.isGetdents64 {
+ // struct linux_dirent64 {
+ // ino64_t d_ino; /* 64-bit inode number */
+ // off64_t d_off; /* 64-bit offset to next structure */
+ // unsigned short d_reclen; /* Size of this dirent */
+ // unsigned char d_type; /* File type */
+ // char d_name[]; /* Filename (null-terminated) */
+ // };
+ size := 8 + 8 + 2 + 1 + 1 + len(dirent.Name)
+ if size < cb.remaining {
+ return syserror.EINVAL
+ }
+ buf = cb.t.CopyScratchBuffer(size)
+ usermem.ByteOrder.PutUint64(buf[0:8], dirent.Ino)
+ usermem.ByteOrder.PutUint64(buf[8:16], uint64(dirent.NextOff))
+ usermem.ByteOrder.PutUint16(buf[16:18], uint16(size))
+ buf[18] = dirent.Type
+ copy(buf[19:], dirent.Name)
+ buf[size-1] = 0 // NUL terminator
+ } else {
+ // struct linux_dirent {
+ // unsigned long d_ino; /* Inode number */
+ // unsigned long d_off; /* Offset to next linux_dirent */
+ // unsigned short d_reclen; /* Length of this linux_dirent */
+ // char d_name[]; /* Filename (null-terminated) */
+ // /* length is actually (d_reclen - 2 -
+ // offsetof(struct linux_dirent, d_name)) */
+ // /*
+ // char pad; // Zero padding byte
+ // char d_type; // File type (only since Linux
+ // // 2.6.4); offset is (d_reclen - 1)
+ // */
+ // };
+ if cb.t.Arch().Width() != 8 {
+ panic(fmt.Sprintf("unsupported sizeof(unsigned long): %d", cb.t.Arch().Width()))
+ }
+ size := 8 + 8 + 2 + 1 + 1 + 1 + len(dirent.Name)
+ if size < cb.remaining {
+ return syserror.EINVAL
+ }
+ buf = cb.t.CopyScratchBuffer(size)
+ usermem.ByteOrder.PutUint64(buf[0:8], dirent.Ino)
+ usermem.ByteOrder.PutUint64(buf[8:16], uint64(dirent.NextOff))
+ usermem.ByteOrder.PutUint16(buf[16:18], uint16(size))
+ copy(buf[18:], dirent.Name)
+ buf[size-3] = 0 // NUL terminator
+ buf[size-2] = 0 // zero padding byte
+ buf[size-1] = dirent.Type
+ }
+ n, err := cb.t.CopyOutBytes(cb.addr, buf)
+ if err != nil {
+ // Don't report partially-written dirents by advancing cb.addr or
+ // cb.remaining.
+ return err
+ }
+ cb.addr += usermem.Addr(n)
+ cb.remaining -= n
+ return nil
+}
diff --git a/pkg/sentry/syscalls/linux/vfs2/ioctl.go b/pkg/sentry/syscalls/linux/vfs2/ioctl.go
new file mode 100644
index 000000000..5a2418da9
--- /dev/null
+++ b/pkg/sentry/syscalls/linux/vfs2/ioctl.go
@@ -0,0 +1,35 @@
+// Copyright 2020 The gVisor Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+package vfs2
+
+import (
+ "gvisor.dev/gvisor/pkg/sentry/arch"
+ "gvisor.dev/gvisor/pkg/sentry/kernel"
+ "gvisor.dev/gvisor/pkg/syserror"
+)
+
+// Ioctl implements Linux syscall ioctl(2).
+func Ioctl(t *kernel.Task, args arch.SyscallArguments) (uintptr, *kernel.SyscallControl, error) {
+ fd := args[0].Int()
+
+ file := t.GetFileVFS2(fd)
+ if file == nil {
+ return 0, nil, syserror.EBADF
+ }
+ defer file.DecRef()
+
+ ret, err := file.Ioctl(t, t.MemoryManager(), args)
+ return ret, nil, err
+}
diff --git a/pkg/sentry/syscalls/linux/vfs2/linux64_override_amd64.go b/pkg/sentry/syscalls/linux/vfs2/linux64_override_amd64.go
index e0ac32b33..7d220bc20 100644
--- a/pkg/sentry/syscalls/linux/vfs2/linux64_override_amd64.go
+++ b/pkg/sentry/syscalls/linux/vfs2/linux64_override_amd64.go
@@ -12,6 +12,8 @@
// See the License for the specific language governing permissions and
// limitations under the License.
+// +build amd64
+
package vfs2
import (
@@ -22,110 +24,142 @@ import (
// Override syscall table to add syscalls implementations from this package.
func Override(table map[uintptr]kernel.Syscall) {
table[0] = syscalls.Supported("read", Read)
-
- // Remove syscalls that haven't been converted yet. It's better to get ENOSYS
- // rather than a SIGSEGV deep in the stack.
- delete(table, 1) // write
- delete(table, 2) // open
- delete(table, 3) // close
- delete(table, 4) // stat
- delete(table, 5) // fstat
- delete(table, 6) // lstat
- delete(table, 7) // poll
- delete(table, 8) // lseek
- delete(table, 9) // mmap
- delete(table, 16) // ioctl
- delete(table, 17) // pread64
- delete(table, 18) // pwrite64
- delete(table, 19) // readv
- delete(table, 20) // writev
- delete(table, 21) // access
- delete(table, 22) // pipe
- delete(table, 32) // dup
- delete(table, 33) // dup2
- delete(table, 40) // sendfile
- delete(table, 59) // execve
- delete(table, 72) // fcntl
- delete(table, 73) // flock
- delete(table, 74) // fsync
- delete(table, 75) // fdatasync
- delete(table, 76) // truncate
- delete(table, 77) // ftruncate
- delete(table, 78) // getdents
- delete(table, 79) // getcwd
- delete(table, 80) // chdir
- delete(table, 81) // fchdir
- delete(table, 82) // rename
- delete(table, 83) // mkdir
- delete(table, 84) // rmdir
- delete(table, 85) // creat
- delete(table, 86) // link
- delete(table, 87) // unlink
- delete(table, 88) // symlink
- delete(table, 89) // readlink
- delete(table, 90) // chmod
- delete(table, 91) // fchmod
- delete(table, 92) // chown
- delete(table, 93) // fchown
- delete(table, 94) // lchown
- delete(table, 133) // mknod
- delete(table, 137) // statfs
- delete(table, 138) // fstatfs
- delete(table, 161) // chroot
- delete(table, 162) // sync
+ table[1] = syscalls.Supported("write", Write)
+ table[2] = syscalls.Supported("open", Open)
+ table[3] = syscalls.Supported("close", Close)
+ table[4] = syscalls.Supported("stat", Stat)
+ table[5] = syscalls.Supported("fstat", Fstat)
+ table[6] = syscalls.Supported("lstat", Lstat)
+ table[7] = syscalls.Supported("poll", Poll)
+ table[8] = syscalls.Supported("lseek", Lseek)
+ table[9] = syscalls.Supported("mmap", Mmap)
+ table[16] = syscalls.Supported("ioctl", Ioctl)
+ table[17] = syscalls.Supported("pread64", Pread64)
+ table[18] = syscalls.Supported("pwrite64", Pwrite64)
+ table[19] = syscalls.Supported("readv", Readv)
+ table[20] = syscalls.Supported("writev", Writev)
+ table[21] = syscalls.Supported("access", Access)
+ delete(table, 22) // pipe
+ table[23] = syscalls.Supported("select", Select)
+ table[32] = syscalls.Supported("dup", Dup)
+ table[33] = syscalls.Supported("dup2", Dup2)
+ delete(table, 40) // sendfile
+ delete(table, 41) // socket
+ delete(table, 42) // connect
+ delete(table, 43) // accept
+ delete(table, 44) // sendto
+ delete(table, 45) // recvfrom
+ delete(table, 46) // sendmsg
+ delete(table, 47) // recvmsg
+ delete(table, 48) // shutdown
+ delete(table, 49) // bind
+ delete(table, 50) // listen
+ delete(table, 51) // getsockname
+ delete(table, 52) // getpeername
+ delete(table, 53) // socketpair
+ delete(table, 54) // setsockopt
+ delete(table, 55) // getsockopt
+ table[59] = syscalls.Supported("execve", Execve)
+ table[72] = syscalls.Supported("fcntl", Fcntl)
+ delete(table, 73) // flock
+ table[74] = syscalls.Supported("fsync", Fsync)
+ table[75] = syscalls.Supported("fdatasync", Fdatasync)
+ table[76] = syscalls.Supported("truncate", Truncate)
+ table[77] = syscalls.Supported("ftruncate", Ftruncate)
+ table[78] = syscalls.Supported("getdents", Getdents)
+ table[79] = syscalls.Supported("getcwd", Getcwd)
+ table[80] = syscalls.Supported("chdir", Chdir)
+ table[81] = syscalls.Supported("fchdir", Fchdir)
+ table[82] = syscalls.Supported("rename", Rename)
+ table[83] = syscalls.Supported("mkdir", Mkdir)
+ table[84] = syscalls.Supported("rmdir", Rmdir)
+ table[85] = syscalls.Supported("creat", Creat)
+ table[86] = syscalls.Supported("link", Link)
+ table[87] = syscalls.Supported("unlink", Unlink)
+ table[88] = syscalls.Supported("symlink", Symlink)
+ table[89] = syscalls.Supported("readlink", Readlink)
+ table[90] = syscalls.Supported("chmod", Chmod)
+ table[91] = syscalls.Supported("fchmod", Fchmod)
+ table[92] = syscalls.Supported("chown", Chown)
+ table[93] = syscalls.Supported("fchown", Fchown)
+ table[94] = syscalls.Supported("lchown", Lchown)
+ table[132] = syscalls.Supported("utime", Utime)
+ table[133] = syscalls.Supported("mknod", Mknod)
+ table[137] = syscalls.Supported("statfs", Statfs)
+ table[138] = syscalls.Supported("fstatfs", Fstatfs)
+ table[161] = syscalls.Supported("chroot", Chroot)
+ table[162] = syscalls.Supported("sync", Sync)
delete(table, 165) // mount
delete(table, 166) // umount2
- delete(table, 172) // iopl
- delete(table, 173) // ioperm
delete(table, 187) // readahead
- delete(table, 188) // setxattr
- delete(table, 189) // lsetxattr
- delete(table, 190) // fsetxattr
- delete(table, 191) // getxattr
- delete(table, 192) // lgetxattr
- delete(table, 193) // fgetxattr
+ table[188] = syscalls.Supported("setxattr", Setxattr)
+ table[189] = syscalls.Supported("lsetxattr", Lsetxattr)
+ table[190] = syscalls.Supported("fsetxattr", Fsetxattr)
+ table[191] = syscalls.Supported("getxattr", Getxattr)
+ table[192] = syscalls.Supported("lgetxattr", Lgetxattr)
+ table[193] = syscalls.Supported("fgetxattr", Fgetxattr)
+ table[194] = syscalls.Supported("listxattr", Listxattr)
+ table[195] = syscalls.Supported("llistxattr", Llistxattr)
+ table[196] = syscalls.Supported("flistxattr", Flistxattr)
+ table[197] = syscalls.Supported("removexattr", Removexattr)
+ table[198] = syscalls.Supported("lremovexattr", Lremovexattr)
+ table[199] = syscalls.Supported("fremovexattr", Fremovexattr)
delete(table, 206) // io_setup
delete(table, 207) // io_destroy
delete(table, 208) // io_getevents
delete(table, 209) // io_submit
delete(table, 210) // io_cancel
- delete(table, 213) // epoll_create
- delete(table, 214) // epoll_ctl_old
- delete(table, 215) // epoll_wait_old
- delete(table, 216) // remap_file_pages
- delete(table, 217) // getdents64
- delete(table, 232) // epoll_wait
- delete(table, 233) // epoll_ctl
+ table[213] = syscalls.Supported("epoll_create", EpollCreate)
+ table[217] = syscalls.Supported("getdents64", Getdents64)
+ delete(table, 221) // fdavise64
+ table[232] = syscalls.Supported("epoll_wait", EpollWait)
+ table[233] = syscalls.Supported("epoll_ctl", EpollCtl)
+ table[235] = syscalls.Supported("utimes", Utimes)
delete(table, 253) // inotify_init
delete(table, 254) // inotify_add_watch
delete(table, 255) // inotify_rm_watch
- delete(table, 257) // openat
- delete(table, 258) // mkdirat
- delete(table, 259) // mknodat
- delete(table, 260) // fchownat
- delete(table, 261) // futimesat
- delete(table, 262) // fstatat
- delete(table, 263) // unlinkat
- delete(table, 264) // renameat
- delete(table, 265) // linkat
- delete(table, 266) // symlinkat
- delete(table, 267) // readlinkat
- delete(table, 268) // fchmodat
- delete(table, 269) // faccessat
- delete(table, 270) // pselect
- delete(table, 271) // ppoll
+ table[257] = syscalls.Supported("openat", Openat)
+ table[258] = syscalls.Supported("mkdirat", Mkdirat)
+ table[259] = syscalls.Supported("mknodat", Mknodat)
+ table[260] = syscalls.Supported("fchownat", Fchownat)
+ table[261] = syscalls.Supported("futimens", Futimens)
+ table[262] = syscalls.Supported("newfstatat", Newfstatat)
+ table[263] = syscalls.Supported("unlinkat", Unlinkat)
+ table[264] = syscalls.Supported("renameat", Renameat)
+ table[265] = syscalls.Supported("linkat", Linkat)
+ table[266] = syscalls.Supported("symlinkat", Symlinkat)
+ table[267] = syscalls.Supported("readlinkat", Readlinkat)
+ table[268] = syscalls.Supported("fchmodat", Fchmodat)
+ table[269] = syscalls.Supported("faccessat", Faccessat)
+ table[270] = syscalls.Supported("pselect", Pselect)
+ table[271] = syscalls.Supported("ppoll", Ppoll)
+ delete(table, 275) // splice
+ delete(table, 276) // tee
+ table[277] = syscalls.Supported("sync_file_range", SyncFileRange)
+ table[280] = syscalls.Supported("utimensat", Utimensat)
+ table[281] = syscalls.Supported("epoll_pwait", EpollPwait)
+ delete(table, 282) // signalfd
+ delete(table, 283) // timerfd_create
+ delete(table, 284) // eventfd
delete(table, 285) // fallocate
- delete(table, 291) // epoll_create1
- delete(table, 292) // dup3
+ delete(table, 286) // timerfd_settime
+ delete(table, 287) // timerfd_gettime
+ delete(table, 288) // accept4
+ delete(table, 289) // signalfd4
+ delete(table, 290) // eventfd2
+ table[291] = syscalls.Supported("epoll_create1", EpollCreate1)
+ table[292] = syscalls.Supported("dup3", Dup3)
delete(table, 293) // pipe2
delete(table, 294) // inotify_init1
- delete(table, 295) // preadv
- delete(table, 296) // pwritev
- delete(table, 306) // syncfs
- delete(table, 316) // renameat2
+ table[295] = syscalls.Supported("preadv", Preadv)
+ table[296] = syscalls.Supported("pwritev", Pwritev)
+ delete(table, 299) // recvmmsg
+ table[306] = syscalls.Supported("syncfs", Syncfs)
+ delete(table, 307) // sendmmsg
+ table[316] = syscalls.Supported("renameat2", Renameat2)
delete(table, 319) // memfd_create
- delete(table, 322) // execveat
- delete(table, 327) // preadv2
- delete(table, 328) // pwritev2
- delete(table, 332) // statx
+ table[322] = syscalls.Supported("execveat", Execveat)
+ table[327] = syscalls.Supported("preadv2", Preadv2)
+ table[328] = syscalls.Supported("pwritev2", Pwritev2)
+ table[332] = syscalls.Supported("statx", Statx)
}
diff --git a/pkg/sentry/syscalls/linux/vfs2/linux64_override_arm64.go b/pkg/sentry/syscalls/linux/vfs2/linux64_override_arm64.go
index 6af5c400f..a6b367468 100644
--- a/pkg/sentry/syscalls/linux/vfs2/linux64_override_arm64.go
+++ b/pkg/sentry/syscalls/linux/vfs2/linux64_override_arm64.go
@@ -12,6 +12,8 @@
// See the License for the specific language governing permissions and
// limitations under the License.
+// +build arm64
+
package vfs2
import (
diff --git a/pkg/sentry/syscalls/linux/vfs2/mmap.go b/pkg/sentry/syscalls/linux/vfs2/mmap.go
new file mode 100644
index 000000000..60a43f0a0
--- /dev/null
+++ b/pkg/sentry/syscalls/linux/vfs2/mmap.go
@@ -0,0 +1,92 @@
+// Copyright 2020 The gVisor Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+package vfs2
+
+import (
+ "gvisor.dev/gvisor/pkg/abi/linux"
+ "gvisor.dev/gvisor/pkg/sentry/arch"
+ "gvisor.dev/gvisor/pkg/sentry/kernel"
+ "gvisor.dev/gvisor/pkg/sentry/memmap"
+ "gvisor.dev/gvisor/pkg/syserror"
+ "gvisor.dev/gvisor/pkg/usermem"
+)
+
+// Mmap implements Linux syscall mmap(2).
+func Mmap(t *kernel.Task, args arch.SyscallArguments) (uintptr, *kernel.SyscallControl, error) {
+ prot := args[2].Int()
+ flags := args[3].Int()
+ fd := args[4].Int()
+ fixed := flags&linux.MAP_FIXED != 0
+ private := flags&linux.MAP_PRIVATE != 0
+ shared := flags&linux.MAP_SHARED != 0
+ anon := flags&linux.MAP_ANONYMOUS != 0
+ map32bit := flags&linux.MAP_32BIT != 0
+
+ // Require exactly one of MAP_PRIVATE and MAP_SHARED.
+ if private == shared {
+ return 0, nil, syserror.EINVAL
+ }
+
+ opts := memmap.MMapOpts{
+ Length: args[1].Uint64(),
+ Offset: args[5].Uint64(),
+ Addr: args[0].Pointer(),
+ Fixed: fixed,
+ Unmap: fixed,
+ Map32Bit: map32bit,
+ Private: private,
+ Perms: usermem.AccessType{
+ Read: linux.PROT_READ&prot != 0,
+ Write: linux.PROT_WRITE&prot != 0,
+ Execute: linux.PROT_EXEC&prot != 0,
+ },
+ MaxPerms: usermem.AnyAccess,
+ GrowsDown: linux.MAP_GROWSDOWN&flags != 0,
+ Precommit: linux.MAP_POPULATE&flags != 0,
+ }
+ if linux.MAP_LOCKED&flags != 0 {
+ opts.MLockMode = memmap.MLockEager
+ }
+ defer func() {
+ if opts.MappingIdentity != nil {
+ opts.MappingIdentity.DecRef()
+ }
+ }()
+
+ if !anon {
+ // Convert the passed FD to a file reference.
+ file := t.GetFileVFS2(fd)
+ if file == nil {
+ return 0, nil, syserror.EBADF
+ }
+ defer file.DecRef()
+
+ // mmap unconditionally requires that the FD is readable.
+ if !file.IsReadable() {
+ return 0, nil, syserror.EACCES
+ }
+ // MAP_SHARED requires that the FD be writable for PROT_WRITE.
+ if shared && !file.IsWritable() {
+ opts.MaxPerms.Write = false
+ }
+
+ if err := file.ConfigureMMap(t, &opts); err != nil {
+ return 0, nil, err
+ }
+ }
+
+ rv, err := t.MemoryManager().MMap(t, opts)
+ return uintptr(rv), nil, err
+}
diff --git a/pkg/sentry/syscalls/linux/vfs2/path.go b/pkg/sentry/syscalls/linux/vfs2/path.go
new file mode 100644
index 000000000..97da6c647
--- /dev/null
+++ b/pkg/sentry/syscalls/linux/vfs2/path.go
@@ -0,0 +1,94 @@
+// Copyright 2020 The gVisor Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+package vfs2
+
+import (
+ "gvisor.dev/gvisor/pkg/abi/linux"
+ "gvisor.dev/gvisor/pkg/fspath"
+ "gvisor.dev/gvisor/pkg/sentry/kernel"
+ "gvisor.dev/gvisor/pkg/sentry/vfs"
+ "gvisor.dev/gvisor/pkg/syserror"
+ "gvisor.dev/gvisor/pkg/usermem"
+)
+
+func copyInPath(t *kernel.Task, addr usermem.Addr) (fspath.Path, error) {
+ pathname, err := t.CopyInString(addr, linux.PATH_MAX)
+ if err != nil {
+ return fspath.Path{}, err
+ }
+ return fspath.Parse(pathname), nil
+}
+
+type taskPathOperation struct {
+ pop vfs.PathOperation
+ haveStartRef bool
+}
+
+func getTaskPathOperation(t *kernel.Task, dirfd int32, path fspath.Path, shouldAllowEmptyPath shouldAllowEmptyPath, shouldFollowFinalSymlink shouldFollowFinalSymlink) (taskPathOperation, error) {
+ root := t.FSContext().RootDirectoryVFS2()
+ start := root
+ haveStartRef := false
+ if !path.Absolute {
+ if !path.HasComponents() && !bool(shouldAllowEmptyPath) {
+ root.DecRef()
+ return taskPathOperation{}, syserror.ENOENT
+ }
+ if dirfd == linux.AT_FDCWD {
+ start = t.FSContext().WorkingDirectoryVFS2()
+ haveStartRef = true
+ } else {
+ dirfile := t.GetFileVFS2(dirfd)
+ if dirfile == nil {
+ root.DecRef()
+ return taskPathOperation{}, syserror.EBADF
+ }
+ start = dirfile.VirtualDentry()
+ start.IncRef()
+ haveStartRef = true
+ dirfile.DecRef()
+ }
+ }
+ return taskPathOperation{
+ pop: vfs.PathOperation{
+ Root: root,
+ Start: start,
+ Path: path,
+ FollowFinalSymlink: bool(shouldFollowFinalSymlink),
+ },
+ haveStartRef: haveStartRef,
+ }, nil
+}
+
+func (tpop *taskPathOperation) Release() {
+ tpop.pop.Root.DecRef()
+ if tpop.haveStartRef {
+ tpop.pop.Start.DecRef()
+ tpop.haveStartRef = false
+ }
+}
+
+type shouldAllowEmptyPath bool
+
+const (
+ disallowEmptyPath shouldAllowEmptyPath = false
+ allowEmptyPath shouldAllowEmptyPath = true
+)
+
+type shouldFollowFinalSymlink bool
+
+const (
+ nofollowFinalSymlink shouldFollowFinalSymlink = false
+ followFinalSymlink shouldFollowFinalSymlink = true
+)
diff --git a/pkg/sentry/syscalls/linux/vfs2/poll.go b/pkg/sentry/syscalls/linux/vfs2/poll.go
new file mode 100644
index 000000000..dbf4882da
--- /dev/null
+++ b/pkg/sentry/syscalls/linux/vfs2/poll.go
@@ -0,0 +1,584 @@
+// Copyright 2020 The gVisor Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+package vfs2
+
+import (
+ "fmt"
+ "time"
+
+ "gvisor.dev/gvisor/pkg/abi/linux"
+ "gvisor.dev/gvisor/pkg/sentry/arch"
+ "gvisor.dev/gvisor/pkg/sentry/kernel"
+ ktime "gvisor.dev/gvisor/pkg/sentry/kernel/time"
+ "gvisor.dev/gvisor/pkg/sentry/limits"
+ "gvisor.dev/gvisor/pkg/sentry/vfs"
+ "gvisor.dev/gvisor/pkg/syserror"
+ "gvisor.dev/gvisor/pkg/usermem"
+ "gvisor.dev/gvisor/pkg/waiter"
+)
+
+// fileCap is the maximum allowable files for poll & select. This has no
+// equivalent in Linux; it exists in gVisor since allocation failure in Go is
+// unrecoverable.
+const fileCap = 1024 * 1024
+
+// Masks for "readable", "writable", and "exceptional" events as defined by
+// select(2).
+const (
+ // selectReadEvents is analogous to the Linux kernel's
+ // fs/select.c:POLLIN_SET.
+ selectReadEvents = linux.POLLIN | linux.POLLHUP | linux.POLLERR
+
+ // selectWriteEvents is analogous to the Linux kernel's
+ // fs/select.c:POLLOUT_SET.
+ selectWriteEvents = linux.POLLOUT | linux.POLLERR
+
+ // selectExceptEvents is analogous to the Linux kernel's
+ // fs/select.c:POLLEX_SET.
+ selectExceptEvents = linux.POLLPRI
+)
+
+// pollState tracks the associated file description and waiter of a PollFD.
+type pollState struct {
+ file *vfs.FileDescription
+ waiter waiter.Entry
+}
+
+// initReadiness gets the current ready mask for the file represented by the FD
+// stored in pfd.FD. If a channel is passed in, the waiter entry in "state" is
+// used to register with the file for event notifications, and a reference to
+// the file is stored in "state".
+func initReadiness(t *kernel.Task, pfd *linux.PollFD, state *pollState, ch chan struct{}) {
+ if pfd.FD < 0 {
+ pfd.REvents = 0
+ return
+ }
+
+ file := t.GetFileVFS2(pfd.FD)
+ if file == nil {
+ pfd.REvents = linux.POLLNVAL
+ return
+ }
+
+ if ch == nil {
+ defer file.DecRef()
+ } else {
+ state.file = file
+ state.waiter, _ = waiter.NewChannelEntry(ch)
+ file.EventRegister(&state.waiter, waiter.EventMaskFromLinux(uint32(pfd.Events)))
+ }
+
+ r := file.Readiness(waiter.EventMaskFromLinux(uint32(pfd.Events)))
+ pfd.REvents = int16(r.ToLinux()) & pfd.Events
+}
+
+// releaseState releases all the pollState in "state".
+func releaseState(state []pollState) {
+ for i := range state {
+ if state[i].file != nil {
+ state[i].file.EventUnregister(&state[i].waiter)
+ state[i].file.DecRef()
+ }
+ }
+}
+
+// pollBlock polls the PollFDs in "pfd" with a bounded time specified in "timeout"
+// when "timeout" is greater than zero.
+//
+// pollBlock returns the remaining timeout, which is always 0 on a timeout; and 0 or
+// positive if interrupted by a signal.
+func pollBlock(t *kernel.Task, pfd []linux.PollFD, timeout time.Duration) (time.Duration, uintptr, error) {
+ var ch chan struct{}
+ if timeout != 0 {
+ ch = make(chan struct{}, 1)
+ }
+
+ // Register for event notification in the files involved if we may
+ // block (timeout not zero). Once we find a file that has a non-zero
+ // result, we stop registering for events but still go through all files
+ // to get their ready masks.
+ state := make([]pollState, len(pfd))
+ defer releaseState(state)
+ n := uintptr(0)
+ for i := range pfd {
+ initReadiness(t, &pfd[i], &state[i], ch)
+ if pfd[i].REvents != 0 {
+ n++
+ ch = nil
+ }
+ }
+
+ if timeout == 0 {
+ return timeout, n, nil
+ }
+
+ haveTimeout := timeout >= 0
+
+ for n == 0 {
+ var err error
+ // Wait for a notification.
+ timeout, err = t.BlockWithTimeout(ch, haveTimeout, timeout)
+ if err != nil {
+ if err == syserror.ETIMEDOUT {
+ err = nil
+ }
+ return timeout, 0, err
+ }
+
+ // We got notified, count how many files are ready. If none,
+ // then this was a spurious notification, and we just go back
+ // to sleep with the remaining timeout.
+ for i := range state {
+ if state[i].file == nil {
+ continue
+ }
+
+ r := state[i].file.Readiness(waiter.EventMaskFromLinux(uint32(pfd[i].Events)))
+ rl := int16(r.ToLinux()) & pfd[i].Events
+ if rl != 0 {
+ pfd[i].REvents = rl
+ n++
+ }
+ }
+ }
+
+ return timeout, n, nil
+}
+
+// copyInPollFDs copies an array of struct pollfd unless nfds exceeds the max.
+func copyInPollFDs(t *kernel.Task, addr usermem.Addr, nfds uint) ([]linux.PollFD, error) {
+ if uint64(nfds) > t.ThreadGroup().Limits().GetCapped(limits.NumberOfFiles, fileCap) {
+ return nil, syserror.EINVAL
+ }
+
+ pfd := make([]linux.PollFD, nfds)
+ if nfds > 0 {
+ if _, err := t.CopyIn(addr, &pfd); err != nil {
+ return nil, err
+ }
+ }
+
+ return pfd, nil
+}
+
+func doPoll(t *kernel.Task, addr usermem.Addr, nfds uint, timeout time.Duration) (time.Duration, uintptr, error) {
+ pfd, err := copyInPollFDs(t, addr, nfds)
+ if err != nil {
+ return timeout, 0, err
+ }
+
+ // Compatibility warning: Linux adds POLLHUP and POLLERR just before
+ // polling, in fs/select.c:do_pollfd(). Since pfd is copied out after
+ // polling, changing event masks here is an application-visible difference.
+ // (Linux also doesn't copy out event masks at all, only revents.)
+ for i := range pfd {
+ pfd[i].Events |= linux.POLLHUP | linux.POLLERR
+ }
+ remainingTimeout, n, err := pollBlock(t, pfd, timeout)
+ err = syserror.ConvertIntr(err, syserror.EINTR)
+
+ // The poll entries are copied out regardless of whether
+ // any are set or not. This aligns with the Linux behavior.
+ if nfds > 0 && err == nil {
+ if _, err := t.CopyOut(addr, pfd); err != nil {
+ return remainingTimeout, 0, err
+ }
+ }
+
+ return remainingTimeout, n, err
+}
+
+// CopyInFDSet copies an fd set from select(2)/pselect(2).
+func CopyInFDSet(t *kernel.Task, addr usermem.Addr, nBytes, nBitsInLastPartialByte int) ([]byte, error) {
+ set := make([]byte, nBytes)
+
+ if addr != 0 {
+ if _, err := t.CopyIn(addr, &set); err != nil {
+ return nil, err
+ }
+ // If we only use part of the last byte, mask out the extraneous bits.
+ //
+ // N.B. This only works on little-endian architectures.
+ if nBitsInLastPartialByte != 0 {
+ set[nBytes-1] &^= byte(0xff) << nBitsInLastPartialByte
+ }
+ }
+ return set, nil
+}
+
+func doSelect(t *kernel.Task, nfds int, readFDs, writeFDs, exceptFDs usermem.Addr, timeout time.Duration) (uintptr, error) {
+ if nfds < 0 || nfds > fileCap {
+ return 0, syserror.EINVAL
+ }
+
+ // Calculate the size of the fd sets (one bit per fd).
+ nBytes := (nfds + 7) / 8
+ nBitsInLastPartialByte := nfds % 8
+
+ // Capture all the provided input vectors.
+ r, err := CopyInFDSet(t, readFDs, nBytes, nBitsInLastPartialByte)
+ if err != nil {
+ return 0, err
+ }
+ w, err := CopyInFDSet(t, writeFDs, nBytes, nBitsInLastPartialByte)
+ if err != nil {
+ return 0, err
+ }
+ e, err := CopyInFDSet(t, exceptFDs, nBytes, nBitsInLastPartialByte)
+ if err != nil {
+ return 0, err
+ }
+
+ // Count how many FDs are actually being requested so that we can build
+ // a PollFD array.
+ fdCount := 0
+ for i := 0; i < nBytes; i++ {
+ v := r[i] | w[i] | e[i]
+ for v != 0 {
+ v &= (v - 1)
+ fdCount++
+ }
+ }
+
+ // Build the PollFD array.
+ pfd := make([]linux.PollFD, 0, fdCount)
+ var fd int32
+ for i := 0; i < nBytes; i++ {
+ rV, wV, eV := r[i], w[i], e[i]
+ v := rV | wV | eV
+ m := byte(1)
+ for j := 0; j < 8; j++ {
+ if (v & m) != 0 {
+ // Make sure the fd is valid and decrement the reference
+ // immediately to ensure we don't leak. Note, another thread
+ // might be about to close fd. This is racy, but that's
+ // OK. Linux is racy in the same way.
+ file := t.GetFileVFS2(fd)
+ if file == nil {
+ return 0, syserror.EBADF
+ }
+ file.DecRef()
+
+ var mask int16
+ if (rV & m) != 0 {
+ mask |= selectReadEvents
+ }
+
+ if (wV & m) != 0 {
+ mask |= selectWriteEvents
+ }
+
+ if (eV & m) != 0 {
+ mask |= selectExceptEvents
+ }
+
+ pfd = append(pfd, linux.PollFD{
+ FD: fd,
+ Events: mask,
+ })
+ }
+
+ fd++
+ m <<= 1
+ }
+ }
+
+ // Do the syscall, then count the number of bits set.
+ if _, _, err = pollBlock(t, pfd, timeout); err != nil {
+ return 0, syserror.ConvertIntr(err, syserror.EINTR)
+ }
+
+ // r, w, and e are currently event mask bitsets; unset bits corresponding
+ // to events that *didn't* occur.
+ bitSetCount := uintptr(0)
+ for idx := range pfd {
+ events := pfd[idx].REvents
+ i, j := pfd[idx].FD/8, uint(pfd[idx].FD%8)
+ m := byte(1) << j
+ if r[i]&m != 0 {
+ if (events & selectReadEvents) != 0 {
+ bitSetCount++
+ } else {
+ r[i] &^= m
+ }
+ }
+ if w[i]&m != 0 {
+ if (events & selectWriteEvents) != 0 {
+ bitSetCount++
+ } else {
+ w[i] &^= m
+ }
+ }
+ if e[i]&m != 0 {
+ if (events & selectExceptEvents) != 0 {
+ bitSetCount++
+ } else {
+ e[i] &^= m
+ }
+ }
+ }
+
+ // Copy updated vectors back.
+ if readFDs != 0 {
+ if _, err := t.CopyOut(readFDs, r); err != nil {
+ return 0, err
+ }
+ }
+
+ if writeFDs != 0 {
+ if _, err := t.CopyOut(writeFDs, w); err != nil {
+ return 0, err
+ }
+ }
+
+ if exceptFDs != 0 {
+ if _, err := t.CopyOut(exceptFDs, e); err != nil {
+ return 0, err
+ }
+ }
+
+ return bitSetCount, nil
+}
+
+// timeoutRemaining returns the amount of time remaining for the specified
+// timeout or 0 if it has elapsed.
+//
+// startNs must be from CLOCK_MONOTONIC.
+func timeoutRemaining(t *kernel.Task, startNs ktime.Time, timeout time.Duration) time.Duration {
+ now := t.Kernel().MonotonicClock().Now()
+ remaining := timeout - now.Sub(startNs)
+ if remaining < 0 {
+ remaining = 0
+ }
+ return remaining
+}
+
+// copyOutTimespecRemaining copies the time remaining in timeout to timespecAddr.
+//
+// startNs must be from CLOCK_MONOTONIC.
+func copyOutTimespecRemaining(t *kernel.Task, startNs ktime.Time, timeout time.Duration, timespecAddr usermem.Addr) error {
+ if timeout <= 0 {
+ return nil
+ }
+ remaining := timeoutRemaining(t, startNs, timeout)
+ tsRemaining := linux.NsecToTimespec(remaining.Nanoseconds())
+ return tsRemaining.CopyOut(t, timespecAddr)
+}
+
+// copyOutTimevalRemaining copies the time remaining in timeout to timevalAddr.
+//
+// startNs must be from CLOCK_MONOTONIC.
+func copyOutTimevalRemaining(t *kernel.Task, startNs ktime.Time, timeout time.Duration, timevalAddr usermem.Addr) error {
+ if timeout <= 0 {
+ return nil
+ }
+ remaining := timeoutRemaining(t, startNs, timeout)
+ tvRemaining := linux.NsecToTimeval(remaining.Nanoseconds())
+ return tvRemaining.CopyOut(t, timevalAddr)
+}
+
+// pollRestartBlock encapsulates the state required to restart poll(2) via
+// restart_syscall(2).
+//
+// +stateify savable
+type pollRestartBlock struct {
+ pfdAddr usermem.Addr
+ nfds uint
+ timeout time.Duration
+}
+
+// Restart implements kernel.SyscallRestartBlock.Restart.
+func (p *pollRestartBlock) Restart(t *kernel.Task) (uintptr, error) {
+ return poll(t, p.pfdAddr, p.nfds, p.timeout)
+}
+
+func poll(t *kernel.Task, pfdAddr usermem.Addr, nfds uint, timeout time.Duration) (uintptr, error) {
+ remainingTimeout, n, err := doPoll(t, pfdAddr, nfds, timeout)
+ // On an interrupt poll(2) is restarted with the remaining timeout.
+ if err == syserror.EINTR {
+ t.SetSyscallRestartBlock(&pollRestartBlock{
+ pfdAddr: pfdAddr,
+ nfds: nfds,
+ timeout: remainingTimeout,
+ })
+ return 0, kernel.ERESTART_RESTARTBLOCK
+ }
+ return n, err
+}
+
+// Poll implements linux syscall poll(2).
+func Poll(t *kernel.Task, args arch.SyscallArguments) (uintptr, *kernel.SyscallControl, error) {
+ pfdAddr := args[0].Pointer()
+ nfds := uint(args[1].Uint()) // poll(2) uses unsigned long.
+ timeout := time.Duration(args[2].Int()) * time.Millisecond
+ n, err := poll(t, pfdAddr, nfds, timeout)
+ return n, nil, err
+}
+
+// Ppoll implements linux syscall ppoll(2).
+func Ppoll(t *kernel.Task, args arch.SyscallArguments) (uintptr, *kernel.SyscallControl, error) {
+ pfdAddr := args[0].Pointer()
+ nfds := uint(args[1].Uint()) // poll(2) uses unsigned long.
+ timespecAddr := args[2].Pointer()
+ maskAddr := args[3].Pointer()
+ maskSize := uint(args[4].Uint())
+
+ timeout, err := copyTimespecInToDuration(t, timespecAddr)
+ if err != nil {
+ return 0, nil, err
+ }
+
+ var startNs ktime.Time
+ if timeout > 0 {
+ startNs = t.Kernel().MonotonicClock().Now()
+ }
+
+ if err := setTempSignalSet(t, maskAddr, maskSize); err != nil {
+ return 0, nil, err
+ }
+
+ _, n, err := doPoll(t, pfdAddr, nfds, timeout)
+ copyErr := copyOutTimespecRemaining(t, startNs, timeout, timespecAddr)
+ // doPoll returns EINTR if interrupted, but ppoll is normally restartable
+ // if interrupted by something other than a signal handled by the
+ // application (i.e. returns ERESTARTNOHAND). However, if
+ // copyOutTimespecRemaining failed, then the restarted ppoll would use the
+ // wrong timeout, so the error should be left as EINTR.
+ //
+ // Note that this means that if err is nil but copyErr is not, copyErr is
+ // ignored. This is consistent with Linux.
+ if err == syserror.EINTR && copyErr == nil {
+ err = kernel.ERESTARTNOHAND
+ }
+ return n, nil, err
+}
+
+// Select implements linux syscall select(2).
+func Select(t *kernel.Task, args arch.SyscallArguments) (uintptr, *kernel.SyscallControl, error) {
+ nfds := int(args[0].Int()) // select(2) uses an int.
+ readFDs := args[1].Pointer()
+ writeFDs := args[2].Pointer()
+ exceptFDs := args[3].Pointer()
+ timevalAddr := args[4].Pointer()
+
+ // Use a negative Duration to indicate "no timeout".
+ timeout := time.Duration(-1)
+ if timevalAddr != 0 {
+ var timeval linux.Timeval
+ if err := timeval.CopyIn(t, timevalAddr); err != nil {
+ return 0, nil, err
+ }
+ if timeval.Sec < 0 || timeval.Usec < 0 {
+ return 0, nil, syserror.EINVAL
+ }
+ timeout = time.Duration(timeval.ToNsecCapped())
+ }
+ startNs := t.Kernel().MonotonicClock().Now()
+ n, err := doSelect(t, nfds, readFDs, writeFDs, exceptFDs, timeout)
+ copyErr := copyOutTimevalRemaining(t, startNs, timeout, timevalAddr)
+ // See comment in Ppoll.
+ if err == syserror.EINTR && copyErr == nil {
+ err = kernel.ERESTARTNOHAND
+ }
+ return n, nil, err
+}
+
+// Pselect implements linux syscall pselect(2).
+func Pselect(t *kernel.Task, args arch.SyscallArguments) (uintptr, *kernel.SyscallControl, error) {
+ nfds := int(args[0].Int()) // select(2) uses an int.
+ readFDs := args[1].Pointer()
+ writeFDs := args[2].Pointer()
+ exceptFDs := args[3].Pointer()
+ timespecAddr := args[4].Pointer()
+ maskWithSizeAddr := args[5].Pointer()
+
+ timeout, err := copyTimespecInToDuration(t, timespecAddr)
+ if err != nil {
+ return 0, nil, err
+ }
+
+ var startNs ktime.Time
+ if timeout > 0 {
+ startNs = t.Kernel().MonotonicClock().Now()
+ }
+
+ if maskWithSizeAddr != 0 {
+ if t.Arch().Width() != 8 {
+ panic(fmt.Sprintf("unsupported sizeof(void*): %d", t.Arch().Width()))
+ }
+ var maskStruct sigSetWithSize
+ if err := maskStruct.CopyIn(t, maskWithSizeAddr); err != nil {
+ return 0, nil, err
+ }
+ if err := setTempSignalSet(t, usermem.Addr(maskStruct.sigsetAddr), uint(maskStruct.sizeofSigset)); err != nil {
+ return 0, nil, err
+ }
+ }
+
+ n, err := doSelect(t, nfds, readFDs, writeFDs, exceptFDs, timeout)
+ copyErr := copyOutTimespecRemaining(t, startNs, timeout, timespecAddr)
+ // See comment in Ppoll.
+ if err == syserror.EINTR && copyErr == nil {
+ err = kernel.ERESTARTNOHAND
+ }
+ return n, nil, err
+}
+
+// +marshal
+type sigSetWithSize struct {
+ sigsetAddr uint64
+ sizeofSigset uint64
+}
+
+// copyTimespecInToDuration copies a Timespec from the untrusted app range,
+// validates it and converts it to a Duration.
+//
+// If the Timespec is larger than what can be represented in a Duration, the
+// returned value is the maximum that Duration will allow.
+//
+// If timespecAddr is NULL, the returned value is negative.
+func copyTimespecInToDuration(t *kernel.Task, timespecAddr usermem.Addr) (time.Duration, error) {
+ // Use a negative Duration to indicate "no timeout".
+ timeout := time.Duration(-1)
+ if timespecAddr != 0 {
+ var timespec linux.Timespec
+ if err := timespec.CopyIn(t, timespecAddr); err != nil {
+ return 0, err
+ }
+ if !timespec.Valid() {
+ return 0, syserror.EINVAL
+ }
+ timeout = time.Duration(timespec.ToNsecCapped())
+ }
+ return timeout, nil
+}
+
+func setTempSignalSet(t *kernel.Task, maskAddr usermem.Addr, maskSize uint) error {
+ if maskAddr == 0 {
+ return nil
+ }
+ if maskSize != linux.SignalSetSize {
+ return syserror.EINVAL
+ }
+ var mask linux.SignalSet
+ if err := mask.CopyIn(t, maskAddr); err != nil {
+ return err
+ }
+ mask &^= kernel.UnblockableSignals
+ oldmask := t.SignalMask()
+ t.SetSignalMask(mask)
+ t.SetSavedSignalMask(oldmask)
+ return nil
+}
diff --git a/pkg/sentry/syscalls/linux/vfs2/read_write.go b/pkg/sentry/syscalls/linux/vfs2/read_write.go
new file mode 100644
index 000000000..35f6308d6
--- /dev/null
+++ b/pkg/sentry/syscalls/linux/vfs2/read_write.go
@@ -0,0 +1,511 @@
+// Copyright 2020 The gVisor Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+package vfs2
+
+import (
+ "gvisor.dev/gvisor/pkg/abi/linux"
+ "gvisor.dev/gvisor/pkg/sentry/arch"
+ "gvisor.dev/gvisor/pkg/sentry/kernel"
+ slinux "gvisor.dev/gvisor/pkg/sentry/syscalls/linux"
+ "gvisor.dev/gvisor/pkg/sentry/vfs"
+ "gvisor.dev/gvisor/pkg/syserror"
+ "gvisor.dev/gvisor/pkg/usermem"
+ "gvisor.dev/gvisor/pkg/waiter"
+)
+
+const (
+ eventMaskRead = waiter.EventIn | waiter.EventHUp | waiter.EventErr
+ eventMaskWrite = waiter.EventOut | waiter.EventHUp | waiter.EventErr
+)
+
+// Read implements Linux syscall read(2).
+func Read(t *kernel.Task, args arch.SyscallArguments) (uintptr, *kernel.SyscallControl, error) {
+ fd := args[0].Int()
+ addr := args[1].Pointer()
+ size := args[2].SizeT()
+
+ file := t.GetFileVFS2(fd)
+ if file == nil {
+ return 0, nil, syserror.EBADF
+ }
+ defer file.DecRef()
+
+ // Check that the size is legitimate.
+ si := int(size)
+ if si < 0 {
+ return 0, nil, syserror.EINVAL
+ }
+
+ // Get the destination of the read.
+ dst, err := t.SingleIOSequence(addr, si, usermem.IOOpts{
+ AddressSpaceActive: true,
+ })
+ if err != nil {
+ return 0, nil, err
+ }
+
+ n, err := read(t, file, dst, vfs.ReadOptions{})
+ t.IOUsage().AccountReadSyscall(n)
+ return uintptr(n), nil, slinux.HandleIOErrorVFS2(t, n != 0, err, kernel.ERESTARTSYS, "read", file)
+}
+
+// Readv implements Linux syscall readv(2).
+func Readv(t *kernel.Task, args arch.SyscallArguments) (uintptr, *kernel.SyscallControl, error) {
+ fd := args[0].Int()
+ addr := args[1].Pointer()
+ iovcnt := int(args[2].Int())
+
+ file := t.GetFileVFS2(fd)
+ if file == nil {
+ return 0, nil, syserror.EBADF
+ }
+ defer file.DecRef()
+
+ // Get the destination of the read.
+ dst, err := t.IovecsIOSequence(addr, iovcnt, usermem.IOOpts{
+ AddressSpaceActive: true,
+ })
+ if err != nil {
+ return 0, nil, err
+ }
+
+ n, err := read(t, file, dst, vfs.ReadOptions{})
+ t.IOUsage().AccountReadSyscall(n)
+ return uintptr(n), nil, slinux.HandleIOErrorVFS2(t, n != 0, err, kernel.ERESTARTSYS, "readv", file)
+}
+
+func read(t *kernel.Task, file *vfs.FileDescription, dst usermem.IOSequence, opts vfs.ReadOptions) (int64, error) {
+ n, err := file.Read(t, dst, opts)
+ if err != syserror.ErrWouldBlock || file.StatusFlags()&linux.O_NONBLOCK != 0 {
+ return n, err
+ }
+
+ // Register for notifications.
+ w, ch := waiter.NewChannelEntry(nil)
+ file.EventRegister(&w, eventMaskRead)
+
+ total := n
+ for {
+ // Shorten dst to reflect bytes previously read.
+ dst = dst.DropFirst(int(n))
+
+ // Issue the request and break out if it completes with anything other than
+ // "would block".
+ n, err := file.Read(t, dst, opts)
+ total += n
+ if err != syserror.ErrWouldBlock {
+ break
+ }
+ if err := t.Block(ch); err != nil {
+ break
+ }
+ }
+ file.EventUnregister(&w)
+
+ return total, err
+}
+
+// Pread64 implements Linux syscall pread64(2).
+func Pread64(t *kernel.Task, args arch.SyscallArguments) (uintptr, *kernel.SyscallControl, error) {
+ fd := args[0].Int()
+ addr := args[1].Pointer()
+ size := args[2].SizeT()
+ offset := args[3].Int64()
+
+ file := t.GetFileVFS2(fd)
+ if file == nil {
+ return 0, nil, syserror.EBADF
+ }
+ defer file.DecRef()
+
+ // Check that the offset is legitimate.
+ if offset < 0 {
+ return 0, nil, syserror.EINVAL
+ }
+
+ // Check that the size is legitimate.
+ si := int(size)
+ if si < 0 {
+ return 0, nil, syserror.EINVAL
+ }
+
+ // Get the destination of the read.
+ dst, err := t.SingleIOSequence(addr, si, usermem.IOOpts{
+ AddressSpaceActive: true,
+ })
+ if err != nil {
+ return 0, nil, err
+ }
+
+ n, err := pread(t, file, dst, offset, vfs.ReadOptions{})
+ t.IOUsage().AccountReadSyscall(n)
+ return uintptr(n), nil, slinux.HandleIOErrorVFS2(t, n != 0, err, kernel.ERESTARTSYS, "pread64", file)
+}
+
+// Preadv implements Linux syscall preadv(2).
+func Preadv(t *kernel.Task, args arch.SyscallArguments) (uintptr, *kernel.SyscallControl, error) {
+ fd := args[0].Int()
+ addr := args[1].Pointer()
+ iovcnt := int(args[2].Int())
+ offset := args[3].Int64()
+
+ file := t.GetFileVFS2(fd)
+ if file == nil {
+ return 0, nil, syserror.EBADF
+ }
+ defer file.DecRef()
+
+ // Check that the offset is legitimate.
+ if offset < 0 {
+ return 0, nil, syserror.EINVAL
+ }
+
+ // Get the destination of the read.
+ dst, err := t.IovecsIOSequence(addr, iovcnt, usermem.IOOpts{
+ AddressSpaceActive: true,
+ })
+ if err != nil {
+ return 0, nil, err
+ }
+
+ n, err := pread(t, file, dst, offset, vfs.ReadOptions{})
+ t.IOUsage().AccountReadSyscall(n)
+ return uintptr(n), nil, slinux.HandleIOErrorVFS2(t, n != 0, err, kernel.ERESTARTSYS, "preadv", file)
+}
+
+// Preadv2 implements Linux syscall preadv2(2).
+func Preadv2(t *kernel.Task, args arch.SyscallArguments) (uintptr, *kernel.SyscallControl, error) {
+ // While the glibc signature is
+ // preadv2(int fd, struct iovec* iov, int iov_cnt, off_t offset, int flags)
+ // the actual syscall
+ // (https://elixir.bootlin.com/linux/v5.5/source/fs/read_write.c#L1142)
+ // splits the offset argument into a high/low value for compatibility with
+ // 32-bit architectures. The flags argument is the 6th argument (index 5).
+ fd := args[0].Int()
+ addr := args[1].Pointer()
+ iovcnt := int(args[2].Int())
+ offset := args[3].Int64()
+ flags := args[5].Int()
+
+ file := t.GetFileVFS2(fd)
+ if file == nil {
+ return 0, nil, syserror.EBADF
+ }
+ defer file.DecRef()
+
+ // Check that the offset is legitimate.
+ if offset < -1 {
+ return 0, nil, syserror.EINVAL
+ }
+
+ // Get the destination of the read.
+ dst, err := t.IovecsIOSequence(addr, iovcnt, usermem.IOOpts{
+ AddressSpaceActive: true,
+ })
+ if err != nil {
+ return 0, nil, err
+ }
+
+ opts := vfs.ReadOptions{
+ Flags: uint32(flags),
+ }
+ var n int64
+ if offset == -1 {
+ n, err = read(t, file, dst, opts)
+ } else {
+ n, err = pread(t, file, dst, offset, opts)
+ }
+ t.IOUsage().AccountReadSyscall(n)
+ return uintptr(n), nil, slinux.HandleIOErrorVFS2(t, n != 0, err, kernel.ERESTARTSYS, "preadv2", file)
+}
+
+func pread(t *kernel.Task, file *vfs.FileDescription, dst usermem.IOSequence, offset int64, opts vfs.ReadOptions) (int64, error) {
+ n, err := file.PRead(t, dst, offset, opts)
+ if err != syserror.ErrWouldBlock || file.StatusFlags()&linux.O_NONBLOCK != 0 {
+ return n, err
+ }
+
+ // Register for notifications.
+ w, ch := waiter.NewChannelEntry(nil)
+ file.EventRegister(&w, eventMaskRead)
+
+ total := n
+ for {
+ // Shorten dst to reflect bytes previously read.
+ dst = dst.DropFirst(int(n))
+
+ // Issue the request and break out if it completes with anything other than
+ // "would block".
+ n, err := file.PRead(t, dst, offset+total, opts)
+ total += n
+ if err != syserror.ErrWouldBlock {
+ break
+ }
+ if err := t.Block(ch); err != nil {
+ break
+ }
+ }
+ file.EventUnregister(&w)
+
+ return total, err
+}
+
+// Write implements Linux syscall write(2).
+func Write(t *kernel.Task, args arch.SyscallArguments) (uintptr, *kernel.SyscallControl, error) {
+ fd := args[0].Int()
+ addr := args[1].Pointer()
+ size := args[2].SizeT()
+
+ file := t.GetFileVFS2(fd)
+ if file == nil {
+ return 0, nil, syserror.EBADF
+ }
+ defer file.DecRef()
+
+ // Check that the size is legitimate.
+ si := int(size)
+ if si < 0 {
+ return 0, nil, syserror.EINVAL
+ }
+
+ // Get the source of the write.
+ src, err := t.SingleIOSequence(addr, si, usermem.IOOpts{
+ AddressSpaceActive: true,
+ })
+ if err != nil {
+ return 0, nil, err
+ }
+
+ n, err := write(t, file, src, vfs.WriteOptions{})
+ t.IOUsage().AccountWriteSyscall(n)
+ return uintptr(n), nil, slinux.HandleIOErrorVFS2(t, n != 0, err, kernel.ERESTARTSYS, "write", file)
+}
+
+// Writev implements Linux syscall writev(2).
+func Writev(t *kernel.Task, args arch.SyscallArguments) (uintptr, *kernel.SyscallControl, error) {
+ fd := args[0].Int()
+ addr := args[1].Pointer()
+ iovcnt := int(args[2].Int())
+
+ file := t.GetFileVFS2(fd)
+ if file == nil {
+ return 0, nil, syserror.EBADF
+ }
+ defer file.DecRef()
+
+ // Get the source of the write.
+ src, err := t.IovecsIOSequence(addr, iovcnt, usermem.IOOpts{
+ AddressSpaceActive: true,
+ })
+ if err != nil {
+ return 0, nil, err
+ }
+
+ n, err := write(t, file, src, vfs.WriteOptions{})
+ t.IOUsage().AccountWriteSyscall(n)
+ return uintptr(n), nil, slinux.HandleIOErrorVFS2(t, n != 0, err, kernel.ERESTARTSYS, "writev", file)
+}
+
+func write(t *kernel.Task, file *vfs.FileDescription, src usermem.IOSequence, opts vfs.WriteOptions) (int64, error) {
+ n, err := file.Write(t, src, opts)
+ if err != syserror.ErrWouldBlock || file.StatusFlags()&linux.O_NONBLOCK != 0 {
+ return n, err
+ }
+
+ // Register for notifications.
+ w, ch := waiter.NewChannelEntry(nil)
+ file.EventRegister(&w, eventMaskWrite)
+
+ total := n
+ for {
+ // Shorten src to reflect bytes previously written.
+ src = src.DropFirst(int(n))
+
+ // Issue the request and break out if it completes with anything other than
+ // "would block".
+ n, err := file.Write(t, src, opts)
+ total += n
+ if err != syserror.ErrWouldBlock {
+ break
+ }
+ if err := t.Block(ch); err != nil {
+ break
+ }
+ }
+ file.EventUnregister(&w)
+
+ return total, err
+}
+
+// Pwrite64 implements Linux syscall pwrite64(2).
+func Pwrite64(t *kernel.Task, args arch.SyscallArguments) (uintptr, *kernel.SyscallControl, error) {
+ fd := args[0].Int()
+ addr := args[1].Pointer()
+ size := args[2].SizeT()
+ offset := args[3].Int64()
+
+ file := t.GetFileVFS2(fd)
+ if file == nil {
+ return 0, nil, syserror.EBADF
+ }
+ defer file.DecRef()
+
+ // Check that the offset is legitimate.
+ if offset < 0 {
+ return 0, nil, syserror.EINVAL
+ }
+
+ // Check that the size is legitimate.
+ si := int(size)
+ if si < 0 {
+ return 0, nil, syserror.EINVAL
+ }
+
+ // Get the source of the write.
+ src, err := t.SingleIOSequence(addr, si, usermem.IOOpts{
+ AddressSpaceActive: true,
+ })
+ if err != nil {
+ return 0, nil, err
+ }
+
+ n, err := pwrite(t, file, src, offset, vfs.WriteOptions{})
+ t.IOUsage().AccountWriteSyscall(n)
+ return uintptr(n), nil, slinux.HandleIOErrorVFS2(t, n != 0, err, kernel.ERESTARTSYS, "pwrite64", file)
+}
+
+// Pwritev implements Linux syscall pwritev(2).
+func Pwritev(t *kernel.Task, args arch.SyscallArguments) (uintptr, *kernel.SyscallControl, error) {
+ fd := args[0].Int()
+ addr := args[1].Pointer()
+ iovcnt := int(args[2].Int())
+ offset := args[3].Int64()
+
+ file := t.GetFileVFS2(fd)
+ if file == nil {
+ return 0, nil, syserror.EBADF
+ }
+ defer file.DecRef()
+
+ // Check that the offset is legitimate.
+ if offset < 0 {
+ return 0, nil, syserror.EINVAL
+ }
+
+ // Get the source of the write.
+ src, err := t.IovecsIOSequence(addr, iovcnt, usermem.IOOpts{
+ AddressSpaceActive: true,
+ })
+ if err != nil {
+ return 0, nil, err
+ }
+
+ n, err := pwrite(t, file, src, offset, vfs.WriteOptions{})
+ t.IOUsage().AccountReadSyscall(n)
+ return uintptr(n), nil, slinux.HandleIOErrorVFS2(t, n != 0, err, kernel.ERESTARTSYS, "pwritev", file)
+}
+
+// Pwritev2 implements Linux syscall pwritev2(2).
+func Pwritev2(t *kernel.Task, args arch.SyscallArguments) (uintptr, *kernel.SyscallControl, error) {
+ // While the glibc signature is
+ // pwritev2(int fd, struct iovec* iov, int iov_cnt, off_t offset, int flags)
+ // the actual syscall
+ // (https://elixir.bootlin.com/linux/v5.5/source/fs/read_write.c#L1162)
+ // splits the offset argument into a high/low value for compatibility with
+ // 32-bit architectures. The flags argument is the 6th argument (index 5).
+ fd := args[0].Int()
+ addr := args[1].Pointer()
+ iovcnt := int(args[2].Int())
+ offset := args[3].Int64()
+ flags := args[5].Int()
+
+ file := t.GetFileVFS2(fd)
+ if file == nil {
+ return 0, nil, syserror.EBADF
+ }
+ defer file.DecRef()
+
+ // Check that the offset is legitimate.
+ if offset < -1 {
+ return 0, nil, syserror.EINVAL
+ }
+
+ // Get the source of the write.
+ src, err := t.IovecsIOSequence(addr, iovcnt, usermem.IOOpts{
+ AddressSpaceActive: true,
+ })
+ if err != nil {
+ return 0, nil, err
+ }
+
+ opts := vfs.WriteOptions{
+ Flags: uint32(flags),
+ }
+ var n int64
+ if offset == -1 {
+ n, err = write(t, file, src, opts)
+ } else {
+ n, err = pwrite(t, file, src, offset, opts)
+ }
+ t.IOUsage().AccountWriteSyscall(n)
+ return uintptr(n), nil, slinux.HandleIOErrorVFS2(t, n != 0, err, kernel.ERESTARTSYS, "pwritev2", file)
+}
+
+func pwrite(t *kernel.Task, file *vfs.FileDescription, src usermem.IOSequence, offset int64, opts vfs.WriteOptions) (int64, error) {
+ n, err := file.PWrite(t, src, offset, opts)
+ if err != syserror.ErrWouldBlock || file.StatusFlags()&linux.O_NONBLOCK != 0 {
+ return n, err
+ }
+
+ // Register for notifications.
+ w, ch := waiter.NewChannelEntry(nil)
+ file.EventRegister(&w, eventMaskWrite)
+
+ total := n
+ for {
+ // Shorten src to reflect bytes previously written.
+ src = src.DropFirst(int(n))
+
+ // Issue the request and break out if it completes with anything other than
+ // "would block".
+ n, err := file.PWrite(t, src, offset+total, opts)
+ total += n
+ if err != syserror.ErrWouldBlock {
+ break
+ }
+ if err := t.Block(ch); err != nil {
+ break
+ }
+ }
+ file.EventUnregister(&w)
+
+ return total, err
+}
+
+// Lseek implements Linux syscall lseek(2).
+func Lseek(t *kernel.Task, args arch.SyscallArguments) (uintptr, *kernel.SyscallControl, error) {
+ fd := args[0].Int()
+ offset := args[1].Int64()
+ whence := args[2].Int()
+
+ file := t.GetFileVFS2(fd)
+ if file == nil {
+ return 0, nil, syserror.EBADF
+ }
+ defer file.DecRef()
+
+ newoff, err := file.Seek(t, offset, whence)
+ return uintptr(newoff), nil, err
+}
diff --git a/pkg/sentry/syscalls/linux/vfs2/setstat.go b/pkg/sentry/syscalls/linux/vfs2/setstat.go
new file mode 100644
index 000000000..9250659ff
--- /dev/null
+++ b/pkg/sentry/syscalls/linux/vfs2/setstat.go
@@ -0,0 +1,380 @@
+// Copyright 2020 The gVisor Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+package vfs2
+
+import (
+ "gvisor.dev/gvisor/pkg/abi/linux"
+ "gvisor.dev/gvisor/pkg/fspath"
+ "gvisor.dev/gvisor/pkg/sentry/arch"
+ "gvisor.dev/gvisor/pkg/sentry/kernel"
+ "gvisor.dev/gvisor/pkg/sentry/kernel/auth"
+ "gvisor.dev/gvisor/pkg/sentry/vfs"
+ "gvisor.dev/gvisor/pkg/syserror"
+ "gvisor.dev/gvisor/pkg/usermem"
+)
+
+const chmodMask = 0777 | linux.S_ISUID | linux.S_ISGID | linux.S_ISVTX
+
+// Chmod implements Linux syscall chmod(2).
+func Chmod(t *kernel.Task, args arch.SyscallArguments) (uintptr, *kernel.SyscallControl, error) {
+ pathAddr := args[0].Pointer()
+ mode := args[1].ModeT()
+ return 0, nil, fchmodat(t, linux.AT_FDCWD, pathAddr, mode)
+}
+
+// Fchmodat implements Linux syscall fchmodat(2).
+func Fchmodat(t *kernel.Task, args arch.SyscallArguments) (uintptr, *kernel.SyscallControl, error) {
+ dirfd := args[0].Int()
+ pathAddr := args[1].Pointer()
+ mode := args[2].ModeT()
+ return 0, nil, fchmodat(t, dirfd, pathAddr, mode)
+}
+
+func fchmodat(t *kernel.Task, dirfd int32, pathAddr usermem.Addr, mode uint) error {
+ path, err := copyInPath(t, pathAddr)
+ if err != nil {
+ return err
+ }
+
+ return setstatat(t, dirfd, path, disallowEmptyPath, followFinalSymlink, &vfs.SetStatOptions{
+ Stat: linux.Statx{
+ Mask: linux.STATX_MODE,
+ Mode: uint16(mode & chmodMask),
+ },
+ })
+}
+
+// Fchmod implements Linux syscall fchmod(2).
+func Fchmod(t *kernel.Task, args arch.SyscallArguments) (uintptr, *kernel.SyscallControl, error) {
+ fd := args[0].Int()
+ mode := args[1].ModeT()
+
+ file := t.GetFileVFS2(fd)
+ if file == nil {
+ return 0, nil, syserror.EBADF
+ }
+ defer file.DecRef()
+
+ return 0, nil, file.SetStat(t, vfs.SetStatOptions{
+ Stat: linux.Statx{
+ Mask: linux.STATX_MODE,
+ Mode: uint16(mode & chmodMask),
+ },
+ })
+}
+
+// Chown implements Linux syscall chown(2).
+func Chown(t *kernel.Task, args arch.SyscallArguments) (uintptr, *kernel.SyscallControl, error) {
+ pathAddr := args[0].Pointer()
+ owner := args[1].Int()
+ group := args[2].Int()
+ return 0, nil, fchownat(t, linux.AT_FDCWD, pathAddr, owner, group, 0 /* flags */)
+}
+
+// Lchown implements Linux syscall lchown(2).
+func Lchown(t *kernel.Task, args arch.SyscallArguments) (uintptr, *kernel.SyscallControl, error) {
+ pathAddr := args[0].Pointer()
+ owner := args[1].Int()
+ group := args[2].Int()
+ return 0, nil, fchownat(t, linux.AT_FDCWD, pathAddr, owner, group, linux.AT_SYMLINK_NOFOLLOW)
+}
+
+// Fchownat implements Linux syscall fchownat(2).
+func Fchownat(t *kernel.Task, args arch.SyscallArguments) (uintptr, *kernel.SyscallControl, error) {
+ dirfd := args[0].Int()
+ pathAddr := args[1].Pointer()
+ owner := args[2].Int()
+ group := args[3].Int()
+ flags := args[4].Int()
+ return 0, nil, fchownat(t, dirfd, pathAddr, owner, group, flags)
+}
+
+func fchownat(t *kernel.Task, dirfd int32, pathAddr usermem.Addr, owner, group, flags int32) error {
+ if flags&^(linux.AT_EMPTY_PATH|linux.AT_SYMLINK_NOFOLLOW) != 0 {
+ return syserror.EINVAL
+ }
+
+ path, err := copyInPath(t, pathAddr)
+ if err != nil {
+ return err
+ }
+
+ var opts vfs.SetStatOptions
+ if err := populateSetStatOptionsForChown(t, owner, group, &opts); err != nil {
+ return err
+ }
+
+ return setstatat(t, dirfd, path, shouldAllowEmptyPath(flags&linux.AT_EMPTY_PATH != 0), shouldFollowFinalSymlink(flags&linux.AT_SYMLINK_NOFOLLOW == 0), &opts)
+}
+
+func populateSetStatOptionsForChown(t *kernel.Task, owner, group int32, opts *vfs.SetStatOptions) error {
+ userns := t.UserNamespace()
+ if owner != -1 {
+ kuid := userns.MapToKUID(auth.UID(owner))
+ if !kuid.Ok() {
+ return syserror.EINVAL
+ }
+ opts.Stat.Mask |= linux.STATX_UID
+ opts.Stat.UID = uint32(kuid)
+ }
+ if group != -1 {
+ kgid := userns.MapToKGID(auth.GID(group))
+ if !kgid.Ok() {
+ return syserror.EINVAL
+ }
+ opts.Stat.Mask |= linux.STATX_GID
+ opts.Stat.GID = uint32(kgid)
+ }
+ return nil
+}
+
+// Fchown implements Linux syscall fchown(2).
+func Fchown(t *kernel.Task, args arch.SyscallArguments) (uintptr, *kernel.SyscallControl, error) {
+ fd := args[0].Int()
+ owner := args[1].Int()
+ group := args[2].Int()
+
+ file := t.GetFileVFS2(fd)
+ if file == nil {
+ return 0, nil, syserror.EBADF
+ }
+ defer file.DecRef()
+
+ var opts vfs.SetStatOptions
+ if err := populateSetStatOptionsForChown(t, owner, group, &opts); err != nil {
+ return 0, nil, err
+ }
+ return 0, nil, file.SetStat(t, opts)
+}
+
+// Truncate implements Linux syscall truncate(2).
+func Truncate(t *kernel.Task, args arch.SyscallArguments) (uintptr, *kernel.SyscallControl, error) {
+ addr := args[0].Pointer()
+ length := args[1].Int64()
+
+ if length < 0 {
+ return 0, nil, syserror.EINVAL
+ }
+
+ path, err := copyInPath(t, addr)
+ if err != nil {
+ return 0, nil, err
+ }
+
+ return 0, nil, setstatat(t, linux.AT_FDCWD, path, disallowEmptyPath, followFinalSymlink, &vfs.SetStatOptions{
+ Stat: linux.Statx{
+ Mask: linux.STATX_SIZE,
+ Size: uint64(length),
+ },
+ })
+}
+
+// Ftruncate implements Linux syscall ftruncate(2).
+func Ftruncate(t *kernel.Task, args arch.SyscallArguments) (uintptr, *kernel.SyscallControl, error) {
+ fd := args[0].Int()
+ length := args[1].Int64()
+
+ if length < 0 {
+ return 0, nil, syserror.EINVAL
+ }
+
+ file := t.GetFileVFS2(fd)
+ if file == nil {
+ return 0, nil, syserror.EBADF
+ }
+ defer file.DecRef()
+
+ return 0, nil, file.SetStat(t, vfs.SetStatOptions{
+ Stat: linux.Statx{
+ Mask: linux.STATX_SIZE,
+ Size: uint64(length),
+ },
+ })
+}
+
+// Utime implements Linux syscall utime(2).
+func Utime(t *kernel.Task, args arch.SyscallArguments) (uintptr, *kernel.SyscallControl, error) {
+ pathAddr := args[0].Pointer()
+ timesAddr := args[1].Pointer()
+
+ path, err := copyInPath(t, pathAddr)
+ if err != nil {
+ return 0, nil, err
+ }
+
+ opts := vfs.SetStatOptions{
+ Stat: linux.Statx{
+ Mask: linux.STATX_ATIME | linux.STATX_MTIME,
+ },
+ }
+ if timesAddr == 0 {
+ opts.Stat.Atime.Nsec = linux.UTIME_NOW
+ opts.Stat.Mtime.Nsec = linux.UTIME_NOW
+ } else {
+ var times linux.Utime
+ if err := times.CopyIn(t, timesAddr); err != nil {
+ return 0, nil, err
+ }
+ opts.Stat.Atime.Sec = times.Actime
+ opts.Stat.Mtime.Sec = times.Modtime
+ }
+
+ return 0, nil, setstatat(t, linux.AT_FDCWD, path, disallowEmptyPath, followFinalSymlink, &opts)
+}
+
+// Utimes implements Linux syscall utimes(2).
+func Utimes(t *kernel.Task, args arch.SyscallArguments) (uintptr, *kernel.SyscallControl, error) {
+ pathAddr := args[0].Pointer()
+ timesAddr := args[1].Pointer()
+
+ path, err := copyInPath(t, pathAddr)
+ if err != nil {
+ return 0, nil, err
+ }
+
+ opts := vfs.SetStatOptions{
+ Stat: linux.Statx{
+ Mask: linux.STATX_ATIME | linux.STATX_MTIME,
+ },
+ }
+ if timesAddr == 0 {
+ opts.Stat.Atime.Nsec = linux.UTIME_NOW
+ opts.Stat.Mtime.Nsec = linux.UTIME_NOW
+ } else {
+ var times [2]linux.Timeval
+ if _, err := t.CopyIn(timesAddr, &times); err != nil {
+ return 0, nil, err
+ }
+ opts.Stat.Atime = linux.StatxTimestamp{
+ Sec: times[0].Sec,
+ Nsec: uint32(times[0].Usec * 1000),
+ }
+ opts.Stat.Mtime = linux.StatxTimestamp{
+ Sec: times[1].Sec,
+ Nsec: uint32(times[1].Usec * 1000),
+ }
+ }
+
+ return 0, nil, setstatat(t, linux.AT_FDCWD, path, disallowEmptyPath, followFinalSymlink, &opts)
+}
+
+// Utimensat implements Linux syscall utimensat(2).
+func Utimensat(t *kernel.Task, args arch.SyscallArguments) (uintptr, *kernel.SyscallControl, error) {
+ dirfd := args[0].Int()
+ pathAddr := args[1].Pointer()
+ timesAddr := args[2].Pointer()
+ flags := args[3].Int()
+
+ if flags&^linux.AT_SYMLINK_NOFOLLOW != 0 {
+ return 0, nil, syserror.EINVAL
+ }
+
+ path, err := copyInPath(t, pathAddr)
+ if err != nil {
+ return 0, nil, err
+ }
+
+ var opts vfs.SetStatOptions
+ if err := populateSetStatOptionsForUtimens(t, timesAddr, &opts); err != nil {
+ return 0, nil, err
+ }
+
+ return 0, nil, setstatat(t, dirfd, path, disallowEmptyPath, followFinalSymlink, &opts)
+}
+
+// Futimens implements Linux syscall futimens(2).
+func Futimens(t *kernel.Task, args arch.SyscallArguments) (uintptr, *kernel.SyscallControl, error) {
+ fd := args[0].Int()
+ timesAddr := args[1].Pointer()
+
+ file := t.GetFileVFS2(fd)
+ if file == nil {
+ return 0, nil, syserror.EBADF
+ }
+ defer file.DecRef()
+
+ var opts vfs.SetStatOptions
+ if err := populateSetStatOptionsForUtimens(t, timesAddr, &opts); err != nil {
+ return 0, nil, err
+ }
+
+ return 0, nil, file.SetStat(t, opts)
+}
+
+func populateSetStatOptionsForUtimens(t *kernel.Task, timesAddr usermem.Addr, opts *vfs.SetStatOptions) error {
+ if timesAddr == 0 {
+ opts.Stat.Mask = linux.STATX_ATIME | linux.STATX_MTIME
+ opts.Stat.Atime.Nsec = linux.UTIME_NOW
+ opts.Stat.Mtime.Nsec = linux.UTIME_NOW
+ return nil
+ }
+ var times [2]linux.Timespec
+ if _, err := t.CopyIn(timesAddr, &times); err != nil {
+ return err
+ }
+ if times[0].Nsec != linux.UTIME_OMIT {
+ opts.Stat.Mask |= linux.STATX_ATIME
+ opts.Stat.Atime = linux.StatxTimestamp{
+ Sec: times[0].Sec,
+ Nsec: uint32(times[0].Nsec),
+ }
+ }
+ if times[1].Nsec != linux.UTIME_OMIT {
+ opts.Stat.Mask |= linux.STATX_MTIME
+ opts.Stat.Mtime = linux.StatxTimestamp{
+ Sec: times[1].Sec,
+ Nsec: uint32(times[1].Nsec),
+ }
+ }
+ return nil
+}
+
+func setstatat(t *kernel.Task, dirfd int32, path fspath.Path, shouldAllowEmptyPath shouldAllowEmptyPath, shouldFollowFinalSymlink shouldFollowFinalSymlink, opts *vfs.SetStatOptions) error {
+ root := t.FSContext().RootDirectoryVFS2()
+ defer root.DecRef()
+ start := root
+ if !path.Absolute {
+ if !path.HasComponents() && !bool(shouldAllowEmptyPath) {
+ return syserror.ENOENT
+ }
+ if dirfd == linux.AT_FDCWD {
+ start = t.FSContext().WorkingDirectoryVFS2()
+ defer start.DecRef()
+ } else {
+ dirfile := t.GetFileVFS2(dirfd)
+ if dirfile == nil {
+ return syserror.EBADF
+ }
+ if !path.HasComponents() {
+ // Use FileDescription.SetStat() instead of
+ // VirtualFilesystem.SetStatAt(), since the former may be able
+ // to use opened file state to expedite the SetStat.
+ err := dirfile.SetStat(t, *opts)
+ dirfile.DecRef()
+ return err
+ }
+ start = dirfile.VirtualDentry()
+ start.IncRef()
+ defer start.DecRef()
+ dirfile.DecRef()
+ }
+ }
+ return t.Kernel().VFS().SetStatAt(t, t.Credentials(), &vfs.PathOperation{
+ Root: root,
+ Start: start,
+ Path: path,
+ FollowFinalSymlink: bool(shouldFollowFinalSymlink),
+ }, opts)
+}
diff --git a/pkg/sentry/syscalls/linux/vfs2/stat.go b/pkg/sentry/syscalls/linux/vfs2/stat.go
new file mode 100644
index 000000000..dca8d7011
--- /dev/null
+++ b/pkg/sentry/syscalls/linux/vfs2/stat.go
@@ -0,0 +1,346 @@
+// Copyright 2020 The gVisor Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+package vfs2
+
+import (
+ "gvisor.dev/gvisor/pkg/abi/linux"
+ "gvisor.dev/gvisor/pkg/fspath"
+ "gvisor.dev/gvisor/pkg/gohacks"
+ "gvisor.dev/gvisor/pkg/sentry/arch"
+ "gvisor.dev/gvisor/pkg/sentry/kernel"
+ "gvisor.dev/gvisor/pkg/sentry/kernel/auth"
+ "gvisor.dev/gvisor/pkg/sentry/vfs"
+ "gvisor.dev/gvisor/pkg/syserror"
+ "gvisor.dev/gvisor/pkg/usermem"
+)
+
+// Stat implements Linux syscall stat(2).
+func Stat(t *kernel.Task, args arch.SyscallArguments) (uintptr, *kernel.SyscallControl, error) {
+ pathAddr := args[0].Pointer()
+ statAddr := args[1].Pointer()
+ return 0, nil, fstatat(t, linux.AT_FDCWD, pathAddr, statAddr, 0 /* flags */)
+}
+
+// Lstat implements Linux syscall lstat(2).
+func Lstat(t *kernel.Task, args arch.SyscallArguments) (uintptr, *kernel.SyscallControl, error) {
+ pathAddr := args[0].Pointer()
+ statAddr := args[1].Pointer()
+ return 0, nil, fstatat(t, linux.AT_FDCWD, pathAddr, statAddr, linux.AT_SYMLINK_NOFOLLOW)
+}
+
+// Newfstatat implements Linux syscall newfstatat, which backs fstatat(2).
+func Newfstatat(t *kernel.Task, args arch.SyscallArguments) (uintptr, *kernel.SyscallControl, error) {
+ dirfd := args[0].Int()
+ pathAddr := args[1].Pointer()
+ statAddr := args[2].Pointer()
+ flags := args[3].Int()
+ return 0, nil, fstatat(t, dirfd, pathAddr, statAddr, flags)
+}
+
+func fstatat(t *kernel.Task, dirfd int32, pathAddr, statAddr usermem.Addr, flags int32) error {
+ if flags&^(linux.AT_EMPTY_PATH|linux.AT_SYMLINK_NOFOLLOW) != 0 {
+ return syserror.EINVAL
+ }
+
+ opts := vfs.StatOptions{
+ Mask: linux.STATX_BASIC_STATS,
+ }
+
+ path, err := copyInPath(t, pathAddr)
+ if err != nil {
+ return err
+ }
+
+ root := t.FSContext().RootDirectoryVFS2()
+ defer root.DecRef()
+ start := root
+ if !path.Absolute {
+ if !path.HasComponents() && flags&linux.AT_EMPTY_PATH == 0 {
+ return syserror.ENOENT
+ }
+ if dirfd == linux.AT_FDCWD {
+ start = t.FSContext().WorkingDirectoryVFS2()
+ defer start.DecRef()
+ } else {
+ dirfile := t.GetFileVFS2(dirfd)
+ if dirfile == nil {
+ return syserror.EBADF
+ }
+ if !path.HasComponents() {
+ // Use FileDescription.Stat() instead of
+ // VirtualFilesystem.StatAt() for fstatat(fd, ""), since the
+ // former may be able to use opened file state to expedite the
+ // Stat.
+ statx, err := dirfile.Stat(t, opts)
+ dirfile.DecRef()
+ if err != nil {
+ return err
+ }
+ var stat linux.Stat
+ convertStatxToUserStat(t, &statx, &stat)
+ return stat.CopyOut(t, statAddr)
+ }
+ start = dirfile.VirtualDentry()
+ start.IncRef()
+ defer start.DecRef()
+ dirfile.DecRef()
+ }
+ }
+
+ statx, err := t.Kernel().VFS().StatAt(t, t.Credentials(), &vfs.PathOperation{
+ Root: root,
+ Start: start,
+ Path: path,
+ FollowFinalSymlink: flags&linux.AT_SYMLINK_NOFOLLOW == 0,
+ }, &opts)
+ if err != nil {
+ return err
+ }
+ var stat linux.Stat
+ convertStatxToUserStat(t, &statx, &stat)
+ return stat.CopyOut(t, statAddr)
+}
+
+// This takes both input and output as pointer arguments to avoid copying large
+// structs.
+func convertStatxToUserStat(t *kernel.Task, statx *linux.Statx, stat *linux.Stat) {
+ // Linux just copies fields from struct kstat without regard to struct
+ // kstat::result_mask (fs/stat.c:cp_new_stat()), so we do too.
+ userns := t.UserNamespace()
+ *stat = linux.Stat{
+ Dev: uint64(linux.MakeDeviceID(uint16(statx.DevMajor), statx.DevMinor)),
+ Ino: statx.Ino,
+ Nlink: uint64(statx.Nlink),
+ Mode: uint32(statx.Mode),
+ UID: uint32(auth.KUID(statx.UID).In(userns).OrOverflow()),
+ GID: uint32(auth.KGID(statx.GID).In(userns).OrOverflow()),
+ Rdev: uint64(linux.MakeDeviceID(uint16(statx.RdevMajor), statx.RdevMinor)),
+ Size: int64(statx.Size),
+ Blksize: int64(statx.Blksize),
+ Blocks: int64(statx.Blocks),
+ ATime: timespecFromStatxTimestamp(statx.Atime),
+ MTime: timespecFromStatxTimestamp(statx.Mtime),
+ CTime: timespecFromStatxTimestamp(statx.Ctime),
+ }
+}
+
+func timespecFromStatxTimestamp(sxts linux.StatxTimestamp) linux.Timespec {
+ return linux.Timespec{
+ Sec: sxts.Sec,
+ Nsec: int64(sxts.Nsec),
+ }
+}
+
+// Fstat implements Linux syscall fstat(2).
+func Fstat(t *kernel.Task, args arch.SyscallArguments) (uintptr, *kernel.SyscallControl, error) {
+ fd := args[0].Int()
+ statAddr := args[1].Pointer()
+
+ file := t.GetFileVFS2(fd)
+ if file == nil {
+ return 0, nil, syserror.EBADF
+ }
+ defer file.DecRef()
+
+ statx, err := file.Stat(t, vfs.StatOptions{
+ Mask: linux.STATX_BASIC_STATS,
+ })
+ if err != nil {
+ return 0, nil, err
+ }
+ var stat linux.Stat
+ convertStatxToUserStat(t, &statx, &stat)
+ return 0, nil, stat.CopyOut(t, statAddr)
+}
+
+// Statx implements Linux syscall statx(2).
+func Statx(t *kernel.Task, args arch.SyscallArguments) (uintptr, *kernel.SyscallControl, error) {
+ dirfd := args[0].Int()
+ pathAddr := args[1].Pointer()
+ flags := args[2].Int()
+ mask := args[3].Uint()
+ statxAddr := args[4].Pointer()
+
+ if flags&^(linux.AT_EMPTY_PATH|linux.AT_SYMLINK_NOFOLLOW) != 0 {
+ return 0, nil, syserror.EINVAL
+ }
+
+ opts := vfs.StatOptions{
+ Mask: mask,
+ Sync: uint32(flags & linux.AT_STATX_SYNC_TYPE),
+ }
+
+ path, err := copyInPath(t, pathAddr)
+ if err != nil {
+ return 0, nil, err
+ }
+
+ root := t.FSContext().RootDirectoryVFS2()
+ defer root.DecRef()
+ start := root
+ if !path.Absolute {
+ if !path.HasComponents() && flags&linux.AT_EMPTY_PATH == 0 {
+ return 0, nil, syserror.ENOENT
+ }
+ if dirfd == linux.AT_FDCWD {
+ start = t.FSContext().WorkingDirectoryVFS2()
+ defer start.DecRef()
+ } else {
+ dirfile := t.GetFileVFS2(dirfd)
+ if dirfile == nil {
+ return 0, nil, syserror.EBADF
+ }
+ if !path.HasComponents() {
+ // Use FileDescription.Stat() instead of
+ // VirtualFilesystem.StatAt() for statx(fd, ""), since the
+ // former may be able to use opened file state to expedite the
+ // Stat.
+ statx, err := dirfile.Stat(t, opts)
+ dirfile.DecRef()
+ if err != nil {
+ return 0, nil, err
+ }
+ userifyStatx(t, &statx)
+ return 0, nil, statx.CopyOut(t, statxAddr)
+ }
+ start = dirfile.VirtualDentry()
+ start.IncRef()
+ defer start.DecRef()
+ dirfile.DecRef()
+ }
+ }
+
+ statx, err := t.Kernel().VFS().StatAt(t, t.Credentials(), &vfs.PathOperation{
+ Root: root,
+ Start: start,
+ Path: path,
+ FollowFinalSymlink: flags&linux.AT_SYMLINK_NOFOLLOW == 0,
+ }, &opts)
+ if err != nil {
+ return 0, nil, err
+ }
+ userifyStatx(t, &statx)
+ return 0, nil, statx.CopyOut(t, statxAddr)
+}
+
+func userifyStatx(t *kernel.Task, statx *linux.Statx) {
+ userns := t.UserNamespace()
+ statx.UID = uint32(auth.KUID(statx.UID).In(userns).OrOverflow())
+ statx.GID = uint32(auth.KGID(statx.GID).In(userns).OrOverflow())
+}
+
+// Readlink implements Linux syscall readlink(2).
+func Readlink(t *kernel.Task, args arch.SyscallArguments) (uintptr, *kernel.SyscallControl, error) {
+ pathAddr := args[0].Pointer()
+ bufAddr := args[1].Pointer()
+ size := args[2].SizeT()
+ return readlinkat(t, linux.AT_FDCWD, pathAddr, bufAddr, size)
+}
+
+// Access implements Linux syscall access(2).
+func Access(t *kernel.Task, args arch.SyscallArguments) (uintptr, *kernel.SyscallControl, error) {
+ // FIXME(jamieliu): actually implement
+ return 0, nil, nil
+}
+
+// Faccessat implements Linux syscall access(2).
+func Faccessat(t *kernel.Task, args arch.SyscallArguments) (uintptr, *kernel.SyscallControl, error) {
+ // FIXME(jamieliu): actually implement
+ return 0, nil, nil
+}
+
+// Readlinkat implements Linux syscall mknodat(2).
+func Readlinkat(t *kernel.Task, args arch.SyscallArguments) (uintptr, *kernel.SyscallControl, error) {
+ dirfd := args[0].Int()
+ pathAddr := args[1].Pointer()
+ bufAddr := args[2].Pointer()
+ size := args[3].SizeT()
+ return readlinkat(t, dirfd, pathAddr, bufAddr, size)
+}
+
+func readlinkat(t *kernel.Task, dirfd int32, pathAddr, bufAddr usermem.Addr, size uint) (uintptr, *kernel.SyscallControl, error) {
+ if int(size) <= 0 {
+ return 0, nil, syserror.EINVAL
+ }
+
+ path, err := copyInPath(t, pathAddr)
+ if err != nil {
+ return 0, nil, err
+ }
+ // "Since Linux 2.6.39, pathname can be an empty string, in which case the
+ // call operates on the symbolic link referred to by dirfd ..." -
+ // readlinkat(2)
+ tpop, err := getTaskPathOperation(t, dirfd, path, allowEmptyPath, nofollowFinalSymlink)
+ if err != nil {
+ return 0, nil, err
+ }
+ defer tpop.Release()
+
+ target, err := t.Kernel().VFS().ReadlinkAt(t, t.Credentials(), &tpop.pop)
+ if err != nil {
+ return 0, nil, err
+ }
+
+ if len(target) > int(size) {
+ target = target[:size]
+ }
+ n, err := t.CopyOutBytes(bufAddr, gohacks.ImmutableBytesFromString(target))
+ if n == 0 {
+ return 0, nil, err
+ }
+ return uintptr(n), nil, nil
+}
+
+// Statfs implements Linux syscall statfs(2).
+func Statfs(t *kernel.Task, args arch.SyscallArguments) (uintptr, *kernel.SyscallControl, error) {
+ pathAddr := args[0].Pointer()
+ bufAddr := args[1].Pointer()
+
+ path, err := copyInPath(t, pathAddr)
+ if err != nil {
+ return 0, nil, err
+ }
+ tpop, err := getTaskPathOperation(t, linux.AT_FDCWD, path, disallowEmptyPath, followFinalSymlink)
+ if err != nil {
+ return 0, nil, err
+ }
+ defer tpop.Release()
+
+ statfs, err := t.Kernel().VFS().StatFSAt(t, t.Credentials(), &tpop.pop)
+ if err != nil {
+ return 0, nil, err
+ }
+
+ return 0, nil, statfs.CopyOut(t, bufAddr)
+}
+
+// Fstatfs implements Linux syscall fstatfs(2).
+func Fstatfs(t *kernel.Task, args arch.SyscallArguments) (uintptr, *kernel.SyscallControl, error) {
+ fd := args[0].Int()
+ bufAddr := args[1].Pointer()
+
+ tpop, err := getTaskPathOperation(t, fd, fspath.Path{}, allowEmptyPath, nofollowFinalSymlink)
+ if err != nil {
+ return 0, nil, err
+ }
+ defer tpop.Release()
+
+ statfs, err := t.Kernel().VFS().StatFSAt(t, t.Credentials(), &tpop.pop)
+ if err != nil {
+ return 0, nil, err
+ }
+
+ return 0, nil, statfs.CopyOut(t, bufAddr)
+}
diff --git a/pkg/sentry/syscalls/linux/vfs2/sync.go b/pkg/sentry/syscalls/linux/vfs2/sync.go
new file mode 100644
index 000000000..365250b0b
--- /dev/null
+++ b/pkg/sentry/syscalls/linux/vfs2/sync.go
@@ -0,0 +1,87 @@
+// Copyright 2020 The gVisor Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+package vfs2
+
+import (
+ "gvisor.dev/gvisor/pkg/abi/linux"
+ "gvisor.dev/gvisor/pkg/sentry/arch"
+ "gvisor.dev/gvisor/pkg/sentry/kernel"
+ "gvisor.dev/gvisor/pkg/syserror"
+)
+
+// Sync implements Linux syscall sync(2).
+func Sync(t *kernel.Task, args arch.SyscallArguments) (uintptr, *kernel.SyscallControl, error) {
+ return 0, nil, t.Kernel().VFS().SyncAllFilesystems(t)
+}
+
+// Syncfs implements Linux syscall syncfs(2).
+func Syncfs(t *kernel.Task, args arch.SyscallArguments) (uintptr, *kernel.SyscallControl, error) {
+ fd := args[0].Int()
+
+ file := t.GetFileVFS2(fd)
+ if file == nil {
+ return 0, nil, syserror.EBADF
+ }
+ defer file.DecRef()
+
+ return 0, nil, file.SyncFS(t)
+}
+
+// Fsync implements Linux syscall fsync(2).
+func Fsync(t *kernel.Task, args arch.SyscallArguments) (uintptr, *kernel.SyscallControl, error) {
+ fd := args[0].Int()
+
+ file := t.GetFileVFS2(fd)
+ if file == nil {
+ return 0, nil, syserror.EBADF
+ }
+ defer file.DecRef()
+
+ return 0, nil, file.Sync(t)
+}
+
+// Fdatasync implements Linux syscall fdatasync(2).
+func Fdatasync(t *kernel.Task, args arch.SyscallArguments) (uintptr, *kernel.SyscallControl, error) {
+ // TODO(gvisor.dev/issue/1897): Avoid writeback of unnecessary metadata.
+ return Fsync(t, args)
+}
+
+// SyncFileRange implements Linux syscall sync_file_range(2).
+func SyncFileRange(t *kernel.Task, args arch.SyscallArguments) (uintptr, *kernel.SyscallControl, error) {
+ fd := args[0].Int()
+ offset := args[1].Int64()
+ nbytes := args[2].Int64()
+ flags := args[3].Uint()
+
+ if offset < 0 {
+ return 0, nil, syserror.EINVAL
+ }
+ if nbytes < 0 {
+ return 0, nil, syserror.EINVAL
+ }
+ if flags&^(linux.SYNC_FILE_RANGE_WAIT_BEFORE|linux.SYNC_FILE_RANGE_WRITE|linux.SYNC_FILE_RANGE_WAIT_AFTER) != 0 {
+ return 0, nil, syserror.EINVAL
+ }
+
+ file := t.GetFileVFS2(fd)
+ if file == nil {
+ return 0, nil, syserror.EBADF
+ }
+ defer file.DecRef()
+
+ // TODO(gvisor.dev/issue/1897): Avoid writeback of data ranges outside of
+ // [offset, offset+nbytes).
+ return 0, nil, file.Sync(t)
+}
diff --git a/pkg/sentry/syscalls/linux/vfs2/sys_read.go b/pkg/sentry/syscalls/linux/vfs2/sys_read.go
deleted file mode 100644
index 7667524c7..000000000
--- a/pkg/sentry/syscalls/linux/vfs2/sys_read.go
+++ /dev/null
@@ -1,95 +0,0 @@
-// Copyright 2020 The gVisor Authors.
-//
-// Licensed under the Apache License, Version 2.0 (the "License");
-// you may not use this file except in compliance with the License.
-// You may obtain a copy of the License at
-//
-// http://www.apache.org/licenses/LICENSE-2.0
-//
-// Unless required by applicable law or agreed to in writing, software
-// distributed under the License is distributed on an "AS IS" BASIS,
-// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-// See the License for the specific language governing permissions and
-// limitations under the License.
-
-package vfs2
-
-import (
- "gvisor.dev/gvisor/pkg/sentry/arch"
- "gvisor.dev/gvisor/pkg/sentry/kernel"
- "gvisor.dev/gvisor/pkg/sentry/syscalls/linux"
- "gvisor.dev/gvisor/pkg/sentry/vfs"
- "gvisor.dev/gvisor/pkg/syserror"
- "gvisor.dev/gvisor/pkg/usermem"
- "gvisor.dev/gvisor/pkg/waiter"
-)
-
-const (
- // EventMaskRead contains events that can be triggered on reads.
- EventMaskRead = waiter.EventIn | waiter.EventHUp | waiter.EventErr
-)
-
-// Read implements linux syscall read(2). Note that we try to get a buffer that
-// is exactly the size requested because some applications like qemu expect
-// they can do large reads all at once. Bug for bug. Same for other read
-// calls below.
-func Read(t *kernel.Task, args arch.SyscallArguments) (uintptr, *kernel.SyscallControl, error) {
- fd := args[0].Int()
- addr := args[1].Pointer()
- size := args[2].SizeT()
-
- file := t.GetFileVFS2(fd)
- if file == nil {
- return 0, nil, syserror.EBADF
- }
- defer file.DecRef()
-
- // Check that the size is legitimate.
- si := int(size)
- if si < 0 {
- return 0, nil, syserror.EINVAL
- }
-
- // Get the destination of the read.
- dst, err := t.SingleIOSequence(addr, si, usermem.IOOpts{
- AddressSpaceActive: true,
- })
- if err != nil {
- return 0, nil, err
- }
-
- n, err := read(t, file, dst, vfs.ReadOptions{})
- t.IOUsage().AccountReadSyscall(n)
- return uintptr(n), nil, linux.HandleIOErrorVFS2(t, n != 0, err, kernel.ERESTARTSYS, "read", file)
-}
-
-func read(t *kernel.Task, file *vfs.FileDescription, dst usermem.IOSequence, opts vfs.ReadOptions) (int64, error) {
- n, err := file.Read(t, dst, opts)
- if err != syserror.ErrWouldBlock {
- return n, err
- }
-
- // Register for notifications.
- w, ch := waiter.NewChannelEntry(nil)
- file.EventRegister(&w, EventMaskRead)
-
- total := n
- for {
- // Shorten dst to reflect bytes previously read.
- dst = dst.DropFirst(int(n))
-
- // Issue the request and break out if it completes with anything other than
- // "would block".
- n, err := file.Read(t, dst, opts)
- total += n
- if err != syserror.ErrWouldBlock {
- break
- }
- if err := t.Block(ch); err != nil {
- break
- }
- }
- file.EventUnregister(&w)
-
- return total, err
-}
diff --git a/pkg/sentry/syscalls/linux/vfs2/xattr.go b/pkg/sentry/syscalls/linux/vfs2/xattr.go
new file mode 100644
index 000000000..89e9ff4d7
--- /dev/null
+++ b/pkg/sentry/syscalls/linux/vfs2/xattr.go
@@ -0,0 +1,353 @@
+// Copyright 2020 The gVisor Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+package vfs2
+
+import (
+ "bytes"
+
+ "gvisor.dev/gvisor/pkg/abi/linux"
+ "gvisor.dev/gvisor/pkg/gohacks"
+ "gvisor.dev/gvisor/pkg/sentry/arch"
+ "gvisor.dev/gvisor/pkg/sentry/kernel"
+ "gvisor.dev/gvisor/pkg/sentry/vfs"
+ "gvisor.dev/gvisor/pkg/syserror"
+ "gvisor.dev/gvisor/pkg/usermem"
+)
+
+// Listxattr implements Linux syscall listxattr(2).
+func Listxattr(t *kernel.Task, args arch.SyscallArguments) (uintptr, *kernel.SyscallControl, error) {
+ return listxattr(t, args, followFinalSymlink)
+}
+
+// Llistxattr implements Linux syscall llistxattr(2).
+func Llistxattr(t *kernel.Task, args arch.SyscallArguments) (uintptr, *kernel.SyscallControl, error) {
+ return listxattr(t, args, nofollowFinalSymlink)
+}
+
+func listxattr(t *kernel.Task, args arch.SyscallArguments, shouldFollowFinalSymlink shouldFollowFinalSymlink) (uintptr, *kernel.SyscallControl, error) {
+ pathAddr := args[0].Pointer()
+ listAddr := args[1].Pointer()
+ size := args[2].SizeT()
+
+ path, err := copyInPath(t, pathAddr)
+ if err != nil {
+ return 0, nil, err
+ }
+ tpop, err := getTaskPathOperation(t, linux.AT_FDCWD, path, disallowEmptyPath, shouldFollowFinalSymlink)
+ if err != nil {
+ return 0, nil, err
+ }
+ defer tpop.Release()
+
+ names, err := t.Kernel().VFS().ListxattrAt(t, t.Credentials(), &tpop.pop)
+ if err != nil {
+ return 0, nil, err
+ }
+ n, err := copyOutXattrNameList(t, listAddr, size, names)
+ if err != nil {
+ return 0, nil, err
+ }
+ return uintptr(n), nil, nil
+}
+
+// Flistxattr implements Linux syscall flistxattr(2).
+func Flistxattr(t *kernel.Task, args arch.SyscallArguments) (uintptr, *kernel.SyscallControl, error) {
+ fd := args[0].Int()
+ listAddr := args[1].Pointer()
+ size := args[2].SizeT()
+
+ file := t.GetFileVFS2(fd)
+ if file == nil {
+ return 0, nil, syserror.EBADF
+ }
+ defer file.DecRef()
+
+ names, err := file.Listxattr(t)
+ if err != nil {
+ return 0, nil, err
+ }
+ n, err := copyOutXattrNameList(t, listAddr, size, names)
+ if err != nil {
+ return 0, nil, err
+ }
+ return uintptr(n), nil, nil
+}
+
+// Getxattr implements Linux syscall getxattr(2).
+func Getxattr(t *kernel.Task, args arch.SyscallArguments) (uintptr, *kernel.SyscallControl, error) {
+ return getxattr(t, args, followFinalSymlink)
+}
+
+// Lgetxattr implements Linux syscall lgetxattr(2).
+func Lgetxattr(t *kernel.Task, args arch.SyscallArguments) (uintptr, *kernel.SyscallControl, error) {
+ return getxattr(t, args, nofollowFinalSymlink)
+}
+
+func getxattr(t *kernel.Task, args arch.SyscallArguments, shouldFollowFinalSymlink shouldFollowFinalSymlink) (uintptr, *kernel.SyscallControl, error) {
+ pathAddr := args[0].Pointer()
+ nameAddr := args[1].Pointer()
+ valueAddr := args[2].Pointer()
+ size := args[3].SizeT()
+
+ path, err := copyInPath(t, pathAddr)
+ if err != nil {
+ return 0, nil, err
+ }
+ tpop, err := getTaskPathOperation(t, linux.AT_FDCWD, path, disallowEmptyPath, shouldFollowFinalSymlink)
+ if err != nil {
+ return 0, nil, err
+ }
+ defer tpop.Release()
+
+ name, err := copyInXattrName(t, nameAddr)
+ if err != nil {
+ return 0, nil, err
+ }
+
+ value, err := t.Kernel().VFS().GetxattrAt(t, t.Credentials(), &tpop.pop, name)
+ if err != nil {
+ return 0, nil, err
+ }
+ n, err := copyOutXattrValue(t, valueAddr, size, value)
+ if err != nil {
+ return 0, nil, err
+ }
+ return uintptr(n), nil, nil
+}
+
+// Fgetxattr implements Linux syscall fgetxattr(2).
+func Fgetxattr(t *kernel.Task, args arch.SyscallArguments) (uintptr, *kernel.SyscallControl, error) {
+ fd := args[0].Int()
+ nameAddr := args[1].Pointer()
+ valueAddr := args[2].Pointer()
+ size := args[3].SizeT()
+
+ file := t.GetFileVFS2(fd)
+ if file == nil {
+ return 0, nil, syserror.EBADF
+ }
+ defer file.DecRef()
+
+ name, err := copyInXattrName(t, nameAddr)
+ if err != nil {
+ return 0, nil, err
+ }
+
+ value, err := file.Getxattr(t, name)
+ if err != nil {
+ return 0, nil, err
+ }
+ n, err := copyOutXattrValue(t, valueAddr, size, value)
+ if err != nil {
+ return 0, nil, err
+ }
+ return uintptr(n), nil, nil
+}
+
+// Setxattr implements Linux syscall setxattr(2).
+func Setxattr(t *kernel.Task, args arch.SyscallArguments) (uintptr, *kernel.SyscallControl, error) {
+ return 0, nil, setxattr(t, args, followFinalSymlink)
+}
+
+// Lsetxattr implements Linux syscall lsetxattr(2).
+func Lsetxattr(t *kernel.Task, args arch.SyscallArguments) (uintptr, *kernel.SyscallControl, error) {
+ return 0, nil, setxattr(t, args, nofollowFinalSymlink)
+}
+
+func setxattr(t *kernel.Task, args arch.SyscallArguments, shouldFollowFinalSymlink shouldFollowFinalSymlink) error {
+ pathAddr := args[0].Pointer()
+ nameAddr := args[1].Pointer()
+ valueAddr := args[2].Pointer()
+ size := args[3].SizeT()
+ flags := args[4].Int()
+
+ if flags&^(linux.XATTR_CREATE|linux.XATTR_REPLACE) != 0 {
+ return syserror.EINVAL
+ }
+
+ path, err := copyInPath(t, pathAddr)
+ if err != nil {
+ return err
+ }
+ tpop, err := getTaskPathOperation(t, linux.AT_FDCWD, path, disallowEmptyPath, shouldFollowFinalSymlink)
+ if err != nil {
+ return err
+ }
+ defer tpop.Release()
+
+ name, err := copyInXattrName(t, nameAddr)
+ if err != nil {
+ return err
+ }
+ value, err := copyInXattrValue(t, valueAddr, size)
+ if err != nil {
+ return err
+ }
+
+ return t.Kernel().VFS().SetxattrAt(t, t.Credentials(), &tpop.pop, &vfs.SetxattrOptions{
+ Name: name,
+ Value: value,
+ Flags: uint32(flags),
+ })
+}
+
+// Fsetxattr implements Linux syscall fsetxattr(2).
+func Fsetxattr(t *kernel.Task, args arch.SyscallArguments) (uintptr, *kernel.SyscallControl, error) {
+ fd := args[0].Int()
+ nameAddr := args[1].Pointer()
+ valueAddr := args[2].Pointer()
+ size := args[3].SizeT()
+ flags := args[4].Int()
+
+ if flags&^(linux.XATTR_CREATE|linux.XATTR_REPLACE) != 0 {
+ return 0, nil, syserror.EINVAL
+ }
+
+ file := t.GetFileVFS2(fd)
+ if file == nil {
+ return 0, nil, syserror.EBADF
+ }
+ defer file.DecRef()
+
+ name, err := copyInXattrName(t, nameAddr)
+ if err != nil {
+ return 0, nil, err
+ }
+ value, err := copyInXattrValue(t, valueAddr, size)
+ if err != nil {
+ return 0, nil, err
+ }
+
+ return 0, nil, file.Setxattr(t, vfs.SetxattrOptions{
+ Name: name,
+ Value: value,
+ Flags: uint32(flags),
+ })
+}
+
+// Removexattr implements Linux syscall removexattr(2).
+func Removexattr(t *kernel.Task, args arch.SyscallArguments) (uintptr, *kernel.SyscallControl, error) {
+ return 0, nil, removexattr(t, args, followFinalSymlink)
+}
+
+// Lremovexattr implements Linux syscall lremovexattr(2).
+func Lremovexattr(t *kernel.Task, args arch.SyscallArguments) (uintptr, *kernel.SyscallControl, error) {
+ return 0, nil, removexattr(t, args, nofollowFinalSymlink)
+}
+
+func removexattr(t *kernel.Task, args arch.SyscallArguments, shouldFollowFinalSymlink shouldFollowFinalSymlink) error {
+ pathAddr := args[0].Pointer()
+ nameAddr := args[1].Pointer()
+
+ path, err := copyInPath(t, pathAddr)
+ if err != nil {
+ return err
+ }
+ tpop, err := getTaskPathOperation(t, linux.AT_FDCWD, path, disallowEmptyPath, shouldFollowFinalSymlink)
+ if err != nil {
+ return err
+ }
+ defer tpop.Release()
+
+ name, err := copyInXattrName(t, nameAddr)
+ if err != nil {
+ return err
+ }
+
+ return t.Kernel().VFS().RemovexattrAt(t, t.Credentials(), &tpop.pop, name)
+}
+
+// Fremovexattr implements Linux syscall fremovexattr(2).
+func Fremovexattr(t *kernel.Task, args arch.SyscallArguments) (uintptr, *kernel.SyscallControl, error) {
+ fd := args[0].Int()
+ nameAddr := args[1].Pointer()
+
+ file := t.GetFileVFS2(fd)
+ if file == nil {
+ return 0, nil, syserror.EBADF
+ }
+ defer file.DecRef()
+
+ name, err := copyInXattrName(t, nameAddr)
+ if err != nil {
+ return 0, nil, err
+ }
+
+ return 0, nil, file.Removexattr(t, name)
+}
+
+func copyInXattrName(t *kernel.Task, nameAddr usermem.Addr) (string, error) {
+ name, err := t.CopyInString(nameAddr, linux.XATTR_NAME_MAX+1)
+ if err != nil {
+ if err == syserror.ENAMETOOLONG {
+ return "", syserror.ERANGE
+ }
+ return "", err
+ }
+ if len(name) == 0 {
+ return "", syserror.ERANGE
+ }
+ return name, nil
+}
+
+func copyOutXattrNameList(t *kernel.Task, listAddr usermem.Addr, size uint, names []string) (int, error) {
+ if size > linux.XATTR_LIST_MAX {
+ size = linux.XATTR_LIST_MAX
+ }
+ var buf bytes.Buffer
+ for _, name := range names {
+ buf.WriteString(name)
+ buf.WriteByte(0)
+ }
+ if size == 0 {
+ // Return the size that would be required to accomodate the list.
+ return buf.Len(), nil
+ }
+ if buf.Len() > int(size) {
+ if size >= linux.XATTR_LIST_MAX {
+ return 0, syserror.E2BIG
+ }
+ return 0, syserror.ERANGE
+ }
+ return t.CopyOutBytes(listAddr, buf.Bytes())
+}
+
+func copyInXattrValue(t *kernel.Task, valueAddr usermem.Addr, size uint) (string, error) {
+ if size > linux.XATTR_SIZE_MAX {
+ return "", syserror.E2BIG
+ }
+ buf := make([]byte, size)
+ if _, err := t.CopyInBytes(valueAddr, buf); err != nil {
+ return "", err
+ }
+ return gohacks.StringFromImmutableBytes(buf), nil
+}
+
+func copyOutXattrValue(t *kernel.Task, valueAddr usermem.Addr, size uint, value string) (int, error) {
+ if size > linux.XATTR_SIZE_MAX {
+ size = linux.XATTR_SIZE_MAX
+ }
+ if size == 0 {
+ // Return the size that would be required to accomodate the value.
+ return len(value), nil
+ }
+ if len(value) > int(size) {
+ if size >= linux.XATTR_SIZE_MAX {
+ return 0, syserror.E2BIG
+ }
+ return 0, syserror.ERANGE
+ }
+ return t.CopyOutBytes(valueAddr, gohacks.ImmutableBytesFromString(value))
+}
diff --git a/pkg/sentry/vfs/BUILD b/pkg/sentry/vfs/BUILD
index 0b4f18ab5..07c8383e6 100644
--- a/pkg/sentry/vfs/BUILD
+++ b/pkg/sentry/vfs/BUILD
@@ -43,6 +43,7 @@ go_library(
"//pkg/abi/linux",
"//pkg/context",
"//pkg/fspath",
+ "//pkg/gohacks",
"//pkg/log",
"//pkg/sentry/arch",
"//pkg/sentry/fs/lock",
diff --git a/pkg/sentry/vfs/epoll.go b/pkg/sentry/vfs/epoll.go
index eed41139b..3da45d744 100644
--- a/pkg/sentry/vfs/epoll.go
+++ b/pkg/sentry/vfs/epoll.go
@@ -202,6 +202,9 @@ func (ep *EpollInstance) AddInterest(file *FileDescription, num int32, event lin
// Add epi to file.epolls so that it is removed when the last
// FileDescription reference is dropped.
file.epollMu.Lock()
+ if file.epolls == nil {
+ file.epolls = make(map[*epollInterest]struct{})
+ }
file.epolls[epi] = struct{}{}
file.epollMu.Unlock()
diff --git a/pkg/sentry/vfs/mount_unsafe.go b/pkg/sentry/vfs/mount_unsafe.go
index 1fe766a44..bc7581698 100644
--- a/pkg/sentry/vfs/mount_unsafe.go
+++ b/pkg/sentry/vfs/mount_unsafe.go
@@ -26,6 +26,7 @@ import (
"sync/atomic"
"unsafe"
+ "gvisor.dev/gvisor/pkg/gohacks"
"gvisor.dev/gvisor/pkg/sync"
)
@@ -160,7 +161,7 @@ func newMountTableSlots(cap uintptr) unsafe.Pointer {
// Lookup may be called even if there are concurrent mutators of mt.
func (mt *mountTable) Lookup(parent *Mount, point *Dentry) *Mount {
key := mountKey{parent: unsafe.Pointer(parent), point: unsafe.Pointer(point)}
- hash := memhash(noescape(unsafe.Pointer(&key)), uintptr(mt.seed), mountKeyBytes)
+ hash := memhash(gohacks.Noescape(unsafe.Pointer(&key)), uintptr(mt.seed), mountKeyBytes)
loop:
for {
@@ -361,12 +362,3 @@ func memhash(p unsafe.Pointer, seed, s uintptr) uintptr
//go:linkname rand32 runtime.fastrand
func rand32() uint32
-
-// This is copy/pasted from runtime.noescape(), and is needed because arguments
-// apparently escape from all functions defined by linkname.
-//
-//go:nosplit
-func noescape(p unsafe.Pointer) unsafe.Pointer {
- x := uintptr(p)
- return unsafe.Pointer(x ^ 0)
-}
diff --git a/pkg/sentry/vfs/resolving_path.go b/pkg/sentry/vfs/resolving_path.go
index 8a0b382f6..eb4ebb511 100644
--- a/pkg/sentry/vfs/resolving_path.go
+++ b/pkg/sentry/vfs/resolving_path.go
@@ -228,7 +228,7 @@ func (rp *ResolvingPath) Advance() {
rp.pit = next
} else { // at end of path segment, continue with next one
rp.curPart--
- rp.pit = rp.parts[rp.curPart-1]
+ rp.pit = rp.parts[rp.curPart]
}
}
diff --git a/pkg/sentry/vfs/vfs.go b/pkg/sentry/vfs/vfs.go
index 8f29031b2..73f8043be 100644
--- a/pkg/sentry/vfs/vfs.go
+++ b/pkg/sentry/vfs/vfs.go
@@ -385,15 +385,11 @@ func (vfs *VirtualFilesystem) OpenAt(ctx context.Context, creds *auth.Credential
// Only a regular file can be executed.
stat, err := fd.Stat(ctx, StatOptions{Mask: linux.STATX_TYPE})
if err != nil {
+ fd.DecRef()
return nil, err
}
- if stat.Mask&linux.STATX_TYPE != 0 {
- // This shouldn't happen, but if type can't be retrieved, file can't
- // be executed.
- return nil, syserror.EACCES
- }
- if t := linux.FileMode(stat.Mode).FileType(); t != linux.ModeRegular {
- ctx.Infof("%q is not a regular file: %v", pop.Path, t)
+ if stat.Mask&linux.STATX_TYPE == 0 || stat.Mode&linux.S_IFMT != linux.S_IFREG {
+ fd.DecRef()
return nil, syserror.EACCES
}
}
diff --git a/pkg/syncevent/BUILD b/pkg/syncevent/BUILD
new file mode 100644
index 000000000..0500a22cf
--- /dev/null
+++ b/pkg/syncevent/BUILD
@@ -0,0 +1,39 @@
+load("//tools:defs.bzl", "go_library", "go_test")
+
+licenses(["notice"])
+
+go_library(
+ name = "syncevent",
+ srcs = [
+ "broadcaster.go",
+ "receiver.go",
+ "source.go",
+ "syncevent.go",
+ "waiter_amd64.s",
+ "waiter_arm64.s",
+ "waiter_asm_unsafe.go",
+ "waiter_noasm_unsafe.go",
+ "waiter_unsafe.go",
+ ],
+ visibility = ["//:sandbox"],
+ deps = [
+ "//pkg/atomicbitops",
+ "//pkg/sync",
+ ],
+)
+
+go_test(
+ name = "syncevent_test",
+ size = "small",
+ srcs = [
+ "broadcaster_test.go",
+ "syncevent_example_test.go",
+ "waiter_test.go",
+ ],
+ library = ":syncevent",
+ deps = [
+ "//pkg/sleep",
+ "//pkg/sync",
+ "//pkg/waiter",
+ ],
+)
diff --git a/pkg/syncevent/broadcaster.go b/pkg/syncevent/broadcaster.go
new file mode 100644
index 000000000..4bff59e7d
--- /dev/null
+++ b/pkg/syncevent/broadcaster.go
@@ -0,0 +1,218 @@
+// Copyright 2020 The gVisor Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+package syncevent
+
+import (
+ "gvisor.dev/gvisor/pkg/sync"
+)
+
+// Broadcaster is an implementation of Source that supports any number of
+// subscribed Receivers.
+//
+// The zero value of Broadcaster is valid and has no subscribed Receivers.
+// Broadcaster is not copyable by value.
+//
+// All Broadcaster methods may be called concurrently from multiple goroutines.
+type Broadcaster struct {
+ // Broadcaster is implemented as a hash table where keys are assigned by
+ // the Broadcaster and returned as SubscriptionIDs, making it safe to use
+ // the identity function for hashing. The hash table resolves collisions
+ // using linear probing and features Robin Hood insertion and backward
+ // shift deletion in order to support a relatively high load factor
+ // efficiently, which matters since the cost of Broadcast is linear in the
+ // size of the table.
+
+ // mu protects the following fields.
+ mu sync.Mutex
+
+ // Invariants: len(table) is 0 or a power of 2.
+ table []broadcasterSlot
+
+ // load is the number of entries in table with receiver != nil.
+ load int
+
+ lastID SubscriptionID
+}
+
+type broadcasterSlot struct {
+ // Invariants: If receiver == nil, then filter == NoEvents and id == 0.
+ // Otherwise, id != 0.
+ receiver *Receiver
+ filter Set
+ id SubscriptionID
+}
+
+const (
+ broadcasterMinNonZeroTableSize = 2 // must be a power of 2 > 1
+
+ broadcasterMaxLoadNum = 13
+ broadcasterMaxLoadDen = 16
+)
+
+// SubscribeEvents implements Source.SubscribeEvents.
+func (b *Broadcaster) SubscribeEvents(r *Receiver, filter Set) SubscriptionID {
+ b.mu.Lock()
+
+ // Assign an ID for this subscription.
+ b.lastID++
+ id := b.lastID
+
+ // Expand the table if over the maximum load factor:
+ //
+ // load / len(b.table) > broadcasterMaxLoadNum / broadcasterMaxLoadDen
+ // load * broadcasterMaxLoadDen > broadcasterMaxLoadNum * len(b.table)
+ b.load++
+ if (b.load * broadcasterMaxLoadDen) > (broadcasterMaxLoadNum * len(b.table)) {
+ // Double the number of slots in the new table.
+ newlen := broadcasterMinNonZeroTableSize
+ if len(b.table) != 0 {
+ newlen = 2 * len(b.table)
+ }
+ if newlen <= cap(b.table) {
+ // Reuse excess capacity in the current table, moving entries not
+ // already in their first-probed positions to better ones.
+ newtable := b.table[:newlen]
+ newmask := uint64(newlen - 1)
+ for i := range b.table {
+ if b.table[i].receiver != nil && uint64(b.table[i].id)&newmask != uint64(i) {
+ entry := b.table[i]
+ b.table[i] = broadcasterSlot{}
+ broadcasterTableInsert(newtable, entry.id, entry.receiver, entry.filter)
+ }
+ }
+ b.table = newtable
+ } else {
+ newtable := make([]broadcasterSlot, newlen)
+ // Copy existing entries to the new table.
+ for i := range b.table {
+ if b.table[i].receiver != nil {
+ broadcasterTableInsert(newtable, b.table[i].id, b.table[i].receiver, b.table[i].filter)
+ }
+ }
+ // Switch to the new table.
+ b.table = newtable
+ }
+ }
+
+ broadcasterTableInsert(b.table, id, r, filter)
+ b.mu.Unlock()
+ return id
+}
+
+// Preconditions: table must not be full. len(table) is a power of 2.
+func broadcasterTableInsert(table []broadcasterSlot, id SubscriptionID, r *Receiver, filter Set) {
+ entry := broadcasterSlot{
+ receiver: r,
+ filter: filter,
+ id: id,
+ }
+ mask := uint64(len(table) - 1)
+ i := uint64(id) & mask
+ disp := uint64(0)
+ for {
+ if table[i].receiver == nil {
+ table[i] = entry
+ return
+ }
+ // If we've been displaced farther from our first-probed slot than the
+ // element stored in this one, swap elements and switch to inserting
+ // the replaced one. (This is Robin Hood insertion.)
+ slotDisp := (i - uint64(table[i].id)) & mask
+ if disp > slotDisp {
+ table[i], entry = entry, table[i]
+ disp = slotDisp
+ }
+ i = (i + 1) & mask
+ disp++
+ }
+}
+
+// UnsubscribeEvents implements Source.UnsubscribeEvents.
+func (b *Broadcaster) UnsubscribeEvents(id SubscriptionID) {
+ b.mu.Lock()
+
+ mask := uint64(len(b.table) - 1)
+ i := uint64(id) & mask
+ for {
+ if b.table[i].id == id {
+ // Found the element to remove. Move all subsequent elements
+ // backward until we either find an empty slot, or an element that
+ // is already in its first-probed slot. (This is backward shift
+ // deletion.)
+ for {
+ next := (i + 1) & mask
+ if b.table[next].receiver == nil {
+ break
+ }
+ if uint64(b.table[next].id)&mask == next {
+ break
+ }
+ b.table[i] = b.table[next]
+ i = next
+ }
+ b.table[i] = broadcasterSlot{}
+ break
+ }
+ i = (i + 1) & mask
+ }
+
+ // If a table 1/4 of the current size would still be at or under the
+ // maximum load factor (i.e. the current table size is at least two
+ // expansions bigger than necessary), halve the size of the table to reduce
+ // the cost of Broadcast. Since we are concerned with iteration time and
+ // not memory usage, reuse the existing slice to reduce future allocations
+ // from table re-expansion.
+ b.load--
+ if len(b.table) > broadcasterMinNonZeroTableSize && (b.load*(4*broadcasterMaxLoadDen)) <= (broadcasterMaxLoadNum*len(b.table)) {
+ newlen := len(b.table) / 2
+ newtable := b.table[:newlen]
+ for i := newlen; i < len(b.table); i++ {
+ if b.table[i].receiver != nil {
+ broadcasterTableInsert(newtable, b.table[i].id, b.table[i].receiver, b.table[i].filter)
+ b.table[i] = broadcasterSlot{}
+ }
+ }
+ b.table = newtable
+ }
+
+ b.mu.Unlock()
+}
+
+// Broadcast notifies all Receivers subscribed to the Broadcaster of the subset
+// of events to which they subscribed. The order in which Receivers are
+// notified is unspecified.
+func (b *Broadcaster) Broadcast(events Set) {
+ b.mu.Lock()
+ for i := range b.table {
+ if intersection := events & b.table[i].filter; intersection != 0 {
+ // We don't need to check if broadcasterSlot.receiver is nil, since
+ // if it is then broadcasterSlot.filter is 0.
+ b.table[i].receiver.Notify(intersection)
+ }
+ }
+ b.mu.Unlock()
+}
+
+// FilteredEvents returns the set of events for which Broadcast will notify at
+// least one Receiver, i.e. the union of filters for all subscribed Receivers.
+func (b *Broadcaster) FilteredEvents() Set {
+ var es Set
+ b.mu.Lock()
+ for i := range b.table {
+ es |= b.table[i].filter
+ }
+ b.mu.Unlock()
+ return es
+}
diff --git a/pkg/syncevent/broadcaster_test.go b/pkg/syncevent/broadcaster_test.go
new file mode 100644
index 000000000..e88779e23
--- /dev/null
+++ b/pkg/syncevent/broadcaster_test.go
@@ -0,0 +1,376 @@
+// Copyright 2020 The gVisor Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+package syncevent
+
+import (
+ "fmt"
+ "math/rand"
+ "testing"
+
+ "gvisor.dev/gvisor/pkg/sync"
+ "gvisor.dev/gvisor/pkg/waiter"
+)
+
+func TestBroadcasterFilter(t *testing.T) {
+ const numReceivers = 2 * MaxEvents
+
+ var br Broadcaster
+ ws := make([]Waiter, numReceivers)
+ for i := range ws {
+ ws[i].Init()
+ br.SubscribeEvents(ws[i].Receiver(), 1<<(i%MaxEvents))
+ }
+ for ev := 0; ev < MaxEvents; ev++ {
+ br.Broadcast(1 << ev)
+ for i := range ws {
+ want := NoEvents
+ if i%MaxEvents == ev {
+ want = 1 << ev
+ }
+ if got := ws[i].Receiver().PendingAndAckAll(); got != want {
+ t.Errorf("after Broadcast of event %d: waiter %d has pending event set %#x, wanted %#x", ev, i, got, want)
+ }
+ }
+ }
+}
+
+// TestBroadcasterManySubscriptions tests that subscriptions are not lost by
+// table expansion/compaction.
+func TestBroadcasterManySubscriptions(t *testing.T) {
+ const numReceivers = 5000 // arbitrary
+
+ var br Broadcaster
+ ws := make([]Waiter, numReceivers)
+ for i := range ws {
+ ws[i].Init()
+ }
+
+ ids := make([]SubscriptionID, numReceivers)
+ for i := 0; i < numReceivers; i++ {
+ // Subscribe receiver i.
+ ids[i] = br.SubscribeEvents(ws[i].Receiver(), 1)
+ // Check that receivers [0, i] are subscribed.
+ br.Broadcast(1)
+ for j := 0; j <= i; j++ {
+ if ws[j].Pending() != 1 {
+ t.Errorf("receiver %d did not receive an event after subscription of receiver %d", j, i)
+ }
+ ws[j].Ack(1)
+ }
+ }
+
+ // Generate a random order for unsubscriptions.
+ unsub := rand.Perm(numReceivers)
+ for i := 0; i < numReceivers; i++ {
+ // Unsubscribe receiver unsub[i].
+ br.UnsubscribeEvents(ids[unsub[i]])
+ // Check that receivers [unsub[0], unsub[i]] are not subscribed, and that
+ // receivers (unsub[i], unsub[numReceivers]) are still subscribed.
+ br.Broadcast(1)
+ for j := 0; j <= i; j++ {
+ if ws[unsub[j]].Pending() != 0 {
+ t.Errorf("unsub iteration %d: receiver %d received an event after unsubscription of receiver %d", i, unsub[j], unsub[i])
+ }
+ }
+ for j := i + 1; j < numReceivers; j++ {
+ if ws[unsub[j]].Pending() != 1 {
+ t.Errorf("unsub iteration %d: receiver %d did not receive an event after unsubscription of receiver %d", i, unsub[j], unsub[i])
+ }
+ ws[unsub[j]].Ack(1)
+ }
+ }
+}
+
+var (
+ receiverCountsNonZero = []int{1, 4, 16, 64}
+ receiverCountsIncludingZero = append([]int{0}, receiverCountsNonZero...)
+)
+
+// BenchmarkBroadcasterX, BenchmarkMapX, and BenchmarkQueueX benchmark usage
+// pattern X (described in terms of Broadcaster) with Broadcaster, a
+// Mutex-protected map[*Receiver]Set, and waiter.Queue respectively.
+
+// BenchmarkXxxSubscribeUnsubscribe measures the cost of a Subscribe/Unsubscribe
+// cycle.
+
+func BenchmarkBroadcasterSubscribeUnsubscribe(b *testing.B) {
+ var br Broadcaster
+ var w Waiter
+ w.Init()
+
+ b.ResetTimer()
+ for i := 0; i < b.N; i++ {
+ id := br.SubscribeEvents(w.Receiver(), 1)
+ br.UnsubscribeEvents(id)
+ }
+}
+
+func BenchmarkMapSubscribeUnsubscribe(b *testing.B) {
+ var mu sync.Mutex
+ m := make(map[*Receiver]Set)
+ var w Waiter
+ w.Init()
+
+ b.ResetTimer()
+ for i := 0; i < b.N; i++ {
+ mu.Lock()
+ m[w.Receiver()] = Set(1)
+ mu.Unlock()
+ mu.Lock()
+ delete(m, w.Receiver())
+ mu.Unlock()
+ }
+}
+
+func BenchmarkQueueSubscribeUnsubscribe(b *testing.B) {
+ var q waiter.Queue
+ e, _ := waiter.NewChannelEntry(nil)
+
+ b.ResetTimer()
+ for i := 0; i < b.N; i++ {
+ q.EventRegister(&e, 1)
+ q.EventUnregister(&e)
+ }
+}
+
+// BenchmarkXxxSubscribeUnsubscribeBatch is similar to
+// BenchmarkXxxSubscribeUnsubscribe, but subscribes and unsubscribes a large
+// number of Receivers at a time in order to measure the amortized overhead of
+// table expansion/compaction. (Since waiter.Queue is implemented using a
+// linked list, BenchmarkQueueSubscribeUnsubscribe and
+// BenchmarkQueueSubscribeUnsubscribeBatch should produce nearly the same
+// result.)
+
+const numBatchReceivers = 1000
+
+func BenchmarkBroadcasterSubscribeUnsubscribeBatch(b *testing.B) {
+ var br Broadcaster
+ ws := make([]Waiter, numBatchReceivers)
+ for i := range ws {
+ ws[i].Init()
+ }
+ ids := make([]SubscriptionID, numBatchReceivers)
+
+ // Generate a random order for unsubscriptions.
+ unsub := rand.Perm(numBatchReceivers)
+
+ b.ResetTimer()
+ for i := 0; i < b.N/numBatchReceivers; i++ {
+ for j := 0; j < numBatchReceivers; j++ {
+ ids[j] = br.SubscribeEvents(ws[j].Receiver(), 1)
+ }
+ for j := 0; j < numBatchReceivers; j++ {
+ br.UnsubscribeEvents(ids[unsub[j]])
+ }
+ }
+}
+
+func BenchmarkMapSubscribeUnsubscribeBatch(b *testing.B) {
+ var mu sync.Mutex
+ m := make(map[*Receiver]Set)
+ ws := make([]Waiter, numBatchReceivers)
+ for i := range ws {
+ ws[i].Init()
+ }
+
+ // Generate a random order for unsubscriptions.
+ unsub := rand.Perm(numBatchReceivers)
+
+ b.ResetTimer()
+ for i := 0; i < b.N/numBatchReceivers; i++ {
+ for j := 0; j < numBatchReceivers; j++ {
+ mu.Lock()
+ m[ws[j].Receiver()] = Set(1)
+ mu.Unlock()
+ }
+ for j := 0; j < numBatchReceivers; j++ {
+ mu.Lock()
+ delete(m, ws[unsub[j]].Receiver())
+ mu.Unlock()
+ }
+ }
+}
+
+func BenchmarkQueueSubscribeUnsubscribeBatch(b *testing.B) {
+ var q waiter.Queue
+ es := make([]waiter.Entry, numBatchReceivers)
+ for i := range es {
+ es[i], _ = waiter.NewChannelEntry(nil)
+ }
+
+ // Generate a random order for unsubscriptions.
+ unsub := rand.Perm(numBatchReceivers)
+
+ b.ResetTimer()
+ for i := 0; i < b.N/numBatchReceivers; i++ {
+ for j := 0; j < numBatchReceivers; j++ {
+ q.EventRegister(&es[j], 1)
+ }
+ for j := 0; j < numBatchReceivers; j++ {
+ q.EventUnregister(&es[unsub[j]])
+ }
+ }
+}
+
+// BenchmarkXxxBroadcastRedundant measures how long it takes to Broadcast
+// already-pending events to multiple Receivers.
+
+func BenchmarkBroadcasterBroadcastRedundant(b *testing.B) {
+ for _, n := range receiverCountsIncludingZero {
+ b.Run(fmt.Sprintf("%d", n), func(b *testing.B) {
+ var br Broadcaster
+ ws := make([]Waiter, n)
+ for i := range ws {
+ ws[i].Init()
+ br.SubscribeEvents(ws[i].Receiver(), 1)
+ }
+ br.Broadcast(1)
+
+ b.ResetTimer()
+ for i := 0; i < b.N; i++ {
+ br.Broadcast(1)
+ }
+ })
+ }
+}
+
+func BenchmarkMapBroadcastRedundant(b *testing.B) {
+ for _, n := range receiverCountsIncludingZero {
+ b.Run(fmt.Sprintf("%d", n), func(b *testing.B) {
+ var mu sync.Mutex
+ m := make(map[*Receiver]Set)
+ ws := make([]Waiter, n)
+ for i := range ws {
+ ws[i].Init()
+ m[ws[i].Receiver()] = Set(1)
+ }
+ mu.Lock()
+ for r := range m {
+ r.Notify(1)
+ }
+ mu.Unlock()
+
+ b.ResetTimer()
+ for i := 0; i < b.N; i++ {
+ mu.Lock()
+ for r := range m {
+ r.Notify(1)
+ }
+ mu.Unlock()
+ }
+ })
+ }
+}
+
+func BenchmarkQueueBroadcastRedundant(b *testing.B) {
+ for _, n := range receiverCountsIncludingZero {
+ b.Run(fmt.Sprintf("%d", n), func(b *testing.B) {
+ var q waiter.Queue
+ for i := 0; i < n; i++ {
+ e, _ := waiter.NewChannelEntry(nil)
+ q.EventRegister(&e, 1)
+ }
+ q.Notify(1)
+
+ b.ResetTimer()
+ for i := 0; i < b.N; i++ {
+ q.Notify(1)
+ }
+ })
+ }
+}
+
+// BenchmarkXxxBroadcastAck measures how long it takes to Broadcast events to
+// multiple Receivers, check that all Receivers have received the event, and
+// clear the event from all Receivers.
+
+func BenchmarkBroadcasterBroadcastAck(b *testing.B) {
+ for _, n := range receiverCountsNonZero {
+ b.Run(fmt.Sprintf("%d", n), func(b *testing.B) {
+ var br Broadcaster
+ ws := make([]Waiter, n)
+ for i := range ws {
+ ws[i].Init()
+ br.SubscribeEvents(ws[i].Receiver(), 1)
+ }
+
+ b.ResetTimer()
+ for i := 0; i < b.N; i++ {
+ br.Broadcast(1)
+ for j := range ws {
+ if got, want := ws[j].Pending(), Set(1); got != want {
+ b.Fatalf("Receiver.Pending(): got %#x, wanted %#x", got, want)
+ }
+ ws[j].Ack(1)
+ }
+ }
+ })
+ }
+}
+
+func BenchmarkMapBroadcastAck(b *testing.B) {
+ for _, n := range receiverCountsNonZero {
+ b.Run(fmt.Sprintf("%d", n), func(b *testing.B) {
+ var mu sync.Mutex
+ m := make(map[*Receiver]Set)
+ ws := make([]Waiter, n)
+ for i := range ws {
+ ws[i].Init()
+ m[ws[i].Receiver()] = Set(1)
+ }
+
+ b.ResetTimer()
+ for i := 0; i < b.N; i++ {
+ mu.Lock()
+ for r := range m {
+ r.Notify(1)
+ }
+ mu.Unlock()
+ for j := range ws {
+ if got, want := ws[j].Pending(), Set(1); got != want {
+ b.Fatalf("Receiver.Pending(): got %#x, wanted %#x", got, want)
+ }
+ ws[j].Ack(1)
+ }
+ }
+ })
+ }
+}
+
+func BenchmarkQueueBroadcastAck(b *testing.B) {
+ for _, n := range receiverCountsNonZero {
+ b.Run(fmt.Sprintf("%d", n), func(b *testing.B) {
+ var q waiter.Queue
+ chs := make([]chan struct{}, n)
+ for i := range chs {
+ e, ch := waiter.NewChannelEntry(nil)
+ q.EventRegister(&e, 1)
+ chs[i] = ch
+ }
+
+ b.ResetTimer()
+ for i := 0; i < b.N; i++ {
+ q.Notify(1)
+ for _, ch := range chs {
+ select {
+ case <-ch:
+ default:
+ b.Fatalf("channel did not receive event")
+ }
+ }
+ }
+ })
+ }
+}
diff --git a/pkg/syncevent/receiver.go b/pkg/syncevent/receiver.go
new file mode 100644
index 000000000..5c86e5400
--- /dev/null
+++ b/pkg/syncevent/receiver.go
@@ -0,0 +1,103 @@
+// Copyright 2020 The gVisor Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+package syncevent
+
+import (
+ "sync/atomic"
+
+ "gvisor.dev/gvisor/pkg/atomicbitops"
+)
+
+// Receiver is an event sink that holds pending events and invokes a callback
+// whenever new events become pending. Receiver's methods may be called
+// concurrently from multiple goroutines.
+//
+// Receiver.Init() must be called before first use.
+type Receiver struct {
+ // pending is the set of pending events. pending is accessed using atomic
+ // memory operations.
+ pending uint64
+
+ // cb is notified when new events become pending. cb is immutable after
+ // Init().
+ cb ReceiverCallback
+}
+
+// ReceiverCallback receives callbacks from a Receiver.
+type ReceiverCallback interface {
+ // NotifyPending is called when the corresponding Receiver has new pending
+ // events.
+ //
+ // NotifyPending is called synchronously from Receiver.Notify(), so
+ // implementations must not take locks that may be held by callers of
+ // Receiver.Notify(). NotifyPending may be called concurrently from
+ // multiple goroutines.
+ NotifyPending()
+}
+
+// Init must be called before first use of r.
+func (r *Receiver) Init(cb ReceiverCallback) {
+ r.cb = cb
+}
+
+// Pending returns the set of pending events.
+func (r *Receiver) Pending() Set {
+ return Set(atomic.LoadUint64(&r.pending))
+}
+
+// Notify sets the given events as pending.
+func (r *Receiver) Notify(es Set) {
+ p := Set(atomic.LoadUint64(&r.pending))
+ // Optimization: Skip the atomic CAS on r.pending if all events are
+ // already pending.
+ if p&es == es {
+ return
+ }
+ // When this is uncontended (the common case), CAS is faster than
+ // atomic-OR because the former is inlined and the latter (which we
+ // implement in assembly ourselves) is not.
+ if !atomic.CompareAndSwapUint64(&r.pending, uint64(p), uint64(p|es)) {
+ // If the CAS fails, fall back to atomic-OR.
+ atomicbitops.OrUint64(&r.pending, uint64(es))
+ }
+ r.cb.NotifyPending()
+}
+
+// Ack unsets the given events as pending.
+func (r *Receiver) Ack(es Set) {
+ p := Set(atomic.LoadUint64(&r.pending))
+ // Optimization: Skip the atomic CAS on r.pending if all events are
+ // already not pending.
+ if p&es == 0 {
+ return
+ }
+ // When this is uncontended (the common case), CAS is faster than
+ // atomic-AND because the former is inlined and the latter (which we
+ // implement in assembly ourselves) is not.
+ if !atomic.CompareAndSwapUint64(&r.pending, uint64(p), uint64(p&^es)) {
+ // If the CAS fails, fall back to atomic-AND.
+ atomicbitops.AndUint64(&r.pending, ^uint64(es))
+ }
+}
+
+// PendingAndAckAll unsets all events as pending and returns the set of
+// previously-pending events.
+//
+// PendingAndAckAll should only be used in preference to a call to Pending
+// followed by a conditional call to Ack when the caller expects events to be
+// pending (e.g. after a call to ReceiverCallback.NotifyPending()).
+func (r *Receiver) PendingAndAckAll() Set {
+ return Set(atomic.SwapUint64(&r.pending, 0))
+}
diff --git a/pkg/syncevent/source.go b/pkg/syncevent/source.go
new file mode 100644
index 000000000..ddffb171a
--- /dev/null
+++ b/pkg/syncevent/source.go
@@ -0,0 +1,59 @@
+// Copyright 2020 The gVisor Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+package syncevent
+
+// Source represents an event source.
+type Source interface {
+ // SubscribeEvents causes the Source to notify the given Receiver of the
+ // given subset of events.
+ //
+ // Preconditions: r != nil. The ReceiverCallback for r must not take locks
+ // that are ordered prior to the Source; for example, it cannot call any
+ // Source methods.
+ SubscribeEvents(r *Receiver, filter Set) SubscriptionID
+
+ // UnsubscribeEvents causes the Source to stop notifying the Receiver
+ // subscribed by a previous call to SubscribeEvents that returned the given
+ // SubscriptionID.
+ //
+ // Preconditions: UnsubscribeEvents may be called at most once for any
+ // given SubscriptionID.
+ UnsubscribeEvents(id SubscriptionID)
+}
+
+// SubscriptionID identifies a call to Source.SubscribeEvents.
+type SubscriptionID uint64
+
+// UnsubscribeAndAck is a convenience function that unsubscribes r from the
+// given events from src and also clears them from r.
+func UnsubscribeAndAck(src Source, r *Receiver, filter Set, id SubscriptionID) {
+ src.UnsubscribeEvents(id)
+ r.Ack(filter)
+}
+
+// NoopSource implements Source by never sending events to subscribed
+// Receivers.
+type NoopSource struct{}
+
+// SubscribeEvents implements Source.SubscribeEvents.
+func (NoopSource) SubscribeEvents(*Receiver, Set) SubscriptionID {
+ return 0
+}
+
+// UnsubscribeEvents implements Source.UnsubscribeEvents.
+func (NoopSource) UnsubscribeEvents(SubscriptionID) {
+}
+
+// See Broadcaster for a non-noop implementations of Source.
diff --git a/pkg/syncevent/syncevent.go b/pkg/syncevent/syncevent.go
new file mode 100644
index 000000000..9fb6a06de
--- /dev/null
+++ b/pkg/syncevent/syncevent.go
@@ -0,0 +1,32 @@
+// Copyright 2020 The gVisor Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+// Package syncevent provides efficient primitives for goroutine
+// synchronization based on event bitmasks.
+package syncevent
+
+// Set is a bitmask where each bit represents a distinct user-defined event.
+// The event package does not treat any bits in Set specially.
+type Set uint64
+
+const (
+ // NoEvents is a Set containing no events.
+ NoEvents = Set(0)
+
+ // AllEvents is a Set containing all possible events.
+ AllEvents = ^Set(0)
+
+ // MaxEvents is the number of distinct events that can be represented by a Set.
+ MaxEvents = 64
+)
diff --git a/pkg/syncevent/syncevent_example_test.go b/pkg/syncevent/syncevent_example_test.go
new file mode 100644
index 000000000..bfb18e2ea
--- /dev/null
+++ b/pkg/syncevent/syncevent_example_test.go
@@ -0,0 +1,108 @@
+// Copyright 2020 The gVisor Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+package syncevent
+
+import (
+ "fmt"
+ "sync/atomic"
+ "time"
+)
+
+func Example_ioReadinessInterrputible() {
+ const (
+ evReady = Set(1 << iota)
+ evInterrupt
+ )
+ errNotReady := fmt.Errorf("not ready for I/O")
+
+ // State of some I/O object.
+ var (
+ br Broadcaster
+ ready uint32
+ )
+ doIO := func() error {
+ if atomic.LoadUint32(&ready) == 0 {
+ return errNotReady
+ }
+ return nil
+ }
+ go func() {
+ // The I/O object eventually becomes ready for I/O.
+ time.Sleep(100 * time.Millisecond)
+ // When it does, it first ensures that future calls to isReady() return
+ // true, then broadcasts the readiness event to Receivers.
+ atomic.StoreUint32(&ready, 1)
+ br.Broadcast(evReady)
+ }()
+
+ // Each user of the I/O object owns a Waiter.
+ var w Waiter
+ w.Init()
+ // The Waiter may be asynchronously interruptible, e.g. for signal
+ // handling in the sentry.
+ go func() {
+ time.Sleep(200 * time.Millisecond)
+ w.Receiver().Notify(evInterrupt)
+ }()
+
+ // To use the I/O object:
+ //
+ // Optionally, if the I/O object is likely to be ready, attempt I/O first.
+ err := doIO()
+ if err == nil {
+ // Success, we're done.
+ return /* nil */
+ }
+ if err != errNotReady {
+ // Failure, I/O failed for some reason other than readiness.
+ return /* err */
+ }
+ // Subscribe for readiness events from the I/O object.
+ id := br.SubscribeEvents(w.Receiver(), evReady)
+ // When we are finished blocking, unsubscribe from readiness events and
+ // remove readiness events from the pending event set.
+ defer UnsubscribeAndAck(&br, w.Receiver(), evReady, id)
+ for {
+ // Attempt I/O again. This must be done after the call to SubscribeEvents,
+ // since the I/O object might have become ready between the previous call
+ // to doIO and the call to SubscribeEvents.
+ err = doIO()
+ if err == nil {
+ return /* nil */
+ }
+ if err != errNotReady {
+ return /* err */
+ }
+ // Block until either the I/O object indicates it is ready, or we are
+ // interrupted.
+ events := w.Wait()
+ if events&evInterrupt != 0 {
+ // In the specific case of sentry signal handling, signal delivery
+ // is handled by another system, so we aren't responsible for
+ // acknowledging evInterrupt.
+ return /* errInterrupted */
+ }
+ // Note that, in a concurrent context, the I/O object might become
+ // ready and then not ready again. To handle this:
+ //
+ // - evReady must be acknowledged before calling doIO() again (rather
+ // than after), so that if the I/O object becomes ready *again* after
+ // the call to doIO(), the readiness event is not lost.
+ //
+ // - We must loop instead of just calling doIO() once after receiving
+ // evReady.
+ w.Ack(evReady)
+ }
+}
diff --git a/pkg/syncevent/waiter_amd64.s b/pkg/syncevent/waiter_amd64.s
new file mode 100644
index 000000000..985b56ae5
--- /dev/null
+++ b/pkg/syncevent/waiter_amd64.s
@@ -0,0 +1,32 @@
+// Copyright 2020 The gVisor Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+#include "textflag.h"
+
+// See waiter_noasm_unsafe.go for a description of waiterUnlock.
+//
+// func waiterUnlock(g unsafe.Pointer, wg *unsafe.Pointer) bool
+TEXT ·waiterUnlock(SB),NOSPLIT,$0-24
+ MOVQ g+0(FP), DI
+ MOVQ wg+8(FP), SI
+
+ MOVQ $·preparingG(SB), AX
+ LOCK
+ CMPXCHGQ DI, 0(SI)
+
+ SETEQ AX
+ MOVB AX, ret+16(FP)
+
+ RET
+
diff --git a/pkg/syncevent/waiter_arm64.s b/pkg/syncevent/waiter_arm64.s
new file mode 100644
index 000000000..20d7ac23b
--- /dev/null
+++ b/pkg/syncevent/waiter_arm64.s
@@ -0,0 +1,34 @@
+// Copyright 2020 The gVisor Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+#include "textflag.h"
+
+// See waiter_noasm_unsafe.go for a description of waiterUnlock.
+//
+// func waiterUnlock(g unsafe.Pointer, wg *unsafe.Pointer) bool
+TEXT ·waiterUnlock(SB),NOSPLIT,$0-24
+ MOVD wg+8(FP), R0
+ MOVD $·preparingG(SB), R1
+ MOVD g+0(FP), R2
+again:
+ LDAXR (R0), R3
+ CMP R1, R3
+ BNE ok
+ STLXR R2, (R0), R3
+ CBNZ R3, again
+ok:
+ CSET EQ, R0
+ MOVB R0, ret+16(FP)
+ RET
+
diff --git a/pkg/fspath/builder_unsafe.go b/pkg/syncevent/waiter_asm_unsafe.go
index 75606808d..0995e9053 100644
--- a/pkg/fspath/builder_unsafe.go
+++ b/pkg/syncevent/waiter_asm_unsafe.go
@@ -1,4 +1,4 @@
-// Copyright 2019 The gVisor Authors.
+// Copyright 2020 The gVisor Authors.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
@@ -12,16 +12,13 @@
// See the License for the specific language governing permissions and
// limitations under the License.
-package fspath
+// +build amd64 arm64
+
+package syncevent
import (
"unsafe"
)
-// String returns the accumulated string. No other methods should be called
-// after String.
-func (b *Builder) String() string {
- bs := b.buf[b.start:]
- // Compare strings.Builder.String().
- return *(*string)(unsafe.Pointer(&bs))
-}
+// See waiter_noasm_unsafe.go for a description of waiterUnlock.
+func waiterUnlock(g unsafe.Pointer, wg *unsafe.Pointer) bool
diff --git a/pkg/syncevent/waiter_noasm_unsafe.go b/pkg/syncevent/waiter_noasm_unsafe.go
new file mode 100644
index 000000000..1c4b0e39a
--- /dev/null
+++ b/pkg/syncevent/waiter_noasm_unsafe.go
@@ -0,0 +1,39 @@
+// Copyright 2020 The gVisor Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+// waiterUnlock is called from g0, so when the race detector is enabled,
+// waiterUnlock must be implemented in assembly since no race context is
+// available.
+//
+// +build !race
+// +build !amd64,!arm64
+
+package syncevent
+
+import (
+ "sync/atomic"
+ "unsafe"
+)
+
+// waiterUnlock is the "unlock function" passed to runtime.gopark by
+// Waiter.Wait*. wg is &Waiter.g, and g is a pointer to the calling runtime.g.
+// waiterUnlock returns true if Waiter.Wait should sleep and false if sleeping
+// should be aborted.
+//
+//go:nosplit
+func waiterUnlock(g unsafe.Pointer, wg *unsafe.Pointer) bool {
+ // The only way this CAS can fail is if a call to Waiter.NotifyPending()
+ // has replaced *wg with nil, in which case we should not sleep.
+ return atomic.CompareAndSwapPointer(wg, (unsafe.Pointer)(&preparingG), g)
+}
diff --git a/pkg/syncevent/waiter_test.go b/pkg/syncevent/waiter_test.go
new file mode 100644
index 000000000..3c8cbcdd8
--- /dev/null
+++ b/pkg/syncevent/waiter_test.go
@@ -0,0 +1,414 @@
+// Copyright 2020 The gVisor Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+package syncevent
+
+import (
+ "sync/atomic"
+ "testing"
+ "time"
+
+ "gvisor.dev/gvisor/pkg/sleep"
+ "gvisor.dev/gvisor/pkg/sync"
+)
+
+func TestWaiterAlreadyPending(t *testing.T) {
+ var w Waiter
+ w.Init()
+ want := Set(1)
+ w.Notify(want)
+ if got := w.Wait(); got != want {
+ t.Errorf("Waiter.Wait: got %#x, wanted %#x", got, want)
+ }
+}
+
+func TestWaiterAsyncNotify(t *testing.T) {
+ var w Waiter
+ w.Init()
+ want := Set(1)
+ go func() {
+ time.Sleep(100 * time.Millisecond)
+ w.Notify(want)
+ }()
+ if got := w.Wait(); got != want {
+ t.Errorf("Waiter.Wait: got %#x, wanted %#x", got, want)
+ }
+}
+
+func TestWaiterWaitFor(t *testing.T) {
+ var w Waiter
+ w.Init()
+ evWaited := Set(1)
+ evOther := Set(2)
+ w.Notify(evOther)
+ notifiedEvent := uint32(0)
+ go func() {
+ time.Sleep(100 * time.Millisecond)
+ atomic.StoreUint32(&notifiedEvent, 1)
+ w.Notify(evWaited)
+ }()
+ if got, want := w.WaitFor(evWaited), evWaited|evOther; got != want {
+ t.Errorf("Waiter.WaitFor: got %#x, wanted %#x", got, want)
+ }
+ if atomic.LoadUint32(&notifiedEvent) == 0 {
+ t.Errorf("Waiter.WaitFor returned before goroutine notified waited-for event")
+ }
+}
+
+func TestWaiterWaitAndAckAll(t *testing.T) {
+ var w Waiter
+ w.Init()
+ w.Notify(AllEvents)
+ if got := w.WaitAndAckAll(); got != AllEvents {
+ t.Errorf("Waiter.WaitAndAckAll: got %#x, wanted %#x", got, AllEvents)
+ }
+ if got := w.Pending(); got != NoEvents {
+ t.Errorf("Waiter.WaitAndAckAll did not ack all events: got %#x, wanted 0", got)
+ }
+}
+
+// BenchmarkWaiterX, BenchmarkSleeperX, and BenchmarkChannelX benchmark usage
+// pattern X (described in terms of Waiter) with Waiter, sleep.Sleeper, and
+// buffered chan struct{} respectively. When the maximum number of event
+// sources is relevant, we use 3 event sources because this is representative
+// of the kernel.Task.block() use case: an interrupt source, a timeout source,
+// and the actual event source being waited on.
+
+// Event set used by most benchmarks.
+const evBench Set = 1
+
+// BenchmarkXxxNotifyRedundant measures how long it takes to notify a Waiter of
+// an event that is already pending.
+
+func BenchmarkWaiterNotifyRedundant(b *testing.B) {
+ var w Waiter
+ w.Init()
+ w.Notify(evBench)
+
+ b.ResetTimer()
+ for i := 0; i < b.N; i++ {
+ w.Notify(evBench)
+ }
+}
+
+func BenchmarkSleeperNotifyRedundant(b *testing.B) {
+ var s sleep.Sleeper
+ var w sleep.Waker
+ s.AddWaker(&w, 0)
+ w.Assert()
+
+ b.ResetTimer()
+ for i := 0; i < b.N; i++ {
+ w.Assert()
+ }
+}
+
+func BenchmarkChannelNotifyRedundant(b *testing.B) {
+ ch := make(chan struct{}, 1)
+ ch <- struct{}{}
+
+ b.ResetTimer()
+ for i := 0; i < b.N; i++ {
+ select {
+ case ch <- struct{}{}:
+ default:
+ }
+ }
+}
+
+// BenchmarkXxxNotifyWaitAck measures how long it takes to notify a Waiter an
+// event, return that event using a blocking check, and then unset the event as
+// pending.
+
+func BenchmarkWaiterNotifyWaitAck(b *testing.B) {
+ var w Waiter
+ w.Init()
+
+ b.ResetTimer()
+ for i := 0; i < b.N; i++ {
+ w.Notify(evBench)
+ w.Wait()
+ w.Ack(evBench)
+ }
+}
+
+func BenchmarkSleeperNotifyWaitAck(b *testing.B) {
+ var s sleep.Sleeper
+ var w sleep.Waker
+ s.AddWaker(&w, 0)
+
+ b.ResetTimer()
+ for i := 0; i < b.N; i++ {
+ w.Assert()
+ s.Fetch(true)
+ }
+}
+
+func BenchmarkChannelNotifyWaitAck(b *testing.B) {
+ ch := make(chan struct{}, 1)
+
+ b.ResetTimer()
+ for i := 0; i < b.N; i++ {
+ // notify
+ select {
+ case ch <- struct{}{}:
+ default:
+ }
+
+ // wait + ack
+ <-ch
+ }
+}
+
+// BenchmarkSleeperMultiNotifyWaitAck is equivalent to
+// BenchmarkSleeperNotifyWaitAck, but also includes allocation of a
+// temporary sleep.Waker. This is necessary when multiple goroutines may wait
+// for the same event, since each sleep.Waker can wake only a single
+// sleep.Sleeper.
+//
+// The syncevent package does not require a distinct object for each
+// waiter-waker relationship, so BenchmarkWaiterNotifyWaitAck and
+// BenchmarkWaiterMultiNotifyWaitAck would be identical. The analogous state
+// for channels, runtime.sudog, is inescapably runtime-allocated, so
+// BenchmarkChannelNotifyWaitAck and BenchmarkChannelMultiNotifyWaitAck would
+// also be identical.
+
+func BenchmarkSleeperMultiNotifyWaitAck(b *testing.B) {
+ var s sleep.Sleeper
+ // The sleep package doesn't provide sync.Pool allocation of Wakers;
+ // we do for a fairer comparison.
+ wakerPool := sync.Pool{
+ New: func() interface{} {
+ return &sleep.Waker{}
+ },
+ }
+
+ b.ResetTimer()
+ for i := 0; i < b.N; i++ {
+ w := wakerPool.Get().(*sleep.Waker)
+ s.AddWaker(w, 0)
+ w.Assert()
+ s.Fetch(true)
+ s.Done()
+ wakerPool.Put(w)
+ }
+}
+
+// BenchmarkXxxTempNotifyWaitAck is equivalent to NotifyWaitAck, but also
+// includes allocation of a temporary Waiter. This models the case where a
+// goroutine not already associated with a Waiter needs one in order to block.
+//
+// The analogous state for channels is built into runtime.g, so
+// BenchmarkChannelNotifyWaitAck and BenchmarkChannelTempNotifyWaitAck would be
+// identical.
+
+func BenchmarkWaiterTempNotifyWaitAck(b *testing.B) {
+ b.ResetTimer()
+ for i := 0; i < b.N; i++ {
+ w := GetWaiter()
+ w.Notify(evBench)
+ w.Wait()
+ w.Ack(evBench)
+ PutWaiter(w)
+ }
+}
+
+func BenchmarkSleeperTempNotifyWaitAck(b *testing.B) {
+ // The sleep package doesn't provide sync.Pool allocation of Sleepers;
+ // we do for a fairer comparison.
+ sleeperPool := sync.Pool{
+ New: func() interface{} {
+ return &sleep.Sleeper{}
+ },
+ }
+ var w sleep.Waker
+
+ b.ResetTimer()
+ for i := 0; i < b.N; i++ {
+ s := sleeperPool.Get().(*sleep.Sleeper)
+ s.AddWaker(&w, 0)
+ w.Assert()
+ s.Fetch(true)
+ s.Done()
+ sleeperPool.Put(s)
+ }
+}
+
+// BenchmarkXxxNotifyWaitMultiAck is equivalent to NotifyWaitAck, but allows
+// for multiple event sources.
+
+func BenchmarkWaiterNotifyWaitMultiAck(b *testing.B) {
+ var w Waiter
+ w.Init()
+
+ b.ResetTimer()
+ for i := 0; i < b.N; i++ {
+ w.Notify(evBench)
+ if e := w.Wait(); e != evBench {
+ b.Fatalf("Wait: got %#x, wanted %#x", e, evBench)
+ }
+ w.Ack(evBench)
+ }
+}
+
+func BenchmarkSleeperNotifyWaitMultiAck(b *testing.B) {
+ var s sleep.Sleeper
+ var ws [3]sleep.Waker
+ for i := range ws {
+ s.AddWaker(&ws[i], i)
+ }
+
+ b.ResetTimer()
+ for i := 0; i < b.N; i++ {
+ ws[0].Assert()
+ if id, _ := s.Fetch(true); id != 0 {
+ b.Fatalf("Fetch: got %d, wanted 0", id)
+ }
+ }
+}
+
+func BenchmarkChannelNotifyWaitMultiAck(b *testing.B) {
+ ch0 := make(chan struct{}, 1)
+ ch1 := make(chan struct{}, 1)
+ ch2 := make(chan struct{}, 1)
+
+ b.ResetTimer()
+ for i := 0; i < b.N; i++ {
+ // notify
+ select {
+ case ch0 <- struct{}{}:
+ default:
+ }
+
+ // wait + clear
+ select {
+ case <-ch0:
+ // ok
+ case <-ch1:
+ b.Fatalf("received from ch1")
+ case <-ch2:
+ b.Fatalf("received from ch2")
+ }
+ }
+}
+
+// BenchmarkXxxNotifyAsyncWaitAck measures how long it takes to wait for an
+// event while another goroutine signals the event. This assumes that a new
+// goroutine doesn't run immediately (i.e. the creator of a new goroutine is
+// allowed to go to sleep before the new goroutine has a chance to run).
+
+func BenchmarkWaiterNotifyAsyncWaitAck(b *testing.B) {
+ var w Waiter
+ w.Init()
+
+ b.ResetTimer()
+ for i := 0; i < b.N; i++ {
+ go func() {
+ w.Notify(1)
+ }()
+ w.Wait()
+ w.Ack(evBench)
+ }
+}
+
+func BenchmarkSleeperNotifyAsyncWaitAck(b *testing.B) {
+ var s sleep.Sleeper
+ var w sleep.Waker
+ s.AddWaker(&w, 0)
+
+ b.ResetTimer()
+ for i := 0; i < b.N; i++ {
+ go func() {
+ w.Assert()
+ }()
+ s.Fetch(true)
+ }
+}
+
+func BenchmarkChannelNotifyAsyncWaitAck(b *testing.B) {
+ ch := make(chan struct{}, 1)
+
+ b.ResetTimer()
+ for i := 0; i < b.N; i++ {
+ go func() {
+ select {
+ case ch <- struct{}{}:
+ default:
+ }
+ }()
+ <-ch
+ }
+}
+
+// BenchmarkXxxNotifyAsyncWaitMultiAck is equivalent to NotifyAsyncWaitAck, but
+// allows for multiple event sources.
+
+func BenchmarkWaiterNotifyAsyncWaitMultiAck(b *testing.B) {
+ var w Waiter
+ w.Init()
+
+ b.ResetTimer()
+ for i := 0; i < b.N; i++ {
+ go func() {
+ w.Notify(evBench)
+ }()
+ if e := w.Wait(); e != evBench {
+ b.Fatalf("Wait: got %#x, wanted %#x", e, evBench)
+ }
+ w.Ack(evBench)
+ }
+}
+
+func BenchmarkSleeperNotifyAsyncWaitMultiAck(b *testing.B) {
+ var s sleep.Sleeper
+ var ws [3]sleep.Waker
+ for i := range ws {
+ s.AddWaker(&ws[i], i)
+ }
+
+ b.ResetTimer()
+ for i := 0; i < b.N; i++ {
+ go func() {
+ ws[0].Assert()
+ }()
+ if id, _ := s.Fetch(true); id != 0 {
+ b.Fatalf("Fetch: got %d, expected 0", id)
+ }
+ }
+}
+
+func BenchmarkChannelNotifyAsyncWaitMultiAck(b *testing.B) {
+ ch0 := make(chan struct{}, 1)
+ ch1 := make(chan struct{}, 1)
+ ch2 := make(chan struct{}, 1)
+
+ b.ResetTimer()
+ for i := 0; i < b.N; i++ {
+ go func() {
+ select {
+ case ch0 <- struct{}{}:
+ default:
+ }
+ }()
+
+ select {
+ case <-ch0:
+ // ok
+ case <-ch1:
+ b.Fatalf("received from ch1")
+ case <-ch2:
+ b.Fatalf("received from ch2")
+ }
+ }
+}
diff --git a/pkg/syncevent/waiter_unsafe.go b/pkg/syncevent/waiter_unsafe.go
new file mode 100644
index 000000000..112e0e604
--- /dev/null
+++ b/pkg/syncevent/waiter_unsafe.go
@@ -0,0 +1,206 @@
+// Copyright 2020 The gVisor Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+// +build go1.11
+// +build !go1.15
+
+// Check go:linkname function signatures when updating Go version.
+
+package syncevent
+
+import (
+ "sync/atomic"
+ "unsafe"
+
+ "gvisor.dev/gvisor/pkg/sync"
+)
+
+//go:linkname gopark runtime.gopark
+func gopark(unlockf func(unsafe.Pointer, *unsafe.Pointer) bool, wg *unsafe.Pointer, reason uint8, traceEv byte, traceskip int)
+
+//go:linkname goready runtime.goready
+func goready(g unsafe.Pointer, traceskip int)
+
+const (
+ waitReasonSelect = 9 // Go: src/runtime/runtime2.go
+ traceEvGoBlockSelect = 24 // Go: src/runtime/trace.go
+)
+
+// Waiter allows a goroutine to block on pending events received by a Receiver.
+//
+// Waiter.Init() must be called before first use.
+type Waiter struct {
+ r Receiver
+
+ // g is one of:
+ //
+ // - nil: No goroutine is blocking in Wait.
+ //
+ // - &preparingG: A goroutine is in Wait preparing to sleep, but hasn't yet
+ // completed waiterUnlock(). Thus the wait can only be interrupted by
+ // replacing the value of g with nil (the G may not be in state Gwaiting
+ // yet, so we can't call goready.)
+ //
+ // - Otherwise: g is a pointer to the runtime.g in state Gwaiting for the
+ // goroutine blocked in Wait, which can only be woken by calling goready.
+ g unsafe.Pointer `state:"zerovalue"`
+}
+
+// Sentinel object for Waiter.g.
+var preparingG struct{}
+
+// Init must be called before first use of w.
+func (w *Waiter) Init() {
+ w.r.Init(w)
+}
+
+// Receiver returns the Receiver that receives events that unblock calls to
+// w.Wait().
+func (w *Waiter) Receiver() *Receiver {
+ return &w.r
+}
+
+// Pending returns the set of pending events.
+func (w *Waiter) Pending() Set {
+ return w.r.Pending()
+}
+
+// Wait blocks until at least one event is pending, then returns the set of
+// pending events. It does not affect the set of pending events; callers must
+// call w.Ack() to do so, or use w.WaitAndAck() instead.
+//
+// Precondition: Only one goroutine may call any Wait* method at a time.
+func (w *Waiter) Wait() Set {
+ return w.WaitFor(AllEvents)
+}
+
+// WaitFor blocks until at least one event in es is pending, then returns the
+// set of pending events (including those not in es). It does not affect the
+// set of pending events; callers must call w.Ack() to do so.
+//
+// Precondition: Only one goroutine may call any Wait* method at a time.
+func (w *Waiter) WaitFor(es Set) Set {
+ for {
+ // Optimization: Skip the atomic store to w.g if an event is already
+ // pending.
+ if p := w.r.Pending(); p&es != NoEvents {
+ return p
+ }
+
+ // Indicate that we're preparing to go to sleep.
+ atomic.StorePointer(&w.g, (unsafe.Pointer)(&preparingG))
+
+ // If an event is pending, abort the sleep.
+ if p := w.r.Pending(); p&es != NoEvents {
+ atomic.StorePointer(&w.g, nil)
+ return p
+ }
+
+ // If w.g is still preparingG (i.e. w.NotifyPending() has not been
+ // called or has not reached atomic.SwapPointer()), go to sleep until
+ // w.NotifyPending() => goready().
+ gopark(waiterUnlock, &w.g, waitReasonSelect, traceEvGoBlockSelect, 0)
+ }
+}
+
+// Ack marks the given events as not pending.
+func (w *Waiter) Ack(es Set) {
+ w.r.Ack(es)
+}
+
+// WaitAndAckAll blocks until at least one event is pending, then marks all
+// events as not pending and returns the set of previously-pending events.
+//
+// Precondition: Only one goroutine may call any Wait* method at a time.
+func (w *Waiter) WaitAndAckAll() Set {
+ // Optimization: Skip the atomic store to w.g if an event is already
+ // pending. Call Pending() first since, in the common case that events are
+ // not yet pending, this skips an atomic swap on w.r.pending.
+ if w.r.Pending() != NoEvents {
+ if p := w.r.PendingAndAckAll(); p != NoEvents {
+ return p
+ }
+ }
+
+ for {
+ // Indicate that we're preparing to go to sleep.
+ atomic.StorePointer(&w.g, (unsafe.Pointer)(&preparingG))
+
+ // If an event is pending, abort the sleep.
+ if w.r.Pending() != NoEvents {
+ if p := w.r.PendingAndAckAll(); p != NoEvents {
+ atomic.StorePointer(&w.g, nil)
+ return p
+ }
+ }
+
+ // If w.g is still preparingG (i.e. w.NotifyPending() has not been
+ // called or has not reached atomic.SwapPointer()), go to sleep until
+ // w.NotifyPending() => goready().
+ gopark(waiterUnlock, &w.g, waitReasonSelect, traceEvGoBlockSelect, 0)
+
+ // Check for pending events. We call PendingAndAckAll() directly now since
+ // we only expect to be woken after events become pending.
+ if p := w.r.PendingAndAckAll(); p != NoEvents {
+ return p
+ }
+ }
+}
+
+// Notify marks the given events as pending, possibly unblocking concurrent
+// calls to w.Wait() or w.WaitFor().
+func (w *Waiter) Notify(es Set) {
+ w.r.Notify(es)
+}
+
+// NotifyPending implements ReceiverCallback.NotifyPending. Users of Waiter
+// should not call NotifyPending.
+func (w *Waiter) NotifyPending() {
+ // Optimization: Skip the atomic swap on w.g if there is no sleeping
+ // goroutine. NotifyPending is called after w.r.Pending() is updated, so
+ // concurrent and future calls to w.Wait() will observe pending events and
+ // abort sleeping.
+ if atomic.LoadPointer(&w.g) == nil {
+ return
+ }
+ // Wake a sleeping G, or prevent a G that is preparing to sleep from doing
+ // so. Swap is needed here to ensure that only one call to NotifyPending
+ // calls goready.
+ if g := atomic.SwapPointer(&w.g, nil); g != nil && g != (unsafe.Pointer)(&preparingG) {
+ goready(g, 0)
+ }
+}
+
+var waiterPool = sync.Pool{
+ New: func() interface{} {
+ w := &Waiter{}
+ w.Init()
+ return w
+ },
+}
+
+// GetWaiter returns an unused Waiter. PutWaiter should be called to release
+// the Waiter once it is no longer needed.
+//
+// Where possible, users should prefer to associate each goroutine that calls
+// Waiter.Wait() with a distinct pre-allocated Waiter to avoid allocation of
+// Waiters in hot paths.
+func GetWaiter() *Waiter {
+ return waiterPool.Get().(*Waiter)
+}
+
+// PutWaiter releases an unused Waiter previously returned by GetWaiter.
+func PutWaiter(w *Waiter) {
+ waiterPool.Put(w)
+}
diff --git a/pkg/syserror/syserror.go b/pkg/syserror/syserror.go
index 2269f6237..4b5a0fca6 100644
--- a/pkg/syserror/syserror.go
+++ b/pkg/syserror/syserror.go
@@ -29,6 +29,7 @@ var (
EACCES = error(syscall.EACCES)
EAGAIN = error(syscall.EAGAIN)
EBADF = error(syscall.EBADF)
+ EBADFD = error(syscall.EBADFD)
EBUSY = error(syscall.EBUSY)
ECHILD = error(syscall.ECHILD)
ECONNREFUSED = error(syscall.ECONNREFUSED)
diff --git a/pkg/tcpip/adapters/gonet/gonet_test.go b/pkg/tcpip/adapters/gonet/gonet_test.go
index ea0a0409a..3c552988a 100644
--- a/pkg/tcpip/adapters/gonet/gonet_test.go
+++ b/pkg/tcpip/adapters/gonet/gonet_test.go
@@ -127,6 +127,10 @@ func TestCloseReader(t *testing.T) {
if err != nil {
t.Fatalf("newLoopbackStack() = %v", err)
}
+ defer func() {
+ s.Close()
+ s.Wait()
+ }()
addr := tcpip.FullAddress{NICID, tcpip.Address(net.IPv4(169, 254, 10, 1).To4()), 11211}
@@ -175,6 +179,10 @@ func TestCloseReaderWithForwarder(t *testing.T) {
if err != nil {
t.Fatalf("newLoopbackStack() = %v", err)
}
+ defer func() {
+ s.Close()
+ s.Wait()
+ }()
addr := tcpip.FullAddress{NICID, tcpip.Address(net.IPv4(169, 254, 10, 1).To4()), 11211}
s.AddAddress(NICID, ipv4.ProtocolNumber, addr.Addr)
@@ -225,30 +233,21 @@ func TestCloseRead(t *testing.T) {
if terr != nil {
t.Fatalf("newLoopbackStack() = %v", terr)
}
+ defer func() {
+ s.Close()
+ s.Wait()
+ }()
addr := tcpip.FullAddress{NICID, tcpip.Address(net.IPv4(169, 254, 10, 1).To4()), 11211}
s.AddAddress(NICID, ipv4.ProtocolNumber, addr.Addr)
fwd := tcp.NewForwarder(s, 30000, 10, func(r *tcp.ForwarderRequest) {
var wq waiter.Queue
- ep, err := r.CreateEndpoint(&wq)
+ _, err := r.CreateEndpoint(&wq)
if err != nil {
t.Fatalf("r.CreateEndpoint() = %v", err)
}
- defer ep.Close()
- r.Complete(false)
-
- c := NewTCPConn(&wq, ep)
-
- buf := make([]byte, 256)
- n, e := c.Read(buf)
- if e != nil || string(buf[:n]) != "abc123" {
- t.Fatalf("c.Read() = (%d, %v), want (6, nil)", n, e)
- }
-
- if n, e = c.Write([]byte("abc123")); e != nil {
- t.Errorf("c.Write() = (%d, %v), want (6, nil)", n, e)
- }
+ // Endpoint will be closed in deferred s.Close (above).
})
s.SetTransportProtocolHandler(tcp.ProtocolNumber, fwd.HandlePacket)
@@ -278,6 +277,10 @@ func TestCloseWrite(t *testing.T) {
if terr != nil {
t.Fatalf("newLoopbackStack() = %v", terr)
}
+ defer func() {
+ s.Close()
+ s.Wait()
+ }()
addr := tcpip.FullAddress{NICID, tcpip.Address(net.IPv4(169, 254, 10, 1).To4()), 11211}
s.AddAddress(NICID, ipv4.ProtocolNumber, addr.Addr)
@@ -334,6 +337,10 @@ func TestUDPForwarder(t *testing.T) {
if terr != nil {
t.Fatalf("newLoopbackStack() = %v", terr)
}
+ defer func() {
+ s.Close()
+ s.Wait()
+ }()
ip1 := tcpip.Address(net.IPv4(169, 254, 10, 1).To4())
addr1 := tcpip.FullAddress{NICID, ip1, 11211}
@@ -391,6 +398,10 @@ func TestDeadlineChange(t *testing.T) {
if err != nil {
t.Fatalf("newLoopbackStack() = %v", err)
}
+ defer func() {
+ s.Close()
+ s.Wait()
+ }()
addr := tcpip.FullAddress{NICID, tcpip.Address(net.IPv4(169, 254, 10, 1).To4()), 11211}
@@ -440,6 +451,10 @@ func TestPacketConnTransfer(t *testing.T) {
if e != nil {
t.Fatalf("newLoopbackStack() = %v", e)
}
+ defer func() {
+ s.Close()
+ s.Wait()
+ }()
ip1 := tcpip.Address(net.IPv4(169, 254, 10, 1).To4())
addr1 := tcpip.FullAddress{NICID, ip1, 11211}
@@ -492,6 +507,10 @@ func TestConnectedPacketConnTransfer(t *testing.T) {
if e != nil {
t.Fatalf("newLoopbackStack() = %v", e)
}
+ defer func() {
+ s.Close()
+ s.Wait()
+ }()
ip := tcpip.Address(net.IPv4(169, 254, 10, 1).To4())
addr := tcpip.FullAddress{NICID, ip, 11211}
@@ -562,6 +581,8 @@ func makePipe() (c1, c2 net.Conn, stop func(), err error) {
stop = func() {
c1.Close()
c2.Close()
+ s.Close()
+ s.Wait()
}
if err := l.Close(); err != nil {
@@ -624,6 +645,10 @@ func TestTCPDialError(t *testing.T) {
if e != nil {
t.Fatalf("newLoopbackStack() = %v", e)
}
+ defer func() {
+ s.Close()
+ s.Wait()
+ }()
ip := tcpip.Address(net.IPv4(169, 254, 10, 1).To4())
addr := tcpip.FullAddress{NICID, ip, 11211}
@@ -641,6 +666,10 @@ func TestDialContextTCPCanceled(t *testing.T) {
if err != nil {
t.Fatalf("newLoopbackStack() = %v", err)
}
+ defer func() {
+ s.Close()
+ s.Wait()
+ }()
addr := tcpip.FullAddress{NICID, tcpip.Address(net.IPv4(169, 254, 10, 1).To4()), 11211}
s.AddAddress(NICID, ipv4.ProtocolNumber, addr.Addr)
@@ -659,6 +688,10 @@ func TestDialContextTCPTimeout(t *testing.T) {
if err != nil {
t.Fatalf("newLoopbackStack() = %v", err)
}
+ defer func() {
+ s.Close()
+ s.Wait()
+ }()
addr := tcpip.FullAddress{NICID, tcpip.Address(net.IPv4(169, 254, 10, 1).To4()), 11211}
s.AddAddress(NICID, ipv4.ProtocolNumber, addr.Addr)
diff --git a/pkg/tcpip/buffer/view.go b/pkg/tcpip/buffer/view.go
index 150310c11..17e94c562 100644
--- a/pkg/tcpip/buffer/view.go
+++ b/pkg/tcpip/buffer/view.go
@@ -156,3 +156,9 @@ func (vv *VectorisedView) Append(vv2 VectorisedView) {
vv.views = append(vv.views, vv2.views...)
vv.size += vv2.size
}
+
+// AppendView appends the given view into this vectorised view.
+func (vv *VectorisedView) AppendView(v View) {
+ vv.views = append(vv.views, v)
+ vv.size += len(v)
+}
diff --git a/pkg/tcpip/checker/checker.go b/pkg/tcpip/checker/checker.go
index 4d6ae0871..c6c160dfc 100644
--- a/pkg/tcpip/checker/checker.go
+++ b/pkg/tcpip/checker/checker.go
@@ -161,6 +161,20 @@ func FragmentFlags(flags uint8) NetworkChecker {
}
}
+// ReceiveTClass creates a checker that checks the TCLASS field in
+// ControlMessages.
+func ReceiveTClass(want uint32) ControlMessagesChecker {
+ return func(t *testing.T, cm tcpip.ControlMessages) {
+ t.Helper()
+ if !cm.HasTClass {
+ t.Fatalf("got cm.HasTClass = %t, want cm.TClass = %d", cm.HasTClass, want)
+ }
+ if got := cm.TClass; got != want {
+ t.Fatalf("got cm.TClass = %d, want %d", got, want)
+ }
+ }
+}
+
// ReceiveTOS creates a checker that checks the TOS field in ControlMessages.
func ReceiveTOS(want uint8) ControlMessagesChecker {
return func(t *testing.T, cm tcpip.ControlMessages) {
diff --git a/pkg/tcpip/iptables/iptables.go b/pkg/tcpip/iptables/iptables.go
index f7dc4f720..80ddbd442 100644
--- a/pkg/tcpip/iptables/iptables.go
+++ b/pkg/tcpip/iptables/iptables.go
@@ -156,25 +156,54 @@ func EmptyNatTable() Table {
}
}
+// A chainVerdict is what a table decides should be done with a packet.
+type chainVerdict int
+
+const (
+ // chainAccept indicates the packet should continue through netstack.
+ chainAccept chainVerdict = iota
+
+ // chainAccept indicates the packet should be dropped.
+ chainDrop
+
+ // chainReturn indicates the packet should return to the calling chain
+ // or the underflow rule of a builtin chain.
+ chainReturn
+)
+
+
// Check runs pkt through the rules for hook. It returns true when the packet
// should continue traversing the network stack and false when it should be
// dropped.
//
// Precondition: pkt.NetworkHeader is set.
func (it *IPTables) Check(hook Hook, pkt tcpip.PacketBuffer) bool {
- // TODO(gvisor.dev/issue/170): A lot of this is uncomplicated because
- // we're missing features. Jumps, the call stack, etc. aren't checked
- // for yet because we're yet to support them.
-
// Go through each table containing the hook.
for _, tablename := range it.Priorities[hook] {
- switch verdict := it.checkTable(hook, pkt, tablename); verdict {
+ table := it.Tables[tablename]
+ ruleIdx := table.BuiltinChains[hook]
+ switch verdict := it.checkChain(hook, pkt, table, ruleIdx); verdict {
// If the table returns Accept, move on to the next table.
- case TableAccept:
+ case chainAccept:
continue
// The Drop verdict is final.
- case TableDrop:
+ case chainDrop:
return false
+ case chainReturn:
+ // Any Return from a built-in chain means we have to
+ // call the underflow.
+ underflow := table.Rules[table.Underflows[hook]]
+ switch v, _ := underflow.Target.Action(pkt); v {
+ case RuleAccept:
+ continue
+ case RuleDrop:
+ return false
+ case RuleJump, RuleReturn:
+ panic("Underflows should only return RuleAccept or RuleDrop.")
+ default:
+ panic(fmt.Sprintf("Unknown verdict: %d", v))
+ }
+
default:
panic(fmt.Sprintf("Unknown verdict %v.", verdict))
}
@@ -185,37 +214,37 @@ func (it *IPTables) Check(hook Hook, pkt tcpip.PacketBuffer) bool {
}
// Precondition: pkt.NetworkHeader is set.
-func (it *IPTables) checkTable(hook Hook, pkt tcpip.PacketBuffer, tablename string) TableVerdict {
+func (it *IPTables) checkChain(hook Hook, pkt tcpip.PacketBuffer, table Table, ruleIdx int) chainVerdict {
// Start from ruleIdx and walk the list of rules until a rule gives us
// a verdict.
- table := it.Tables[tablename]
- for ruleIdx := table.BuiltinChains[hook]; ruleIdx < len(table.Rules); ruleIdx++ {
- switch verdict := it.checkRule(hook, pkt, table, ruleIdx); verdict {
+ for ruleIdx < len(table.Rules) {
+ switch verdict, jumpTo := it.checkRule(hook, pkt, table, ruleIdx); verdict {
case RuleAccept:
- return TableAccept
+ return chainAccept
case RuleDrop:
- return TableDrop
-
- case RuleContinue:
- continue
+ return chainDrop
case RuleReturn:
- // TODO(gvisor.dev/issue/170): We don't implement jump
- // yet, so any Return is from a built-in chain. That
- // means we have to to call the underflow.
- underflow := table.Rules[table.Underflows[hook]]
- // Underflow is guaranteed to be an unconditional
- // ACCEPT or DROP.
- switch v, _ := underflow.Target.Action(pkt, underflow.Filter); v {
- case RuleAccept:
- return TableAccept
- case RuleDrop:
- return TableDrop
- case RuleContinue, RuleReturn:
- panic("Underflows should only return RuleAccept or RuleDrop.")
+ return chainReturn
+
+ case RuleJump:
+ // "Jumping" to the next rule just means we're
+ // continuing on down the list.
+ if jumpTo == ruleIdx+1 {
+ ruleIdx++
+ continue
+ }
+ switch verdict := it.checkChain(hook, pkt, table, jumpTo); verdict {
+ case chainAccept:
+ return chainAccept
+ case chainDrop:
+ return chainDrop
+ case chainReturn:
+ ruleIdx++
+ continue
default:
- panic(fmt.Sprintf("Unknown verdict: %d", v))
+ panic(fmt.Sprintf("Unknown verdict: %d", verdict))
}
default:
@@ -226,11 +255,11 @@ func (it *IPTables) checkTable(hook Hook, pkt tcpip.PacketBuffer, tablename stri
// We got through the entire table without a decision. Default to DROP
// for safety.
- return TableDrop
+ return chainDrop
}
// Precondition: pk.NetworkHeader is set.
-func (it *IPTables) checkRule(hook Hook, pkt tcpip.PacketBuffer, table Table, ruleIdx int) RuleVerdict {
+func (it *IPTables) checkRule(hook Hook, pkt tcpip.PacketBuffer, table Table, ruleIdx int) (RuleVerdict, int) {
rule := table.Rules[ruleIdx]
// If pkt.NetworkHeader hasn't been set yet, it will be contained in
@@ -242,7 +271,8 @@ func (it *IPTables) checkRule(hook Hook, pkt tcpip.PacketBuffer, table Table, ru
// First check whether the packet matches the IP header filter.
// TODO(gvisor.dev/issue/170): Support other fields of the filter.
if rule.Filter.Protocol != 0 && rule.Filter.Protocol != header.IPv4(pkt.NetworkHeader).TransportProtocol() {
- return RuleContinue
+ // Continue on to the next rule.
+ return RuleJump, ruleIdx + 1
}
// Go through each rule matcher. If they all match, run
@@ -250,14 +280,14 @@ func (it *IPTables) checkRule(hook Hook, pkt tcpip.PacketBuffer, table Table, ru
for _, matcher := range rule.Matchers {
matches, hotdrop := matcher.Match(hook, pkt, "")
if hotdrop {
- return RuleDrop
+ return RuleDrop, 0
}
if !matches {
- return RuleContinue
+ // Continue on to the next rule.
+ return RuleJump, ruleIdx + 1
}
}
// All the matchers matched, so run the target.
- verdict, _ := rule.Target.Action(pkt, rule.Filter)
- return verdict
+ return rule.Target.Action(pkt, rule.Filter)
}
diff --git a/pkg/tcpip/iptables/targets.go b/pkg/tcpip/iptables/targets.go
index a75938da3..5dbb28145 100644
--- a/pkg/tcpip/iptables/targets.go
+++ b/pkg/tcpip/iptables/targets.go
@@ -12,8 +12,6 @@
// See the License for the specific language governing permissions and
// limitations under the License.
-// This file contains various Targets.
-
package iptables
import (
@@ -26,16 +24,16 @@ import (
type AcceptTarget struct{}
// Action implements Target.Action.
-func (AcceptTarget) Action(packet tcpip.PacketBuffer, filter IPHeaderFilter) (RuleVerdict, string) {
- return RuleAccept, ""
+func (AcceptTarget) Action(packet tcpip.PacketBuffer, filter IPHeaderFilter) (RuleVerdict, int) {
+ return RuleAccept, 0
}
// DropTarget drops packets.
type DropTarget struct{}
// Action implements Target.Action.
-func (DropTarget) Action(packet tcpip.PacketBuffer, filter IPHeaderFilter) (RuleVerdict, string) {
- return RuleDrop, ""
+func (DropTarget) Action(packet tcpip.PacketBuffer, filter IPHeaderFilter) (RuleVerdict, int) {
+ return RuleDrop, 0
}
// ErrorTarget logs an error and drops the packet. It represents a target that
@@ -43,9 +41,9 @@ func (DropTarget) Action(packet tcpip.PacketBuffer, filter IPHeaderFilter) (Rule
type ErrorTarget struct{}
// Action implements Target.Action.
-func (ErrorTarget) Action(packet tcpip.PacketBuffer, filter IPHeaderFilter) (RuleVerdict, string) {
+func (ErrorTarget) Action(packet tcpip.PacketBuffer, filter IPHeaderFilter) (RuleVerdict, int) {
log.Debugf("ErrorTarget triggered.")
- return RuleDrop, ""
+ return RuleDrop, 0
}
// UserChainTarget marks a rule as the beginning of a user chain.
@@ -54,7 +52,7 @@ type UserChainTarget struct {
}
// Action implements Target.Action.
-func (UserChainTarget) Action(tcpip.PacketBuffer, IPHeaderFilter) (RuleVerdict, string) {
+func (UserChainTarget) Action(tcpip.PacketBuffer, IPHeaderFilter) (RuleVerdict, int) {
panic("UserChainTarget should never be called.")
}
@@ -63,8 +61,8 @@ func (UserChainTarget) Action(tcpip.PacketBuffer, IPHeaderFilter) (RuleVerdict,
type ReturnTarget struct{}
// Action implements Target.Action.
-func (ReturnTarget) Action(tcpip.PacketBuffer, IPHeaderFilter) (RuleVerdict, string) {
- return RuleReturn, ""
+func (ReturnTarget) Action(tcpip.PacketBuffer, IPHeaderFilter) (RuleVerdict, int) {
+ return RuleReturn, 0
}
// RedirectTarget redirects the packet by modifying the destination port/IP.
@@ -88,13 +86,13 @@ type RedirectTarget struct {
}
// Action implements Target.Action.
-func (rt RedirectTarget) Action(pkt tcpip.PacketBuffer, filter IPHeaderFilter) (RuleVerdict, string) {
+func (rt RedirectTarget) Action(pkt tcpip.PacketBuffer, filter IPHeaderFilter) (RuleVerdict, int) {
headerView := pkt.Data.First()
// Network header should be set.
netHeader := header.IPv4(headerView)
if netHeader == nil {
- return RuleDrop, ""
+ return RuleDrop, 0
}
// TODO(gvisor.dev/issue/170): Check Flags in RedirectTarget if
@@ -111,7 +109,7 @@ func (rt RedirectTarget) Action(pkt tcpip.PacketBuffer, filter IPHeaderFilter) (
tcp := header.TCP(headerView[hlen:])
tcp.SetDestinationPort(rt.MinPort)
default:
- return RuleDrop, ""
+ return RuleDrop, 0
}
- return RuleAccept, ""
+ return RuleAccept, 0
}
diff --git a/pkg/tcpip/iptables/types.go b/pkg/tcpip/iptables/types.go
index 0102831d0..8bd3a2c94 100644
--- a/pkg/tcpip/iptables/types.go
+++ b/pkg/tcpip/iptables/types.go
@@ -56,17 +56,6 @@ const (
NumHooks
)
-// A TableVerdict is what a table decides should be done with a packet.
-type TableVerdict int
-
-const (
- // TableAccept indicates the packet should continue through netstack.
- TableAccept TableVerdict = iota
-
- // TableDrop indicates the packet should be dropped.
- TableDrop
-)
-
// A RuleVerdict is what a rule decides should be done with a packet.
type RuleVerdict int
@@ -74,12 +63,12 @@ const (
// RuleAccept indicates the packet should continue through netstack.
RuleAccept RuleVerdict = iota
- // RuleContinue indicates the packet should continue to the next rule.
- RuleContinue
-
// RuleDrop indicates the packet should be dropped.
RuleDrop
+ // RuleJump indicates the packet should jump to another chain.
+ RuleJump
+
// RuleReturn indicates the packet should return to the previous chain.
RuleReturn
)
@@ -175,5 +164,5 @@ type Target interface {
// Action takes an action on the packet and returns a verdict on how
// traversal should (or should not) continue. If the return value is
// Jump, it also returns the name of the chain to jump to.
- Action(packet tcpip.PacketBuffer, filter IPHeaderFilter) (RuleVerdict, string)
+ Action(packet tcpip.PacketBuffer, filter IPHeaderFilter) (RuleVerdict, int)
}
diff --git a/pkg/tcpip/link/channel/BUILD b/pkg/tcpip/link/channel/BUILD
index 3974c464e..b8b93e78e 100644
--- a/pkg/tcpip/link/channel/BUILD
+++ b/pkg/tcpip/link/channel/BUILD
@@ -7,6 +7,7 @@ go_library(
srcs = ["channel.go"],
visibility = ["//visibility:public"],
deps = [
+ "//pkg/sync",
"//pkg/tcpip",
"//pkg/tcpip/buffer",
"//pkg/tcpip/stack",
diff --git a/pkg/tcpip/link/channel/channel.go b/pkg/tcpip/link/channel/channel.go
index 78d447acd..5944ba190 100644
--- a/pkg/tcpip/link/channel/channel.go
+++ b/pkg/tcpip/link/channel/channel.go
@@ -20,6 +20,7 @@ package channel
import (
"context"
+ "gvisor.dev/gvisor/pkg/sync"
"gvisor.dev/gvisor/pkg/tcpip"
"gvisor.dev/gvisor/pkg/tcpip/buffer"
"gvisor.dev/gvisor/pkg/tcpip/stack"
@@ -33,6 +34,118 @@ type PacketInfo struct {
Route stack.Route
}
+// Notification is the interface for receiving notification from the packet
+// queue.
+type Notification interface {
+ // WriteNotify will be called when a write happens to the queue.
+ WriteNotify()
+}
+
+// NotificationHandle is an opaque handle to the registered notification target.
+// It can be used to unregister the notification when no longer interested.
+//
+// +stateify savable
+type NotificationHandle struct {
+ n Notification
+}
+
+type queue struct {
+ // mu protects fields below.
+ mu sync.RWMutex
+ // c is the outbound packet channel. Sending to c should hold mu.
+ c chan PacketInfo
+ numWrite int
+ numRead int
+ notify []*NotificationHandle
+}
+
+func (q *queue) Close() {
+ close(q.c)
+}
+
+func (q *queue) Read() (PacketInfo, bool) {
+ q.mu.Lock()
+ defer q.mu.Unlock()
+ select {
+ case p := <-q.c:
+ q.numRead++
+ return p, true
+ default:
+ return PacketInfo{}, false
+ }
+}
+
+func (q *queue) ReadContext(ctx context.Context) (PacketInfo, bool) {
+ // We have to receive from channel without holding the lock, since it can
+ // block indefinitely. This will cause a window that numWrite - numRead
+ // produces a larger number, but won't go to negative. numWrite >= numRead
+ // still holds.
+ select {
+ case pkt := <-q.c:
+ q.mu.Lock()
+ defer q.mu.Unlock()
+ q.numRead++
+ return pkt, true
+ case <-ctx.Done():
+ return PacketInfo{}, false
+ }
+}
+
+func (q *queue) Write(p PacketInfo) bool {
+ wrote := false
+
+ // It's important to make sure nobody can see numWrite until we increment it,
+ // so numWrite >= numRead holds.
+ q.mu.Lock()
+ select {
+ case q.c <- p:
+ wrote = true
+ q.numWrite++
+ default:
+ }
+ notify := q.notify
+ q.mu.Unlock()
+
+ if wrote {
+ // Send notification outside of lock.
+ for _, h := range notify {
+ h.n.WriteNotify()
+ }
+ }
+ return wrote
+}
+
+func (q *queue) Num() int {
+ q.mu.RLock()
+ defer q.mu.RUnlock()
+ n := q.numWrite - q.numRead
+ if n < 0 {
+ panic("numWrite < numRead")
+ }
+ return n
+}
+
+func (q *queue) AddNotify(notify Notification) *NotificationHandle {
+ q.mu.Lock()
+ defer q.mu.Unlock()
+ h := &NotificationHandle{n: notify}
+ q.notify = append(q.notify, h)
+ return h
+}
+
+func (q *queue) RemoveNotify(handle *NotificationHandle) {
+ q.mu.Lock()
+ defer q.mu.Unlock()
+ // Make a copy, since we reads the array outside of lock when notifying.
+ notify := make([]*NotificationHandle, 0, len(q.notify))
+ for _, h := range q.notify {
+ if h != handle {
+ notify = append(notify, h)
+ }
+ }
+ q.notify = notify
+}
+
// Endpoint is link layer endpoint that stores outbound packets in a channel
// and allows injection of inbound packets.
type Endpoint struct {
@@ -41,14 +154,16 @@ type Endpoint struct {
linkAddr tcpip.LinkAddress
LinkEPCapabilities stack.LinkEndpointCapabilities
- // c is where outbound packets are queued.
- c chan PacketInfo
+ // Outbound packet queue.
+ q *queue
}
// New creates a new channel endpoint.
func New(size int, mtu uint32, linkAddr tcpip.LinkAddress) *Endpoint {
return &Endpoint{
- c: make(chan PacketInfo, size),
+ q: &queue{
+ c: make(chan PacketInfo, size),
+ },
mtu: mtu,
linkAddr: linkAddr,
}
@@ -57,43 +172,36 @@ func New(size int, mtu uint32, linkAddr tcpip.LinkAddress) *Endpoint {
// Close closes e. Further packet injections will panic. Reads continue to
// succeed until all packets are read.
func (e *Endpoint) Close() {
- close(e.c)
+ e.q.Close()
}
-// Read does non-blocking read for one packet from the outbound packet queue.
+// Read does non-blocking read one packet from the outbound packet queue.
func (e *Endpoint) Read() (PacketInfo, bool) {
- select {
- case pkt := <-e.c:
- return pkt, true
- default:
- return PacketInfo{}, false
- }
+ return e.q.Read()
}
// ReadContext does blocking read for one packet from the outbound packet queue.
// It can be cancelled by ctx, and in this case, it returns false.
func (e *Endpoint) ReadContext(ctx context.Context) (PacketInfo, bool) {
- select {
- case pkt := <-e.c:
- return pkt, true
- case <-ctx.Done():
- return PacketInfo{}, false
- }
+ return e.q.ReadContext(ctx)
}
// Drain removes all outbound packets from the channel and counts them.
func (e *Endpoint) Drain() int {
c := 0
for {
- select {
- case <-e.c:
- c++
- default:
+ if _, ok := e.Read(); !ok {
return c
}
+ c++
}
}
+// NumQueued returns the number of packet queued for outbound.
+func (e *Endpoint) NumQueued() int {
+ return e.q.Num()
+}
+
// InjectInbound injects an inbound packet.
func (e *Endpoint) InjectInbound(protocol tcpip.NetworkProtocolNumber, pkt tcpip.PacketBuffer) {
e.InjectLinkAddr(protocol, "", pkt)
@@ -155,10 +263,7 @@ func (e *Endpoint) WritePacket(r *stack.Route, gso *stack.GSO, protocol tcpip.Ne
Route: route,
}
- select {
- case e.c <- p:
- default:
- }
+ e.q.Write(p)
return nil
}
@@ -171,7 +276,6 @@ func (e *Endpoint) WritePackets(r *stack.Route, gso *stack.GSO, pkts []tcpip.Pac
route.Release()
payloadView := pkts[0].Data.ToView()
n := 0
-packetLoop:
for _, pkt := range pkts {
off := pkt.DataOffset
size := pkt.DataSize
@@ -185,12 +289,10 @@ packetLoop:
Route: route,
}
- select {
- case e.c <- p:
- n++
- default:
- break packetLoop
+ if !e.q.Write(p) {
+ break
}
+ n++
}
return n, nil
@@ -204,13 +306,21 @@ func (e *Endpoint) WriteRawPacket(vv buffer.VectorisedView) *tcpip.Error {
GSO: nil,
}
- select {
- case e.c <- p:
- default:
- }
+ e.q.Write(p)
return nil
}
// Wait implements stack.LinkEndpoint.Wait.
func (*Endpoint) Wait() {}
+
+// AddNotify adds a notification target for receiving event about outgoing
+// packets.
+func (e *Endpoint) AddNotify(notify Notification) *NotificationHandle {
+ return e.q.AddNotify(notify)
+}
+
+// RemoveNotify removes handle from the list of notification targets.
+func (e *Endpoint) RemoveNotify(handle *NotificationHandle) {
+ e.q.RemoveNotify(handle)
+}
diff --git a/pkg/tcpip/link/tun/BUILD b/pkg/tcpip/link/tun/BUILD
index e5096ea38..e0db6cf54 100644
--- a/pkg/tcpip/link/tun/BUILD
+++ b/pkg/tcpip/link/tun/BUILD
@@ -4,6 +4,22 @@ package(licenses = ["notice"])
go_library(
name = "tun",
- srcs = ["tun_unsafe.go"],
+ srcs = [
+ "device.go",
+ "protocol.go",
+ "tun_unsafe.go",
+ ],
visibility = ["//visibility:public"],
+ deps = [
+ "//pkg/abi/linux",
+ "//pkg/refs",
+ "//pkg/sync",
+ "//pkg/syserror",
+ "//pkg/tcpip",
+ "//pkg/tcpip/buffer",
+ "//pkg/tcpip/header",
+ "//pkg/tcpip/link/channel",
+ "//pkg/tcpip/stack",
+ "//pkg/waiter",
+ ],
)
diff --git a/pkg/tcpip/link/tun/device.go b/pkg/tcpip/link/tun/device.go
new file mode 100644
index 000000000..6ff47a742
--- /dev/null
+++ b/pkg/tcpip/link/tun/device.go
@@ -0,0 +1,352 @@
+// Copyright 2020 The gVisor Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+package tun
+
+import (
+ "fmt"
+
+ "gvisor.dev/gvisor/pkg/abi/linux"
+ "gvisor.dev/gvisor/pkg/refs"
+ "gvisor.dev/gvisor/pkg/sync"
+ "gvisor.dev/gvisor/pkg/syserror"
+ "gvisor.dev/gvisor/pkg/tcpip"
+ "gvisor.dev/gvisor/pkg/tcpip/buffer"
+ "gvisor.dev/gvisor/pkg/tcpip/header"
+ "gvisor.dev/gvisor/pkg/tcpip/link/channel"
+ "gvisor.dev/gvisor/pkg/tcpip/stack"
+ "gvisor.dev/gvisor/pkg/waiter"
+)
+
+const (
+ // drivers/net/tun.c:tun_net_init()
+ defaultDevMtu = 1500
+
+ // Queue length for outbound packet, arriving at fd side for read. Overflow
+ // causes packet drops. gVisor implementation-specific.
+ defaultDevOutQueueLen = 1024
+)
+
+var zeroMAC [6]byte
+
+// Device is an opened /dev/net/tun device.
+//
+// +stateify savable
+type Device struct {
+ waiter.Queue
+
+ mu sync.RWMutex `state:"nosave"`
+ endpoint *tunEndpoint
+ notifyHandle *channel.NotificationHandle
+ flags uint16
+}
+
+// beforeSave is invoked by stateify.
+func (d *Device) beforeSave() {
+ d.mu.Lock()
+ defer d.mu.Unlock()
+ // TODO(b/110961832): Restore the device to stack. At this moment, the stack
+ // is not savable.
+ if d.endpoint != nil {
+ panic("/dev/net/tun does not support save/restore when a device is associated with it.")
+ }
+}
+
+// Release implements fs.FileOperations.Release.
+func (d *Device) Release() {
+ d.mu.Lock()
+ defer d.mu.Unlock()
+
+ // Decrease refcount if there is an endpoint associated with this file.
+ if d.endpoint != nil {
+ d.endpoint.RemoveNotify(d.notifyHandle)
+ d.endpoint.DecRef()
+ d.endpoint = nil
+ }
+}
+
+// SetIff services TUNSETIFF ioctl(2) request.
+func (d *Device) SetIff(s *stack.Stack, name string, flags uint16) error {
+ d.mu.Lock()
+ defer d.mu.Unlock()
+
+ if d.endpoint != nil {
+ return syserror.EINVAL
+ }
+
+ // Input validations.
+ isTun := flags&linux.IFF_TUN != 0
+ isTap := flags&linux.IFF_TAP != 0
+ supportedFlags := uint16(linux.IFF_TUN | linux.IFF_TAP | linux.IFF_NO_PI)
+ if isTap && isTun || !isTap && !isTun || flags&^supportedFlags != 0 {
+ return syserror.EINVAL
+ }
+
+ prefix := "tun"
+ if isTap {
+ prefix = "tap"
+ }
+
+ endpoint, err := attachOrCreateNIC(s, name, prefix)
+ if err != nil {
+ return syserror.EINVAL
+ }
+
+ d.endpoint = endpoint
+ d.notifyHandle = d.endpoint.AddNotify(d)
+ d.flags = flags
+ return nil
+}
+
+func attachOrCreateNIC(s *stack.Stack, name, prefix string) (*tunEndpoint, error) {
+ for {
+ // 1. Try to attach to an existing NIC.
+ if name != "" {
+ if nic, found := s.GetNICByName(name); found {
+ endpoint, ok := nic.LinkEndpoint().(*tunEndpoint)
+ if !ok {
+ // Not a NIC created by tun device.
+ return nil, syserror.EOPNOTSUPP
+ }
+ if !endpoint.TryIncRef() {
+ // Race detected: NIC got deleted in between.
+ continue
+ }
+ return endpoint, nil
+ }
+ }
+
+ // 2. Creating a new NIC.
+ id := tcpip.NICID(s.UniqueID())
+ endpoint := &tunEndpoint{
+ Endpoint: channel.New(defaultDevOutQueueLen, defaultDevMtu, ""),
+ stack: s,
+ nicID: id,
+ name: name,
+ }
+ if endpoint.name == "" {
+ endpoint.name = fmt.Sprintf("%s%d", prefix, id)
+ }
+ err := s.CreateNICWithOptions(endpoint.nicID, endpoint, stack.NICOptions{
+ Name: endpoint.name,
+ })
+ switch err {
+ case nil:
+ return endpoint, nil
+ case tcpip.ErrDuplicateNICID:
+ // Race detected: A NIC has been created in between.
+ continue
+ default:
+ return nil, syserror.EINVAL
+ }
+ }
+}
+
+// Write inject one inbound packet to the network interface.
+func (d *Device) Write(data []byte) (int64, error) {
+ d.mu.RLock()
+ endpoint := d.endpoint
+ d.mu.RUnlock()
+ if endpoint == nil {
+ return 0, syserror.EBADFD
+ }
+ if !endpoint.IsAttached() {
+ return 0, syserror.EIO
+ }
+
+ dataLen := int64(len(data))
+
+ // Packet information.
+ var pktInfoHdr PacketInfoHeader
+ if !d.hasFlags(linux.IFF_NO_PI) {
+ if len(data) < PacketInfoHeaderSize {
+ // Ignore bad packet.
+ return dataLen, nil
+ }
+ pktInfoHdr = PacketInfoHeader(data[:PacketInfoHeaderSize])
+ data = data[PacketInfoHeaderSize:]
+ }
+
+ // Ethernet header (TAP only).
+ var ethHdr header.Ethernet
+ if d.hasFlags(linux.IFF_TAP) {
+ if len(data) < header.EthernetMinimumSize {
+ // Ignore bad packet.
+ return dataLen, nil
+ }
+ ethHdr = header.Ethernet(data[:header.EthernetMinimumSize])
+ data = data[header.EthernetMinimumSize:]
+ }
+
+ // Try to determine network protocol number, default zero.
+ var protocol tcpip.NetworkProtocolNumber
+ switch {
+ case pktInfoHdr != nil:
+ protocol = pktInfoHdr.Protocol()
+ case ethHdr != nil:
+ protocol = ethHdr.Type()
+ }
+
+ // Try to determine remote link address, default zero.
+ var remote tcpip.LinkAddress
+ switch {
+ case ethHdr != nil:
+ remote = ethHdr.SourceAddress()
+ default:
+ remote = tcpip.LinkAddress(zeroMAC[:])
+ }
+
+ pkt := tcpip.PacketBuffer{
+ Data: buffer.View(data).ToVectorisedView(),
+ }
+ if ethHdr != nil {
+ pkt.LinkHeader = buffer.View(ethHdr)
+ }
+ endpoint.InjectLinkAddr(protocol, remote, pkt)
+ return dataLen, nil
+}
+
+// Read reads one outgoing packet from the network interface.
+func (d *Device) Read() ([]byte, error) {
+ d.mu.RLock()
+ endpoint := d.endpoint
+ d.mu.RUnlock()
+ if endpoint == nil {
+ return nil, syserror.EBADFD
+ }
+
+ for {
+ info, ok := endpoint.Read()
+ if !ok {
+ return nil, syserror.ErrWouldBlock
+ }
+
+ v, ok := d.encodePkt(&info)
+ if !ok {
+ // Ignore unsupported packet.
+ continue
+ }
+ return v, nil
+ }
+}
+
+// encodePkt encodes packet for fd side.
+func (d *Device) encodePkt(info *channel.PacketInfo) (buffer.View, bool) {
+ var vv buffer.VectorisedView
+
+ // Packet information.
+ if !d.hasFlags(linux.IFF_NO_PI) {
+ hdr := make(PacketInfoHeader, PacketInfoHeaderSize)
+ hdr.Encode(&PacketInfoFields{
+ Protocol: info.Proto,
+ })
+ vv.AppendView(buffer.View(hdr))
+ }
+
+ // If the packet does not already have link layer header, and the route
+ // does not exist, we can't compute it. This is possibly a raw packet, tun
+ // device doesn't support this at the moment.
+ if info.Pkt.LinkHeader == nil && info.Route.RemoteLinkAddress == "" {
+ return nil, false
+ }
+
+ // Ethernet header (TAP only).
+ if d.hasFlags(linux.IFF_TAP) {
+ // Add ethernet header if not provided.
+ if info.Pkt.LinkHeader == nil {
+ hdr := &header.EthernetFields{
+ SrcAddr: info.Route.LocalLinkAddress,
+ DstAddr: info.Route.RemoteLinkAddress,
+ Type: info.Proto,
+ }
+ if hdr.SrcAddr == "" {
+ hdr.SrcAddr = d.endpoint.LinkAddress()
+ }
+
+ eth := make(header.Ethernet, header.EthernetMinimumSize)
+ eth.Encode(hdr)
+ vv.AppendView(buffer.View(eth))
+ } else {
+ vv.AppendView(info.Pkt.LinkHeader)
+ }
+ }
+
+ // Append upper headers.
+ vv.AppendView(buffer.View(info.Pkt.Header.View()[len(info.Pkt.LinkHeader):]))
+ // Append data payload.
+ vv.Append(info.Pkt.Data)
+
+ return vv.ToView(), true
+}
+
+// Name returns the name of the attached network interface. Empty string if
+// unattached.
+func (d *Device) Name() string {
+ d.mu.RLock()
+ defer d.mu.RUnlock()
+ if d.endpoint != nil {
+ return d.endpoint.name
+ }
+ return ""
+}
+
+// Flags returns the flags set for d. Zero value if unset.
+func (d *Device) Flags() uint16 {
+ d.mu.RLock()
+ defer d.mu.RUnlock()
+ return d.flags
+}
+
+func (d *Device) hasFlags(flags uint16) bool {
+ return d.flags&flags == flags
+}
+
+// Readiness implements watier.Waitable.Readiness.
+func (d *Device) Readiness(mask waiter.EventMask) waiter.EventMask {
+ if mask&waiter.EventIn != 0 {
+ d.mu.RLock()
+ endpoint := d.endpoint
+ d.mu.RUnlock()
+ if endpoint != nil && endpoint.NumQueued() == 0 {
+ mask &= ^waiter.EventIn
+ }
+ }
+ return mask & (waiter.EventIn | waiter.EventOut)
+}
+
+// WriteNotify implements channel.Notification.WriteNotify.
+func (d *Device) WriteNotify() {
+ d.Notify(waiter.EventIn)
+}
+
+// tunEndpoint is the link endpoint for the NIC created by the tun device.
+//
+// It is ref-counted as multiple opening files can attach to the same NIC.
+// The last owner is responsible for deleting the NIC.
+type tunEndpoint struct {
+ *channel.Endpoint
+
+ refs.AtomicRefCount
+
+ stack *stack.Stack
+ nicID tcpip.NICID
+ name string
+}
+
+// DecRef decrements refcount of e, removes NIC if refcount goes to 0.
+func (e *tunEndpoint) DecRef() {
+ e.DecRefWithDestructor(func() {
+ e.stack.RemoveNIC(e.nicID)
+ })
+}
diff --git a/pkg/tcpip/link/tun/protocol.go b/pkg/tcpip/link/tun/protocol.go
new file mode 100644
index 000000000..89d9d91a9
--- /dev/null
+++ b/pkg/tcpip/link/tun/protocol.go
@@ -0,0 +1,56 @@
+// Copyright 2020 The gVisor Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+package tun
+
+import (
+ "encoding/binary"
+
+ "gvisor.dev/gvisor/pkg/tcpip"
+)
+
+const (
+ // PacketInfoHeaderSize is the size of the packet information header.
+ PacketInfoHeaderSize = 4
+
+ offsetFlags = 0
+ offsetProtocol = 2
+)
+
+// PacketInfoFields contains fields sent through the wire if IFF_NO_PI flag is
+// not set.
+type PacketInfoFields struct {
+ Flags uint16
+ Protocol tcpip.NetworkProtocolNumber
+}
+
+// PacketInfoHeader is the wire representation of the packet information sent if
+// IFF_NO_PI flag is not set.
+type PacketInfoHeader []byte
+
+// Encode encodes f into h.
+func (h PacketInfoHeader) Encode(f *PacketInfoFields) {
+ binary.BigEndian.PutUint16(h[offsetFlags:][:2], f.Flags)
+ binary.BigEndian.PutUint16(h[offsetProtocol:][:2], uint16(f.Protocol))
+}
+
+// Flags returns the flag field in h.
+func (h PacketInfoHeader) Flags() uint16 {
+ return binary.BigEndian.Uint16(h[offsetFlags:])
+}
+
+// Protocol returns the protocol field in h.
+func (h PacketInfoHeader) Protocol() tcpip.NetworkProtocolNumber {
+ return tcpip.NetworkProtocolNumber(binary.BigEndian.Uint16(h[offsetProtocol:]))
+}
diff --git a/pkg/tcpip/network/arp/arp.go b/pkg/tcpip/network/arp/arp.go
index 4da13c5df..e9fcc89a8 100644
--- a/pkg/tcpip/network/arp/arp.go
+++ b/pkg/tcpip/network/arp/arp.go
@@ -148,12 +148,12 @@ func (p *protocol) NewEndpoint(nicID tcpip.NICID, addrWithPrefix tcpip.AddressWi
}, nil
}
-// LinkAddressProtocol implements stack.LinkAddressResolver.
+// LinkAddressProtocol implements stack.LinkAddressResolver.LinkAddressProtocol.
func (*protocol) LinkAddressProtocol() tcpip.NetworkProtocolNumber {
return header.IPv4ProtocolNumber
}
-// LinkAddressRequest implements stack.LinkAddressResolver.
+// LinkAddressRequest implements stack.LinkAddressResolver.LinkAddressRequest.
func (*protocol) LinkAddressRequest(addr, localAddr tcpip.Address, linkEP stack.LinkEndpoint) *tcpip.Error {
r := &stack.Route{
RemoteLinkAddress: broadcastMAC,
@@ -172,7 +172,7 @@ func (*protocol) LinkAddressRequest(addr, localAddr tcpip.Address, linkEP stack.
})
}
-// ResolveStaticAddress implements stack.LinkAddressResolver.
+// ResolveStaticAddress implements stack.LinkAddressResolver.ResolveStaticAddress.
func (*protocol) ResolveStaticAddress(addr tcpip.Address) (tcpip.LinkAddress, bool) {
if addr == header.IPv4Broadcast {
return broadcastMAC, true
@@ -183,16 +183,22 @@ func (*protocol) ResolveStaticAddress(addr tcpip.Address) (tcpip.LinkAddress, bo
return tcpip.LinkAddress([]byte(nil)), false
}
-// SetOption implements NetworkProtocol.
-func (p *protocol) SetOption(option interface{}) *tcpip.Error {
+// SetOption implements stack.NetworkProtocol.SetOption.
+func (*protocol) SetOption(option interface{}) *tcpip.Error {
return tcpip.ErrUnknownProtocolOption
}
-// Option implements NetworkProtocol.
-func (p *protocol) Option(option interface{}) *tcpip.Error {
+// Option implements stack.NetworkProtocol.Option.
+func (*protocol) Option(option interface{}) *tcpip.Error {
return tcpip.ErrUnknownProtocolOption
}
+// Close implements stack.TransportProtocol.Close.
+func (*protocol) Close() {}
+
+// Wait implements stack.TransportProtocol.Wait.
+func (*protocol) Wait() {}
+
var broadcastMAC = tcpip.LinkAddress([]byte{0xff, 0xff, 0xff, 0xff, 0xff, 0xff})
// NewProtocol returns an ARP network protocol.
diff --git a/pkg/tcpip/network/ipv4/ipv4.go b/pkg/tcpip/network/ipv4/ipv4.go
index 6597e6781..4f1742938 100644
--- a/pkg/tcpip/network/ipv4/ipv4.go
+++ b/pkg/tcpip/network/ipv4/ipv4.go
@@ -473,6 +473,12 @@ func (p *protocol) DefaultTTL() uint8 {
return uint8(atomic.LoadUint32(&p.defaultTTL))
}
+// Close implements stack.TransportProtocol.Close.
+func (*protocol) Close() {}
+
+// Wait implements stack.TransportProtocol.Wait.
+func (*protocol) Wait() {}
+
// calculateMTU calculates the network-layer payload MTU based on the link-layer
// payload mtu.
func calculateMTU(mtu uint32) uint32 {
diff --git a/pkg/tcpip/network/ipv6/ipv6.go b/pkg/tcpip/network/ipv6/ipv6.go
index 180a480fd..9aef5234b 100644
--- a/pkg/tcpip/network/ipv6/ipv6.go
+++ b/pkg/tcpip/network/ipv6/ipv6.go
@@ -265,6 +265,12 @@ func (p *protocol) DefaultTTL() uint8 {
return uint8(atomic.LoadUint32(&p.defaultTTL))
}
+// Close implements stack.TransportProtocol.Close.
+func (*protocol) Close() {}
+
+// Wait implements stack.TransportProtocol.Wait.
+func (*protocol) Wait() {}
+
// calculateMTU calculates the network-layer payload MTU based on the link-layer
// payload mtu.
func calculateMTU(mtu uint32) uint32 {
diff --git a/pkg/tcpip/stack/ndp.go b/pkg/tcpip/stack/ndp.go
index 045409bda..f651871ce 100644
--- a/pkg/tcpip/stack/ndp.go
+++ b/pkg/tcpip/stack/ndp.go
@@ -1148,22 +1148,27 @@ func (ndp *ndpState) cleanupAutoGenAddrResourcesAndNotify(addr tcpip.Address) bo
return true
}
-// cleanupHostOnlyState cleans up any state that is only useful for hosts.
+// cleanupState cleans up ndp's state.
//
-// cleanupHostOnlyState MUST be called when ndp's NIC is transitioning from a
-// host to a router. This function will invalidate all discovered on-link
-// prefixes, discovered routers, and auto-generated addresses as routers do not
-// normally process Router Advertisements to discover default routers and
-// on-link prefixes, and auto-generate addresses via SLAAC.
+// If hostOnly is true, then only host-specific state will be cleaned up.
+//
+// cleanupState MUST be called with hostOnly set to true when ndp's NIC is
+// transitioning from a host to a router. This function will invalidate all
+// discovered on-link prefixes, discovered routers, and auto-generated
+// addresses.
+//
+// If hostOnly is true, then the link-local auto-generated address will not be
+// invalidated as routers are also expected to generate a link-local address.
//
// The NIC that ndp belongs to MUST be locked.
-func (ndp *ndpState) cleanupHostOnlyState() {
+func (ndp *ndpState) cleanupState(hostOnly bool) {
linkLocalSubnet := header.IPv6LinkLocalPrefix.Subnet()
linkLocalAddrs := 0
for addr := range ndp.autoGenAddresses {
// RFC 4862 section 5 states that routers are also expected to generate a
- // link-local address so we do not invalidate them.
- if linkLocalSubnet.Contains(addr) {
+ // link-local address so we do not invalidate them if we are cleaning up
+ // host-only state.
+ if hostOnly && linkLocalSubnet.Contains(addr) {
linkLocalAddrs++
continue
}
@@ -1230,7 +1235,7 @@ func (ndp *ndpState) startSolicitingRouters() {
}
payloadSize := header.ICMPv6HeaderSize + header.NDPRSMinimumSize
- hdr := buffer.NewPrependable(header.IPv6MinimumSize + payloadSize)
+ hdr := buffer.NewPrependable(int(r.MaxHeaderLength()) + payloadSize)
pkt := header.ICMPv6(hdr.Prepend(payloadSize))
pkt.SetType(header.ICMPv6RouterSolicit)
pkt.SetChecksum(header.ICMPv6Checksum(pkt, r.LocalAddress, r.RemoteAddress, buffer.VectorisedView{}))
diff --git a/pkg/tcpip/stack/ndp_test.go b/pkg/tcpip/stack/ndp_test.go
index 1f6f77439..6e9306d09 100644
--- a/pkg/tcpip/stack/ndp_test.go
+++ b/pkg/tcpip/stack/ndp_test.go
@@ -267,6 +267,17 @@ func (n *ndpDispatcher) OnDHCPv6Configuration(nicID tcpip.NICID, configuration s
}
}
+// channelLinkWithHeaderLength is a channel.Endpoint with a configurable
+// header length.
+type channelLinkWithHeaderLength struct {
+ *channel.Endpoint
+ headerLength uint16
+}
+
+func (l *channelLinkWithHeaderLength) MaxHeaderLength() uint16 {
+ return l.headerLength
+}
+
// Check e to make sure that the event is for addr on nic with ID 1, and the
// resolved flag set to resolved with the specified err.
func checkDADEvent(e ndpDADEvent, nicID tcpip.NICID, addr tcpip.Address, resolved bool, err *tcpip.Error) string {
@@ -323,21 +334,46 @@ func TestDADDisabled(t *testing.T) {
// DAD for various values of DupAddrDetectTransmits and RetransmitTimer.
// Included in the subtests is a test to make sure that an invalid
// RetransmitTimer (<1ms) values get fixed to the default RetransmitTimer of 1s.
+// This tests also validates the NDP NS packet that is transmitted.
func TestDADResolve(t *testing.T) {
const nicID = 1
tests := []struct {
name string
+ linkHeaderLen uint16
dupAddrDetectTransmits uint8
retransTimer time.Duration
expectedRetransmitTimer time.Duration
}{
- {"1:1s:1s", 1, time.Second, time.Second},
- {"2:1s:1s", 2, time.Second, time.Second},
- {"1:2s:2s", 1, 2 * time.Second, 2 * time.Second},
+ {
+ name: "1:1s:1s",
+ dupAddrDetectTransmits: 1,
+ retransTimer: time.Second,
+ expectedRetransmitTimer: time.Second,
+ },
+ {
+ name: "2:1s:1s",
+ linkHeaderLen: 1,
+ dupAddrDetectTransmits: 2,
+ retransTimer: time.Second,
+ expectedRetransmitTimer: time.Second,
+ },
+ {
+ name: "1:2s:2s",
+ linkHeaderLen: 2,
+ dupAddrDetectTransmits: 1,
+ retransTimer: 2 * time.Second,
+ expectedRetransmitTimer: 2 * time.Second,
+ },
// 0s is an invalid RetransmitTimer timer and will be fixed to
// the default RetransmitTimer value of 1s.
- {"1:0s:1s", 1, 0, time.Second},
+ {
+ name: "1:0s:1s",
+ linkHeaderLen: 3,
+ dupAddrDetectTransmits: 1,
+ retransTimer: 0,
+ expectedRetransmitTimer: time.Second,
+ },
}
for _, test := range tests {
@@ -356,10 +392,13 @@ func TestDADResolve(t *testing.T) {
opts.NDPConfigs.RetransmitTimer = test.retransTimer
opts.NDPConfigs.DupAddrDetectTransmits = test.dupAddrDetectTransmits
- e := channel.New(int(test.dupAddrDetectTransmits), 1280, linkAddr1)
- e.LinkEPCapabilities |= stack.CapabilityResolutionRequired
+ e := channelLinkWithHeaderLength{
+ Endpoint: channel.New(int(test.dupAddrDetectTransmits), 1280, linkAddr1),
+ headerLength: test.linkHeaderLen,
+ }
+ e.Endpoint.LinkEPCapabilities |= stack.CapabilityResolutionRequired
s := stack.New(opts)
- if err := s.CreateNIC(nicID, e); err != nil {
+ if err := s.CreateNIC(nicID, &e); err != nil {
t.Fatalf("CreateNIC(%d, _) = %s", nicID, err)
}
@@ -445,6 +484,10 @@ func TestDADResolve(t *testing.T) {
checker.NDPNSTargetAddress(addr1),
checker.NDPNSOptions(nil),
))
+
+ if l, want := p.Pkt.Header.AvailableLength(), int(test.linkHeaderLen); l != want {
+ t.Errorf("got p.Pkt.Header.AvailableLength() = %d; want = %d", l, want)
+ }
}
})
}
@@ -592,70 +635,94 @@ func TestDADFail(t *testing.T) {
}
}
-// TestDADStop tests to make sure that the DAD process stops when an address is
-// removed.
func TestDADStop(t *testing.T) {
const nicID = 1
- ndpDisp := ndpDispatcher{
- dadC: make(chan ndpDADEvent, 1),
- }
- ndpConfigs := stack.NDPConfigurations{
- RetransmitTimer: time.Second,
- DupAddrDetectTransmits: 2,
- }
- opts := stack.Options{
- NetworkProtocols: []stack.NetworkProtocol{ipv6.NewProtocol()},
- NDPDisp: &ndpDisp,
- NDPConfigs: ndpConfigs,
- }
+ tests := []struct {
+ name string
+ stopFn func(t *testing.T, s *stack.Stack)
+ }{
+ // Tests to make sure that DAD stops when an address is removed.
+ {
+ name: "Remove address",
+ stopFn: func(t *testing.T, s *stack.Stack) {
+ if err := s.RemoveAddress(nicID, addr1); err != nil {
+ t.Fatalf("RemoveAddress(%d, %s): %s", nicID, addr1, err)
+ }
+ },
+ },
- e := channel.New(0, 1280, linkAddr1)
- s := stack.New(opts)
- if err := s.CreateNIC(nicID, e); err != nil {
- t.Fatalf("CreateNIC(%d, _) = %s", nicID, err)
+ // Tests to make sure that DAD stops when the NIC is disabled.
+ {
+ name: "Disable NIC",
+ stopFn: func(t *testing.T, s *stack.Stack) {
+ if err := s.DisableNIC(nicID); err != nil {
+ t.Fatalf("DisableNIC(%d): %s", nicID, err)
+ }
+ },
+ },
}
- if err := s.AddAddress(nicID, header.IPv6ProtocolNumber, addr1); err != nil {
- t.Fatalf("AddAddress(%d, %d, %s) = %s", nicID, header.IPv6ProtocolNumber, addr1, err)
- }
+ for _, test := range tests {
+ t.Run(test.name, func(t *testing.T) {
+ ndpDisp := ndpDispatcher{
+ dadC: make(chan ndpDADEvent, 1),
+ }
+ ndpConfigs := stack.NDPConfigurations{
+ RetransmitTimer: time.Second,
+ DupAddrDetectTransmits: 2,
+ }
+ opts := stack.Options{
+ NetworkProtocols: []stack.NetworkProtocol{ipv6.NewProtocol()},
+ NDPDisp: &ndpDisp,
+ NDPConfigs: ndpConfigs,
+ }
- // Address should not be considered bound to the NIC yet (DAD ongoing).
- addr, err := s.GetMainNICAddress(nicID, header.IPv6ProtocolNumber)
- if err != nil {
- t.Fatalf("got stack.GetMainNICAddress(%d, %d) = (_, %v), want = (_, nil)", nicID, header.IPv6ProtocolNumber, err)
- }
- if want := (tcpip.AddressWithPrefix{}); addr != want {
- t.Fatalf("got stack.GetMainNICAddress(%d, %d) = (%s, nil), want = (%s, nil)", nicID, header.IPv6ProtocolNumber, addr, want)
- }
+ e := channel.New(0, 1280, linkAddr1)
+ s := stack.New(opts)
+ if err := s.CreateNIC(nicID, e); err != nil {
+ t.Fatalf("CreateNIC(%d, _): %s", nicID, err)
+ }
- // Remove the address. This should stop DAD.
- if err := s.RemoveAddress(nicID, addr1); err != nil {
- t.Fatalf("RemoveAddress(%d, %s) = %s", nicID, addr1, err)
- }
+ if err := s.AddAddress(nicID, header.IPv6ProtocolNumber, addr1); err != nil {
+ t.Fatalf("AddAddress(%d, %d, %s): %s", nicID, header.IPv6ProtocolNumber, addr1, err)
+ }
- // Wait for DAD to fail (since the address was removed during DAD).
- select {
- case <-time.After(time.Duration(ndpConfigs.DupAddrDetectTransmits)*ndpConfigs.RetransmitTimer + time.Second):
- // If we don't get a failure event after the expected resolution
- // time + extra 1s buffer, something is wrong.
- t.Fatal("timed out waiting for DAD failure")
- case e := <-ndpDisp.dadC:
- if diff := checkDADEvent(e, nicID, addr1, false, nil); diff != "" {
- t.Errorf("dad event mismatch (-want +got):\n%s", diff)
- }
- }
- addr, err = s.GetMainNICAddress(nicID, header.IPv6ProtocolNumber)
- if err != nil {
- t.Fatalf("got stack.GetMainNICAddress(%d, %d) = (_, %v), want = (_, nil)", nicID, header.IPv6ProtocolNumber, err)
- }
- if want := (tcpip.AddressWithPrefix{}); addr != want {
- t.Fatalf("got stack.GetMainNICAddress(%d, %d) = (%s, nil), want = (%s, nil)", nicID, header.IPv6ProtocolNumber, addr, want)
- }
+ // Address should not be considered bound to the NIC yet (DAD ongoing).
+ addr, err := s.GetMainNICAddress(nicID, header.IPv6ProtocolNumber)
+ if err != nil {
+ t.Fatalf("got stack.GetMainNICAddress(%d, %d) = (_, %v), want = (_, nil)", nicID, header.IPv6ProtocolNumber, err)
+ }
+ if want := (tcpip.AddressWithPrefix{}); addr != want {
+ t.Fatalf("got stack.GetMainNICAddress(%d, %d) = (%s, nil), want = (%s, nil)", nicID, header.IPv6ProtocolNumber, addr, want)
+ }
+
+ test.stopFn(t, s)
+
+ // Wait for DAD to fail (since the address was removed during DAD).
+ select {
+ case <-time.After(time.Duration(ndpConfigs.DupAddrDetectTransmits)*ndpConfigs.RetransmitTimer + time.Second):
+ // If we don't get a failure event after the expected resolution
+ // time + extra 1s buffer, something is wrong.
+ t.Fatal("timed out waiting for DAD failure")
+ case e := <-ndpDisp.dadC:
+ if diff := checkDADEvent(e, nicID, addr1, false, nil); diff != "" {
+ t.Errorf("dad event mismatch (-want +got):\n%s", diff)
+ }
+ }
+ addr, err = s.GetMainNICAddress(nicID, header.IPv6ProtocolNumber)
+ if err != nil {
+ t.Fatalf("got stack.GetMainNICAddress(%d, %d) = (_, %v), want = (_, nil)", nicID, header.IPv6ProtocolNumber, err)
+ }
+ if want := (tcpip.AddressWithPrefix{}); addr != want {
+ t.Errorf("got stack.GetMainNICAddress(%d, %d) = (%s, nil), want = (%s, nil)", nicID, header.IPv6ProtocolNumber, addr, want)
+ }
- // Should not have sent more than 1 NS message.
- if got := s.Stats().ICMP.V6PacketsSent.NeighborSolicit.Value(); got > 1 {
- t.Fatalf("got NeighborSolicit = %d, want <= 1", got)
+ // Should not have sent more than 1 NS message.
+ if got := s.Stats().ICMP.V6PacketsSent.NeighborSolicit.Value(); got > 1 {
+ t.Errorf("got NeighborSolicit = %d, want <= 1", got)
+ }
+ })
}
}
@@ -2886,17 +2953,16 @@ func TestNDPRecursiveDNSServerDispatch(t *testing.T) {
}
}
-// TestCleanupHostOnlyStateOnBecomingRouter tests that all discovered routers
-// and prefixes, and non-linklocal auto-generated addresses are invalidated when
-// a NIC becomes a router.
-func TestCleanupHostOnlyStateOnBecomingRouter(t *testing.T) {
+// TestCleanupNDPState tests that all discovered routers and prefixes, and
+// auto-generated addresses are invalidated when a NIC becomes a router.
+func TestCleanupNDPState(t *testing.T) {
t.Parallel()
const (
- lifetimeSeconds = 5
- maxEvents = 4
- nicID1 = 1
- nicID2 = 2
+ lifetimeSeconds = 5
+ maxRouterAndPrefixEvents = 4
+ nicID1 = 1
+ nicID2 = 2
)
prefix1, subnet1, e1Addr1 := prefixSubnetAddr(0, linkAddr1)
@@ -2912,254 +2978,308 @@ func TestCleanupHostOnlyStateOnBecomingRouter(t *testing.T) {
PrefixLen: 64,
}
- ndpDisp := ndpDispatcher{
- routerC: make(chan ndpRouterEvent, maxEvents),
- rememberRouter: true,
- prefixC: make(chan ndpPrefixEvent, maxEvents),
- rememberPrefix: true,
- autoGenAddrC: make(chan ndpAutoGenAddrEvent, maxEvents),
- }
- s := stack.New(stack.Options{
- NetworkProtocols: []stack.NetworkProtocol{ipv6.NewProtocol()},
- AutoGenIPv6LinkLocal: true,
- NDPConfigs: stack.NDPConfigurations{
- HandleRAs: true,
- DiscoverDefaultRouters: true,
- DiscoverOnLinkPrefixes: true,
- AutoGenGlobalAddresses: true,
+ tests := []struct {
+ name string
+ cleanupFn func(t *testing.T, s *stack.Stack)
+ keepAutoGenLinkLocal bool
+ maxAutoGenAddrEvents int
+ }{
+ // A NIC should still keep its auto-generated link-local address when
+ // becoming a router.
+ {
+ name: "Forwarding Enable",
+ cleanupFn: func(t *testing.T, s *stack.Stack) {
+ t.Helper()
+ s.SetForwarding(true)
+ },
+ keepAutoGenLinkLocal: true,
+ maxAutoGenAddrEvents: 4,
},
- NDPDisp: &ndpDisp,
- })
- expectRouterEvent := func() (bool, ndpRouterEvent) {
- select {
- case e := <-ndpDisp.routerC:
- return true, e
- default:
- }
+ // A NIC should cleanup all NDP state when it is disabled.
+ {
+ name: "NIC Disable",
+ cleanupFn: func(t *testing.T, s *stack.Stack) {
+ t.Helper()
- return false, ndpRouterEvent{}
+ if err := s.DisableNIC(nicID1); err != nil {
+ t.Fatalf("s.DisableNIC(%d): %s", nicID1, err)
+ }
+ if err := s.DisableNIC(nicID2); err != nil {
+ t.Fatalf("s.DisableNIC(%d): %s", nicID2, err)
+ }
+ },
+ keepAutoGenLinkLocal: false,
+ maxAutoGenAddrEvents: 6,
+ },
}
- expectPrefixEvent := func() (bool, ndpPrefixEvent) {
- select {
- case e := <-ndpDisp.prefixC:
- return true, e
- default:
- }
+ for _, test := range tests {
+ t.Run(test.name, func(t *testing.T) {
+ ndpDisp := ndpDispatcher{
+ routerC: make(chan ndpRouterEvent, maxRouterAndPrefixEvents),
+ rememberRouter: true,
+ prefixC: make(chan ndpPrefixEvent, maxRouterAndPrefixEvents),
+ rememberPrefix: true,
+ autoGenAddrC: make(chan ndpAutoGenAddrEvent, test.maxAutoGenAddrEvents),
+ }
+ s := stack.New(stack.Options{
+ NetworkProtocols: []stack.NetworkProtocol{ipv6.NewProtocol()},
+ AutoGenIPv6LinkLocal: true,
+ NDPConfigs: stack.NDPConfigurations{
+ HandleRAs: true,
+ DiscoverDefaultRouters: true,
+ DiscoverOnLinkPrefixes: true,
+ AutoGenGlobalAddresses: true,
+ },
+ NDPDisp: &ndpDisp,
+ })
- return false, ndpPrefixEvent{}
- }
+ expectRouterEvent := func() (bool, ndpRouterEvent) {
+ select {
+ case e := <-ndpDisp.routerC:
+ return true, e
+ default:
+ }
- expectAutoGenAddrEvent := func() (bool, ndpAutoGenAddrEvent) {
- select {
- case e := <-ndpDisp.autoGenAddrC:
- return true, e
- default:
- }
+ return false, ndpRouterEvent{}
+ }
- return false, ndpAutoGenAddrEvent{}
- }
+ expectPrefixEvent := func() (bool, ndpPrefixEvent) {
+ select {
+ case e := <-ndpDisp.prefixC:
+ return true, e
+ default:
+ }
- e1 := channel.New(0, 1280, linkAddr1)
- if err := s.CreateNIC(nicID1, e1); err != nil {
- t.Fatalf("CreateNIC(%d, _) = %s", nicID1, err)
- }
- // We have other tests that make sure we receive the *correct* events
- // on normal discovery of routers/prefixes, and auto-generated
- // addresses. Here we just make sure we get an event and let other tests
- // handle the correctness check.
- expectAutoGenAddrEvent()
+ return false, ndpPrefixEvent{}
+ }
- e2 := channel.New(0, 1280, linkAddr2)
- if err := s.CreateNIC(nicID2, e2); err != nil {
- t.Fatalf("CreateNIC(%d, _) = %s", nicID2, err)
- }
- expectAutoGenAddrEvent()
+ expectAutoGenAddrEvent := func() (bool, ndpAutoGenAddrEvent) {
+ select {
+ case e := <-ndpDisp.autoGenAddrC:
+ return true, e
+ default:
+ }
- // Receive RAs on NIC(1) and NIC(2) from default routers (llAddr3 and
- // llAddr4) w/ PI (for prefix1 in RA from llAddr3 and prefix2 in RA from
- // llAddr4) to discover multiple routers and prefixes, and auto-gen
- // multiple addresses.
+ return false, ndpAutoGenAddrEvent{}
+ }
- e1.InjectInbound(header.IPv6ProtocolNumber, raBufWithPI(llAddr3, lifetimeSeconds, prefix1, true, true, lifetimeSeconds, lifetimeSeconds))
- if ok, _ := expectRouterEvent(); !ok {
- t.Errorf("expected router event for %s on NIC(%d)", llAddr3, nicID1)
- }
- if ok, _ := expectPrefixEvent(); !ok {
- t.Errorf("expected prefix event for %s on NIC(%d)", prefix1, nicID1)
- }
- if ok, _ := expectAutoGenAddrEvent(); !ok {
- t.Errorf("expected auto-gen addr event for %s on NIC(%d)", e1Addr1, nicID1)
- }
+ e1 := channel.New(0, 1280, linkAddr1)
+ if err := s.CreateNIC(nicID1, e1); err != nil {
+ t.Fatalf("CreateNIC(%d, _) = %s", nicID1, err)
+ }
+ // We have other tests that make sure we receive the *correct* events
+ // on normal discovery of routers/prefixes, and auto-generated
+ // addresses. Here we just make sure we get an event and let other tests
+ // handle the correctness check.
+ expectAutoGenAddrEvent()
- e1.InjectInbound(header.IPv6ProtocolNumber, raBufWithPI(llAddr4, lifetimeSeconds, prefix2, true, true, lifetimeSeconds, lifetimeSeconds))
- if ok, _ := expectRouterEvent(); !ok {
- t.Errorf("expected router event for %s on NIC(%d)", llAddr4, nicID1)
- }
- if ok, _ := expectPrefixEvent(); !ok {
- t.Errorf("expected prefix event for %s on NIC(%d)", prefix2, nicID1)
- }
- if ok, _ := expectAutoGenAddrEvent(); !ok {
- t.Errorf("expected auto-gen addr event for %s on NIC(%d)", e1Addr2, nicID1)
- }
+ e2 := channel.New(0, 1280, linkAddr2)
+ if err := s.CreateNIC(nicID2, e2); err != nil {
+ t.Fatalf("CreateNIC(%d, _) = %s", nicID2, err)
+ }
+ expectAutoGenAddrEvent()
- e2.InjectInbound(header.IPv6ProtocolNumber, raBufWithPI(llAddr3, lifetimeSeconds, prefix1, true, true, lifetimeSeconds, lifetimeSeconds))
- if ok, _ := expectRouterEvent(); !ok {
- t.Errorf("expected router event for %s on NIC(%d)", llAddr3, nicID2)
- }
- if ok, _ := expectPrefixEvent(); !ok {
- t.Errorf("expected prefix event for %s on NIC(%d)", prefix1, nicID2)
- }
- if ok, _ := expectAutoGenAddrEvent(); !ok {
- t.Errorf("expected auto-gen addr event for %s on NIC(%d)", e1Addr2, nicID2)
- }
+ // Receive RAs on NIC(1) and NIC(2) from default routers (llAddr3 and
+ // llAddr4) w/ PI (for prefix1 in RA from llAddr3 and prefix2 in RA from
+ // llAddr4) to discover multiple routers and prefixes, and auto-gen
+ // multiple addresses.
- e2.InjectInbound(header.IPv6ProtocolNumber, raBufWithPI(llAddr4, lifetimeSeconds, prefix2, true, true, lifetimeSeconds, lifetimeSeconds))
- if ok, _ := expectRouterEvent(); !ok {
- t.Errorf("expected router event for %s on NIC(%d)", llAddr4, nicID2)
- }
- if ok, _ := expectPrefixEvent(); !ok {
- t.Errorf("expected prefix event for %s on NIC(%d)", prefix2, nicID2)
- }
- if ok, _ := expectAutoGenAddrEvent(); !ok {
- t.Errorf("expected auto-gen addr event for %s on NIC(%d)", e2Addr2, nicID2)
- }
+ e1.InjectInbound(header.IPv6ProtocolNumber, raBufWithPI(llAddr3, lifetimeSeconds, prefix1, true, true, lifetimeSeconds, lifetimeSeconds))
+ if ok, _ := expectRouterEvent(); !ok {
+ t.Errorf("expected router event for %s on NIC(%d)", llAddr3, nicID1)
+ }
+ if ok, _ := expectPrefixEvent(); !ok {
+ t.Errorf("expected prefix event for %s on NIC(%d)", prefix1, nicID1)
+ }
+ if ok, _ := expectAutoGenAddrEvent(); !ok {
+ t.Errorf("expected auto-gen addr event for %s on NIC(%d)", e1Addr1, nicID1)
+ }
- // We should have the auto-generated addresses added.
- nicinfo := s.NICInfo()
- nic1Addrs := nicinfo[nicID1].ProtocolAddresses
- nic2Addrs := nicinfo[nicID2].ProtocolAddresses
- if !containsV6Addr(nic1Addrs, llAddrWithPrefix1) {
- t.Errorf("missing %s from the list of addresses for NIC(%d): %+v", llAddrWithPrefix1, nicID1, nic1Addrs)
- }
- if !containsV6Addr(nic1Addrs, e1Addr1) {
- t.Errorf("missing %s from the list of addresses for NIC(%d): %+v", e1Addr1, nicID1, nic1Addrs)
- }
- if !containsV6Addr(nic1Addrs, e1Addr2) {
- t.Errorf("missing %s from the list of addresses for NIC(%d): %+v", e1Addr2, nicID1, nic1Addrs)
- }
- if !containsV6Addr(nic2Addrs, llAddrWithPrefix2) {
- t.Errorf("missing %s from the list of addresses for NIC(%d): %+v", llAddrWithPrefix2, nicID2, nic2Addrs)
- }
- if !containsV6Addr(nic2Addrs, e2Addr1) {
- t.Errorf("missing %s from the list of addresses for NIC(%d): %+v", e2Addr1, nicID2, nic2Addrs)
- }
- if !containsV6Addr(nic2Addrs, e2Addr2) {
- t.Errorf("missing %s from the list of addresses for NIC(%d): %+v", e2Addr2, nicID2, nic2Addrs)
- }
+ e1.InjectInbound(header.IPv6ProtocolNumber, raBufWithPI(llAddr4, lifetimeSeconds, prefix2, true, true, lifetimeSeconds, lifetimeSeconds))
+ if ok, _ := expectRouterEvent(); !ok {
+ t.Errorf("expected router event for %s on NIC(%d)", llAddr4, nicID1)
+ }
+ if ok, _ := expectPrefixEvent(); !ok {
+ t.Errorf("expected prefix event for %s on NIC(%d)", prefix2, nicID1)
+ }
+ if ok, _ := expectAutoGenAddrEvent(); !ok {
+ t.Errorf("expected auto-gen addr event for %s on NIC(%d)", e1Addr2, nicID1)
+ }
- // We can't proceed any further if we already failed the test (missing
- // some discovery/auto-generated address events or addresses).
- if t.Failed() {
- t.FailNow()
- }
+ e2.InjectInbound(header.IPv6ProtocolNumber, raBufWithPI(llAddr3, lifetimeSeconds, prefix1, true, true, lifetimeSeconds, lifetimeSeconds))
+ if ok, _ := expectRouterEvent(); !ok {
+ t.Errorf("expected router event for %s on NIC(%d)", llAddr3, nicID2)
+ }
+ if ok, _ := expectPrefixEvent(); !ok {
+ t.Errorf("expected prefix event for %s on NIC(%d)", prefix1, nicID2)
+ }
+ if ok, _ := expectAutoGenAddrEvent(); !ok {
+ t.Errorf("expected auto-gen addr event for %s on NIC(%d)", e1Addr2, nicID2)
+ }
- s.SetForwarding(true)
+ e2.InjectInbound(header.IPv6ProtocolNumber, raBufWithPI(llAddr4, lifetimeSeconds, prefix2, true, true, lifetimeSeconds, lifetimeSeconds))
+ if ok, _ := expectRouterEvent(); !ok {
+ t.Errorf("expected router event for %s on NIC(%d)", llAddr4, nicID2)
+ }
+ if ok, _ := expectPrefixEvent(); !ok {
+ t.Errorf("expected prefix event for %s on NIC(%d)", prefix2, nicID2)
+ }
+ if ok, _ := expectAutoGenAddrEvent(); !ok {
+ t.Errorf("expected auto-gen addr event for %s on NIC(%d)", e2Addr2, nicID2)
+ }
- // Collect invalidation events after becoming a router
- gotRouterEvents := make(map[ndpRouterEvent]int)
- for i := 0; i < maxEvents; i++ {
- ok, e := expectRouterEvent()
- if !ok {
- t.Errorf("expected %d router events after becoming a router; got = %d", maxEvents, i)
- break
- }
- gotRouterEvents[e]++
- }
- gotPrefixEvents := make(map[ndpPrefixEvent]int)
- for i := 0; i < maxEvents; i++ {
- ok, e := expectPrefixEvent()
- if !ok {
- t.Errorf("expected %d prefix events after becoming a router; got = %d", maxEvents, i)
- break
- }
- gotPrefixEvents[e]++
- }
- gotAutoGenAddrEvents := make(map[ndpAutoGenAddrEvent]int)
- for i := 0; i < maxEvents; i++ {
- ok, e := expectAutoGenAddrEvent()
- if !ok {
- t.Errorf("expected %d auto-generated address events after becoming a router; got = %d", maxEvents, i)
- break
- }
- gotAutoGenAddrEvents[e]++
- }
+ // We should have the auto-generated addresses added.
+ nicinfo := s.NICInfo()
+ nic1Addrs := nicinfo[nicID1].ProtocolAddresses
+ nic2Addrs := nicinfo[nicID2].ProtocolAddresses
+ if !containsV6Addr(nic1Addrs, llAddrWithPrefix1) {
+ t.Errorf("missing %s from the list of addresses for NIC(%d): %+v", llAddrWithPrefix1, nicID1, nic1Addrs)
+ }
+ if !containsV6Addr(nic1Addrs, e1Addr1) {
+ t.Errorf("missing %s from the list of addresses for NIC(%d): %+v", e1Addr1, nicID1, nic1Addrs)
+ }
+ if !containsV6Addr(nic1Addrs, e1Addr2) {
+ t.Errorf("missing %s from the list of addresses for NIC(%d): %+v", e1Addr2, nicID1, nic1Addrs)
+ }
+ if !containsV6Addr(nic2Addrs, llAddrWithPrefix2) {
+ t.Errorf("missing %s from the list of addresses for NIC(%d): %+v", llAddrWithPrefix2, nicID2, nic2Addrs)
+ }
+ if !containsV6Addr(nic2Addrs, e2Addr1) {
+ t.Errorf("missing %s from the list of addresses for NIC(%d): %+v", e2Addr1, nicID2, nic2Addrs)
+ }
+ if !containsV6Addr(nic2Addrs, e2Addr2) {
+ t.Errorf("missing %s from the list of addresses for NIC(%d): %+v", e2Addr2, nicID2, nic2Addrs)
+ }
- // No need to proceed any further if we already failed the test (missing
- // some invalidation events).
- if t.Failed() {
- t.FailNow()
- }
+ // We can't proceed any further if we already failed the test (missing
+ // some discovery/auto-generated address events or addresses).
+ if t.Failed() {
+ t.FailNow()
+ }
- expectedRouterEvents := map[ndpRouterEvent]int{
- {nicID: nicID1, addr: llAddr3, discovered: false}: 1,
- {nicID: nicID1, addr: llAddr4, discovered: false}: 1,
- {nicID: nicID2, addr: llAddr3, discovered: false}: 1,
- {nicID: nicID2, addr: llAddr4, discovered: false}: 1,
- }
- if diff := cmp.Diff(expectedRouterEvents, gotRouterEvents); diff != "" {
- t.Errorf("router events mismatch (-want +got):\n%s", diff)
- }
- expectedPrefixEvents := map[ndpPrefixEvent]int{
- {nicID: nicID1, prefix: subnet1, discovered: false}: 1,
- {nicID: nicID1, prefix: subnet2, discovered: false}: 1,
- {nicID: nicID2, prefix: subnet1, discovered: false}: 1,
- {nicID: nicID2, prefix: subnet2, discovered: false}: 1,
- }
- if diff := cmp.Diff(expectedPrefixEvents, gotPrefixEvents); diff != "" {
- t.Errorf("prefix events mismatch (-want +got):\n%s", diff)
- }
- expectedAutoGenAddrEvents := map[ndpAutoGenAddrEvent]int{
- {nicID: nicID1, addr: e1Addr1, eventType: invalidatedAddr}: 1,
- {nicID: nicID1, addr: e1Addr2, eventType: invalidatedAddr}: 1,
- {nicID: nicID2, addr: e2Addr1, eventType: invalidatedAddr}: 1,
- {nicID: nicID2, addr: e2Addr2, eventType: invalidatedAddr}: 1,
- }
- if diff := cmp.Diff(expectedAutoGenAddrEvents, gotAutoGenAddrEvents); diff != "" {
- t.Errorf("auto-generated address events mismatch (-want +got):\n%s", diff)
- }
+ test.cleanupFn(t, s)
- // Make sure the auto-generated addresses got removed.
- nicinfo = s.NICInfo()
- nic1Addrs = nicinfo[nicID1].ProtocolAddresses
- nic2Addrs = nicinfo[nicID2].ProtocolAddresses
- if !containsV6Addr(nic1Addrs, llAddrWithPrefix1) {
- t.Errorf("missing %s from the list of addresses for NIC(%d): %+v", llAddrWithPrefix1, nicID1, nic1Addrs)
- }
- if containsV6Addr(nic1Addrs, e1Addr1) {
- t.Errorf("still have %s in the list of addresses for NIC(%d): %+v", e1Addr1, nicID1, nic1Addrs)
- }
- if containsV6Addr(nic1Addrs, e1Addr2) {
- t.Errorf("still have %s in the list of addresses for NIC(%d): %+v", e1Addr2, nicID1, nic1Addrs)
- }
- if !containsV6Addr(nic2Addrs, llAddrWithPrefix2) {
- t.Errorf("missing %s from the list of addresses for NIC(%d): %+v", llAddrWithPrefix2, nicID2, nic2Addrs)
- }
- if containsV6Addr(nic2Addrs, e2Addr1) {
- t.Errorf("still have %s in the list of addresses for NIC(%d): %+v", e2Addr1, nicID2, nic2Addrs)
- }
- if containsV6Addr(nic2Addrs, e2Addr2) {
- t.Errorf("still have %s in the list of addresses for NIC(%d): %+v", e2Addr2, nicID2, nic2Addrs)
- }
+ // Collect invalidation events after having NDP state cleaned up.
+ gotRouterEvents := make(map[ndpRouterEvent]int)
+ for i := 0; i < maxRouterAndPrefixEvents; i++ {
+ ok, e := expectRouterEvent()
+ if !ok {
+ t.Errorf("expected %d router events after becoming a router; got = %d", maxRouterAndPrefixEvents, i)
+ break
+ }
+ gotRouterEvents[e]++
+ }
+ gotPrefixEvents := make(map[ndpPrefixEvent]int)
+ for i := 0; i < maxRouterAndPrefixEvents; i++ {
+ ok, e := expectPrefixEvent()
+ if !ok {
+ t.Errorf("expected %d prefix events after becoming a router; got = %d", maxRouterAndPrefixEvents, i)
+ break
+ }
+ gotPrefixEvents[e]++
+ }
+ gotAutoGenAddrEvents := make(map[ndpAutoGenAddrEvent]int)
+ for i := 0; i < test.maxAutoGenAddrEvents; i++ {
+ ok, e := expectAutoGenAddrEvent()
+ if !ok {
+ t.Errorf("expected %d auto-generated address events after becoming a router; got = %d", test.maxAutoGenAddrEvents, i)
+ break
+ }
+ gotAutoGenAddrEvents[e]++
+ }
- // Should not get any more events (invalidation timers should have been
- // cancelled when we transitioned into a router).
- time.Sleep(lifetimeSeconds*time.Second + defaultTimeout)
- select {
- case <-ndpDisp.routerC:
- t.Error("unexpected router event")
- default:
- }
- select {
- case <-ndpDisp.prefixC:
- t.Error("unexpected prefix event")
- default:
- }
- select {
- case <-ndpDisp.autoGenAddrC:
- t.Error("unexpected auto-generated address event")
- default:
+ // No need to proceed any further if we already failed the test (missing
+ // some invalidation events).
+ if t.Failed() {
+ t.FailNow()
+ }
+
+ expectedRouterEvents := map[ndpRouterEvent]int{
+ {nicID: nicID1, addr: llAddr3, discovered: false}: 1,
+ {nicID: nicID1, addr: llAddr4, discovered: false}: 1,
+ {nicID: nicID2, addr: llAddr3, discovered: false}: 1,
+ {nicID: nicID2, addr: llAddr4, discovered: false}: 1,
+ }
+ if diff := cmp.Diff(expectedRouterEvents, gotRouterEvents); diff != "" {
+ t.Errorf("router events mismatch (-want +got):\n%s", diff)
+ }
+ expectedPrefixEvents := map[ndpPrefixEvent]int{
+ {nicID: nicID1, prefix: subnet1, discovered: false}: 1,
+ {nicID: nicID1, prefix: subnet2, discovered: false}: 1,
+ {nicID: nicID2, prefix: subnet1, discovered: false}: 1,
+ {nicID: nicID2, prefix: subnet2, discovered: false}: 1,
+ }
+ if diff := cmp.Diff(expectedPrefixEvents, gotPrefixEvents); diff != "" {
+ t.Errorf("prefix events mismatch (-want +got):\n%s", diff)
+ }
+ expectedAutoGenAddrEvents := map[ndpAutoGenAddrEvent]int{
+ {nicID: nicID1, addr: e1Addr1, eventType: invalidatedAddr}: 1,
+ {nicID: nicID1, addr: e1Addr2, eventType: invalidatedAddr}: 1,
+ {nicID: nicID2, addr: e2Addr1, eventType: invalidatedAddr}: 1,
+ {nicID: nicID2, addr: e2Addr2, eventType: invalidatedAddr}: 1,
+ }
+
+ if !test.keepAutoGenLinkLocal {
+ expectedAutoGenAddrEvents[ndpAutoGenAddrEvent{nicID: nicID1, addr: llAddrWithPrefix1, eventType: invalidatedAddr}] = 1
+ expectedAutoGenAddrEvents[ndpAutoGenAddrEvent{nicID: nicID2, addr: llAddrWithPrefix2, eventType: invalidatedAddr}] = 1
+ }
+
+ if diff := cmp.Diff(expectedAutoGenAddrEvents, gotAutoGenAddrEvents); diff != "" {
+ t.Errorf("auto-generated address events mismatch (-want +got):\n%s", diff)
+ }
+
+ // Make sure the auto-generated addresses got removed.
+ nicinfo = s.NICInfo()
+ nic1Addrs = nicinfo[nicID1].ProtocolAddresses
+ nic2Addrs = nicinfo[nicID2].ProtocolAddresses
+ if containsV6Addr(nic1Addrs, llAddrWithPrefix1) != test.keepAutoGenLinkLocal {
+ if test.keepAutoGenLinkLocal {
+ t.Errorf("missing %s from the list of addresses for NIC(%d): %+v", llAddrWithPrefix1, nicID1, nic1Addrs)
+ } else {
+ t.Errorf("still have %s in the list of addresses for NIC(%d): %+v", llAddrWithPrefix1, nicID1, nic1Addrs)
+ }
+ }
+ if containsV6Addr(nic1Addrs, e1Addr1) {
+ t.Errorf("still have %s in the list of addresses for NIC(%d): %+v", e1Addr1, nicID1, nic1Addrs)
+ }
+ if containsV6Addr(nic1Addrs, e1Addr2) {
+ t.Errorf("still have %s in the list of addresses for NIC(%d): %+v", e1Addr2, nicID1, nic1Addrs)
+ }
+ if containsV6Addr(nic2Addrs, llAddrWithPrefix2) != test.keepAutoGenLinkLocal {
+ if test.keepAutoGenLinkLocal {
+ t.Errorf("missing %s from the list of addresses for NIC(%d): %+v", llAddrWithPrefix2, nicID2, nic2Addrs)
+ } else {
+ t.Errorf("still have %s in the list of addresses for NIC(%d): %+v", llAddrWithPrefix2, nicID2, nic2Addrs)
+ }
+ }
+ if containsV6Addr(nic2Addrs, e2Addr1) {
+ t.Errorf("still have %s in the list of addresses for NIC(%d): %+v", e2Addr1, nicID2, nic2Addrs)
+ }
+ if containsV6Addr(nic2Addrs, e2Addr2) {
+ t.Errorf("still have %s in the list of addresses for NIC(%d): %+v", e2Addr2, nicID2, nic2Addrs)
+ }
+
+ // Should not get any more events (invalidation timers should have been
+ // cancelled when the NDP state was cleaned up).
+ time.Sleep(lifetimeSeconds*time.Second + defaultTimeout)
+ select {
+ case <-ndpDisp.routerC:
+ t.Error("unexpected router event")
+ default:
+ }
+ select {
+ case <-ndpDisp.prefixC:
+ t.Error("unexpected prefix event")
+ default:
+ }
+ select {
+ case <-ndpDisp.autoGenAddrC:
+ t.Error("unexpected auto-generated address event")
+ default:
+ }
+ })
}
}
@@ -3259,8 +3379,11 @@ func TestDHCPv6ConfigurationFromNDPDA(t *testing.T) {
func TestRouterSolicitation(t *testing.T) {
t.Parallel()
+ const nicID = 1
+
tests := []struct {
name string
+ linkHeaderLen uint16
maxRtrSolicit uint8
rtrSolicitInt time.Duration
effectiveRtrSolicitInt time.Duration
@@ -3277,6 +3400,7 @@ func TestRouterSolicitation(t *testing.T) {
},
{
name: "Two RS with delay",
+ linkHeaderLen: 1,
maxRtrSolicit: 2,
rtrSolicitInt: time.Second,
effectiveRtrSolicitInt: time.Second,
@@ -3285,6 +3409,7 @@ func TestRouterSolicitation(t *testing.T) {
},
{
name: "Single RS without delay",
+ linkHeaderLen: 2,
maxRtrSolicit: 1,
rtrSolicitInt: time.Second,
effectiveRtrSolicitInt: time.Second,
@@ -3293,6 +3418,7 @@ func TestRouterSolicitation(t *testing.T) {
},
{
name: "Two RS without delay and invalid zero interval",
+ linkHeaderLen: 3,
maxRtrSolicit: 2,
rtrSolicitInt: 0,
effectiveRtrSolicitInt: 4 * time.Second,
@@ -3330,8 +3456,11 @@ func TestRouterSolicitation(t *testing.T) {
t.Run(test.name, func(t *testing.T) {
t.Parallel()
- e := channel.New(int(test.maxRtrSolicit), 1280, linkAddr1)
- e.LinkEPCapabilities |= stack.CapabilityResolutionRequired
+ e := channelLinkWithHeaderLength{
+ Endpoint: channel.New(int(test.maxRtrSolicit), 1280, linkAddr1),
+ headerLength: test.linkHeaderLen,
+ }
+ e.Endpoint.LinkEPCapabilities |= stack.CapabilityResolutionRequired
waitForPkt := func(timeout time.Duration) {
t.Helper()
ctx, _ := context.WithTimeout(context.Background(), timeout)
@@ -3357,6 +3486,10 @@ func TestRouterSolicitation(t *testing.T) {
checker.TTL(header.NDPHopLimit),
checker.NDPRS(),
)
+
+ if l, want := p.Pkt.Header.AvailableLength(), int(test.linkHeaderLen); l != want {
+ t.Errorf("got p.Pkt.Header.AvailableLength() = %d; want = %d", l, want)
+ }
}
waitForNothing := func(timeout time.Duration) {
t.Helper()
@@ -3373,8 +3506,8 @@ func TestRouterSolicitation(t *testing.T) {
MaxRtrSolicitationDelay: test.maxRtrSolicitDelay,
},
})
- if err := s.CreateNIC(1, e); err != nil {
- t.Fatalf("CreateNIC(1) = %s", err)
+ if err := s.CreateNIC(nicID, &e); err != nil {
+ t.Fatalf("CreateNIC(%d, _) = %s", nicID, err)
}
// Make sure each RS got sent at the right
@@ -3406,77 +3539,130 @@ func TestRouterSolicitation(t *testing.T) {
})
}
-// TestStopStartSolicitingRouters tests that when forwarding is enabled or
-// disabled, router solicitations are stopped or started, respecitively.
func TestStopStartSolicitingRouters(t *testing.T) {
t.Parallel()
+ const nicID = 1
const interval = 500 * time.Millisecond
const delay = time.Second
const maxRtrSolicitations = 3
- e := channel.New(maxRtrSolicitations, 1280, linkAddr1)
- waitForPkt := func(timeout time.Duration) {
- t.Helper()
- ctx, _ := context.WithTimeout(context.Background(), timeout)
- p, ok := e.ReadContext(ctx)
- if !ok {
- t.Fatal("timed out waiting for packet")
- return
- }
- if p.Proto != header.IPv6ProtocolNumber {
- t.Fatalf("got Proto = %d, want = %d", p.Proto, header.IPv6ProtocolNumber)
- }
- checker.IPv6(t, p.Pkt.Header.View(),
- checker.SrcAddr(header.IPv6Any),
- checker.DstAddr(header.IPv6AllRoutersMulticastAddress),
- checker.TTL(header.NDPHopLimit),
- checker.NDPRS())
- }
- s := stack.New(stack.Options{
- NetworkProtocols: []stack.NetworkProtocol{ipv6.NewProtocol()},
- NDPConfigs: stack.NDPConfigurations{
- MaxRtrSolicitations: maxRtrSolicitations,
- RtrSolicitationInterval: interval,
- MaxRtrSolicitationDelay: delay,
+ tests := []struct {
+ name string
+ startFn func(t *testing.T, s *stack.Stack)
+ stopFn func(t *testing.T, s *stack.Stack)
+ }{
+ // Tests that when forwarding is enabled or disabled, router solicitations
+ // are stopped or started, respectively.
+ {
+ name: "Forwarding enabled and disabled",
+ startFn: func(t *testing.T, s *stack.Stack) {
+ t.Helper()
+ s.SetForwarding(false)
+ },
+ stopFn: func(t *testing.T, s *stack.Stack) {
+ t.Helper()
+ s.SetForwarding(true)
+ },
},
- })
- if err := s.CreateNIC(1, e); err != nil {
- t.Fatalf("CreateNIC(1) = %s", err)
- }
- // Enable forwarding which should stop router solicitations.
- s.SetForwarding(true)
- ctx, _ := context.WithTimeout(context.Background(), delay+defaultTimeout)
- if _, ok := e.ReadContext(ctx); ok {
- // A single RS may have been sent before forwarding was enabled.
- ctx, _ = context.WithTimeout(context.Background(), interval+defaultTimeout)
- if _, ok = e.ReadContext(ctx); ok {
- t.Fatal("Should not have sent more than one RS message")
- }
- }
+ // Tests that when a NIC is enabled or disabled, router solicitations
+ // are started or stopped, respectively.
+ {
+ name: "NIC disabled and enabled",
+ startFn: func(t *testing.T, s *stack.Stack) {
+ t.Helper()
- // Enabling forwarding again should do nothing.
- s.SetForwarding(true)
- ctx, _ = context.WithTimeout(context.Background(), delay+defaultTimeout)
- if _, ok := e.ReadContext(ctx); ok {
- t.Fatal("unexpectedly got a packet after becoming a router")
- }
+ if err := s.EnableNIC(nicID); err != nil {
+ t.Fatalf("s.EnableNIC(%d): %s", nicID, err)
+ }
+ },
+ stopFn: func(t *testing.T, s *stack.Stack) {
+ t.Helper()
- // Disable forwarding which should start router solicitations.
- s.SetForwarding(false)
- waitForPkt(delay + defaultAsyncEventTimeout)
- waitForPkt(interval + defaultAsyncEventTimeout)
- waitForPkt(interval + defaultAsyncEventTimeout)
- ctx, _ = context.WithTimeout(context.Background(), interval+defaultTimeout)
- if _, ok := e.ReadContext(ctx); ok {
- t.Fatal("unexpectedly got an extra packet after sending out the expected RSs")
+ if err := s.DisableNIC(nicID); err != nil {
+ t.Fatalf("s.DisableNIC(%d): %s", nicID, err)
+ }
+ },
+ },
}
- // Disabling forwarding again should do nothing.
- s.SetForwarding(false)
- ctx, _ = context.WithTimeout(context.Background(), delay+defaultTimeout)
- if _, ok := e.ReadContext(ctx); ok {
- t.Fatal("unexpectedly got a packet after becoming a router")
+ for _, test := range tests {
+ t.Run(test.name, func(t *testing.T) {
+ e := channel.New(maxRtrSolicitations, 1280, linkAddr1)
+ waitForPkt := func(timeout time.Duration) {
+ t.Helper()
+
+ ctx, cancel := context.WithTimeout(context.Background(), timeout)
+ defer cancel()
+ p, ok := e.ReadContext(ctx)
+ if !ok {
+ t.Fatal("timed out waiting for packet")
+ return
+ }
+
+ if p.Proto != header.IPv6ProtocolNumber {
+ t.Fatalf("got Proto = %d, want = %d", p.Proto, header.IPv6ProtocolNumber)
+ }
+ checker.IPv6(t, p.Pkt.Header.View(),
+ checker.SrcAddr(header.IPv6Any),
+ checker.DstAddr(header.IPv6AllRoutersMulticastAddress),
+ checker.TTL(header.NDPHopLimit),
+ checker.NDPRS())
+ }
+ s := stack.New(stack.Options{
+ NetworkProtocols: []stack.NetworkProtocol{ipv6.NewProtocol()},
+ NDPConfigs: stack.NDPConfigurations{
+ MaxRtrSolicitations: maxRtrSolicitations,
+ RtrSolicitationInterval: interval,
+ MaxRtrSolicitationDelay: delay,
+ },
+ })
+ if err := s.CreateNIC(nicID, e); err != nil {
+ t.Fatalf("CreateNIC(%d, _) = %s", nicID, err)
+ }
+
+ // Stop soliciting routers.
+ test.stopFn(t, s)
+ ctx, cancel := context.WithTimeout(context.Background(), delay+defaultTimeout)
+ defer cancel()
+ if _, ok := e.ReadContext(ctx); ok {
+ // A single RS may have been sent before forwarding was enabled.
+ ctx, cancel := context.WithTimeout(context.Background(), interval+defaultTimeout)
+ defer cancel()
+ if _, ok = e.ReadContext(ctx); ok {
+ t.Fatal("should not have sent more than one RS message")
+ }
+ }
+
+ // Stopping router solicitations after it has already been stopped should
+ // do nothing.
+ test.stopFn(t, s)
+ ctx, cancel = context.WithTimeout(context.Background(), delay+defaultTimeout)
+ defer cancel()
+ if _, ok := e.ReadContext(ctx); ok {
+ t.Fatal("unexpectedly got a packet after router solicitation has been stopepd")
+ }
+
+ // Start soliciting routers.
+ test.startFn(t, s)
+ waitForPkt(delay + defaultAsyncEventTimeout)
+ waitForPkt(interval + defaultAsyncEventTimeout)
+ waitForPkt(interval + defaultAsyncEventTimeout)
+ ctx, cancel = context.WithTimeout(context.Background(), interval+defaultTimeout)
+ defer cancel()
+ if _, ok := e.ReadContext(ctx); ok {
+ t.Fatal("unexpectedly got an extra packet after sending out the expected RSs")
+ }
+
+ // Starting router solicitations after it has already completed should do
+ // nothing.
+ test.startFn(t, s)
+ ctx, cancel = context.WithTimeout(context.Background(), delay+defaultTimeout)
+ defer cancel()
+ if _, ok := e.ReadContext(ctx); ok {
+ t.Fatal("unexpectedly got a packet after finishing router solicitations")
+ }
+ })
}
}
diff --git a/pkg/tcpip/stack/nic.go b/pkg/tcpip/stack/nic.go
index a75dc0322..271e55be0 100644
--- a/pkg/tcpip/stack/nic.go
+++ b/pkg/tcpip/stack/nic.go
@@ -28,6 +28,14 @@ import (
"gvisor.dev/gvisor/pkg/tcpip/iptables"
)
+var ipv4BroadcastAddr = tcpip.ProtocolAddress{
+ Protocol: header.IPv4ProtocolNumber,
+ AddressWithPrefix: tcpip.AddressWithPrefix{
+ Address: header.IPv4Broadcast,
+ PrefixLen: 8 * header.IPv4AddressSize,
+ },
+}
+
// NIC represents a "network interface card" to which the networking stack is
// attached.
type NIC struct {
@@ -133,11 +141,77 @@ func newNIC(stack *Stack, id tcpip.NICID, name string, ep LinkEndpoint, ctx NICC
nic.mu.packetEPs[netProto.Number()] = []PacketEndpoint{}
}
+ nic.linkEP.Attach(nic)
+
return nic
}
-// enable enables the NIC. enable will attach the link to its LinkEndpoint and
-// join the IPv6 All-Nodes Multicast address (ff02::1).
+// enabled returns true if n is enabled.
+func (n *NIC) enabled() bool {
+ n.mu.RLock()
+ enabled := n.mu.enabled
+ n.mu.RUnlock()
+ return enabled
+}
+
+// disable disables n.
+//
+// It undoes the work done by enable.
+func (n *NIC) disable() *tcpip.Error {
+ n.mu.RLock()
+ enabled := n.mu.enabled
+ n.mu.RUnlock()
+ if !enabled {
+ return nil
+ }
+
+ n.mu.Lock()
+ defer n.mu.Unlock()
+
+ if !n.mu.enabled {
+ return nil
+ }
+
+ // TODO(b/147015577): Should Routes that are currently bound to n be
+ // invalidated? Currently, Routes will continue to work when a NIC is enabled
+ // again, and applications may not know that the underlying NIC was ever
+ // disabled.
+
+ if _, ok := n.stack.networkProtocols[header.IPv6ProtocolNumber]; ok {
+ n.mu.ndp.stopSolicitingRouters()
+ n.mu.ndp.cleanupState(false /* hostOnly */)
+
+ // Stop DAD for all the unicast IPv6 endpoints that are in the
+ // permanentTentative state.
+ for _, r := range n.mu.endpoints {
+ if addr := r.ep.ID().LocalAddress; r.getKind() == permanentTentative && header.IsV6UnicastAddress(addr) {
+ n.mu.ndp.stopDuplicateAddressDetection(addr)
+ }
+ }
+
+ // The NIC may have already left the multicast group.
+ if err := n.leaveGroupLocked(header.IPv6AllNodesMulticastAddress); err != nil && err != tcpip.ErrBadLocalAddress {
+ return err
+ }
+ }
+
+ if _, ok := n.stack.networkProtocols[header.IPv4ProtocolNumber]; ok {
+ // The address may have already been removed.
+ if err := n.removePermanentAddressLocked(ipv4BroadcastAddr.AddressWithPrefix.Address); err != nil && err != tcpip.ErrBadLocalAddress {
+ return err
+ }
+ }
+
+ n.mu.enabled = false
+ return nil
+}
+
+// enable enables n.
+//
+// If the stack has IPv6 enabled, enable will join the IPv6 All-Nodes Multicast
+// address (ff02::1), start DAD for permanent addresses, and start soliciting
+// routers if the stack is not operating as a router. If the stack is also
+// configured to auto-generate a link-local address, one will be generated.
func (n *NIC) enable() *tcpip.Error {
n.mu.RLock()
enabled := n.mu.enabled
@@ -155,14 +229,9 @@ func (n *NIC) enable() *tcpip.Error {
n.mu.enabled = true
- n.attachLinkEndpoint()
-
// Create an endpoint to receive broadcast packets on this interface.
if _, ok := n.stack.networkProtocols[header.IPv4ProtocolNumber]; ok {
- if _, err := n.addAddressLocked(tcpip.ProtocolAddress{
- Protocol: header.IPv4ProtocolNumber,
- AddressWithPrefix: tcpip.AddressWithPrefix{header.IPv4Broadcast, 8 * header.IPv4AddressSize},
- }, NeverPrimaryEndpoint, permanent, static, false /* deprecated */); err != nil {
+ if _, err := n.addAddressLocked(ipv4BroadcastAddr, NeverPrimaryEndpoint, permanent, static, false /* deprecated */); err != nil {
return err
}
}
@@ -184,6 +253,14 @@ func (n *NIC) enable() *tcpip.Error {
return nil
}
+ // Join the All-Nodes multicast group before starting DAD as responses to DAD
+ // (NDP NS) messages may be sent to the All-Nodes multicast group if the
+ // source address of the NDP NS is the unspecified address, as per RFC 4861
+ // section 7.2.4.
+ if err := n.joinGroupLocked(header.IPv6ProtocolNumber, header.IPv6AllNodesMulticastAddress); err != nil {
+ return err
+ }
+
// Perform DAD on the all the unicast IPv6 endpoints that are in the permanent
// state.
//
@@ -201,10 +278,6 @@ func (n *NIC) enable() *tcpip.Error {
}
}
- if err := n.joinGroupLocked(header.IPv6ProtocolNumber, header.IPv6AllNodesMulticastAddress); err != nil {
- return err
- }
-
// Do not auto-generate an IPv6 link-local address for loopback devices.
if n.stack.autoGenIPv6LinkLocal && !n.isLoopback() {
// The valid and preferred lifetime is infinite for the auto-generated
@@ -226,6 +299,33 @@ func (n *NIC) enable() *tcpip.Error {
return nil
}
+// remove detaches NIC from the link endpoint, and marks existing referenced
+// network endpoints expired. This guarantees no packets between this NIC and
+// the network stack.
+func (n *NIC) remove() *tcpip.Error {
+ n.mu.Lock()
+ defer n.mu.Unlock()
+
+ // Detach from link endpoint, so no packet comes in.
+ n.linkEP.Attach(nil)
+
+ // Remove permanent and permanentTentative addresses, so no packet goes out.
+ var errs []*tcpip.Error
+ for nid, ref := range n.mu.endpoints {
+ switch ref.getKind() {
+ case permanentTentative, permanent:
+ if err := n.removePermanentAddressLocked(nid.LocalAddress); err != nil {
+ errs = append(errs, err)
+ }
+ }
+ }
+ if len(errs) > 0 {
+ return errs[0]
+ }
+
+ return nil
+}
+
// becomeIPv6Router transitions n into an IPv6 router.
//
// When transitioning into an IPv6 router, host-only state (NDP discovered
@@ -235,7 +335,7 @@ func (n *NIC) becomeIPv6Router() {
n.mu.Lock()
defer n.mu.Unlock()
- n.mu.ndp.cleanupHostOnlyState()
+ n.mu.ndp.cleanupState(true /* hostOnly */)
n.mu.ndp.stopSolicitingRouters()
}
@@ -250,12 +350,6 @@ func (n *NIC) becomeIPv6Host() {
n.mu.ndp.startSolicitingRouters()
}
-// attachLinkEndpoint attaches the NIC to the endpoint, which will enable it
-// to start delivering packets.
-func (n *NIC) attachLinkEndpoint() {
- n.linkEP.Attach(n)
-}
-
// setPromiscuousMode enables or disables promiscuous mode.
func (n *NIC) setPromiscuousMode(enable bool) {
n.mu.Lock()
@@ -713,6 +807,7 @@ func (n *NIC) AllAddresses() []tcpip.ProtocolAddress {
case permanentExpired, temporary:
continue
}
+
addrs = append(addrs, tcpip.ProtocolAddress{
Protocol: ref.protocol,
AddressWithPrefix: tcpip.AddressWithPrefix{
@@ -1010,6 +1105,15 @@ func (n *NIC) leaveGroupLocked(addr tcpip.Address) *tcpip.Error {
return nil
}
+// isInGroup returns true if n has joined the multicast group addr.
+func (n *NIC) isInGroup(addr tcpip.Address) bool {
+ n.mu.RLock()
+ joins := n.mu.mcastJoins[NetworkEndpointID{addr}]
+ n.mu.RUnlock()
+
+ return joins != 0
+}
+
func handlePacket(protocol tcpip.NetworkProtocolNumber, dst, src tcpip.Address, localLinkAddr, remotelinkAddr tcpip.LinkAddress, ref *referencedNetworkEndpoint, pkt tcpip.PacketBuffer) {
r := makeRoute(protocol, dst, src, localLinkAddr, ref, false /* handleLocal */, false /* multicastLoop */)
r.RemoteLinkAddress = remotelinkAddr
@@ -1237,6 +1341,11 @@ func (n *NIC) Stack() *Stack {
return n.stack
}
+// LinkEndpoint returns the link endpoint of n.
+func (n *NIC) LinkEndpoint() LinkEndpoint {
+ return n.linkEP
+}
+
// isAddrTentative returns true if addr is tentative on n.
//
// Note that if addr is not associated with n, then this function will return
@@ -1423,7 +1532,7 @@ func (r *referencedNetworkEndpoint) isValidForOutgoing() bool {
//
// r's NIC must be read locked.
func (r *referencedNetworkEndpoint) isValidForOutgoingRLocked() bool {
- return r.getKind() != permanentExpired || r.nic.mu.spoofing
+ return r.nic.mu.enabled && (r.getKind() != permanentExpired || r.nic.mu.spoofing)
}
// decRef decrements the ref count and cleans up the endpoint once it reaches
diff --git a/pkg/tcpip/stack/registration.go b/pkg/tcpip/stack/registration.go
index ec91f60dd..f9fd8f18f 100644
--- a/pkg/tcpip/stack/registration.go
+++ b/pkg/tcpip/stack/registration.go
@@ -74,10 +74,11 @@ type TransportEndpoint interface {
// HandleControlPacket takes ownership of pkt.
HandleControlPacket(id TransportEndpointID, typ ControlType, extra uint32, pkt tcpip.PacketBuffer)
- // Close puts the endpoint in a closed state and frees all resources
- // associated with it. This cleanup may happen asynchronously. Wait can
- // be used to block on this asynchronous cleanup.
- Close()
+ // Abort initiates an expedited endpoint teardown. It puts the endpoint
+ // in a closed state and frees all resources associated with it. This
+ // cleanup may happen asynchronously. Wait can be used to block on this
+ // asynchronous cleanup.
+ Abort()
// Wait waits for any worker goroutines owned by the endpoint to stop.
//
@@ -160,6 +161,13 @@ type TransportProtocol interface {
// Option returns an error if the option is not supported or the
// provided option value is invalid.
Option(option interface{}) *tcpip.Error
+
+ // Close requests that any worker goroutines owned by the protocol
+ // stop.
+ Close()
+
+ // Wait waits for any worker goroutines owned by the protocol to stop.
+ Wait()
}
// TransportDispatcher contains the methods used by the network stack to deliver
@@ -277,7 +285,7 @@ type NetworkProtocol interface {
// DefaultPrefixLen returns the protocol's default prefix length.
DefaultPrefixLen() int
- // ParsePorts returns the source and destination addresses stored in a
+ // ParseAddresses returns the source and destination addresses stored in a
// packet of this protocol.
ParseAddresses(v buffer.View) (src, dst tcpip.Address)
@@ -293,6 +301,13 @@ type NetworkProtocol interface {
// Option returns an error if the option is not supported or the
// provided option value is invalid.
Option(option interface{}) *tcpip.Error
+
+ // Close requests that any worker goroutines owned by the protocol
+ // stop.
+ Close()
+
+ // Wait waits for any worker goroutines owned by the protocol to stop.
+ Wait()
}
// NetworkDispatcher contains the methods used by the network stack to deliver
diff --git a/pkg/tcpip/stack/stack.go b/pkg/tcpip/stack/stack.go
index 6eac16e16..ebb6c5e3b 100644
--- a/pkg/tcpip/stack/stack.go
+++ b/pkg/tcpip/stack/stack.go
@@ -881,6 +881,8 @@ type NICOptions struct {
// CreateNICWithOptions creates a NIC with the provided id, LinkEndpoint, and
// NICOptions. See the documentation on type NICOptions for details on how
// NICs can be configured.
+//
+// LinkEndpoint.Attach will be called to bind ep with a NetworkDispatcher.
func (s *Stack) CreateNICWithOptions(id tcpip.NICID, ep LinkEndpoint, opts NICOptions) *tcpip.Error {
s.mu.Lock()
defer s.mu.Unlock()
@@ -900,7 +902,6 @@ func (s *Stack) CreateNICWithOptions(id tcpip.NICID, ep LinkEndpoint, opts NICOp
}
n := newNIC(s, id, opts.Name, ep, opts.Context)
-
s.nics[id] = n
if !opts.Disabled {
return n.enable()
@@ -910,34 +911,88 @@ func (s *Stack) CreateNICWithOptions(id tcpip.NICID, ep LinkEndpoint, opts NICOp
}
// CreateNIC creates a NIC with the provided id and LinkEndpoint and calls
-// `LinkEndpoint.Attach` to start delivering packets to it.
+// LinkEndpoint.Attach to bind ep with a NetworkDispatcher.
func (s *Stack) CreateNIC(id tcpip.NICID, ep LinkEndpoint) *tcpip.Error {
return s.CreateNICWithOptions(id, ep, NICOptions{})
}
+// GetNICByName gets the NIC specified by name.
+func (s *Stack) GetNICByName(name string) (*NIC, bool) {
+ s.mu.RLock()
+ defer s.mu.RUnlock()
+ for _, nic := range s.nics {
+ if nic.Name() == name {
+ return nic, true
+ }
+ }
+ return nil, false
+}
+
// EnableNIC enables the given NIC so that the link-layer endpoint can start
// delivering packets to it.
func (s *Stack) EnableNIC(id tcpip.NICID) *tcpip.Error {
s.mu.RLock()
defer s.mu.RUnlock()
- nic := s.nics[id]
- if nic == nil {
+ nic, ok := s.nics[id]
+ if !ok {
return tcpip.ErrUnknownNICID
}
return nic.enable()
}
+// DisableNIC disables the given NIC.
+func (s *Stack) DisableNIC(id tcpip.NICID) *tcpip.Error {
+ s.mu.RLock()
+ defer s.mu.RUnlock()
+
+ nic, ok := s.nics[id]
+ if !ok {
+ return tcpip.ErrUnknownNICID
+ }
+
+ return nic.disable()
+}
+
// CheckNIC checks if a NIC is usable.
func (s *Stack) CheckNIC(id tcpip.NICID) bool {
s.mu.RLock()
+ defer s.mu.RUnlock()
+
nic, ok := s.nics[id]
- s.mu.RUnlock()
- if ok {
- return nic.linkEP.IsAttached()
+ if !ok {
+ return false
}
- return false
+
+ return nic.enabled()
+}
+
+// RemoveNIC removes NIC and all related routes from the network stack.
+func (s *Stack) RemoveNIC(id tcpip.NICID) *tcpip.Error {
+ s.mu.Lock()
+ defer s.mu.Unlock()
+
+ nic, ok := s.nics[id]
+ if !ok {
+ return tcpip.ErrUnknownNICID
+ }
+ delete(s.nics, id)
+
+ // Remove routes in-place. n tracks the number of routes written.
+ n := 0
+ for i, r := range s.routeTable {
+ if r.NIC != id {
+ // Keep this route.
+ if i > n {
+ s.routeTable[n] = r
+ }
+ n++
+ }
+ }
+ s.routeTable = s.routeTable[:n]
+
+ return nic.remove()
}
// NICAddressRanges returns a map of NICIDs to their associated subnets.
@@ -989,7 +1044,7 @@ func (s *Stack) NICInfo() map[tcpip.NICID]NICInfo {
for id, nic := range s.nics {
flags := NICStateFlags{
Up: true, // Netstack interfaces are always up.
- Running: nic.linkEP.IsAttached(),
+ Running: nic.enabled(),
Promiscuous: nic.isPromiscuousMode(),
Loopback: nic.isLoopback(),
}
@@ -1151,7 +1206,7 @@ func (s *Stack) FindRoute(id tcpip.NICID, localAddr, remoteAddr tcpip.Address, n
isMulticast := header.IsV4MulticastAddress(remoteAddr) || header.IsV6MulticastAddress(remoteAddr)
needRoute := !(isBroadcast || isMulticast || header.IsV6LinkLocalAddress(remoteAddr))
if id != 0 && !needRoute {
- if nic, ok := s.nics[id]; ok {
+ if nic, ok := s.nics[id]; ok && nic.enabled() {
if ref := s.getRefEP(nic, localAddr, remoteAddr, netProto); ref != nil {
return makeRoute(netProto, ref.ep.ID().LocalAddress, remoteAddr, nic.linkEP.LinkAddress(), ref, s.handleLocal && !nic.isLoopback(), multicastLoop && !nic.isLoopback()), nil
}
@@ -1161,7 +1216,7 @@ func (s *Stack) FindRoute(id tcpip.NICID, localAddr, remoteAddr tcpip.Address, n
if (id != 0 && id != route.NIC) || (len(remoteAddr) != 0 && !route.Destination.Contains(remoteAddr)) {
continue
}
- if nic, ok := s.nics[route.NIC]; ok {
+ if nic, ok := s.nics[route.NIC]; ok && nic.enabled() {
if ref := s.getRefEP(nic, localAddr, remoteAddr, netProto); ref != nil {
if len(remoteAddr) == 0 {
// If no remote address was provided, then the route
@@ -1391,7 +1446,13 @@ func (s *Stack) RestoreCleanupEndpoints(es []TransportEndpoint) {
// Endpoints created or modified during this call may not get closed.
func (s *Stack) Close() {
for _, e := range s.RegisteredEndpoints() {
- e.Close()
+ e.Abort()
+ }
+ for _, p := range s.transportProtocols {
+ p.proto.Close()
+ }
+ for _, p := range s.networkProtocols {
+ p.Close()
}
}
@@ -1409,6 +1470,12 @@ func (s *Stack) Wait() {
for _, e := range s.CleanupEndpoints() {
e.Wait()
}
+ for _, p := range s.transportProtocols {
+ p.proto.Wait()
+ }
+ for _, p := range s.networkProtocols {
+ p.Wait()
+ }
s.mu.RLock()
defer s.mu.RUnlock()
@@ -1614,6 +1681,18 @@ func (s *Stack) LeaveGroup(protocol tcpip.NetworkProtocolNumber, nicID tcpip.NIC
return tcpip.ErrUnknownNICID
}
+// IsInGroup returns true if the NIC with ID nicID has joined the multicast
+// group multicastAddr.
+func (s *Stack) IsInGroup(nicID tcpip.NICID, multicastAddr tcpip.Address) (bool, *tcpip.Error) {
+ s.mu.RLock()
+ defer s.mu.RUnlock()
+
+ if nic, ok := s.nics[nicID]; ok {
+ return nic.isInGroup(multicastAddr), nil
+ }
+ return false, tcpip.ErrUnknownNICID
+}
+
// IPTables returns the stack's iptables.
func (s *Stack) IPTables() iptables.IPTables {
s.tablesMu.RLock()
diff --git a/pkg/tcpip/stack/stack_test.go b/pkg/tcpip/stack/stack_test.go
index 7ba604442..edf6bec52 100644
--- a/pkg/tcpip/stack/stack_test.go
+++ b/pkg/tcpip/stack/stack_test.go
@@ -33,6 +33,7 @@ import (
"gvisor.dev/gvisor/pkg/tcpip/header"
"gvisor.dev/gvisor/pkg/tcpip/link/channel"
"gvisor.dev/gvisor/pkg/tcpip/link/loopback"
+ "gvisor.dev/gvisor/pkg/tcpip/network/ipv4"
"gvisor.dev/gvisor/pkg/tcpip/network/ipv6"
"gvisor.dev/gvisor/pkg/tcpip/stack"
"gvisor.dev/gvisor/pkg/tcpip/transport/udp"
@@ -234,10 +235,33 @@ func (f *fakeNetworkProtocol) Option(option interface{}) *tcpip.Error {
}
}
+// Close implements TransportProtocol.Close.
+func (*fakeNetworkProtocol) Close() {}
+
+// Wait implements TransportProtocol.Wait.
+func (*fakeNetworkProtocol) Wait() {}
+
func fakeNetFactory() stack.NetworkProtocol {
return &fakeNetworkProtocol{}
}
+// linkEPWithMockedAttach is a stack.LinkEndpoint that tests can use to verify
+// that LinkEndpoint.Attach was called.
+type linkEPWithMockedAttach struct {
+ stack.LinkEndpoint
+ attached bool
+}
+
+// Attach implements stack.LinkEndpoint.Attach.
+func (l *linkEPWithMockedAttach) Attach(d stack.NetworkDispatcher) {
+ l.LinkEndpoint.Attach(d)
+ l.attached = true
+}
+
+func (l *linkEPWithMockedAttach) isAttached() bool {
+ return l.attached
+}
+
func TestNetworkReceive(t *testing.T) {
// Create a stack with the fake network protocol, one nic, and two
// addresses attached to it: 1 & 2.
@@ -509,6 +533,296 @@ func testNoRoute(t *testing.T, s *stack.Stack, nic tcpip.NICID, srcAddr, dstAddr
}
}
+// TestAttachToLinkEndpointImmediately tests that a LinkEndpoint is attached to
+// a NetworkDispatcher when the NIC is created.
+func TestAttachToLinkEndpointImmediately(t *testing.T) {
+ const nicID = 1
+
+ tests := []struct {
+ name string
+ nicOpts stack.NICOptions
+ }{
+ {
+ name: "Create enabled NIC",
+ nicOpts: stack.NICOptions{Disabled: false},
+ },
+ {
+ name: "Create disabled NIC",
+ nicOpts: stack.NICOptions{Disabled: true},
+ },
+ }
+
+ for _, test := range tests {
+ t.Run(test.name, func(t *testing.T) {
+ s := stack.New(stack.Options{
+ NetworkProtocols: []stack.NetworkProtocol{fakeNetFactory()},
+ })
+
+ e := linkEPWithMockedAttach{
+ LinkEndpoint: loopback.New(),
+ }
+
+ if err := s.CreateNICWithOptions(nicID, &e, test.nicOpts); err != nil {
+ t.Fatalf("CreateNICWithOptions(%d, _, %+v) = %s", nicID, test.nicOpts, err)
+ }
+ if !e.isAttached() {
+ t.Fatalf("link endpoint not attached to a network disatcher")
+ }
+ })
+ }
+}
+
+func TestDisableUnknownNIC(t *testing.T) {
+ s := stack.New(stack.Options{
+ NetworkProtocols: []stack.NetworkProtocol{fakeNetFactory()},
+ })
+
+ if err := s.DisableNIC(1); err != tcpip.ErrUnknownNICID {
+ t.Fatalf("got s.DisableNIC(1) = %v, want = %s", err, tcpip.ErrUnknownNICID)
+ }
+}
+
+func TestDisabledNICsNICInfoAndCheckNIC(t *testing.T) {
+ const nicID = 1
+
+ s := stack.New(stack.Options{
+ NetworkProtocols: []stack.NetworkProtocol{fakeNetFactory()},
+ })
+
+ e := loopback.New()
+ nicOpts := stack.NICOptions{Disabled: true}
+ if err := s.CreateNICWithOptions(nicID, e, nicOpts); err != nil {
+ t.Fatalf("CreateNICWithOptions(%d, _, %+v) = %s", nicID, nicOpts, err)
+ }
+
+ checkNIC := func(enabled bool) {
+ t.Helper()
+
+ allNICInfo := s.NICInfo()
+ nicInfo, ok := allNICInfo[nicID]
+ if !ok {
+ t.Errorf("entry for %d missing from allNICInfo = %+v", nicID, allNICInfo)
+ } else if nicInfo.Flags.Running != enabled {
+ t.Errorf("got nicInfo.Flags.Running = %t, want = %t", nicInfo.Flags.Running, enabled)
+ }
+
+ if got := s.CheckNIC(nicID); got != enabled {
+ t.Errorf("got s.CheckNIC(%d) = %t, want = %t", nicID, got, enabled)
+ }
+ }
+
+ // NIC should initially report itself as disabled.
+ checkNIC(false)
+
+ if err := s.EnableNIC(nicID); err != nil {
+ t.Fatalf("s.EnableNIC(%d): %s", nicID, err)
+ }
+ checkNIC(true)
+
+ // If the NIC is not reporting a correct enabled status, we cannot trust the
+ // next check so end the test here.
+ if t.Failed() {
+ t.FailNow()
+ }
+
+ if err := s.DisableNIC(nicID); err != nil {
+ t.Fatalf("s.DisableNIC(%d): %s", nicID, err)
+ }
+ checkNIC(false)
+}
+
+func TestRoutesWithDisabledNIC(t *testing.T) {
+ const unspecifiedNIC = 0
+ const nicID1 = 1
+ const nicID2 = 2
+
+ s := stack.New(stack.Options{
+ NetworkProtocols: []stack.NetworkProtocol{fakeNetFactory()},
+ })
+
+ ep1 := channel.New(0, defaultMTU, "")
+ if err := s.CreateNIC(nicID1, ep1); err != nil {
+ t.Fatalf("CreateNIC(%d, _): %s", nicID1, err)
+ }
+
+ addr1 := tcpip.Address("\x01")
+ if err := s.AddAddress(nicID1, fakeNetNumber, addr1); err != nil {
+ t.Fatalf("AddAddress(%d, %d, %s): %s", nicID1, fakeNetNumber, addr1, err)
+ }
+
+ ep2 := channel.New(0, defaultMTU, "")
+ if err := s.CreateNIC(nicID2, ep2); err != nil {
+ t.Fatalf("CreateNIC(%d, _): %s", nicID2, err)
+ }
+
+ addr2 := tcpip.Address("\x02")
+ if err := s.AddAddress(nicID2, fakeNetNumber, addr2); err != nil {
+ t.Fatalf("AddAddress(%d, %d, %s): %s", nicID2, fakeNetNumber, addr2, err)
+ }
+
+ // Set a route table that sends all packets with odd destination
+ // addresses through the first NIC, and all even destination address
+ // through the second one.
+ {
+ subnet0, err := tcpip.NewSubnet("\x00", "\x01")
+ if err != nil {
+ t.Fatal(err)
+ }
+ subnet1, err := tcpip.NewSubnet("\x01", "\x01")
+ if err != nil {
+ t.Fatal(err)
+ }
+ s.SetRouteTable([]tcpip.Route{
+ {Destination: subnet1, Gateway: "\x00", NIC: nicID1},
+ {Destination: subnet0, Gateway: "\x00", NIC: nicID2},
+ })
+ }
+
+ // Test routes to odd address.
+ testRoute(t, s, unspecifiedNIC, "", "\x05", addr1)
+ testRoute(t, s, unspecifiedNIC, addr1, "\x05", addr1)
+ testRoute(t, s, nicID1, addr1, "\x05", addr1)
+
+ // Test routes to even address.
+ testRoute(t, s, unspecifiedNIC, "", "\x06", addr2)
+ testRoute(t, s, unspecifiedNIC, addr2, "\x06", addr2)
+ testRoute(t, s, nicID2, addr2, "\x06", addr2)
+
+ // Disabling NIC1 should result in no routes to odd addresses. Routes to even
+ // addresses should continue to be available as NIC2 is still enabled.
+ if err := s.DisableNIC(nicID1); err != nil {
+ t.Fatalf("s.DisableNIC(%d): %s", nicID1, err)
+ }
+ nic1Dst := tcpip.Address("\x05")
+ testNoRoute(t, s, unspecifiedNIC, "", nic1Dst)
+ testNoRoute(t, s, unspecifiedNIC, addr1, nic1Dst)
+ testNoRoute(t, s, nicID1, addr1, nic1Dst)
+ nic2Dst := tcpip.Address("\x06")
+ testRoute(t, s, unspecifiedNIC, "", nic2Dst, addr2)
+ testRoute(t, s, unspecifiedNIC, addr2, nic2Dst, addr2)
+ testRoute(t, s, nicID2, addr2, nic2Dst, addr2)
+
+ // Disabling NIC2 should result in no routes to even addresses. No route
+ // should be available to any address as routes to odd addresses were made
+ // unavailable by disabling NIC1 above.
+ if err := s.DisableNIC(nicID2); err != nil {
+ t.Fatalf("s.DisableNIC(%d): %s", nicID2, err)
+ }
+ testNoRoute(t, s, unspecifiedNIC, "", nic1Dst)
+ testNoRoute(t, s, unspecifiedNIC, addr1, nic1Dst)
+ testNoRoute(t, s, nicID1, addr1, nic1Dst)
+ testNoRoute(t, s, unspecifiedNIC, "", nic2Dst)
+ testNoRoute(t, s, unspecifiedNIC, addr2, nic2Dst)
+ testNoRoute(t, s, nicID2, addr2, nic2Dst)
+
+ // Enabling NIC1 should make routes to odd addresses available again. Routes
+ // to even addresses should continue to be unavailable as NIC2 is still
+ // disabled.
+ if err := s.EnableNIC(nicID1); err != nil {
+ t.Fatalf("s.EnableNIC(%d): %s", nicID1, err)
+ }
+ testRoute(t, s, unspecifiedNIC, "", nic1Dst, addr1)
+ testRoute(t, s, unspecifiedNIC, addr1, nic1Dst, addr1)
+ testRoute(t, s, nicID1, addr1, nic1Dst, addr1)
+ testNoRoute(t, s, unspecifiedNIC, "", nic2Dst)
+ testNoRoute(t, s, unspecifiedNIC, addr2, nic2Dst)
+ testNoRoute(t, s, nicID2, addr2, nic2Dst)
+}
+
+func TestRouteWritePacketWithDisabledNIC(t *testing.T) {
+ const unspecifiedNIC = 0
+ const nicID1 = 1
+ const nicID2 = 2
+
+ s := stack.New(stack.Options{
+ NetworkProtocols: []stack.NetworkProtocol{fakeNetFactory()},
+ })
+
+ ep1 := channel.New(1, defaultMTU, "")
+ if err := s.CreateNIC(nicID1, ep1); err != nil {
+ t.Fatalf("CreateNIC(%d, _): %s", nicID1, err)
+ }
+
+ addr1 := tcpip.Address("\x01")
+ if err := s.AddAddress(nicID1, fakeNetNumber, addr1); err != nil {
+ t.Fatalf("AddAddress(%d, %d, %s): %s", nicID1, fakeNetNumber, addr1, err)
+ }
+
+ ep2 := channel.New(1, defaultMTU, "")
+ if err := s.CreateNIC(nicID2, ep2); err != nil {
+ t.Fatalf("CreateNIC(%d, _): %s", nicID2, err)
+ }
+
+ addr2 := tcpip.Address("\x02")
+ if err := s.AddAddress(nicID2, fakeNetNumber, addr2); err != nil {
+ t.Fatalf("AddAddress(%d, %d, %s): %s", nicID2, fakeNetNumber, addr2, err)
+ }
+
+ // Set a route table that sends all packets with odd destination
+ // addresses through the first NIC, and all even destination address
+ // through the second one.
+ {
+ subnet0, err := tcpip.NewSubnet("\x00", "\x01")
+ if err != nil {
+ t.Fatal(err)
+ }
+ subnet1, err := tcpip.NewSubnet("\x01", "\x01")
+ if err != nil {
+ t.Fatal(err)
+ }
+ s.SetRouteTable([]tcpip.Route{
+ {Destination: subnet1, Gateway: "\x00", NIC: nicID1},
+ {Destination: subnet0, Gateway: "\x00", NIC: nicID2},
+ })
+ }
+
+ nic1Dst := tcpip.Address("\x05")
+ r1, err := s.FindRoute(nicID1, addr1, nic1Dst, fakeNetNumber, false /* multicastLoop */)
+ if err != nil {
+ t.Errorf("FindRoute(%d, %s, %s, %d, false): %s", nicID1, addr1, nic1Dst, fakeNetNumber, err)
+ }
+ defer r1.Release()
+
+ nic2Dst := tcpip.Address("\x06")
+ r2, err := s.FindRoute(nicID2, addr2, nic2Dst, fakeNetNumber, false /* multicastLoop */)
+ if err != nil {
+ t.Errorf("FindRoute(%d, %s, %s, %d, false): %s", nicID2, addr2, nic2Dst, fakeNetNumber, err)
+ }
+ defer r2.Release()
+
+ // If we failed to get routes r1 or r2, we cannot proceed with the test.
+ if t.Failed() {
+ t.FailNow()
+ }
+
+ buf := buffer.View([]byte{1})
+ testSend(t, r1, ep1, buf)
+ testSend(t, r2, ep2, buf)
+
+ // Writes with Routes that use the disabled NIC1 should fail.
+ if err := s.DisableNIC(nicID1); err != nil {
+ t.Fatalf("s.DisableNIC(%d): %s", nicID1, err)
+ }
+ testFailingSend(t, r1, ep1, buf, tcpip.ErrInvalidEndpointState)
+ testSend(t, r2, ep2, buf)
+
+ // Writes with Routes that use the disabled NIC2 should fail.
+ if err := s.DisableNIC(nicID2); err != nil {
+ t.Fatalf("s.DisableNIC(%d): %s", nicID2, err)
+ }
+ testFailingSend(t, r1, ep1, buf, tcpip.ErrInvalidEndpointState)
+ testFailingSend(t, r2, ep2, buf, tcpip.ErrInvalidEndpointState)
+
+ // Writes with Routes that use the re-enabled NIC1 should succeed.
+ // TODO(b/147015577): Should we instead completely invalidate all Routes that
+ // were bound to a disabled NIC at some point?
+ if err := s.EnableNIC(nicID1); err != nil {
+ t.Fatalf("s.EnableNIC(%d): %s", nicID1, err)
+ }
+ testSend(t, r1, ep1, buf)
+ testFailingSend(t, r2, ep2, buf, tcpip.ErrInvalidEndpointState)
+}
+
func TestRoutes(t *testing.T) {
// Create a stack with the fake network protocol, two nics, and two
// addresses per nic, the first nic has odd address, the second one has
@@ -2173,13 +2487,29 @@ func TestNICAutoGenLinkLocalAddr(t *testing.T) {
e := channel.New(0, 1280, test.linkAddr)
s := stack.New(opts)
- nicOpts := stack.NICOptions{Name: test.nicName}
+ nicOpts := stack.NICOptions{Name: test.nicName, Disabled: true}
if err := s.CreateNICWithOptions(nicID, e, nicOpts); err != nil {
t.Fatalf("CreateNICWithOptions(%d, _, %+v) = %s", nicID, opts, err)
}
- var expectedMainAddr tcpip.AddressWithPrefix
+ // A new disabled NIC should not have any address, even if auto generation
+ // was enabled.
+ allStackAddrs := s.AllAddresses()
+ allNICAddrs, ok := allStackAddrs[nicID]
+ if !ok {
+ t.Fatalf("entry for %d missing from allStackAddrs = %+v", nicID, allStackAddrs)
+ }
+ if l := len(allNICAddrs); l != 0 {
+ t.Fatalf("got len(allNICAddrs) = %d, want = 0", l)
+ }
+ // Enabling the NIC should attempt auto-generation of a link-local
+ // address.
+ if err := s.EnableNIC(nicID); err != nil {
+ t.Fatalf("s.EnableNIC(%d): %s", nicID, err)
+ }
+
+ var expectedMainAddr tcpip.AddressWithPrefix
if test.shouldGen {
expectedMainAddr = tcpip.AddressWithPrefix{
Address: test.expectedAddr,
@@ -2609,6 +2939,111 @@ func TestIPv6SourceAddressSelectionScopeAndSameAddress(t *testing.T) {
}
}
+func TestAddRemoveIPv4BroadcastAddressOnNICEnableDisable(t *testing.T) {
+ const nicID = 1
+
+ e := loopback.New()
+ s := stack.New(stack.Options{
+ NetworkProtocols: []stack.NetworkProtocol{ipv4.NewProtocol()},
+ })
+ nicOpts := stack.NICOptions{Disabled: true}
+ if err := s.CreateNICWithOptions(nicID, e, nicOpts); err != nil {
+ t.Fatalf("CreateNIC(%d, _, %+v) = %s", nicID, nicOpts, err)
+ }
+
+ allStackAddrs := s.AllAddresses()
+ allNICAddrs, ok := allStackAddrs[nicID]
+ if !ok {
+ t.Fatalf("entry for %d missing from allStackAddrs = %+v", nicID, allStackAddrs)
+ }
+ if l := len(allNICAddrs); l != 0 {
+ t.Fatalf("got len(allNICAddrs) = %d, want = 0", l)
+ }
+
+ // Enabling the NIC should add the IPv4 broadcast address.
+ if err := s.EnableNIC(nicID); err != nil {
+ t.Fatalf("s.EnableNIC(%d): %s", nicID, err)
+ }
+ allStackAddrs = s.AllAddresses()
+ allNICAddrs, ok = allStackAddrs[nicID]
+ if !ok {
+ t.Fatalf("entry for %d missing from allStackAddrs = %+v", nicID, allStackAddrs)
+ }
+ if l := len(allNICAddrs); l != 1 {
+ t.Fatalf("got len(allNICAddrs) = %d, want = 1", l)
+ }
+ want := tcpip.ProtocolAddress{
+ Protocol: header.IPv4ProtocolNumber,
+ AddressWithPrefix: tcpip.AddressWithPrefix{
+ Address: header.IPv4Broadcast,
+ PrefixLen: 32,
+ },
+ }
+ if allNICAddrs[0] != want {
+ t.Fatalf("got allNICAddrs[0] = %+v, want = %+v", allNICAddrs[0], want)
+ }
+
+ // Disabling the NIC should remove the IPv4 broadcast address.
+ if err := s.DisableNIC(nicID); err != nil {
+ t.Fatalf("s.DisableNIC(%d): %s", nicID, err)
+ }
+ allStackAddrs = s.AllAddresses()
+ allNICAddrs, ok = allStackAddrs[nicID]
+ if !ok {
+ t.Fatalf("entry for %d missing from allStackAddrs = %+v", nicID, allStackAddrs)
+ }
+ if l := len(allNICAddrs); l != 0 {
+ t.Fatalf("got len(allNICAddrs) = %d, want = 0", l)
+ }
+}
+
+func TestJoinLeaveAllNodesMulticastOnNICEnableDisable(t *testing.T) {
+ const nicID = 1
+
+ e := loopback.New()
+ s := stack.New(stack.Options{
+ NetworkProtocols: []stack.NetworkProtocol{ipv6.NewProtocol()},
+ })
+ nicOpts := stack.NICOptions{Disabled: true}
+ if err := s.CreateNICWithOptions(nicID, e, nicOpts); err != nil {
+ t.Fatalf("CreateNIC(%d, _, %+v) = %s", nicID, nicOpts, err)
+ }
+
+ // Should not be in the IPv6 all-nodes multicast group yet because the NIC has
+ // not been enabled yet.
+ isInGroup, err := s.IsInGroup(nicID, header.IPv6AllNodesMulticastAddress)
+ if err != nil {
+ t.Fatalf("IsInGroup(%d, %s): %s", nicID, header.IPv6AllNodesMulticastAddress, err)
+ }
+ if isInGroup {
+ t.Fatalf("got IsInGroup(%d, %s) = true, want = false", nicID, header.IPv6AllNodesMulticastAddress)
+ }
+
+ // The all-nodes multicast group should be joined when the NIC is enabled.
+ if err := s.EnableNIC(nicID); err != nil {
+ t.Fatalf("s.EnableNIC(%d): %s", nicID, err)
+ }
+ isInGroup, err = s.IsInGroup(nicID, header.IPv6AllNodesMulticastAddress)
+ if err != nil {
+ t.Fatalf("IsInGroup(%d, %s): %s", nicID, header.IPv6AllNodesMulticastAddress, err)
+ }
+ if !isInGroup {
+ t.Fatalf("got IsInGroup(%d, %s) = false, want = true", nicID, header.IPv6AllNodesMulticastAddress)
+ }
+
+ // The all-nodes multicast group should be left when the NIC is disabled.
+ if err := s.DisableNIC(nicID); err != nil {
+ t.Fatalf("s.DisableNIC(%d): %s", nicID, err)
+ }
+ isInGroup, err = s.IsInGroup(nicID, header.IPv6AllNodesMulticastAddress)
+ if err != nil {
+ t.Fatalf("IsInGroup(%d, %s): %s", nicID, header.IPv6AllNodesMulticastAddress, err)
+ }
+ if isInGroup {
+ t.Fatalf("got IsInGroup(%d, %s) = true, want = false", nicID, header.IPv6AllNodesMulticastAddress)
+ }
+}
+
// TestDoDADWhenNICEnabled tests that IPv6 endpoints that were added while a NIC
// was disabled have DAD performed on them when the NIC is enabled.
func TestDoDADWhenNICEnabled(t *testing.T) {
diff --git a/pkg/tcpip/stack/transport_demuxer.go b/pkg/tcpip/stack/transport_demuxer.go
index d686e6eb8..778c0a4d6 100644
--- a/pkg/tcpip/stack/transport_demuxer.go
+++ b/pkg/tcpip/stack/transport_demuxer.go
@@ -306,26 +306,6 @@ func (ep *multiPortEndpoint) handlePacketAll(r *Route, id TransportEndpointID, p
ep.mu.RUnlock() // Don't use defer for performance reasons.
}
-// Close implements stack.TransportEndpoint.Close.
-func (ep *multiPortEndpoint) Close() {
- ep.mu.RLock()
- eps := append([]TransportEndpoint(nil), ep.endpointsArr...)
- ep.mu.RUnlock()
- for _, e := range eps {
- e.Close()
- }
-}
-
-// Wait implements stack.TransportEndpoint.Wait.
-func (ep *multiPortEndpoint) Wait() {
- ep.mu.RLock()
- eps := append([]TransportEndpoint(nil), ep.endpointsArr...)
- ep.mu.RUnlock()
- for _, e := range eps {
- e.Wait()
- }
-}
-
// singleRegisterEndpoint tries to add an endpoint to the multiPortEndpoint
// list. The list might be empty already.
func (ep *multiPortEndpoint) singleRegisterEndpoint(t TransportEndpoint, reusePort bool) *tcpip.Error {
diff --git a/pkg/tcpip/stack/transport_test.go b/pkg/tcpip/stack/transport_test.go
index 869c69a6d..5d1da2f8b 100644
--- a/pkg/tcpip/stack/transport_test.go
+++ b/pkg/tcpip/stack/transport_test.go
@@ -61,6 +61,10 @@ func newFakeTransportEndpoint(s *stack.Stack, proto *fakeTransportProtocol, netP
return &fakeTransportEndpoint{stack: s, TransportEndpointInfo: stack.TransportEndpointInfo{NetProto: netProto}, proto: proto, uniqueID: uniqueID}
}
+func (f *fakeTransportEndpoint) Abort() {
+ f.Close()
+}
+
func (f *fakeTransportEndpoint) Close() {
f.route.Release()
}
@@ -272,7 +276,7 @@ func (f *fakeTransportProtocol) NewEndpoint(stack *stack.Stack, netProto tcpip.N
return newFakeTransportEndpoint(stack, f, netProto, stack.UniqueID()), nil
}
-func (f *fakeTransportProtocol) NewRawEndpoint(stack *stack.Stack, netProto tcpip.NetworkProtocolNumber, _ *waiter.Queue) (tcpip.Endpoint, *tcpip.Error) {
+func (*fakeTransportProtocol) NewRawEndpoint(stack *stack.Stack, netProto tcpip.NetworkProtocolNumber, _ *waiter.Queue) (tcpip.Endpoint, *tcpip.Error) {
return nil, tcpip.ErrUnknownProtocol
}
@@ -310,6 +314,15 @@ func (f *fakeTransportProtocol) Option(option interface{}) *tcpip.Error {
}
}
+// Abort implements TransportProtocol.Abort.
+func (*fakeTransportProtocol) Abort() {}
+
+// Close implements tcpip.Endpoint.Close.
+func (*fakeTransportProtocol) Close() {}
+
+// Wait implements TransportProtocol.Wait.
+func (*fakeTransportProtocol) Wait() {}
+
func fakeTransFactory() stack.TransportProtocol {
return &fakeTransportProtocol{}
}
diff --git a/pkg/tcpip/tcpip.go b/pkg/tcpip/tcpip.go
index 9ca39ce40..3dc5d87d6 100644
--- a/pkg/tcpip/tcpip.go
+++ b/pkg/tcpip/tcpip.go
@@ -323,11 +323,11 @@ type ControlMessages struct {
// TOS is the IPv4 type of service of the associated packet.
TOS uint8
- // HasTClass indicates whether Tclass is valid/set.
+ // HasTClass indicates whether TClass is valid/set.
HasTClass bool
- // Tclass is the IPv6 traffic class of the associated packet.
- TClass int32
+ // TClass is the IPv6 traffic class of the associated packet.
+ TClass uint32
// HasIPPacketInfo indicates whether PacketInfo is set.
HasIPPacketInfo bool
@@ -341,9 +341,15 @@ type ControlMessages struct {
// networking stack.
type Endpoint interface {
// Close puts the endpoint in a closed state and frees all resources
- // associated with it.
+ // associated with it. Close initiates the teardown process, the
+ // Endpoint may not be fully closed when Close returns.
Close()
+ // Abort initiates an expedited endpoint teardown. As compared to
+ // Close, Abort prioritizes closing the Endpoint quickly over cleanly.
+ // Abort is best effort; implementing Abort with Close is acceptable.
+ Abort()
+
// Read reads data from the endpoint and optionally returns the sender.
//
// This method does not block if there is no data pending. It will also
@@ -502,9 +508,13 @@ type WriteOptions struct {
type SockOptBool int
const (
+ // ReceiveTClassOption is used by SetSockOpt/GetSockOpt to specify if the
+ // IPV6_TCLASS ancillary message is passed with incoming packets.
+ ReceiveTClassOption SockOptBool = iota
+
// ReceiveTOSOption is used by SetSockOpt/GetSockOpt to specify if the TOS
// ancillary message is passed with incoming packets.
- ReceiveTOSOption SockOptBool = iota
+ ReceiveTOSOption
// V6OnlyOption is used by {G,S}etSockOptBool to specify whether an IPv6
// socket is to be restricted to sending and receiving IPv6 packets only.
@@ -514,6 +524,9 @@ const (
// if more inforamtion is provided with incoming packets such
// as interface index and address.
ReceiveIPPacketInfoOption
+
+ // TODO(b/146901447): convert existing bool socket options to be handled via
+ // Get/SetSockOptBool
)
// SockOptInt represents socket options which values have the int type.
diff --git a/pkg/tcpip/time_unsafe.go b/pkg/tcpip/time_unsafe.go
index 48764b978..2f98a996f 100644
--- a/pkg/tcpip/time_unsafe.go
+++ b/pkg/tcpip/time_unsafe.go
@@ -25,6 +25,8 @@ import (
)
// StdClock implements Clock with the time package.
+//
+// +stateify savable
type StdClock struct{}
var _ Clock = (*StdClock)(nil)
diff --git a/pkg/tcpip/transport/icmp/endpoint.go b/pkg/tcpip/transport/icmp/endpoint.go
index 42afb3f5b..426da1ee6 100644
--- a/pkg/tcpip/transport/icmp/endpoint.go
+++ b/pkg/tcpip/transport/icmp/endpoint.go
@@ -96,6 +96,11 @@ func (e *endpoint) UniqueID() uint64 {
return e.uniqueID
}
+// Abort implements stack.TransportEndpoint.Abort.
+func (e *endpoint) Abort() {
+ e.Close()
+}
+
// Close puts the endpoint in a closed state and frees all resources
// associated with it.
func (e *endpoint) Close() {
diff --git a/pkg/tcpip/transport/icmp/protocol.go b/pkg/tcpip/transport/icmp/protocol.go
index 9ce500e80..113d92901 100644
--- a/pkg/tcpip/transport/icmp/protocol.go
+++ b/pkg/tcpip/transport/icmp/protocol.go
@@ -104,20 +104,26 @@ func (p *protocol) ParsePorts(v buffer.View) (src, dst uint16, err *tcpip.Error)
// HandleUnknownDestinationPacket handles packets targeted at this protocol but
// that don't match any existing endpoint.
-func (p *protocol) HandleUnknownDestinationPacket(*stack.Route, stack.TransportEndpointID, tcpip.PacketBuffer) bool {
+func (*protocol) HandleUnknownDestinationPacket(*stack.Route, stack.TransportEndpointID, tcpip.PacketBuffer) bool {
return true
}
-// SetOption implements TransportProtocol.SetOption.
-func (p *protocol) SetOption(option interface{}) *tcpip.Error {
+// SetOption implements stack.TransportProtocol.SetOption.
+func (*protocol) SetOption(option interface{}) *tcpip.Error {
return tcpip.ErrUnknownProtocolOption
}
-// Option implements TransportProtocol.Option.
-func (p *protocol) Option(option interface{}) *tcpip.Error {
+// Option implements stack.TransportProtocol.Option.
+func (*protocol) Option(option interface{}) *tcpip.Error {
return tcpip.ErrUnknownProtocolOption
}
+// Close implements stack.TransportProtocol.Close.
+func (*protocol) Close() {}
+
+// Wait implements stack.TransportProtocol.Wait.
+func (*protocol) Wait() {}
+
// NewProtocol4 returns an ICMPv4 transport protocol.
func NewProtocol4() stack.TransportProtocol {
return &protocol{ProtocolNumber4}
diff --git a/pkg/tcpip/transport/packet/endpoint.go b/pkg/tcpip/transport/packet/endpoint.go
index fc5bc69fa..5722815e9 100644
--- a/pkg/tcpip/transport/packet/endpoint.go
+++ b/pkg/tcpip/transport/packet/endpoint.go
@@ -98,6 +98,11 @@ func NewEndpoint(s *stack.Stack, cooked bool, netProto tcpip.NetworkProtocolNumb
return ep, nil
}
+// Abort implements stack.TransportEndpoint.Abort.
+func (e *endpoint) Abort() {
+ e.Close()
+}
+
// Close implements tcpip.Endpoint.Close.
func (ep *endpoint) Close() {
ep.mu.Lock()
diff --git a/pkg/tcpip/transport/raw/endpoint.go b/pkg/tcpip/transport/raw/endpoint.go
index ee9c4c58b..2ef5fac76 100644
--- a/pkg/tcpip/transport/raw/endpoint.go
+++ b/pkg/tcpip/transport/raw/endpoint.go
@@ -121,6 +121,11 @@ func newEndpoint(s *stack.Stack, netProto tcpip.NetworkProtocolNumber, transProt
return e, nil
}
+// Abort implements stack.TransportEndpoint.Abort.
+func (e *endpoint) Abort() {
+ e.Close()
+}
+
// Close implements tcpip.Endpoint.Close.
func (e *endpoint) Close() {
e.mu.Lock()
diff --git a/pkg/tcpip/transport/tcp/accept.go b/pkg/tcpip/transport/tcp/accept.go
index 08afb7c17..13e383ffc 100644
--- a/pkg/tcpip/transport/tcp/accept.go
+++ b/pkg/tcpip/transport/tcp/accept.go
@@ -299,6 +299,13 @@ func (l *listenContext) createEndpointAndPerformHandshake(s *segment, opts *head
h := newPassiveHandshake(ep, seqnum.Size(ep.initialReceiveWindow()), isn, irs, opts, deferAccept)
if err := h.execute(); err != nil {
ep.Close()
+ // Wake up any waiters. This is strictly not required normally
+ // as a socket that was never accepted can't really have any
+ // registered waiters except when stack.Wait() is called which
+ // waits for all registered endpoints to stop and expects an
+ // EventHUp.
+ ep.waiterQueue.Notify(waiter.EventHUp | waiter.EventErr | waiter.EventIn | waiter.EventOut)
+
if l.listenEP != nil {
l.removePendingEndpoint(ep)
}
@@ -607,7 +614,7 @@ func (e *endpoint) protocolListenLoop(rcvWnd seqnum.Size) *tcpip.Error {
e.mu.Unlock()
// Notify waiters that the endpoint is shutdown.
- e.waiterQueue.Notify(waiter.EventIn | waiter.EventOut)
+ e.waiterQueue.Notify(waiter.EventIn | waiter.EventOut | waiter.EventHUp | waiter.EventErr)
}()
s := sleep.Sleeper{}
diff --git a/pkg/tcpip/transport/tcp/connect.go b/pkg/tcpip/transport/tcp/connect.go
index 5c5397823..7730e6445 100644
--- a/pkg/tcpip/transport/tcp/connect.go
+++ b/pkg/tcpip/transport/tcp/connect.go
@@ -1372,7 +1372,7 @@ func (e *endpoint) protocolMainLoop(handshake bool, wakerInitDone chan<- struct{
e.snd.updateMaxPayloadSize(mtu, count)
}
- if n&notifyReset != 0 {
+ if n&notifyReset != 0 || n&notifyAbort != 0 {
return tcpip.ErrConnectionAborted
}
@@ -1655,7 +1655,7 @@ func (e *endpoint) doTimeWait() (twReuse func()) {
}
case notification:
n := e.fetchNotifications()
- if n&notifyClose != 0 {
+ if n&notifyClose != 0 || n&notifyAbort != 0 {
return nil
}
if n&notifyDrain != 0 {
diff --git a/pkg/tcpip/transport/tcp/dispatcher.go b/pkg/tcpip/transport/tcp/dispatcher.go
index e18012ac0..d792b07d6 100644
--- a/pkg/tcpip/transport/tcp/dispatcher.go
+++ b/pkg/tcpip/transport/tcp/dispatcher.go
@@ -68,17 +68,28 @@ func (q *epQueue) empty() bool {
type processor struct {
epQ epQueue
newEndpointWaker sleep.Waker
+ closeWaker sleep.Waker
id int
+ wg sync.WaitGroup
}
func newProcessor(id int) *processor {
p := &processor{
id: id,
}
+ p.wg.Add(1)
go p.handleSegments()
return p
}
+func (p *processor) close() {
+ p.closeWaker.Assert()
+}
+
+func (p *processor) wait() {
+ p.wg.Wait()
+}
+
func (p *processor) queueEndpoint(ep *endpoint) {
// Queue an endpoint for processing by the processor goroutine.
p.epQ.enqueue(ep)
@@ -87,11 +98,17 @@ func (p *processor) queueEndpoint(ep *endpoint) {
func (p *processor) handleSegments() {
const newEndpointWaker = 1
+ const closeWaker = 2
s := sleep.Sleeper{}
s.AddWaker(&p.newEndpointWaker, newEndpointWaker)
+ s.AddWaker(&p.closeWaker, closeWaker)
defer s.Done()
for {
- s.Fetch(true)
+ id, ok := s.Fetch(true)
+ if ok && id == closeWaker {
+ p.wg.Done()
+ return
+ }
for ep := p.epQ.dequeue(); ep != nil; ep = p.epQ.dequeue() {
if ep.segmentQueue.empty() {
continue
@@ -160,6 +177,18 @@ func newDispatcher(nProcessors int) *dispatcher {
}
}
+func (d *dispatcher) close() {
+ for _, p := range d.processors {
+ p.close()
+ }
+}
+
+func (d *dispatcher) wait() {
+ for _, p := range d.processors {
+ p.wait()
+ }
+}
+
func (d *dispatcher) queuePacket(r *stack.Route, stackEP stack.TransportEndpoint, id stack.TransportEndpointID, pkt tcpip.PacketBuffer) {
ep := stackEP.(*endpoint)
s := newSegment(r, id, pkt)
diff --git a/pkg/tcpip/transport/tcp/endpoint.go b/pkg/tcpip/transport/tcp/endpoint.go
index f2be0e651..f1ad19dac 100644
--- a/pkg/tcpip/transport/tcp/endpoint.go
+++ b/pkg/tcpip/transport/tcp/endpoint.go
@@ -121,6 +121,8 @@ const (
notifyDrain
notifyReset
notifyResetByPeer
+ // notifyAbort is a request for an expedited teardown.
+ notifyAbort
notifyKeepaliveChanged
notifyMSSChanged
// notifyTickleWorker is used to tickle the protocol main loop during a
@@ -785,6 +787,24 @@ func (e *endpoint) notifyProtocolGoroutine(n uint32) {
}
}
+// Abort implements stack.TransportEndpoint.Abort.
+func (e *endpoint) Abort() {
+ // The abort notification is not processed synchronously, so no
+ // synchronization is needed.
+ //
+ // If the endpoint becomes connected after this check, we still close
+ // the endpoint. This worst case results in a slower abort.
+ //
+ // If the endpoint disconnected after the check, nothing needs to be
+ // done, so sending a notification which will potentially be ignored is
+ // fine.
+ if e.EndpointState().connected() {
+ e.notifyProtocolGoroutine(notifyAbort)
+ return
+ }
+ e.Close()
+}
+
// Close puts the endpoint in a closed state and frees all resources associated
// with it. It must be called only once and with no other concurrent calls to
// the endpoint.
@@ -829,9 +849,18 @@ func (e *endpoint) closeNoShutdown() {
// Either perform the local cleanup or kick the worker to make sure it
// knows it needs to cleanup.
tcpip.AddDanglingEndpoint(e)
- if !e.workerRunning {
+ switch e.EndpointState() {
+ // Sockets in StateSynRecv state(passive connections) are closed when
+ // the handshake fails or if the listening socket is closed while
+ // handshake was in progress. In such cases the handshake goroutine
+ // is already gone by the time Close is called and we need to cleanup
+ // here.
+ case StateInitial, StateBound, StateSynRecv:
e.cleanupLocked()
- } else {
+ e.setEndpointState(StateClose)
+ case StateError, StateClose:
+ // do nothing.
+ default:
e.workerCleanup = true
e.notifyProtocolGoroutine(notifyClose)
}
diff --git a/pkg/tcpip/transport/tcp/protocol.go b/pkg/tcpip/transport/tcp/protocol.go
index 958c06fa7..73098d904 100644
--- a/pkg/tcpip/transport/tcp/protocol.go
+++ b/pkg/tcpip/transport/tcp/protocol.go
@@ -194,7 +194,7 @@ func replyWithReset(s *segment) {
sendTCP(&s.route, s.id, buffer.VectorisedView{}, s.route.DefaultTTL(), stack.DefaultTOS, flags, seq, ack, 0 /* rcvWnd */, nil /* options */, nil /* gso */)
}
-// SetOption implements TransportProtocol.SetOption.
+// SetOption implements stack.TransportProtocol.SetOption.
func (p *protocol) SetOption(option interface{}) *tcpip.Error {
switch v := option.(type) {
case SACKEnabled:
@@ -269,7 +269,7 @@ func (p *protocol) SetOption(option interface{}) *tcpip.Error {
}
}
-// Option implements TransportProtocol.Option.
+// Option implements stack.TransportProtocol.Option.
func (p *protocol) Option(option interface{}) *tcpip.Error {
switch v := option.(type) {
case *SACKEnabled:
@@ -331,6 +331,16 @@ func (p *protocol) Option(option interface{}) *tcpip.Error {
}
}
+// Close implements stack.TransportProtocol.Close.
+func (p *protocol) Close() {
+ p.dispatcher.close()
+}
+
+// Wait implements stack.TransportProtocol.Wait.
+func (p *protocol) Wait() {
+ p.dispatcher.wait()
+}
+
// NewProtocol returns a TCP transport protocol.
func NewProtocol() stack.TransportProtocol {
return &protocol{
diff --git a/pkg/tcpip/transport/tcp/tcp_test.go b/pkg/tcpip/transport/tcp/tcp_test.go
index cc118c993..5b2b16afa 100644
--- a/pkg/tcpip/transport/tcp/tcp_test.go
+++ b/pkg/tcpip/transport/tcp/tcp_test.go
@@ -543,8 +543,9 @@ func TestCurrentConnectedIncrement(t *testing.T) {
),
)
- // Wait for the TIME-WAIT state to transition to CLOSED.
- time.Sleep(1 * time.Second)
+ // Wait for a little more than the TIME-WAIT duration for the socket to
+ // transition to CLOSED state.
+ time.Sleep(1200 * time.Millisecond)
if got := c.Stack().Stats().TCP.CurrentEstablished.Value(); got != 0 {
t.Errorf("got stats.TCP.CurrentEstablished.Value() = %v, want = 0", got)
diff --git a/pkg/tcpip/transport/udp/endpoint.go b/pkg/tcpip/transport/udp/endpoint.go
index 3fe91cac2..1c6a600b8 100644
--- a/pkg/tcpip/transport/udp/endpoint.go
+++ b/pkg/tcpip/transport/udp/endpoint.go
@@ -32,7 +32,8 @@ type udpPacket struct {
packetInfo tcpip.IPPacketInfo
data buffer.VectorisedView `state:".(buffer.VectorisedView)"`
timestamp int64
- tos uint8
+ // tos stores either the receiveTOS or receiveTClass value.
+ tos uint8
}
// EndpointState represents the state of a UDP endpoint.
@@ -119,6 +120,10 @@ type endpoint struct {
// as ancillary data to ControlMessages on Read.
receiveTOS bool
+ // receiveTClass determines if the incoming IPv6 TClass header field is
+ // passed as ancillary data to ControlMessages on Read.
+ receiveTClass bool
+
// receiveIPPacketInfo determines if the packet info is returned by Read.
receiveIPPacketInfo bool
@@ -181,6 +186,11 @@ func (e *endpoint) UniqueID() uint64 {
return e.uniqueID
}
+// Abort implements stack.TransportEndpoint.Abort.
+func (e *endpoint) Abort() {
+ e.Close()
+}
+
// Close puts the endpoint in a closed state and frees all resources
// associated with it.
func (e *endpoint) Close() {
@@ -258,13 +268,18 @@ func (e *endpoint) Read(addr *tcpip.FullAddress) (buffer.View, tcpip.ControlMess
}
e.mu.RLock()
receiveTOS := e.receiveTOS
+ receiveTClass := e.receiveTClass
receiveIPPacketInfo := e.receiveIPPacketInfo
e.mu.RUnlock()
if receiveTOS {
cm.HasTOS = true
cm.TOS = p.tos
}
-
+ if receiveTClass {
+ cm.HasTClass = true
+ // Although TClass is an 8-bit value it's read in the CMsg as a uint32.
+ cm.TClass = uint32(p.tos)
+ }
if receiveIPPacketInfo {
cm.HasIPPacketInfo = true
cm.PacketInfo = p.packetInfo
@@ -490,6 +505,17 @@ func (e *endpoint) SetSockOptBool(opt tcpip.SockOptBool, v bool) *tcpip.Error {
e.mu.Unlock()
return nil
+ case tcpip.ReceiveTClassOption:
+ // We only support this option on v6 endpoints.
+ if e.NetProto != header.IPv6ProtocolNumber {
+ return tcpip.ErrNotSupported
+ }
+
+ e.mu.Lock()
+ e.receiveTClass = v
+ e.mu.Unlock()
+ return nil
+
case tcpip.V6OnlyOption:
// We only recognize this option on v6 endpoints.
if e.NetProto != header.IPv6ProtocolNumber {
@@ -709,6 +735,17 @@ func (e *endpoint) GetSockOptBool(opt tcpip.SockOptBool) (bool, *tcpip.Error) {
e.mu.RUnlock()
return v, nil
+ case tcpip.ReceiveTClassOption:
+ // We only support this option on v6 endpoints.
+ if e.NetProto != header.IPv6ProtocolNumber {
+ return false, tcpip.ErrNotSupported
+ }
+
+ e.mu.RLock()
+ v := e.receiveTClass
+ e.mu.RUnlock()
+ return v, nil
+
case tcpip.V6OnlyOption:
// We only recognize this option on v6 endpoints.
if e.NetProto != header.IPv6ProtocolNumber {
@@ -1273,6 +1310,8 @@ func (e *endpoint) HandlePacket(r *stack.Route, id stack.TransportEndpointID, pk
packet.packetInfo.LocalAddr = r.LocalAddress
packet.packetInfo.DestinationAddr = r.RemoteAddress
packet.packetInfo.NIC = r.NICID()
+ case header.IPv6ProtocolNumber:
+ packet.tos, _ = header.IPv6(pkt.NetworkHeader).TOS()
}
packet.timestamp = e.stack.NowNanoseconds()
diff --git a/pkg/tcpip/transport/udp/protocol.go b/pkg/tcpip/transport/udp/protocol.go
index 259c3072a..8df089d22 100644
--- a/pkg/tcpip/transport/udp/protocol.go
+++ b/pkg/tcpip/transport/udp/protocol.go
@@ -180,16 +180,22 @@ func (p *protocol) HandleUnknownDestinationPacket(r *stack.Route, id stack.Trans
return true
}
-// SetOption implements TransportProtocol.SetOption.
-func (p *protocol) SetOption(option interface{}) *tcpip.Error {
+// SetOption implements stack.TransportProtocol.SetOption.
+func (*protocol) SetOption(option interface{}) *tcpip.Error {
return tcpip.ErrUnknownProtocolOption
}
-// Option implements TransportProtocol.Option.
-func (p *protocol) Option(option interface{}) *tcpip.Error {
+// Option implements stack.TransportProtocol.Option.
+func (*protocol) Option(option interface{}) *tcpip.Error {
return tcpip.ErrUnknownProtocolOption
}
+// Close implements stack.TransportProtocol.Close.
+func (*protocol) Close() {}
+
+// Wait implements stack.TransportProtocol.Wait.
+func (*protocol) Wait() {}
+
// NewProtocol returns a UDP transport protocol.
func NewProtocol() stack.TransportProtocol {
return &protocol{}
diff --git a/pkg/tcpip/transport/udp/udp_test.go b/pkg/tcpip/transport/udp/udp_test.go
index f0ff3fe71..34b7c2360 100644
--- a/pkg/tcpip/transport/udp/udp_test.go
+++ b/pkg/tcpip/transport/udp/udp_test.go
@@ -409,6 +409,7 @@ func (c *testContext) injectV6Packet(payload []byte, h *header4Tuple, valid bool
// Initialize the IP header.
ip := header.IPv6(buf)
ip.Encode(&header.IPv6Fields{
+ TrafficClass: testTOS,
PayloadLength: uint16(header.UDPMinimumSize + len(payload)),
NextHeader: uint8(udp.ProtocolNumber),
HopLimit: 65,
@@ -1336,7 +1337,7 @@ func TestSetTTL(t *testing.T) {
}
}
-func TestTOSV4(t *testing.T) {
+func TestSetTOS(t *testing.T) {
for _, flow := range []testFlow{unicastV4, multicastV4, broadcast} {
t.Run(fmt.Sprintf("flow:%s", flow), func(t *testing.T) {
c := newDualTestContext(t, defaultMTU)
@@ -1347,23 +1348,23 @@ func TestTOSV4(t *testing.T) {
const tos = testTOS
var v tcpip.IPv4TOSOption
if err := c.ep.GetSockOpt(&v); err != nil {
- c.t.Errorf("GetSockopt failed: %s", err)
+ c.t.Errorf("GetSockopt(%T) failed: %s", v, err)
}
// Test for expected default value.
if v != 0 {
- c.t.Errorf("got GetSockOpt(...) = %#v, want = %#v", v, 0)
+ c.t.Errorf("got GetSockOpt(%T) = 0x%x, want = 0x%x", v, v, 0)
}
if err := c.ep.SetSockOpt(tcpip.IPv4TOSOption(tos)); err != nil {
- c.t.Errorf("SetSockOpt(%#v) failed: %s", tcpip.IPv4TOSOption(tos), err)
+ c.t.Errorf("SetSockOpt(%T, 0x%x) failed: %s", v, tcpip.IPv4TOSOption(tos), err)
}
if err := c.ep.GetSockOpt(&v); err != nil {
- c.t.Errorf("GetSockopt failed: %s", err)
+ c.t.Errorf("GetSockopt(%T) failed: %s", v, err)
}
if want := tcpip.IPv4TOSOption(tos); v != want {
- c.t.Errorf("got GetSockOpt(...) = %#v, want = %#v", v, want)
+ c.t.Errorf("got GetSockOpt(%T) = 0x%x, want = 0x%x", v, v, want)
}
testWrite(c, flow, checker.TOS(tos, 0))
@@ -1371,7 +1372,7 @@ func TestTOSV4(t *testing.T) {
}
}
-func TestTOSV6(t *testing.T) {
+func TestSetTClass(t *testing.T) {
for _, flow := range []testFlow{unicastV4in6, unicastV6, unicastV6Only, multicastV4in6, multicastV6, broadcastIn6} {
t.Run(fmt.Sprintf("flow:%s", flow), func(t *testing.T) {
c := newDualTestContext(t, defaultMTU)
@@ -1379,71 +1380,92 @@ func TestTOSV6(t *testing.T) {
c.createEndpointForFlow(flow)
- const tos = testTOS
+ const tClass = testTOS
var v tcpip.IPv6TrafficClassOption
if err := c.ep.GetSockOpt(&v); err != nil {
- c.t.Errorf("GetSockopt failed: %s", err)
+ c.t.Errorf("GetSockopt(%T) failed: %s", v, err)
}
// Test for expected default value.
if v != 0 {
- c.t.Errorf("got GetSockOpt(...) = %#v, want = %#v", v, 0)
+ c.t.Errorf("got GetSockOpt(%T) = 0x%x, want = 0x%x", v, v, 0)
}
- if err := c.ep.SetSockOpt(tcpip.IPv6TrafficClassOption(tos)); err != nil {
- c.t.Errorf("SetSockOpt failed: %s", err)
+ if err := c.ep.SetSockOpt(tcpip.IPv6TrafficClassOption(tClass)); err != nil {
+ c.t.Errorf("SetSockOpt(%T, 0x%x) failed: %s", v, tcpip.IPv6TrafficClassOption(tClass), err)
}
if err := c.ep.GetSockOpt(&v); err != nil {
- c.t.Errorf("GetSockopt failed: %s", err)
+ c.t.Errorf("GetSockopt(%T) failed: %s", v, err)
}
- if want := tcpip.IPv6TrafficClassOption(tos); v != want {
- c.t.Errorf("got GetSockOpt(...) = %#v, want = %#v", v, want)
+ if want := tcpip.IPv6TrafficClassOption(tClass); v != want {
+ c.t.Errorf("got GetSockOpt(%T) = 0x%x, want = 0x%x", v, v, want)
}
- testWrite(c, flow, checker.TOS(tos, 0))
+ // The header getter for TClass is called TOS, so use that checker.
+ testWrite(c, flow, checker.TOS(tClass, 0))
})
}
}
-func TestReceiveTOSV4(t *testing.T) {
- for _, flow := range []testFlow{unicastV4, broadcast} {
- t.Run(fmt.Sprintf("flow:%s", flow), func(t *testing.T) {
- c := newDualTestContext(t, defaultMTU)
- defer c.cleanup()
+func TestReceiveTosTClass(t *testing.T) {
+ testCases := []struct {
+ name string
+ getReceiveOption tcpip.SockOptBool
+ tests []testFlow
+ }{
+ {"ReceiveTosOption", tcpip.ReceiveTOSOption, []testFlow{unicastV4, broadcast}},
+ {"ReceiveTClassOption", tcpip.ReceiveTClassOption, []testFlow{unicastV4in6, unicastV6, unicastV6Only, broadcastIn6}},
+ }
+ for _, testCase := range testCases {
+ for _, flow := range testCase.tests {
+ t.Run(fmt.Sprintf("%s:flow:%s", testCase.name, flow), func(t *testing.T) {
+ c := newDualTestContext(t, defaultMTU)
+ defer c.cleanup()
- c.createEndpointForFlow(flow)
+ c.createEndpointForFlow(flow)
+ option := testCase.getReceiveOption
+ name := testCase.name
- // Verify that setting and reading the option works.
- v, err := c.ep.GetSockOptBool(tcpip.ReceiveTOSOption)
- if err != nil {
- c.t.Fatal("GetSockOptBool(tcpip.ReceiveTOSOption) failed:", err)
- }
- // Test for expected default value.
- if v != false {
- c.t.Errorf("got GetSockOptBool(tcpip.ReceiveTOSOption) = %t, want = %t", v, false)
- }
+ // Verify that setting and reading the option works.
+ v, err := c.ep.GetSockOptBool(option)
+ if err != nil {
+ c.t.Errorf("GetSockoptBool(%s) failed: %s", name, err)
+ }
+ // Test for expected default value.
+ if v != false {
+ c.t.Errorf("got GetSockOptBool(%s) = %t, want = %t", name, v, false)
+ }
- want := true
- if err := c.ep.SetSockOptBool(tcpip.ReceiveTOSOption, want); err != nil {
- c.t.Fatalf("SetSockOptBool(tcpip.ReceiveTOSOption, %t) failed: %s", want, err)
- }
+ want := true
+ if err := c.ep.SetSockOptBool(option, want); err != nil {
+ c.t.Fatalf("SetSockOptBool(%s, %t) failed: %s", name, want, err)
+ }
- got, err := c.ep.GetSockOptBool(tcpip.ReceiveTOSOption)
- if err != nil {
- c.t.Fatal("GetSockOptBool(tcpip.ReceiveTOSOption) failed:", err)
- }
- if got != want {
- c.t.Fatalf("got GetSockOptBool(tcpip.ReceiveTOSOption) = %t, want = %t", got, want)
- }
+ got, err := c.ep.GetSockOptBool(option)
+ if err != nil {
+ c.t.Errorf("GetSockoptBool(%s) failed: %s", name, err)
+ }
- // Verify that the correct received TOS is handed through as
- // ancillary data to the ControlMessages struct.
- if err := c.ep.Bind(tcpip.FullAddress{Port: stackPort}); err != nil {
- c.t.Fatal("Bind failed:", err)
- }
- testRead(c, flow, checker.ReceiveTOS(testTOS))
- })
+ if got != want {
+ c.t.Errorf("got GetSockOptBool(%s) = %t, want = %t", name, got, want)
+ }
+
+ // Verify that the correct received TOS or TClass is handed through as
+ // ancillary data to the ControlMessages struct.
+ if err := c.ep.Bind(tcpip.FullAddress{Port: stackPort}); err != nil {
+ c.t.Fatalf("Bind failed: %s", err)
+ }
+ switch option {
+ case tcpip.ReceiveTClassOption:
+ testRead(c, flow, checker.ReceiveTClass(testTOS))
+ case tcpip.ReceiveTOSOption:
+ testRead(c, flow, checker.ReceiveTOS(testTOS))
+ default:
+ t.Fatalf("unknown test variant: %s", name)
+ }
+ })
+ }
}
}
diff --git a/pkg/usermem/BUILD b/pkg/usermem/BUILD
index ff8b9e91a..6c9ada9c7 100644
--- a/pkg/usermem/BUILD
+++ b/pkg/usermem/BUILD
@@ -25,7 +25,6 @@ go_library(
"bytes_io_unsafe.go",
"usermem.go",
"usermem_arm64.go",
- "usermem_unsafe.go",
"usermem_x86.go",
],
visibility = ["//:sandbox"],
@@ -33,6 +32,7 @@ go_library(
"//pkg/atomicbitops",
"//pkg/binary",
"//pkg/context",
+ "//pkg/gohacks",
"//pkg/log",
"//pkg/safemem",
"//pkg/syserror",
diff --git a/pkg/usermem/usermem.go b/pkg/usermem/usermem.go
index 71fd4e155..d2f4403b0 100644
--- a/pkg/usermem/usermem.go
+++ b/pkg/usermem/usermem.go
@@ -23,6 +23,7 @@ import (
"gvisor.dev/gvisor/pkg/binary"
"gvisor.dev/gvisor/pkg/context"
+ "gvisor.dev/gvisor/pkg/gohacks"
"gvisor.dev/gvisor/pkg/safemem"
"gvisor.dev/gvisor/pkg/syserror"
)
@@ -251,7 +252,7 @@ func CopyStringIn(ctx context.Context, uio IO, addr Addr, maxlen int, opts IOOpt
}
end, ok := addr.AddLength(uint64(readlen))
if !ok {
- return stringFromImmutableBytes(buf[:done]), syserror.EFAULT
+ return gohacks.StringFromImmutableBytes(buf[:done]), syserror.EFAULT
}
// Shorten the read to avoid crossing page boundaries, since faulting
// in a page unnecessarily is expensive. This also ensures that partial
@@ -272,16 +273,16 @@ func CopyStringIn(ctx context.Context, uio IO, addr Addr, maxlen int, opts IOOpt
// Look for the terminating zero byte, which may have occurred before
// hitting err.
if i := bytes.IndexByte(buf[done:done+n], byte(0)); i >= 0 {
- return stringFromImmutableBytes(buf[:done+i]), nil
+ return gohacks.StringFromImmutableBytes(buf[:done+i]), nil
}
done += n
if err != nil {
- return stringFromImmutableBytes(buf[:done]), err
+ return gohacks.StringFromImmutableBytes(buf[:done]), err
}
addr = end
}
- return stringFromImmutableBytes(buf), syserror.ENAMETOOLONG
+ return gohacks.StringFromImmutableBytes(buf), syserror.ENAMETOOLONG
}
// CopyOutVec copies bytes from src to the memory mapped at ars in uio. The
diff --git a/runsc/boot/BUILD b/runsc/boot/BUILD
index ae4dd102a..26f68fe3d 100644
--- a/runsc/boot/BUILD
+++ b/runsc/boot/BUILD
@@ -19,7 +19,6 @@ go_library(
"loader_amd64.go",
"loader_arm64.go",
"network.go",
- "pprof.go",
"strace.go",
"user.go",
],
@@ -91,6 +90,7 @@ go_library(
"//pkg/usermem",
"//runsc/boot/filter",
"//runsc/boot/platforms",
+ "//runsc/boot/pprof",
"//runsc/specutils",
"@com_github_golang_protobuf//proto:go_default_library",
"@com_github_opencontainers_runtime-spec//specs-go:go_default_library",
diff --git a/runsc/boot/controller.go b/runsc/boot/controller.go
index 9c9e94864..17e774e0c 100644
--- a/runsc/boot/controller.go
+++ b/runsc/boot/controller.go
@@ -32,6 +32,7 @@ import (
"gvisor.dev/gvisor/pkg/sentry/watchdog"
"gvisor.dev/gvisor/pkg/tcpip/stack"
"gvisor.dev/gvisor/pkg/urpc"
+ "gvisor.dev/gvisor/runsc/boot/pprof"
"gvisor.dev/gvisor/runsc/specutils"
)
@@ -142,7 +143,7 @@ func newController(fd int, l *Loader) (*controller, error) {
}
srv.Register(manager)
- if eps, ok := l.k.NetworkStack().(*netstack.Stack); ok {
+ if eps, ok := l.k.RootNetworkNamespace().Stack().(*netstack.Stack); ok {
net := &Network{
Stack: eps.Stack,
}
@@ -341,7 +342,7 @@ func (cm *containerManager) Restore(o *RestoreOpts, _ *struct{}) error {
return fmt.Errorf("creating memory file: %v", err)
}
k.SetMemoryFile(mf)
- networkStack := cm.l.k.NetworkStack()
+ networkStack := cm.l.k.RootNetworkNamespace().Stack()
cm.l.k = k
// Set up the restore environment.
@@ -365,9 +366,9 @@ func (cm *containerManager) Restore(o *RestoreOpts, _ *struct{}) error {
}
if cm.l.conf.ProfileEnable {
- // initializePProf opens /proc/self/maps, so has to be
- // called before installing seccomp filters.
- initializePProf()
+ // pprof.Initialize opens /proc/self/maps, so has to be called before
+ // installing seccomp filters.
+ pprof.Initialize()
}
// Seccomp filters have to be applied before parsing the state file.
diff --git a/runsc/boot/filter/config.go b/runsc/boot/filter/config.go
index c69f4c602..a4627905e 100644
--- a/runsc/boot/filter/config.go
+++ b/runsc/boot/filter/config.go
@@ -229,7 +229,9 @@ var allowedSyscalls = seccomp.SyscallRules{
syscall.SYS_NANOSLEEP: {},
syscall.SYS_PPOLL: {},
syscall.SYS_PREAD64: {},
+ syscall.SYS_PREADV: {},
syscall.SYS_PWRITE64: {},
+ syscall.SYS_PWRITEV: {},
syscall.SYS_READ: {},
syscall.SYS_RECVMSG: []seccomp.Rule{
{
diff --git a/runsc/boot/loader.go b/runsc/boot/loader.go
index eef43b9df..e7ca98134 100644
--- a/runsc/boot/loader.go
+++ b/runsc/boot/loader.go
@@ -49,6 +49,7 @@ import (
"gvisor.dev/gvisor/pkg/sentry/watchdog"
"gvisor.dev/gvisor/pkg/sync"
"gvisor.dev/gvisor/pkg/tcpip"
+ "gvisor.dev/gvisor/pkg/tcpip/link/loopback"
"gvisor.dev/gvisor/pkg/tcpip/link/sniffer"
"gvisor.dev/gvisor/pkg/tcpip/network/arp"
"gvisor.dev/gvisor/pkg/tcpip/network/ipv4"
@@ -60,6 +61,7 @@ import (
"gvisor.dev/gvisor/pkg/tcpip/transport/udp"
"gvisor.dev/gvisor/runsc/boot/filter"
_ "gvisor.dev/gvisor/runsc/boot/platforms" // register all platforms.
+ "gvisor.dev/gvisor/runsc/boot/pprof"
"gvisor.dev/gvisor/runsc/specutils"
// Include supported socket providers.
@@ -230,11 +232,8 @@ func New(args Args) (*Loader, error) {
return nil, fmt.Errorf("enabling strace: %v", err)
}
- // Create an empty network stack because the network namespace may be empty at
- // this point. Netns is configured before Run() is called. Netstack is
- // configured using a control uRPC message. Host network is configured inside
- // Run().
- networkStack, err := newEmptyNetworkStack(args.Conf, k, k)
+ // Create root network namespace/stack.
+ netns, err := newRootNetworkNamespace(args.Conf, k, k)
if err != nil {
return nil, fmt.Errorf("creating network: %v", err)
}
@@ -277,7 +276,7 @@ func New(args Args) (*Loader, error) {
FeatureSet: cpuid.HostFeatureSet(),
Timekeeper: tk,
RootUserNamespace: creds.UserNamespace,
- NetworkStack: networkStack,
+ RootNetworkNamespace: netns,
ApplicationCores: uint(args.NumCPU),
Vdso: vdso,
RootUTSNamespace: kernel.NewUTSNamespace(args.Spec.Hostname, args.Spec.Hostname, creds.UserNamespace),
@@ -466,7 +465,7 @@ func (l *Loader) run() error {
// Delay host network configuration to this point because network namespace
// is configured after the loader is created and before Run() is called.
log.Debugf("Configuring host network")
- stack := l.k.NetworkStack().(*hostinet.Stack)
+ stack := l.k.RootNetworkNamespace().Stack().(*hostinet.Stack)
if err := stack.Configure(); err != nil {
return err
}
@@ -485,7 +484,7 @@ func (l *Loader) run() error {
// l.restore is set by the container manager when a restore call is made.
if !l.restore {
if l.conf.ProfileEnable {
- initializePProf()
+ pprof.Initialize()
}
// Finally done with all configuration. Setup filters before user code
@@ -908,48 +907,92 @@ func (l *Loader) WaitExit() kernel.ExitStatus {
return l.k.GlobalInit().ExitStatus()
}
-func newEmptyNetworkStack(conf *Config, clock tcpip.Clock, uniqueID stack.UniqueID) (inet.Stack, error) {
+func newRootNetworkNamespace(conf *Config, clock tcpip.Clock, uniqueID stack.UniqueID) (*inet.Namespace, error) {
+ // Create an empty network stack because the network namespace may be empty at
+ // this point. Netns is configured before Run() is called. Netstack is
+ // configured using a control uRPC message. Host network is configured inside
+ // Run().
switch conf.Network {
case NetworkHost:
- return hostinet.NewStack(), nil
+ // No network namespacing support for hostinet yet, hence creator is nil.
+ return inet.NewRootNamespace(hostinet.NewStack(), nil), nil
case NetworkNone, NetworkSandbox:
- // NetworkNone sets up loopback using netstack.
- netProtos := []stack.NetworkProtocol{ipv4.NewProtocol(), ipv6.NewProtocol(), arp.NewProtocol()}
- transProtos := []stack.TransportProtocol{tcp.NewProtocol(), udp.NewProtocol(), icmp.NewProtocol4()}
- s := netstack.Stack{stack.New(stack.Options{
- NetworkProtocols: netProtos,
- TransportProtocols: transProtos,
- Clock: clock,
- Stats: netstack.Metrics,
- HandleLocal: true,
- // Enable raw sockets for users with sufficient
- // privileges.
- RawFactory: raw.EndpointFactory{},
- UniqueID: uniqueID,
- })}
-
- // Enable SACK Recovery.
- if err := s.Stack.SetTransportProtocolOption(tcp.ProtocolNumber, tcp.SACKEnabled(true)); err != nil {
- return nil, fmt.Errorf("failed to enable SACK: %v", err)
+ s, err := newEmptySandboxNetworkStack(clock, uniqueID)
+ if err != nil {
+ return nil, err
}
+ creator := &sandboxNetstackCreator{
+ clock: clock,
+ uniqueID: uniqueID,
+ }
+ return inet.NewRootNamespace(s, creator), nil
- // Set default TTLs as required by socket/netstack.
- s.Stack.SetNetworkProtocolOption(ipv4.ProtocolNumber, tcpip.DefaultTTLOption(netstack.DefaultTTL))
- s.Stack.SetNetworkProtocolOption(ipv6.ProtocolNumber, tcpip.DefaultTTLOption(netstack.DefaultTTL))
+ default:
+ panic(fmt.Sprintf("invalid network configuration: %v", conf.Network))
+ }
- // Enable Receive Buffer Auto-Tuning.
- if err := s.Stack.SetTransportProtocolOption(tcp.ProtocolNumber, tcpip.ModerateReceiveBufferOption(true)); err != nil {
- return nil, fmt.Errorf("SetTransportProtocolOption failed: %v", err)
- }
+}
- s.FillDefaultIPTables()
+func newEmptySandboxNetworkStack(clock tcpip.Clock, uniqueID stack.UniqueID) (inet.Stack, error) {
+ netProtos := []stack.NetworkProtocol{ipv4.NewProtocol(), ipv6.NewProtocol(), arp.NewProtocol()}
+ transProtos := []stack.TransportProtocol{tcp.NewProtocol(), udp.NewProtocol(), icmp.NewProtocol4()}
+ s := netstack.Stack{stack.New(stack.Options{
+ NetworkProtocols: netProtos,
+ TransportProtocols: transProtos,
+ Clock: clock,
+ Stats: netstack.Metrics,
+ HandleLocal: true,
+ // Enable raw sockets for users with sufficient
+ // privileges.
+ RawFactory: raw.EndpointFactory{},
+ UniqueID: uniqueID,
+ })}
- return &s, nil
+ // Enable SACK Recovery.
+ if err := s.Stack.SetTransportProtocolOption(tcp.ProtocolNumber, tcp.SACKEnabled(true)); err != nil {
+ return nil, fmt.Errorf("failed to enable SACK: %v", err)
+ }
- default:
- panic(fmt.Sprintf("invalid network configuration: %v", conf.Network))
+ // Set default TTLs as required by socket/netstack.
+ s.Stack.SetNetworkProtocolOption(ipv4.ProtocolNumber, tcpip.DefaultTTLOption(netstack.DefaultTTL))
+ s.Stack.SetNetworkProtocolOption(ipv6.ProtocolNumber, tcpip.DefaultTTLOption(netstack.DefaultTTL))
+
+ // Enable Receive Buffer Auto-Tuning.
+ if err := s.Stack.SetTransportProtocolOption(tcp.ProtocolNumber, tcpip.ModerateReceiveBufferOption(true)); err != nil {
+ return nil, fmt.Errorf("SetTransportProtocolOption failed: %v", err)
+ }
+
+ s.FillDefaultIPTables()
+
+ return &s, nil
+}
+
+// sandboxNetstackCreator implements kernel.NetworkStackCreator.
+//
+// +stateify savable
+type sandboxNetstackCreator struct {
+ clock tcpip.Clock
+ uniqueID stack.UniqueID
+}
+
+// CreateStack implements kernel.NetworkStackCreator.CreateStack.
+func (f *sandboxNetstackCreator) CreateStack() (inet.Stack, error) {
+ s, err := newEmptySandboxNetworkStack(f.clock, f.uniqueID)
+ if err != nil {
+ return nil, err
}
+
+ // Setup loopback.
+ n := &Network{Stack: s.(*netstack.Stack).Stack}
+ nicID := tcpip.NICID(f.uniqueID.UniqueID())
+ link := DefaultLoopbackLink
+ linkEP := loopback.New()
+ if err := n.createNICWithAddrs(nicID, link.Name, linkEP, link.Addresses); err != nil {
+ return nil, err
+ }
+
+ return s, nil
}
// signal sends a signal to one or more processes in a container. If PID is 0,
diff --git a/runsc/boot/network.go b/runsc/boot/network.go
index 6a8765ec8..bee6ee336 100644
--- a/runsc/boot/network.go
+++ b/runsc/boot/network.go
@@ -17,6 +17,7 @@ package boot
import (
"fmt"
"net"
+ "strings"
"syscall"
"gvisor.dev/gvisor/pkg/log"
@@ -31,6 +32,32 @@ import (
"gvisor.dev/gvisor/pkg/urpc"
)
+var (
+ // DefaultLoopbackLink contains IP addresses and routes of "127.0.0.1/8" and
+ // "::1/8" on "lo" interface.
+ DefaultLoopbackLink = LoopbackLink{
+ Name: "lo",
+ Addresses: []net.IP{
+ net.IP("\x7f\x00\x00\x01"),
+ net.IPv6loopback,
+ },
+ Routes: []Route{
+ {
+ Destination: net.IPNet{
+ IP: net.IPv4(0x7f, 0, 0, 0),
+ Mask: net.IPv4Mask(0xff, 0, 0, 0),
+ },
+ },
+ {
+ Destination: net.IPNet{
+ IP: net.IPv6loopback,
+ Mask: net.IPMask(strings.Repeat("\xff", net.IPv6len)),
+ },
+ },
+ },
+ }
+)
+
// Network exposes methods that can be used to configure a network stack.
type Network struct {
Stack *stack.Stack
diff --git a/runsc/boot/pprof/BUILD b/runsc/boot/pprof/BUILD
new file mode 100644
index 000000000..29cb42b2f
--- /dev/null
+++ b/runsc/boot/pprof/BUILD
@@ -0,0 +1,11 @@
+load("//tools:defs.bzl", "go_library")
+
+package(licenses = ["notice"])
+
+go_library(
+ name = "pprof",
+ srcs = ["pprof.go"],
+ visibility = [
+ "//runsc:__subpackages__",
+ ],
+)
diff --git a/runsc/boot/pprof.go b/runsc/boot/pprof/pprof.go
index 463362f02..1ded20dee 100644
--- a/runsc/boot/pprof.go
+++ b/runsc/boot/pprof/pprof.go
@@ -12,7 +12,9 @@
// See the License for the specific language governing permissions and
// limitations under the License.
-package boot
+// Package pprof provides a stub to initialize custom profilers.
+package pprof
-func initializePProf() {
+// Initialize will be called at boot for initializing custom profilers.
+func Initialize() {
}
diff --git a/runsc/cmd/BUILD b/runsc/cmd/BUILD
index 2a88b85a9..d0bb4613a 100644
--- a/runsc/cmd/BUILD
+++ b/runsc/cmd/BUILD
@@ -31,6 +31,7 @@ go_library(
"spec.go",
"start.go",
"state.go",
+ "statefile.go",
"syscalls.go",
"wait.go",
],
@@ -43,6 +44,8 @@ go_library(
"//pkg/sentry/control",
"//pkg/sentry/kernel",
"//pkg/sentry/kernel/auth",
+ "//pkg/state",
+ "//pkg/state/statefile",
"//pkg/sync",
"//pkg/unet",
"//pkg/urpc",
diff --git a/runsc/cmd/statefile.go b/runsc/cmd/statefile.go
new file mode 100644
index 000000000..e6f1907da
--- /dev/null
+++ b/runsc/cmd/statefile.go
@@ -0,0 +1,143 @@
+// Copyright 2020 The gVisor Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+package cmd
+
+import (
+ "context"
+ "fmt"
+ "os"
+
+ "github.com/google/subcommands"
+ "gvisor.dev/gvisor/pkg/state"
+ "gvisor.dev/gvisor/pkg/state/statefile"
+ "gvisor.dev/gvisor/runsc/flag"
+)
+
+// Statefile implements subcommands.Command for the "statefile" command.
+type Statefile struct {
+ list bool
+ get string
+ key string
+ output string
+ html bool
+}
+
+// Name implements subcommands.Command.
+func (*Statefile) Name() string {
+ return "state"
+}
+
+// Synopsis implements subcommands.Command.
+func (*Statefile) Synopsis() string {
+ return "shows information about a statefile"
+}
+
+// Usage implements subcommands.Command.
+func (*Statefile) Usage() string {
+ return `statefile [flags] <statefile>`
+}
+
+// SetFlags implements subcommands.Command.
+func (s *Statefile) SetFlags(f *flag.FlagSet) {
+ f.BoolVar(&s.list, "list", false, "lists the metdata in the statefile.")
+ f.StringVar(&s.get, "get", "", "extracts the given metadata key.")
+ f.StringVar(&s.key, "key", "", "the integrity key for the file.")
+ f.StringVar(&s.output, "output", "", "target to write the result.")
+ f.BoolVar(&s.html, "html", false, "outputs in HTML format.")
+}
+
+// Execute implements subcommands.Command.Execute.
+func (s *Statefile) Execute(_ context.Context, f *flag.FlagSet, args ...interface{}) subcommands.ExitStatus {
+ // Check arguments.
+ if s.list && s.get != "" {
+ Fatalf("error: can't specify -list and -get simultaneously.")
+ }
+
+ // Setup output.
+ var output = os.Stdout // Default.
+ if s.output != "" {
+ f, err := os.OpenFile(s.output, os.O_WRONLY|os.O_TRUNC|os.O_CREATE, 0644)
+ if err != nil {
+ Fatalf("error opening output: %v", err)
+ }
+ defer func() {
+ if err := f.Close(); err != nil {
+ Fatalf("error flushing output: %v", err)
+ }
+ }()
+ output = f
+ }
+
+ // Open the file.
+ if f.NArg() != 1 {
+ f.Usage()
+ return subcommands.ExitUsageError
+ }
+ input, err := os.Open(f.Arg(0))
+ if err != nil {
+ Fatalf("error opening input: %v\n", err)
+ }
+
+ if s.html {
+ fmt.Fprintf(output, "<html><body>\n")
+ defer fmt.Fprintf(output, "</body></html>\n")
+ }
+
+ // Dump the full file?
+ if !s.list && s.get == "" {
+ var key []byte
+ if s.key != "" {
+ key = []byte(s.key)
+ }
+ rc, _, err := statefile.NewReader(input, key)
+ if err != nil {
+ Fatalf("error parsing statefile: %v", err)
+ }
+ if err := state.PrettyPrint(output, rc, s.html); err != nil {
+ Fatalf("error printing state: %v", err)
+ }
+ return subcommands.ExitSuccess
+ }
+
+ // Load just the metadata.
+ metadata, err := statefile.MetadataUnsafe(input)
+ if err != nil {
+ Fatalf("error reading metadata: %v", err)
+ }
+
+ // Is it a single key?
+ if s.get != "" {
+ val, ok := metadata[s.get]
+ if !ok {
+ Fatalf("metadata key %s: not found", s.get)
+ }
+ fmt.Fprintf(output, "%s\n", val)
+ return subcommands.ExitSuccess
+ }
+
+ // List all keys.
+ if s.html {
+ fmt.Fprintf(output, " <ul>\n")
+ defer fmt.Fprintf(output, " </ul>\n")
+ }
+ for key := range metadata {
+ if s.html {
+ fmt.Fprintf(output, " <li>%s</li>\n", key)
+ } else {
+ fmt.Fprintf(output, "%s\n", key)
+ }
+ }
+ return subcommands.ExitSuccess
+}
diff --git a/runsc/container/container_test.go b/runsc/container/container_test.go
index 04a7dc237..bdd65b498 100644
--- a/runsc/container/container_test.go
+++ b/runsc/container/container_test.go
@@ -71,6 +71,7 @@ func waitForProcessCount(cont *Container, want int) error {
return &backoff.PermanentError{Err: err}
}
if got := len(pss); got != want {
+ log.Infof("Waiting for process count to reach %d. Current: %d", want, got)
return fmt.Errorf("wrong process count, got: %d, want: %d", got, want)
}
return nil
diff --git a/runsc/main.go b/runsc/main.go
index 762b0f801..af73bed97 100644
--- a/runsc/main.go
+++ b/runsc/main.go
@@ -116,8 +116,8 @@ func main() {
subcommands.Register(new(cmd.Resume), "")
subcommands.Register(new(cmd.Run), "")
subcommands.Register(new(cmd.Spec), "")
- subcommands.Register(new(cmd.Start), "")
subcommands.Register(new(cmd.State), "")
+ subcommands.Register(new(cmd.Start), "")
subcommands.Register(new(cmd.Wait), "")
// Register internal commands with the internal group name. This causes
@@ -127,6 +127,7 @@ func main() {
subcommands.Register(new(cmd.Boot), internalGroup)
subcommands.Register(new(cmd.Debug), internalGroup)
subcommands.Register(new(cmd.Gofer), internalGroup)
+ subcommands.Register(new(cmd.Statefile), internalGroup)
// All subcommands must be registered before flag parsing.
flag.Parse()
diff --git a/runsc/sandbox/network.go b/runsc/sandbox/network.go
index 99e143696..bc093fba5 100644
--- a/runsc/sandbox/network.go
+++ b/runsc/sandbox/network.go
@@ -21,7 +21,6 @@ import (
"path/filepath"
"runtime"
"strconv"
- "strings"
"syscall"
specs "github.com/opencontainers/runtime-spec/specs-go"
@@ -75,30 +74,8 @@ func setupNetwork(conn *urpc.Client, pid int, spec *specs.Spec, conf *boot.Confi
}
func createDefaultLoopbackInterface(conn *urpc.Client) error {
- link := boot.LoopbackLink{
- Name: "lo",
- Addresses: []net.IP{
- net.IP("\x7f\x00\x00\x01"),
- net.IPv6loopback,
- },
- Routes: []boot.Route{
- {
- Destination: net.IPNet{
-
- IP: net.IPv4(0x7f, 0, 0, 0),
- Mask: net.IPv4Mask(0xff, 0, 0, 0),
- },
- },
- {
- Destination: net.IPNet{
- IP: net.IPv6loopback,
- Mask: net.IPMask(strings.Repeat("\xff", net.IPv6len)),
- },
- },
- },
- }
if err := conn.Call(boot.NetworkCreateLinksAndRoutes, &boot.CreateLinksAndRoutesArgs{
- LoopbackLinks: []boot.LoopbackLink{link},
+ LoopbackLinks: []boot.LoopbackLink{boot.DefaultLoopbackLink},
}, nil); err != nil {
return fmt.Errorf("creating loopback link and routes: %v", err)
}
diff --git a/scripts/benchmark.sh b/scripts/benchmark.sh
new file mode 100644
index 000000000..a0317db02
--- /dev/null
+++ b/scripts/benchmark.sh
@@ -0,0 +1,25 @@
+#!/bin/bash
+
+# Copyright 2020 The gVisor Authors.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+# Run in the root of the repo.
+cd "$(dirname "$0")"
+
+KEY_PATH=${KEY_PATH:-"${KOKORO_KEYSTORE_DIR}/${KOKORO_SERVICE_ACCOUNT}"}
+
+gcloud auth activate-service-account --key-file "${KEY_PATH}"
+
+gcloud compute instances list
+
diff --git a/scripts/common_build.sh b/scripts/common_build.sh
index ae8b67383..3be0bb21c 100755
--- a/scripts/common_build.sh
+++ b/scripts/common_build.sh
@@ -70,7 +70,9 @@ function collect_logs() {
for d in `find -L "bazel-testlogs" -name 'shard_*_of_*' | xargs dirname | sort | uniq`; do
junitparser merge `find $d -name test.xml` $d/test.xml
cat $d/shard_*_of_*/test.log > $d/test.log
- ls -l $d/shard_*_of_*/test.outputs/outputs.zip && zip -r -1 $d/outputs.zip $d/shard_*_of_*/test.outputs/outputs.zip
+ if ls -l $d/shard_*_of_*/test.outputs/outputs.zip 2>/dev/null; then
+ zip -r -1 "$d/outputs.zip" $d/shard_*_of_*/test.outputs/outputs.zip
+ fi
done
find -L "bazel-testlogs" -name 'shard_*_of_*' | xargs rm -rf
# Move test logs to Kokoro directory. tar is used to conveniently perform
@@ -90,7 +92,13 @@ function collect_logs() {
echo " gsutil cp gs://gvisor/logs/${KOKORO_BUILD_ARTIFACTS_SUBDIR}/${archive} /tmp"
echo " https://storage.cloud.google.com/gvisor/logs/${KOKORO_BUILD_ARTIFACTS_SUBDIR}/${archive}"
fi
- tar --create --gzip --file="${KOKORO_ARTIFACTS_DIR}/${archive}" -C "${RUNSC_LOGS_DIR}" .
+ time tar \
+ --verbose \
+ --create \
+ --gzip \
+ --file="${KOKORO_ARTIFACTS_DIR}/${archive}" \
+ --directory "${RUNSC_LOGS_DIR}" \
+ .
fi
fi
fi
diff --git a/test/iptables/filter_input.go b/test/iptables/filter_input.go
index e26d6a7d2..b2fb6401a 100644
--- a/test/iptables/filter_input.go
+++ b/test/iptables/filter_input.go
@@ -26,6 +26,7 @@ const (
acceptPort = 2402
sendloopDuration = 2 * time.Second
network = "udp4"
+ chainName = "foochain"
)
func init() {
@@ -40,6 +41,12 @@ func init() {
RegisterTestCase(FilterInputDefaultPolicyAccept{})
RegisterTestCase(FilterInputDefaultPolicyDrop{})
RegisterTestCase(FilterInputReturnUnderflow{})
+ RegisterTestCase(FilterInputSerializeJump{})
+ RegisterTestCase(FilterInputJumpBasic{})
+ RegisterTestCase(FilterInputJumpReturn{})
+ RegisterTestCase(FilterInputJumpReturnDrop{})
+ RegisterTestCase(FilterInputJumpBuiltin{})
+ RegisterTestCase(FilterInputJumpTwice{})
}
// FilterInputDropUDP tests that we can drop UDP traffic.
@@ -267,13 +274,12 @@ func (FilterInputMultiUDPRules) Name() string {
// ContainerAction implements TestCase.ContainerAction.
func (FilterInputMultiUDPRules) ContainerAction(ip net.IP) error {
- if err := filterTable("-A", "INPUT", "-p", "udp", "-m", "udp", "--destination-port", fmt.Sprintf("%d", dropPort), "-j", "DROP"); err != nil {
- return err
- }
- if err := filterTable("-A", "INPUT", "-p", "udp", "-m", "udp", "--destination-port", fmt.Sprintf("%d", acceptPort), "-j", "ACCEPT"); err != nil {
- return err
+ rules := [][]string{
+ {"-A", "INPUT", "-p", "udp", "-m", "udp", "--destination-port", fmt.Sprintf("%d", dropPort), "-j", "DROP"},
+ {"-A", "INPUT", "-p", "udp", "-m", "udp", "--destination-port", fmt.Sprintf("%d", acceptPort), "-j", "ACCEPT"},
+ {"-L"},
}
- return filterTable("-L")
+ return filterTableRules(rules)
}
// LocalAction implements TestCase.LocalAction.
@@ -314,14 +320,13 @@ func (FilterInputCreateUserChain) Name() string {
// ContainerAction implements TestCase.ContainerAction.
func (FilterInputCreateUserChain) ContainerAction(ip net.IP) error {
- // Create a chain.
- const chainName = "foochain"
- if err := filterTable("-N", chainName); err != nil {
- return err
+ rules := [][]string{
+ // Create a chain.
+ {"-N", chainName},
+ // Add a simple rule to the chain.
+ {"-A", chainName, "-j", "DROP"},
}
-
- // Add a simple rule to the chain.
- return filterTable("-A", chainName, "-j", "DROP")
+ return filterTableRules(rules)
}
// LocalAction implements TestCase.LocalAction.
@@ -396,13 +401,12 @@ func (FilterInputReturnUnderflow) Name() string {
func (FilterInputReturnUnderflow) ContainerAction(ip net.IP) error {
// Add a RETURN rule followed by an unconditional accept, and set the
// default policy to DROP.
- if err := filterTable("-A", "INPUT", "-j", "RETURN"); err != nil {
- return err
+ rules := [][]string{
+ {"-A", "INPUT", "-j", "RETURN"},
+ {"-A", "INPUT", "-j", "DROP"},
+ {"-P", "INPUT", "ACCEPT"},
}
- if err := filterTable("-A", "INPUT", "-j", "DROP"); err != nil {
- return err
- }
- if err := filterTable("-P", "INPUT", "ACCEPT"); err != nil {
+ if err := filterTableRules(rules); err != nil {
return err
}
@@ -415,3 +419,179 @@ func (FilterInputReturnUnderflow) ContainerAction(ip net.IP) error {
func (FilterInputReturnUnderflow) LocalAction(ip net.IP) error {
return sendUDPLoop(ip, acceptPort, sendloopDuration)
}
+
+// FilterInputSerializeJump verifies that we can serialize jumps.
+type FilterInputSerializeJump struct{}
+
+// Name implements TestCase.Name.
+func (FilterInputSerializeJump) Name() string {
+ return "FilterInputSerializeJump"
+}
+
+// ContainerAction implements TestCase.ContainerAction.
+func (FilterInputSerializeJump) ContainerAction(ip net.IP) error {
+ // Write a JUMP rule, the serialize it with `-L`.
+ rules := [][]string{
+ {"-N", chainName},
+ {"-A", "INPUT", "-j", chainName},
+ {"-L"},
+ }
+ return filterTableRules(rules)
+}
+
+// LocalAction implements TestCase.LocalAction.
+func (FilterInputSerializeJump) LocalAction(ip net.IP) error {
+ // No-op.
+ return nil
+}
+
+// FilterInputJumpBasic jumps to a chain and executes a rule there.
+type FilterInputJumpBasic struct{}
+
+// Name implements TestCase.Name.
+func (FilterInputJumpBasic) Name() string {
+ return "FilterInputJumpBasic"
+}
+
+// ContainerAction implements TestCase.ContainerAction.
+func (FilterInputJumpBasic) ContainerAction(ip net.IP) error {
+ rules := [][]string{
+ {"-P", "INPUT", "DROP"},
+ {"-N", chainName},
+ {"-A", "INPUT", "-j", chainName},
+ {"-A", chainName, "-j", "ACCEPT"},
+ }
+ if err := filterTableRules(rules); err != nil {
+ return err
+ }
+
+ // Listen for UDP packets on acceptPort.
+ return listenUDP(acceptPort, sendloopDuration)
+}
+
+// LocalAction implements TestCase.LocalAction.
+func (FilterInputJumpBasic) LocalAction(ip net.IP) error {
+ return sendUDPLoop(ip, acceptPort, sendloopDuration)
+}
+
+// FilterInputJumpReturn jumps, returns, and executes a rule.
+type FilterInputJumpReturn struct{}
+
+// Name implements TestCase.Name.
+func (FilterInputJumpReturn) Name() string {
+ return "FilterInputJumpReturn"
+}
+
+// ContainerAction implements TestCase.ContainerAction.
+func (FilterInputJumpReturn) ContainerAction(ip net.IP) error {
+ rules := [][]string{
+ {"-N", chainName},
+ {"-P", "INPUT", "ACCEPT"},
+ {"-A", "INPUT", "-j", chainName},
+ {"-A", chainName, "-j", "RETURN"},
+ {"-A", chainName, "-j", "DROP"},
+ }
+ if err := filterTableRules(rules); err != nil {
+ return err
+ }
+
+ // Listen for UDP packets on acceptPort.
+ return listenUDP(acceptPort, sendloopDuration)
+}
+
+// LocalAction implements TestCase.LocalAction.
+func (FilterInputJumpReturn) LocalAction(ip net.IP) error {
+ return sendUDPLoop(ip, acceptPort, sendloopDuration)
+}
+
+// FilterInputJumpReturnDrop jumps to a chain, returns, and DROPs packets.
+type FilterInputJumpReturnDrop struct{}
+
+// Name implements TestCase.Name.
+func (FilterInputJumpReturnDrop) Name() string {
+ return "FilterInputJumpReturnDrop"
+}
+
+// ContainerAction implements TestCase.ContainerAction.
+func (FilterInputJumpReturnDrop) ContainerAction(ip net.IP) error {
+ rules := [][]string{
+ {"-N", chainName},
+ {"-A", "INPUT", "-j", chainName},
+ {"-A", "INPUT", "-j", "DROP"},
+ {"-A", chainName, "-j", "RETURN"},
+ }
+ if err := filterTableRules(rules); err != nil {
+ return err
+ }
+
+ // Listen for UDP packets on dropPort.
+ if err := listenUDP(dropPort, sendloopDuration); err == nil {
+ return fmt.Errorf("packets on port %d should have been dropped, but got a packet", dropPort)
+ } else if netErr, ok := err.(net.Error); !ok || !netErr.Timeout() {
+ return fmt.Errorf("error reading: %v", err)
+ }
+
+ // At this point we know that reading timed out and never received a
+ // packet.
+ return nil
+}
+
+// LocalAction implements TestCase.LocalAction.
+func (FilterInputJumpReturnDrop) LocalAction(ip net.IP) error {
+ return sendUDPLoop(ip, dropPort, sendloopDuration)
+}
+
+// FilterInputJumpBuiltin verifies that jumping to a top-levl chain is illegal.
+type FilterInputJumpBuiltin struct{}
+
+// Name implements TestCase.Name.
+func (FilterInputJumpBuiltin) Name() string {
+ return "FilterInputJumpBuiltin"
+}
+
+// ContainerAction implements TestCase.ContainerAction.
+func (FilterInputJumpBuiltin) ContainerAction(ip net.IP) error {
+ if err := filterTable("-A", "INPUT", "-j", "OUTPUT"); err == nil {
+ return fmt.Errorf("iptables should be unable to jump to a built-in chain")
+ }
+ return nil
+}
+
+// LocalAction implements TestCase.LocalAction.
+func (FilterInputJumpBuiltin) LocalAction(ip net.IP) error {
+ // No-op.
+ return nil
+}
+
+// FilterInputJumpTwice jumps twice, then returns twice and executes a rule.
+type FilterInputJumpTwice struct{}
+
+// Name implements TestCase.Name.
+func (FilterInputJumpTwice) Name() string {
+ return "FilterInputJumpTwice"
+}
+
+// ContainerAction implements TestCase.ContainerAction.
+func (FilterInputJumpTwice) ContainerAction(ip net.IP) error {
+ const chainName2 = chainName + "2"
+ rules := [][]string{
+ {"-P", "INPUT", "DROP"},
+ {"-N", chainName},
+ {"-N", chainName2},
+ {"-A", "INPUT", "-j", chainName},
+ {"-A", chainName, "-j", chainName2},
+ {"-A", "INPUT", "-j", "ACCEPT"},
+ }
+ if err := filterTableRules(rules); err != nil {
+ return err
+ }
+
+ // UDP packets should jump and return twice, eventually hitting the
+ // ACCEPT rule.
+ return listenUDP(acceptPort, sendloopDuration)
+}
+
+// LocalAction implements TestCase.LocalAction.
+func (FilterInputJumpTwice) LocalAction(ip net.IP) error {
+ return sendUDPLoop(ip, acceptPort, sendloopDuration)
+}
diff --git a/test/iptables/iptables_test.go b/test/iptables/iptables_test.go
index 7d061acba..29ad5932d 100644
--- a/test/iptables/iptables_test.go
+++ b/test/iptables/iptables_test.go
@@ -261,3 +261,39 @@ func TestFilterOutputDropTCPSrcPort(t *testing.T) {
t.Fatal(err)
}
}
+
+func TestJumpSerialize(t *testing.T) {
+ if err := singleTest(FilterInputSerializeJump{}); err != nil {
+ t.Fatal(err)
+ }
+}
+
+func TestJumpBasic(t *testing.T) {
+ if err := singleTest(FilterInputJumpBasic{}); err != nil {
+ t.Fatal(err)
+ }
+}
+
+func TestJumpReturn(t *testing.T) {
+ if err := singleTest(FilterInputJumpReturn{}); err != nil {
+ t.Fatal(err)
+ }
+}
+
+func TestJumpReturnDrop(t *testing.T) {
+ if err := singleTest(FilterInputJumpReturnDrop{}); err != nil {
+ t.Fatal(err)
+ }
+}
+
+func TestJumpBuiltin(t *testing.T) {
+ if err := singleTest(FilterInputJumpBuiltin{}); err != nil {
+ t.Fatal(err)
+ }
+}
+
+func TestJumpTwice(t *testing.T) {
+ if err := singleTest(FilterInputJumpTwice{}); err != nil {
+ t.Fatal(err)
+ }
+}
diff --git a/test/iptables/iptables_util.go b/test/iptables/iptables_util.go
index 5c9199abf..32cf5a417 100644
--- a/test/iptables/iptables_util.go
+++ b/test/iptables/iptables_util.go
@@ -27,17 +27,16 @@ const iptablesBinary = "iptables"
// filterTable calls `iptables -t filter` with the given args.
func filterTable(args ...string) error {
- args = append([]string{"-t", "filter"}, args...)
- cmd := exec.Command(iptablesBinary, args...)
- if out, err := cmd.CombinedOutput(); err != nil {
- return fmt.Errorf("error running iptables with args %v\nerror: %v\noutput: %s", args, err, string(out))
- }
- return nil
+ return tableCmd("filter", args)
}
// natTable calls `iptables -t nat` with the given args.
func natTable(args ...string) error {
- args = append([]string{"-t", "nat"}, args...)
+ return tableCmd("nat", args)
+}
+
+func tableCmd(table string, args []string) error {
+ args = append([]string{"-t", table}, args...)
cmd := exec.Command(iptablesBinary, args...)
if out, err := cmd.CombinedOutput(); err != nil {
return fmt.Errorf("error running iptables with args %v\nerror: %v\noutput: %s", args, err, string(out))
@@ -45,6 +44,16 @@ func natTable(args ...string) error {
return nil
}
+// filterTableRules is like filterTable, but runs multiple iptables commands.
+func filterTableRules(argsList [][]string) error {
+ for _, args := range argsList {
+ if err := filterTable(args...); err != nil {
+ return err
+ }
+ }
+ return nil
+}
+
// listenUDP listens on a UDP port and returns the value of net.Conn.Read() for
// the first read on that port.
func listenUDP(port int, timeout time.Duration) error {
diff --git a/test/perf/BUILD b/test/perf/BUILD
new file mode 100644
index 000000000..346a28e16
--- /dev/null
+++ b/test/perf/BUILD
@@ -0,0 +1,114 @@
+load("//test/runner:defs.bzl", "syscall_test")
+
+package(licenses = ["notice"])
+
+syscall_test(
+ test = "//test/perf/linux:clock_getres_benchmark",
+)
+
+syscall_test(
+ test = "//test/perf/linux:clock_gettime_benchmark",
+)
+
+syscall_test(
+ test = "//test/perf/linux:death_benchmark",
+)
+
+syscall_test(
+ test = "//test/perf/linux:epoll_benchmark",
+)
+
+syscall_test(
+ size = "large",
+ test = "//test/perf/linux:fork_benchmark",
+)
+
+syscall_test(
+ size = "large",
+ test = "//test/perf/linux:futex_benchmark",
+)
+
+syscall_test(
+ size = "enormous",
+ test = "//test/perf/linux:getdents_benchmark",
+)
+
+syscall_test(
+ size = "large",
+ test = "//test/perf/linux:getpid_benchmark",
+)
+
+syscall_test(
+ size = "enormous",
+ test = "//test/perf/linux:gettid_benchmark",
+)
+
+syscall_test(
+ size = "large",
+ test = "//test/perf/linux:mapping_benchmark",
+)
+
+syscall_test(
+ size = "large",
+ add_overlay = True,
+ test = "//test/perf/linux:open_benchmark",
+)
+
+syscall_test(
+ test = "//test/perf/linux:pipe_benchmark",
+)
+
+syscall_test(
+ size = "large",
+ add_overlay = True,
+ test = "//test/perf/linux:randread_benchmark",
+)
+
+syscall_test(
+ size = "large",
+ add_overlay = True,
+ test = "//test/perf/linux:read_benchmark",
+)
+
+syscall_test(
+ size = "large",
+ test = "//test/perf/linux:sched_yield_benchmark",
+)
+
+syscall_test(
+ size = "large",
+ test = "//test/perf/linux:send_recv_benchmark",
+)
+
+syscall_test(
+ size = "large",
+ add_overlay = True,
+ test = "//test/perf/linux:seqwrite_benchmark",
+)
+
+syscall_test(
+ size = "enormous",
+ test = "//test/perf/linux:signal_benchmark",
+)
+
+syscall_test(
+ test = "//test/perf/linux:sleep_benchmark",
+)
+
+syscall_test(
+ size = "large",
+ add_overlay = True,
+ test = "//test/perf/linux:stat_benchmark",
+)
+
+syscall_test(
+ size = "enormous",
+ add_overlay = True,
+ test = "//test/perf/linux:unlink_benchmark",
+)
+
+syscall_test(
+ size = "large",
+ add_overlay = True,
+ test = "//test/perf/linux:write_benchmark",
+)
diff --git a/test/perf/linux/BUILD b/test/perf/linux/BUILD
new file mode 100644
index 000000000..b4e907826
--- /dev/null
+++ b/test/perf/linux/BUILD
@@ -0,0 +1,356 @@
+load("//tools:defs.bzl", "cc_binary", "gbenchmark", "gtest")
+
+package(
+ default_visibility = ["//:sandbox"],
+ licenses = ["notice"],
+)
+
+cc_binary(
+ name = "getpid_benchmark",
+ testonly = 1,
+ srcs = [
+ "getpid_benchmark.cc",
+ ],
+ deps = [
+ gbenchmark,
+ gtest,
+ "//test/util:test_main",
+ ],
+)
+
+cc_binary(
+ name = "send_recv_benchmark",
+ testonly = 1,
+ srcs = [
+ "send_recv_benchmark.cc",
+ ],
+ deps = [
+ gbenchmark,
+ gtest,
+ "//test/syscalls/linux:socket_test_util",
+ "//test/util:file_descriptor",
+ "//test/util:logging",
+ "//test/util:posix_error",
+ "//test/util:test_main",
+ "//test/util:test_util",
+ "//test/util:thread_util",
+ "@com_google_absl//absl/synchronization",
+ ],
+)
+
+cc_binary(
+ name = "gettid_benchmark",
+ testonly = 1,
+ srcs = [
+ "gettid_benchmark.cc",
+ ],
+ deps = [
+ gbenchmark,
+ gtest,
+ "//test/util:test_main",
+ ],
+)
+
+cc_binary(
+ name = "sched_yield_benchmark",
+ testonly = 1,
+ srcs = [
+ "sched_yield_benchmark.cc",
+ ],
+ deps = [
+ gbenchmark,
+ gtest,
+ "//test/util:test_main",
+ "//test/util:test_util",
+ ],
+)
+
+cc_binary(
+ name = "clock_getres_benchmark",
+ testonly = 1,
+ srcs = [
+ "clock_getres_benchmark.cc",
+ ],
+ deps = [
+ gbenchmark,
+ gtest,
+ "//test/util:test_main",
+ ],
+)
+
+cc_binary(
+ name = "clock_gettime_benchmark",
+ testonly = 1,
+ srcs = [
+ "clock_gettime_benchmark.cc",
+ ],
+ deps = [
+ gbenchmark,
+ gtest,
+ "//test/util:test_main",
+ "@com_google_absl//absl/time",
+ ],
+)
+
+cc_binary(
+ name = "open_benchmark",
+ testonly = 1,
+ srcs = [
+ "open_benchmark.cc",
+ ],
+ deps = [
+ gbenchmark,
+ gtest,
+ "//test/util:fs_util",
+ "//test/util:logging",
+ "//test/util:temp_path",
+ "//test/util:test_main",
+ ],
+)
+
+cc_binary(
+ name = "read_benchmark",
+ testonly = 1,
+ srcs = [
+ "read_benchmark.cc",
+ ],
+ deps = [
+ gbenchmark,
+ gtest,
+ "//test/util:fs_util",
+ "//test/util:logging",
+ "//test/util:temp_path",
+ "//test/util:test_main",
+ "//test/util:test_util",
+ ],
+)
+
+cc_binary(
+ name = "randread_benchmark",
+ testonly = 1,
+ srcs = [
+ "randread_benchmark.cc",
+ ],
+ deps = [
+ gbenchmark,
+ gtest,
+ "//test/util:file_descriptor",
+ "//test/util:logging",
+ "//test/util:temp_path",
+ "//test/util:test_main",
+ "//test/util:test_util",
+ "@com_google_absl//absl/random",
+ ],
+)
+
+cc_binary(
+ name = "write_benchmark",
+ testonly = 1,
+ srcs = [
+ "write_benchmark.cc",
+ ],
+ deps = [
+ gbenchmark,
+ gtest,
+ "//test/util:logging",
+ "//test/util:temp_path",
+ "//test/util:test_main",
+ "//test/util:test_util",
+ ],
+)
+
+cc_binary(
+ name = "seqwrite_benchmark",
+ testonly = 1,
+ srcs = [
+ "seqwrite_benchmark.cc",
+ ],
+ deps = [
+ gbenchmark,
+ gtest,
+ "//test/util:logging",
+ "//test/util:temp_path",
+ "//test/util:test_main",
+ "//test/util:test_util",
+ "@com_google_absl//absl/random",
+ ],
+)
+
+cc_binary(
+ name = "pipe_benchmark",
+ testonly = 1,
+ srcs = [
+ "pipe_benchmark.cc",
+ ],
+ deps = [
+ gbenchmark,
+ gtest,
+ "//test/util:logging",
+ "//test/util:test_main",
+ "//test/util:test_util",
+ "//test/util:thread_util",
+ ],
+)
+
+cc_binary(
+ name = "fork_benchmark",
+ testonly = 1,
+ srcs = [
+ "fork_benchmark.cc",
+ ],
+ deps = [
+ gbenchmark,
+ gtest,
+ "//test/util:cleanup",
+ "//test/util:file_descriptor",
+ "//test/util:logging",
+ "//test/util:test_main",
+ "//test/util:test_util",
+ "//test/util:thread_util",
+ "@com_google_absl//absl/synchronization",
+ ],
+)
+
+cc_binary(
+ name = "futex_benchmark",
+ testonly = 1,
+ srcs = [
+ "futex_benchmark.cc",
+ ],
+ deps = [
+ gbenchmark,
+ gtest,
+ "//test/util:logging",
+ "//test/util:test_main",
+ "//test/util:thread_util",
+ "@com_google_absl//absl/time",
+ ],
+)
+
+cc_binary(
+ name = "epoll_benchmark",
+ testonly = 1,
+ srcs = [
+ "epoll_benchmark.cc",
+ ],
+ deps = [
+ gbenchmark,
+ gtest,
+ "//test/util:epoll_util",
+ "//test/util:file_descriptor",
+ "//test/util:test_main",
+ "//test/util:test_util",
+ "//test/util:thread_util",
+ "@com_google_absl//absl/time",
+ ],
+)
+
+cc_binary(
+ name = "death_benchmark",
+ testonly = 1,
+ srcs = [
+ "death_benchmark.cc",
+ ],
+ deps = [
+ gbenchmark,
+ gtest,
+ "//test/util:logging",
+ "//test/util:test_main",
+ ],
+)
+
+cc_binary(
+ name = "mapping_benchmark",
+ testonly = 1,
+ srcs = [
+ "mapping_benchmark.cc",
+ ],
+ deps = [
+ gbenchmark,
+ gtest,
+ "//test/util:logging",
+ "//test/util:memory_util",
+ "//test/util:posix_error",
+ "//test/util:test_main",
+ "//test/util:test_util",
+ ],
+)
+
+cc_binary(
+ name = "signal_benchmark",
+ testonly = 1,
+ srcs = [
+ "signal_benchmark.cc",
+ ],
+ deps = [
+ gbenchmark,
+ gtest,
+ "//test/util:logging",
+ "//test/util:test_main",
+ "//test/util:test_util",
+ ],
+)
+
+cc_binary(
+ name = "getdents_benchmark",
+ testonly = 1,
+ srcs = [
+ "getdents_benchmark.cc",
+ ],
+ deps = [
+ gbenchmark,
+ gtest,
+ "//test/util:file_descriptor",
+ "//test/util:fs_util",
+ "//test/util:temp_path",
+ "//test/util:test_main",
+ "//test/util:test_util",
+ ],
+)
+
+cc_binary(
+ name = "sleep_benchmark",
+ testonly = 1,
+ srcs = [
+ "sleep_benchmark.cc",
+ ],
+ deps = [
+ gbenchmark,
+ gtest,
+ "//test/util:logging",
+ "//test/util:test_main",
+ ],
+)
+
+cc_binary(
+ name = "stat_benchmark",
+ testonly = 1,
+ srcs = [
+ "stat_benchmark.cc",
+ ],
+ deps = [
+ gbenchmark,
+ gtest,
+ "//test/util:fs_util",
+ "//test/util:temp_path",
+ "//test/util:test_main",
+ "//test/util:test_util",
+ "@com_google_absl//absl/strings",
+ ],
+)
+
+cc_binary(
+ name = "unlink_benchmark",
+ testonly = 1,
+ srcs = [
+ "unlink_benchmark.cc",
+ ],
+ deps = [
+ gbenchmark,
+ gtest,
+ "//test/util:fs_util",
+ "//test/util:temp_path",
+ "//test/util:test_main",
+ "//test/util:test_util",
+ ],
+)
diff --git a/test/perf/linux/clock_getres_benchmark.cc b/test/perf/linux/clock_getres_benchmark.cc
new file mode 100644
index 000000000..b051293ad
--- /dev/null
+++ b/test/perf/linux/clock_getres_benchmark.cc
@@ -0,0 +1,39 @@
+// Copyright 2020 The gVisor Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+#include <time.h>
+
+#include "gtest/gtest.h"
+#include "benchmark/benchmark.h"
+
+namespace gvisor {
+namespace testing {
+
+namespace {
+
+// clock_getres(1) is very nearly a no-op syscall, but it does require copying
+// out to a userspace struct. It thus provides a nice small copy-out benchmark.
+void BM_ClockGetRes(benchmark::State& state) {
+ struct timespec ts;
+ for (auto _ : state) {
+ clock_getres(CLOCK_MONOTONIC, &ts);
+ }
+}
+
+BENCHMARK(BM_ClockGetRes);
+
+} // namespace
+
+} // namespace testing
+} // namespace gvisor
diff --git a/test/perf/linux/clock_gettime_benchmark.cc b/test/perf/linux/clock_gettime_benchmark.cc
new file mode 100644
index 000000000..6691bebd9
--- /dev/null
+++ b/test/perf/linux/clock_gettime_benchmark.cc
@@ -0,0 +1,60 @@
+// Copyright 2020 The gVisor Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+#include <pthread.h>
+#include <time.h>
+
+#include "gtest/gtest.h"
+#include "absl/time/clock.h"
+#include "absl/time/time.h"
+#include "benchmark/benchmark.h"
+
+namespace gvisor {
+namespace testing {
+
+namespace {
+
+void BM_ClockGettimeThreadCPUTime(benchmark::State& state) {
+ clockid_t clockid;
+ ASSERT_EQ(0, pthread_getcpuclockid(pthread_self(), &clockid));
+ struct timespec tp;
+
+ for (auto _ : state) {
+ clock_gettime(clockid, &tp);
+ }
+}
+
+BENCHMARK(BM_ClockGettimeThreadCPUTime);
+
+void BM_VDSOClockGettime(benchmark::State& state) {
+ const clockid_t clock = state.range(0);
+ struct timespec tp;
+ absl::Time start = absl::Now();
+
+ // Don't benchmark the calibration phase.
+ while (absl::Now() < start + absl::Milliseconds(2100)) {
+ clock_gettime(clock, &tp);
+ }
+
+ for (auto _ : state) {
+ clock_gettime(clock, &tp);
+ }
+}
+
+BENCHMARK(BM_VDSOClockGettime)->Arg(CLOCK_MONOTONIC)->Arg(CLOCK_REALTIME);
+
+} // namespace
+
+} // namespace testing
+} // namespace gvisor
diff --git a/test/perf/linux/death_benchmark.cc b/test/perf/linux/death_benchmark.cc
new file mode 100644
index 000000000..cb2b6fd07
--- /dev/null
+++ b/test/perf/linux/death_benchmark.cc
@@ -0,0 +1,36 @@
+// Copyright 2020 The gVisor Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+#include <signal.h>
+
+#include "gtest/gtest.h"
+#include "benchmark/benchmark.h"
+#include "test/util/logging.h"
+
+namespace gvisor {
+namespace testing {
+
+namespace {
+
+// DeathTest is not so much a microbenchmark as a macrobenchmark. It is testing
+// the ability of gVisor (on whatever platform) to execute all the related
+// stack-dumping routines associated with EXPECT_EXIT / EXPECT_DEATH.
+TEST(DeathTest, ZeroEqualsOne) {
+ EXPECT_EXIT({ TEST_CHECK(0 == 1); }, ::testing::KilledBySignal(SIGABRT), "");
+}
+
+} // namespace
+
+} // namespace testing
+} // namespace gvisor
diff --git a/test/perf/linux/epoll_benchmark.cc b/test/perf/linux/epoll_benchmark.cc
new file mode 100644
index 000000000..0b121338a
--- /dev/null
+++ b/test/perf/linux/epoll_benchmark.cc
@@ -0,0 +1,99 @@
+// Copyright 2020 The gVisor Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+#include <sys/epoll.h>
+#include <sys/eventfd.h>
+
+#include <atomic>
+#include <cerrno>
+#include <cstdint>
+#include <cstdlib>
+#include <ctime>
+#include <memory>
+
+#include "gtest/gtest.h"
+#include "absl/time/time.h"
+#include "benchmark/benchmark.h"
+#include "test/util/epoll_util.h"
+#include "test/util/file_descriptor.h"
+#include "test/util/test_util.h"
+#include "test/util/thread_util.h"
+
+namespace gvisor {
+namespace testing {
+
+namespace {
+
+// Returns a new eventfd.
+PosixErrorOr<FileDescriptor> NewEventFD() {
+ int fd = eventfd(0, /* flags = */ 0);
+ MaybeSave();
+ if (fd < 0) {
+ return PosixError(errno, "eventfd");
+ }
+ return FileDescriptor(fd);
+}
+
+// Also stolen from epoll.cc unit tests.
+void BM_EpollTimeout(benchmark::State& state) {
+ constexpr int kFDsPerEpoll = 3;
+ auto epollfd = ASSERT_NO_ERRNO_AND_VALUE(NewEpollFD());
+
+ std::vector<FileDescriptor> eventfds;
+ for (int i = 0; i < kFDsPerEpoll; i++) {
+ eventfds.push_back(ASSERT_NO_ERRNO_AND_VALUE(NewEventFD()));
+ ASSERT_NO_ERRNO(
+ RegisterEpollFD(epollfd.get(), eventfds[i].get(), EPOLLIN, 0));
+ }
+
+ struct epoll_event result[kFDsPerEpoll];
+ int timeout_ms = state.range(0);
+
+ for (auto _ : state) {
+ EXPECT_EQ(0, epoll_wait(epollfd.get(), result, kFDsPerEpoll, timeout_ms));
+ }
+}
+
+BENCHMARK(BM_EpollTimeout)->Range(0, 8);
+
+// Also stolen from epoll.cc unit tests.
+void BM_EpollAllEvents(benchmark::State& state) {
+ auto epollfd = ASSERT_NO_ERRNO_AND_VALUE(NewEpollFD());
+ const int fds_per_epoll = state.range(0);
+ constexpr uint64_t kEventVal = 5;
+
+ std::vector<FileDescriptor> eventfds;
+ for (int i = 0; i < fds_per_epoll; i++) {
+ eventfds.push_back(ASSERT_NO_ERRNO_AND_VALUE(NewEventFD()));
+ ASSERT_NO_ERRNO(
+ RegisterEpollFD(epollfd.get(), eventfds[i].get(), EPOLLIN, 0));
+
+ ASSERT_THAT(WriteFd(eventfds[i].get(), &kEventVal, sizeof(kEventVal)),
+ SyscallSucceedsWithValue(sizeof(kEventVal)));
+ }
+
+ std::vector<struct epoll_event> result(fds_per_epoll);
+
+ for (auto _ : state) {
+ EXPECT_EQ(fds_per_epoll,
+ epoll_wait(epollfd.get(), result.data(), fds_per_epoll, 0));
+ }
+}
+
+BENCHMARK(BM_EpollAllEvents)->Range(2, 1024);
+
+} // namespace
+
+} // namespace testing
+} // namespace gvisor
diff --git a/test/perf/linux/fork_benchmark.cc b/test/perf/linux/fork_benchmark.cc
new file mode 100644
index 000000000..84fdbc8a0
--- /dev/null
+++ b/test/perf/linux/fork_benchmark.cc
@@ -0,0 +1,350 @@
+// Copyright 2020 The gVisor Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+#include <unistd.h>
+
+#include "gtest/gtest.h"
+#include "absl/synchronization/barrier.h"
+#include "benchmark/benchmark.h"
+#include "test/util/cleanup.h"
+#include "test/util/file_descriptor.h"
+#include "test/util/logging.h"
+#include "test/util/test_util.h"
+#include "test/util/thread_util.h"
+
+namespace gvisor {
+namespace testing {
+
+namespace {
+
+constexpr int kBusyMax = 250;
+
+// Do some CPU-bound busy-work.
+int busy(int max) {
+ // Prevent the compiler from optimizing this work away,
+ volatile int count = 0;
+
+ for (int i = 1; i < max; i++) {
+ for (int j = 2; j < i / 2; j++) {
+ if (i % j == 0) {
+ count++;
+ }
+ }
+ }
+
+ return count;
+}
+
+void BM_CPUBoundUniprocess(benchmark::State& state) {
+ for (auto _ : state) {
+ busy(kBusyMax);
+ }
+}
+
+BENCHMARK(BM_CPUBoundUniprocess);
+
+void BM_CPUBoundAsymmetric(benchmark::State& state) {
+ const size_t max = state.max_iterations;
+ pid_t child = fork();
+ if (child == 0) {
+ for (int i = 0; i < max; i++) {
+ busy(kBusyMax);
+ }
+ _exit(0);
+ }
+ ASSERT_THAT(child, SyscallSucceeds());
+ ASSERT_TRUE(state.KeepRunningBatch(max));
+
+ int status;
+ EXPECT_THAT(RetryEINTR(waitpid)(child, &status, 0), SyscallSucceeds());
+ EXPECT_TRUE(WIFEXITED(status));
+ EXPECT_EQ(0, WEXITSTATUS(status));
+ ASSERT_FALSE(state.KeepRunning());
+}
+
+BENCHMARK(BM_CPUBoundAsymmetric)->UseRealTime();
+
+void BM_CPUBoundSymmetric(benchmark::State& state) {
+ std::vector<pid_t> children;
+ auto child_cleanup = Cleanup([&] {
+ for (const pid_t child : children) {
+ int status;
+ EXPECT_THAT(RetryEINTR(waitpid)(child, &status, 0), SyscallSucceeds());
+ EXPECT_TRUE(WIFEXITED(status));
+ EXPECT_EQ(0, WEXITSTATUS(status));
+ }
+ ASSERT_FALSE(state.KeepRunning());
+ });
+
+ const int processes = state.range(0);
+ for (int i = 0; i < processes; i++) {
+ size_t cur = (state.max_iterations + (processes - 1)) / processes;
+ if ((state.iterations() + cur) >= state.max_iterations) {
+ cur = state.max_iterations - state.iterations();
+ }
+ pid_t child = fork();
+ if (child == 0) {
+ for (int i = 0; i < cur; i++) {
+ busy(kBusyMax);
+ }
+ _exit(0);
+ }
+ ASSERT_THAT(child, SyscallSucceeds());
+ if (cur > 0) {
+ // We can have a zero cur here, depending.
+ ASSERT_TRUE(state.KeepRunningBatch(cur));
+ }
+ children.push_back(child);
+ }
+}
+
+BENCHMARK(BM_CPUBoundSymmetric)->Range(2, 16)->UseRealTime();
+
+// Child routine for ProcessSwitch/ThreadSwitch.
+// Reads from readfd and writes the result to writefd.
+void SwitchChild(int readfd, int writefd) {
+ while (1) {
+ char buf;
+ int ret = ReadFd(readfd, &buf, 1);
+ if (ret == 0) {
+ break;
+ }
+ TEST_CHECK_MSG(ret == 1, "read failed");
+
+ ret = WriteFd(writefd, &buf, 1);
+ if (ret == -1) {
+ TEST_CHECK_MSG(errno == EPIPE, "unexpected write failure");
+ break;
+ }
+ TEST_CHECK_MSG(ret == 1, "write failed");
+ }
+}
+
+// Send bytes in a loop through a series of pipes, each passing through a
+// different process.
+//
+// Proc 0 Proc 1
+// * ----------> *
+// ^ Pipe 1 |
+// | |
+// | Pipe 0 | Pipe 2
+// | |
+// | |
+// | Pipe 3 v
+// * <---------- *
+// Proc 3 Proc 2
+//
+// This exercises context switching through multiple processes.
+void BM_ProcessSwitch(benchmark::State& state) {
+ // Code below assumes there are at least two processes.
+ const int num_processes = state.range(0);
+ ASSERT_GE(num_processes, 2);
+
+ std::vector<pid_t> children;
+ auto child_cleanup = Cleanup([&] {
+ for (const pid_t child : children) {
+ int status;
+ EXPECT_THAT(RetryEINTR(waitpid)(child, &status, 0), SyscallSucceeds());
+ EXPECT_TRUE(WIFEXITED(status));
+ EXPECT_EQ(0, WEXITSTATUS(status));
+ }
+ });
+
+ // Must come after children, as the FDs must be closed before the children
+ // will exit.
+ std::vector<FileDescriptor> read_fds;
+ std::vector<FileDescriptor> write_fds;
+
+ for (int i = 0; i < num_processes; i++) {
+ int fds[2];
+ ASSERT_THAT(pipe(fds), SyscallSucceeds());
+ read_fds.emplace_back(fds[0]);
+ write_fds.emplace_back(fds[1]);
+ }
+
+ // This process is one of the processes in the loop. It will be considered
+ // index 0.
+ for (int i = 1; i < num_processes; i++) {
+ // Read from current pipe index, write to next.
+ const int read_index = i;
+ const int read_fd = read_fds[read_index].get();
+
+ const int write_index = (i + 1) % num_processes;
+ const int write_fd = write_fds[write_index].get();
+
+ // std::vector isn't safe to use from the fork child.
+ FileDescriptor* read_array = read_fds.data();
+ FileDescriptor* write_array = write_fds.data();
+
+ pid_t child = fork();
+ if (!child) {
+ // Close all other FDs.
+ for (int j = 0; j < num_processes; j++) {
+ if (j != read_index) {
+ read_array[j].reset();
+ }
+ if (j != write_index) {
+ write_array[j].reset();
+ }
+ }
+
+ SwitchChild(read_fd, write_fd);
+ _exit(0);
+ }
+ ASSERT_THAT(child, SyscallSucceeds());
+ children.push_back(child);
+ }
+
+ // Read from current pipe index (0), write to next (1).
+ const int read_index = 0;
+ const int read_fd = read_fds[read_index].get();
+
+ const int write_index = 1;
+ const int write_fd = write_fds[write_index].get();
+
+ // Kick start the loop.
+ char buf = 'a';
+ ASSERT_THAT(WriteFd(write_fd, &buf, 1), SyscallSucceedsWithValue(1));
+
+ for (auto _ : state) {
+ ASSERT_THAT(ReadFd(read_fd, &buf, 1), SyscallSucceedsWithValue(1));
+ ASSERT_THAT(WriteFd(write_fd, &buf, 1), SyscallSucceedsWithValue(1));
+ }
+}
+
+BENCHMARK(BM_ProcessSwitch)->Range(2, 16)->UseRealTime();
+
+// Equivalent to BM_ThreadSwitch using threads instead of processes.
+void BM_ThreadSwitch(benchmark::State& state) {
+ // Code below assumes there are at least two threads.
+ const int num_threads = state.range(0);
+ ASSERT_GE(num_threads, 2);
+
+ // Must come after threads, as the FDs must be closed before the children
+ // will exit.
+ std::vector<std::unique_ptr<ScopedThread>> threads;
+ std::vector<FileDescriptor> read_fds;
+ std::vector<FileDescriptor> write_fds;
+
+ for (int i = 0; i < num_threads; i++) {
+ int fds[2];
+ ASSERT_THAT(pipe(fds), SyscallSucceeds());
+ read_fds.emplace_back(fds[0]);
+ write_fds.emplace_back(fds[1]);
+ }
+
+ // This thread is one of the threads in the loop. It will be considered
+ // index 0.
+ for (int i = 1; i < num_threads; i++) {
+ // Read from current pipe index, write to next.
+ //
+ // Transfer ownership of the FDs to the thread.
+ const int read_index = i;
+ const int read_fd = read_fds[read_index].release();
+
+ const int write_index = (i + 1) % num_threads;
+ const int write_fd = write_fds[write_index].release();
+
+ threads.emplace_back(std::make_unique<ScopedThread>([read_fd, write_fd] {
+ FileDescriptor read(read_fd);
+ FileDescriptor write(write_fd);
+ SwitchChild(read.get(), write.get());
+ }));
+ }
+
+ // Read from current pipe index (0), write to next (1).
+ const int read_index = 0;
+ const int read_fd = read_fds[read_index].get();
+
+ const int write_index = 1;
+ const int write_fd = write_fds[write_index].get();
+
+ // Kick start the loop.
+ char buf = 'a';
+ ASSERT_THAT(WriteFd(write_fd, &buf, 1), SyscallSucceedsWithValue(1));
+
+ for (auto _ : state) {
+ ASSERT_THAT(ReadFd(read_fd, &buf, 1), SyscallSucceedsWithValue(1));
+ ASSERT_THAT(WriteFd(write_fd, &buf, 1), SyscallSucceedsWithValue(1));
+ }
+
+ // The two FDs still owned by this thread are closed, causing the next thread
+ // to exit its loop and close its FDs, and so on until all threads exit.
+}
+
+BENCHMARK(BM_ThreadSwitch)->Range(2, 16)->UseRealTime();
+
+void BM_ThreadStart(benchmark::State& state) {
+ const int num_threads = state.range(0);
+
+ for (auto _ : state) {
+ state.PauseTiming();
+
+ auto barrier = new absl::Barrier(num_threads + 1);
+ std::vector<std::unique_ptr<ScopedThread>> threads;
+
+ state.ResumeTiming();
+
+ for (size_t i = 0; i < num_threads; ++i) {
+ threads.emplace_back(std::make_unique<ScopedThread>([barrier] {
+ if (barrier->Block()) {
+ delete barrier;
+ }
+ }));
+ }
+
+ if (barrier->Block()) {
+ delete barrier;
+ }
+
+ state.PauseTiming();
+
+ for (const auto& thread : threads) {
+ thread->Join();
+ }
+
+ state.ResumeTiming();
+ }
+}
+
+BENCHMARK(BM_ThreadStart)->Range(1, 2048)->UseRealTime();
+
+// Benchmark the complete fork + exit + wait.
+void BM_ProcessLifecycle(benchmark::State& state) {
+ const int num_procs = state.range(0);
+
+ std::vector<pid_t> pids(num_procs);
+ for (auto _ : state) {
+ for (size_t i = 0; i < num_procs; ++i) {
+ int pid = fork();
+ if (pid == 0) {
+ _exit(0);
+ }
+ ASSERT_THAT(pid, SyscallSucceeds());
+ pids[i] = pid;
+ }
+
+ for (const int pid : pids) {
+ ASSERT_THAT(RetryEINTR(waitpid)(pid, nullptr, 0),
+ SyscallSucceedsWithValue(pid));
+ }
+ }
+}
+
+BENCHMARK(BM_ProcessLifecycle)->Range(1, 512)->UseRealTime();
+
+} // namespace
+
+} // namespace testing
+} // namespace gvisor
diff --git a/test/perf/linux/futex_benchmark.cc b/test/perf/linux/futex_benchmark.cc
new file mode 100644
index 000000000..b349d50bf
--- /dev/null
+++ b/test/perf/linux/futex_benchmark.cc
@@ -0,0 +1,248 @@
+// Copyright 2020 The gVisor Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+#include <linux/futex.h>
+
+#include <atomic>
+#include <cerrno>
+#include <cstdint>
+#include <cstdlib>
+#include <ctime>
+
+#include "gtest/gtest.h"
+#include "absl/time/clock.h"
+#include "absl/time/time.h"
+#include "benchmark/benchmark.h"
+#include "test/util/logging.h"
+#include "test/util/thread_util.h"
+
+namespace gvisor {
+namespace testing {
+
+namespace {
+
+inline int FutexWait(std::atomic<int32_t>* v, int32_t val) {
+ return syscall(SYS_futex, v, FUTEX_BITSET_MATCH_ANY, nullptr);
+}
+
+inline int FutexWaitRelativeTimeout(std::atomic<int32_t>* v, int32_t val,
+ const struct timespec* reltime) {
+ return syscall(SYS_futex, v, FUTEX_WAIT_PRIVATE, reltime);
+}
+
+inline int FutexWaitAbsoluteTimeout(std::atomic<int32_t>* v, int32_t val,
+ const struct timespec* abstime) {
+ return syscall(SYS_futex, v, FUTEX_BITSET_MATCH_ANY, abstime);
+}
+
+inline int FutexWaitBitsetAbsoluteTimeout(std::atomic<int32_t>* v, int32_t val,
+ int32_t bits,
+ const struct timespec* abstime) {
+ return syscall(SYS_futex, v, FUTEX_WAIT_BITSET_PRIVATE | FUTEX_CLOCK_REALTIME,
+ val, abstime, nullptr, bits);
+}
+
+inline int FutexWake(std::atomic<int32_t>* v, int32_t count) {
+ return syscall(SYS_futex, v, FUTEX_WAKE_PRIVATE, count);
+}
+
+// This just uses FUTEX_WAKE on an address with nothing waiting, very simple.
+void BM_FutexWakeNop(benchmark::State& state) {
+ std::atomic<int32_t> v(0);
+
+ for (auto _ : state) {
+ EXPECT_EQ(0, FutexWake(&v, 1));
+ }
+}
+
+BENCHMARK(BM_FutexWakeNop);
+
+// This just uses FUTEX_WAIT on an address whose value has changed, i.e., the
+// syscall won't wait.
+void BM_FutexWaitNop(benchmark::State& state) {
+ std::atomic<int32_t> v(0);
+
+ for (auto _ : state) {
+ EXPECT_EQ(-EAGAIN, FutexWait(&v, 1));
+ }
+}
+
+BENCHMARK(BM_FutexWaitNop);
+
+// This uses FUTEX_WAIT with a timeout on an address whose value never
+// changes, such that it always times out. Timeout overhead can be estimated by
+// timer overruns for short timeouts.
+void BM_FutexWaitTimeout(benchmark::State& state) {
+ const int timeout_ns = state.range(0);
+ std::atomic<int32_t> v(0);
+ auto ts = absl::ToTimespec(absl::Nanoseconds(timeout_ns));
+
+ for (auto _ : state) {
+ EXPECT_EQ(-ETIMEDOUT, FutexWaitRelativeTimeout(&v, 0, &ts));
+ }
+}
+
+BENCHMARK(BM_FutexWaitTimeout)
+ ->Arg(1)
+ ->Arg(10)
+ ->Arg(100)
+ ->Arg(1000)
+ ->Arg(10000);
+
+// This calls FUTEX_WAIT_BITSET with CLOCK_REALTIME.
+void BM_FutexWaitBitset(benchmark::State& state) {
+ std::atomic<int32_t> v(0);
+ int timeout_ns = state.range(0);
+ auto ts = absl::ToTimespec(absl::Nanoseconds(timeout_ns));
+ for (auto _ : state) {
+ EXPECT_EQ(-ETIMEDOUT, FutexWaitBitsetAbsoluteTimeout(&v, 0, 1, &ts));
+ }
+}
+
+BENCHMARK(BM_FutexWaitBitset)->Range(0, 100000);
+
+int64_t GetCurrentMonotonicTimeNanos() {
+ struct timespec ts;
+ TEST_CHECK(clock_gettime(CLOCK_MONOTONIC, &ts) != -1);
+ return ts.tv_sec * 1000000000ULL + ts.tv_nsec;
+}
+
+void SpinNanos(int64_t delay_ns) {
+ if (delay_ns <= 0) {
+ return;
+ }
+ const int64_t end = GetCurrentMonotonicTimeNanos() + delay_ns;
+ while (GetCurrentMonotonicTimeNanos() < end) {
+ // spin
+ }
+}
+
+// Each iteration of FutexRoundtripDelayed involves a thread sending a futex
+// wakeup to another thread, which spins for delay_us and then sends a futex
+// wakeup back. The time per iteration is 2* (delay_us + kBeforeWakeDelayNs +
+// futex/scheduling overhead).
+void BM_FutexRoundtripDelayed(benchmark::State& state) {
+ const int delay_us = state.range(0);
+
+ const int64_t delay_ns = delay_us * 1000;
+ // Spin for an extra kBeforeWakeDelayNs before invoking FUTEX_WAKE to reduce
+ // the probability that the wakeup comes before the wait, preventing the wait
+ // from ever taking effect and causing the benchmark to underestimate the
+ // actual wakeup time.
+ constexpr int64_t kBeforeWakeDelayNs = 500;
+ std::atomic<int32_t> v(0);
+ ScopedThread t([&] {
+ for (int i = 0; i < state.max_iterations; i++) {
+ SpinNanos(delay_ns);
+ while (v.load(std::memory_order_acquire) == 0) {
+ FutexWait(&v, 0);
+ }
+ SpinNanos(kBeforeWakeDelayNs + delay_ns);
+ v.store(0, std::memory_order_release);
+ FutexWake(&v, 1);
+ }
+ });
+ for (auto _ : state) {
+ SpinNanos(kBeforeWakeDelayNs + delay_ns);
+ v.store(1, std::memory_order_release);
+ FutexWake(&v, 1);
+ SpinNanos(delay_ns);
+ while (v.load(std::memory_order_acquire) == 1) {
+ FutexWait(&v, 1);
+ }
+ }
+}
+
+BENCHMARK(BM_FutexRoundtripDelayed)
+ ->Arg(0)
+ ->Arg(10)
+ ->Arg(20)
+ ->Arg(50)
+ ->Arg(100);
+
+// FutexLock is a simple, dumb futex based lock implementation.
+// It will try to acquire the lock by atomically incrementing the
+// lock word. If it did not increment the lock from 0 to 1, someone
+// else has the lock, so it will FUTEX_WAIT until it is woken in
+// the unlock path.
+class FutexLock {
+ public:
+ FutexLock() : lock_word_(0) {}
+
+ void lock(struct timespec* deadline) {
+ int32_t val;
+ while ((val = lock_word_.fetch_add(1, std::memory_order_acquire) + 1) !=
+ 1) {
+ // If we didn't get the lock by incrementing from 0 to 1,
+ // do a FUTEX_WAIT with the desired current value set to
+ // val. If val is no longer what the atomic increment returned,
+ // someone might have set it to 0 so we can try to acquire
+ // again.
+ int ret = FutexWaitAbsoluteTimeout(&lock_word_, val, deadline);
+ if (ret == 0 || ret == -EWOULDBLOCK || ret == -EINTR) {
+ continue;
+ } else {
+ FAIL() << "unexpected FUTEX_WAIT return: " << ret;
+ }
+ }
+ }
+
+ void unlock() {
+ // Store 0 into the lock word and wake one waiter. We intentionally
+ // ignore the return value of the FUTEX_WAKE here, since there may be
+ // no waiters to wake anyway.
+ lock_word_.store(0, std::memory_order_release);
+ (void)FutexWake(&lock_word_, 1);
+ }
+
+ private:
+ std::atomic<int32_t> lock_word_;
+};
+
+FutexLock* test_lock; // Used below.
+
+void FutexContend(benchmark::State& state, int thread_index,
+ struct timespec* deadline) {
+ int counter = 0;
+ if (thread_index == 0) {
+ test_lock = new FutexLock();
+ }
+ for (auto _ : state) {
+ test_lock->lock(deadline);
+ counter++;
+ test_lock->unlock();
+ }
+ if (thread_index == 0) {
+ delete test_lock;
+ }
+ state.SetItemsProcessed(state.iterations());
+}
+
+void BM_FutexContend(benchmark::State& state) {
+ FutexContend(state, state.thread_index, nullptr);
+}
+
+BENCHMARK(BM_FutexContend)->ThreadRange(1, 1024)->UseRealTime();
+
+void BM_FutexDeadlineContend(benchmark::State& state) {
+ auto deadline = absl::ToTimespec(absl::Now() + absl::Minutes(10));
+ FutexContend(state, state.thread_index, &deadline);
+}
+
+BENCHMARK(BM_FutexDeadlineContend)->ThreadRange(1, 1024)->UseRealTime();
+
+} // namespace
+
+} // namespace testing
+} // namespace gvisor
diff --git a/test/perf/linux/getdents_benchmark.cc b/test/perf/linux/getdents_benchmark.cc
new file mode 100644
index 000000000..afc599ad2
--- /dev/null
+++ b/test/perf/linux/getdents_benchmark.cc
@@ -0,0 +1,149 @@
+// Copyright 2020 The gVisor Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+#include <sys/stat.h>
+#include <sys/types.h>
+#include <unistd.h>
+
+#include "gtest/gtest.h"
+#include "benchmark/benchmark.h"
+#include "test/util/file_descriptor.h"
+#include "test/util/fs_util.h"
+#include "test/util/temp_path.h"
+#include "test/util/test_util.h"
+
+#ifndef SYS_getdents64
+#if defined(__x86_64__)
+#define SYS_getdents64 217
+#elif defined(__aarch64__)
+#define SYS_getdents64 217
+#else
+#error "Unknown architecture"
+#endif
+#endif // SYS_getdents64
+
+namespace gvisor {
+namespace testing {
+
+namespace {
+
+constexpr int kBufferSize = 16384;
+
+PosixErrorOr<TempPath> CreateDirectory(int count,
+ std::vector<std::string>* files) {
+ ASSIGN_OR_RETURN_ERRNO(TempPath dir, TempPath::CreateDir());
+
+ ASSIGN_OR_RETURN_ERRNO(FileDescriptor dfd,
+ Open(dir.path(), O_RDONLY | O_DIRECTORY));
+
+ for (int i = 0; i < count; i++) {
+ auto file = NewTempRelPath();
+ auto res = MknodAt(dfd, file, S_IFREG | 0644, 0);
+ RETURN_IF_ERRNO(res);
+ files->push_back(file);
+ }
+
+ return std::move(dir);
+}
+
+PosixError CleanupDirectory(const TempPath& dir,
+ std::vector<std::string>* files) {
+ ASSIGN_OR_RETURN_ERRNO(FileDescriptor dfd,
+ Open(dir.path(), O_RDONLY | O_DIRECTORY));
+
+ for (auto it = files->begin(); it != files->end(); ++it) {
+ auto res = UnlinkAt(dfd, *it, 0);
+ RETURN_IF_ERRNO(res);
+ }
+ return NoError();
+}
+
+// Creates a directory containing `files` files, and reads all the directory
+// entries from the directory using a single FD.
+void BM_GetdentsSameFD(benchmark::State& state) {
+ // Create directory with given files.
+ const int count = state.range(0);
+
+ // Keep a vector of all of the file TempPaths that is destroyed before dir.
+ //
+ // Normally, we'd simply allow dir to recursively clean up the contained
+ // files, but that recursive cleanup uses getdents, which may be very slow in
+ // extreme benchmarks.
+ TempPath dir;
+ std::vector<std::string> files;
+ dir = ASSERT_NO_ERRNO_AND_VALUE(CreateDirectory(count, &files));
+
+ FileDescriptor fd =
+ ASSERT_NO_ERRNO_AND_VALUE(Open(dir.path(), O_RDONLY | O_DIRECTORY));
+ char buffer[kBufferSize];
+
+ // We read all directory entries on each iteration, but report this as a
+ // "batch" iteration so that reported times are per file.
+ while (state.KeepRunningBatch(count)) {
+ ASSERT_THAT(lseek(fd.get(), 0, SEEK_SET), SyscallSucceeds());
+
+ int ret;
+ do {
+ ASSERT_THAT(ret = syscall(SYS_getdents64, fd.get(), buffer, kBufferSize),
+ SyscallSucceeds());
+ } while (ret > 0);
+ }
+
+ ASSERT_NO_ERRNO(CleanupDirectory(dir, &files));
+
+ state.SetItemsProcessed(state.iterations());
+}
+
+BENCHMARK(BM_GetdentsSameFD)->Range(1, 1 << 16)->UseRealTime();
+
+// Creates a directory containing `files` files, and reads all the directory
+// entries from the directory using a new FD each time.
+void BM_GetdentsNewFD(benchmark::State& state) {
+ // Create directory with given files.
+ const int count = state.range(0);
+
+ // Keep a vector of all of the file TempPaths that is destroyed before dir.
+ //
+ // Normally, we'd simply allow dir to recursively clean up the contained
+ // files, but that recursive cleanup uses getdents, which may be very slow in
+ // extreme benchmarks.
+ TempPath dir;
+ std::vector<std::string> files;
+ dir = ASSERT_NO_ERRNO_AND_VALUE(CreateDirectory(count, &files));
+ char buffer[kBufferSize];
+
+ // We read all directory entries on each iteration, but report this as a
+ // "batch" iteration so that reported times are per file.
+ while (state.KeepRunningBatch(count)) {
+ FileDescriptor fd =
+ ASSERT_NO_ERRNO_AND_VALUE(Open(dir.path(), O_RDONLY | O_DIRECTORY));
+
+ int ret;
+ do {
+ ASSERT_THAT(ret = syscall(SYS_getdents64, fd.get(), buffer, kBufferSize),
+ SyscallSucceeds());
+ } while (ret > 0);
+ }
+
+ ASSERT_NO_ERRNO(CleanupDirectory(dir, &files));
+
+ state.SetItemsProcessed(state.iterations());
+}
+
+BENCHMARK(BM_GetdentsNewFD)->Range(1, 1 << 12)->UseRealTime();
+
+} // namespace
+
+} // namespace testing
+} // namespace gvisor
diff --git a/test/perf/linux/getpid_benchmark.cc b/test/perf/linux/getpid_benchmark.cc
new file mode 100644
index 000000000..db74cb264
--- /dev/null
+++ b/test/perf/linux/getpid_benchmark.cc
@@ -0,0 +1,37 @@
+// Copyright 2020 The gVisor Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+#include <sys/syscall.h>
+#include <unistd.h>
+
+#include "gtest/gtest.h"
+#include "benchmark/benchmark.h"
+
+namespace gvisor {
+namespace testing {
+
+namespace {
+
+void BM_Getpid(benchmark::State& state) {
+ for (auto _ : state) {
+ syscall(SYS_getpid);
+ }
+}
+
+BENCHMARK(BM_Getpid);
+
+} // namespace
+
+} // namespace testing
+} // namespace gvisor
diff --git a/test/perf/linux/gettid_benchmark.cc b/test/perf/linux/gettid_benchmark.cc
new file mode 100644
index 000000000..8f4961f5e
--- /dev/null
+++ b/test/perf/linux/gettid_benchmark.cc
@@ -0,0 +1,38 @@
+// Copyright 2020 The gVisor Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+#include <sys/syscall.h>
+#include <sys/types.h>
+#include <unistd.h>
+
+#include "gtest/gtest.h"
+#include "benchmark/benchmark.h"
+
+namespace gvisor {
+namespace testing {
+
+namespace {
+
+void BM_Gettid(benchmark::State& state) {
+ for (auto _ : state) {
+ syscall(SYS_gettid);
+ }
+}
+
+BENCHMARK(BM_Gettid)->ThreadRange(1, 4000)->UseRealTime();
+
+} // namespace
+
+} // namespace testing
+} // namespace gvisor
diff --git a/test/perf/linux/mapping_benchmark.cc b/test/perf/linux/mapping_benchmark.cc
new file mode 100644
index 000000000..39c30fe69
--- /dev/null
+++ b/test/perf/linux/mapping_benchmark.cc
@@ -0,0 +1,163 @@
+// Copyright 2020 The gVisor Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+#include <stdlib.h>
+#include <sys/mman.h>
+#include <unistd.h>
+
+#include "gtest/gtest.h"
+#include "benchmark/benchmark.h"
+#include "test/util/logging.h"
+#include "test/util/memory_util.h"
+#include "test/util/posix_error.h"
+#include "test/util/test_util.h"
+
+namespace gvisor {
+namespace testing {
+
+namespace {
+
+// Conservative value for /proc/sys/vm/max_map_count, which limits the number of
+// VMAs, minus a safety margin for VMAs that already exist for the test binary.
+// The default value for max_map_count is
+// include/linux/mm.h:DEFAULT_MAX_MAP_COUNT = 65530.
+constexpr size_t kMaxVMAs = 64001;
+
+// Map then unmap pages without touching them.
+void BM_MapUnmap(benchmark::State& state) {
+ // Number of pages to map.
+ const int pages = state.range(0);
+
+ while (state.KeepRunning()) {
+ void* addr = mmap(0, pages * kPageSize, PROT_READ | PROT_WRITE,
+ MAP_PRIVATE | MAP_ANONYMOUS, -1, 0);
+ TEST_CHECK_MSG(addr != MAP_FAILED, "mmap failed");
+
+ int ret = munmap(addr, pages * kPageSize);
+ TEST_CHECK_MSG(ret == 0, "munmap failed");
+ }
+}
+
+BENCHMARK(BM_MapUnmap)->Range(1, 1 << 17)->UseRealTime();
+
+// Map, touch, then unmap pages.
+void BM_MapTouchUnmap(benchmark::State& state) {
+ // Number of pages to map.
+ const int pages = state.range(0);
+
+ while (state.KeepRunning()) {
+ void* addr = mmap(0, pages * kPageSize, PROT_READ | PROT_WRITE,
+ MAP_PRIVATE | MAP_ANONYMOUS, -1, 0);
+ TEST_CHECK_MSG(addr != MAP_FAILED, "mmap failed");
+
+ char* c = reinterpret_cast<char*>(addr);
+ char* end = c + pages * kPageSize;
+ while (c < end) {
+ *c = 42;
+ c += kPageSize;
+ }
+
+ int ret = munmap(addr, pages * kPageSize);
+ TEST_CHECK_MSG(ret == 0, "munmap failed");
+ }
+}
+
+BENCHMARK(BM_MapTouchUnmap)->Range(1, 1 << 17)->UseRealTime();
+
+// Map and touch many pages, unmapping all at once.
+//
+// NOTE(b/111429208): This is a regression test to ensure performant mapping and
+// allocation even with tons of mappings.
+void BM_MapTouchMany(benchmark::State& state) {
+ // Number of pages to map.
+ const int page_count = state.range(0);
+
+ while (state.KeepRunning()) {
+ std::vector<void*> pages;
+
+ for (int i = 0; i < page_count; i++) {
+ void* addr = mmap(nullptr, kPageSize, PROT_READ | PROT_WRITE,
+ MAP_PRIVATE | MAP_ANONYMOUS, -1, 0);
+ TEST_CHECK_MSG(addr != MAP_FAILED, "mmap failed");
+
+ char* c = reinterpret_cast<char*>(addr);
+ *c = 42;
+
+ pages.push_back(addr);
+ }
+
+ for (void* addr : pages) {
+ int ret = munmap(addr, kPageSize);
+ TEST_CHECK_MSG(ret == 0, "munmap failed");
+ }
+ }
+
+ state.SetBytesProcessed(kPageSize * page_count * state.iterations());
+}
+
+BENCHMARK(BM_MapTouchMany)->Range(1, 1 << 12)->UseRealTime();
+
+void BM_PageFault(benchmark::State& state) {
+ // Map the region in which we will take page faults. To ensure that each page
+ // fault maps only a single page, each page we touch must correspond to a
+ // distinct VMA. Thus we need a 1-page gap between each 1-page VMA. However,
+ // each gap consists of a PROT_NONE VMA, instead of an unmapped hole, so that
+ // if there are background threads running, they can't inadvertently creating
+ // mappings in our gaps that are unmapped when the test ends.
+ size_t test_pages = kMaxVMAs;
+ // Ensure that test_pages is odd, since we want the test region to both
+ // begin and end with a mapped page.
+ if (test_pages % 2 == 0) {
+ test_pages--;
+ }
+ const size_t test_region_bytes = test_pages * kPageSize;
+ // Use MAP_SHARED here because madvise(MADV_DONTNEED) on private mappings on
+ // gVisor won't force future sentry page faults (by design). Use MAP_POPULATE
+ // so that Linux pre-allocates the shmem file used to back the mapping.
+ Mapping m = ASSERT_NO_ERRNO_AND_VALUE(
+ MmapAnon(test_region_bytes, PROT_READ, MAP_SHARED | MAP_POPULATE));
+ for (size_t i = 0; i < test_pages / 2; i++) {
+ ASSERT_THAT(
+ mprotect(reinterpret_cast<void*>(m.addr() + ((2 * i + 1) * kPageSize)),
+ kPageSize, PROT_NONE),
+ SyscallSucceeds());
+ }
+
+ const size_t mapped_pages = test_pages / 2 + 1;
+ // "Start" at the end of the mapped region to force the mapped region to be
+ // reset, since we mapped it with MAP_POPULATE.
+ size_t cur_page = mapped_pages;
+ for (auto _ : state) {
+ if (cur_page >= mapped_pages) {
+ // We've reached the end of our mapped region and have to reset it to
+ // incur page faults again.
+ state.PauseTiming();
+ ASSERT_THAT(madvise(m.ptr(), test_region_bytes, MADV_DONTNEED),
+ SyscallSucceeds());
+ cur_page = 0;
+ state.ResumeTiming();
+ }
+ const uintptr_t addr = m.addr() + (2 * cur_page * kPageSize);
+ const char c = *reinterpret_cast<volatile char*>(addr);
+ benchmark::DoNotOptimize(c);
+ cur_page++;
+ }
+}
+
+BENCHMARK(BM_PageFault)->UseRealTime();
+
+} // namespace
+
+} // namespace testing
+} // namespace gvisor
diff --git a/test/perf/linux/open_benchmark.cc b/test/perf/linux/open_benchmark.cc
new file mode 100644
index 000000000..68008f6d5
--- /dev/null
+++ b/test/perf/linux/open_benchmark.cc
@@ -0,0 +1,56 @@
+// Copyright 2020 The gVisor Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+#include <fcntl.h>
+#include <stdlib.h>
+#include <unistd.h>
+
+#include <memory>
+#include <string>
+#include <vector>
+
+#include "gtest/gtest.h"
+#include "benchmark/benchmark.h"
+#include "test/util/fs_util.h"
+#include "test/util/logging.h"
+#include "test/util/temp_path.h"
+
+namespace gvisor {
+namespace testing {
+
+namespace {
+
+void BM_Open(benchmark::State& state) {
+ const int size = state.range(0);
+ std::vector<TempPath> cache;
+ for (int i = 0; i < size; i++) {
+ auto path = ASSERT_NO_ERRNO_AND_VALUE(TempPath::CreateFile());
+ cache.emplace_back(std::move(path));
+ }
+
+ unsigned int seed = 1;
+ for (auto _ : state) {
+ const int chosen = rand_r(&seed) % size;
+ int fd = open(cache[chosen].path().c_str(), O_RDONLY);
+ TEST_CHECK(fd != -1);
+ close(fd);
+ }
+}
+
+BENCHMARK(BM_Open)->Range(1, 128)->UseRealTime();
+
+} // namespace
+
+} // namespace testing
+} // namespace gvisor
diff --git a/test/perf/linux/pipe_benchmark.cc b/test/perf/linux/pipe_benchmark.cc
new file mode 100644
index 000000000..8f5f6a2a3
--- /dev/null
+++ b/test/perf/linux/pipe_benchmark.cc
@@ -0,0 +1,66 @@
+// Copyright 2020 The gVisor Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+#include <stdlib.h>
+#include <sys/stat.h>
+#include <unistd.h>
+
+#include <cerrno>
+
+#include "gtest/gtest.h"
+#include "benchmark/benchmark.h"
+#include "test/util/logging.h"
+#include "test/util/test_util.h"
+#include "test/util/thread_util.h"
+
+namespace gvisor {
+namespace testing {
+
+namespace {
+
+void BM_Pipe(benchmark::State& state) {
+ int fds[2];
+ TEST_CHECK(pipe(fds) == 0);
+
+ const int size = state.range(0);
+ std::vector<char> wbuf(size);
+ std::vector<char> rbuf(size);
+ RandomizeBuffer(wbuf.data(), size);
+
+ ScopedThread t([&] {
+ auto const fd = fds[1];
+ for (int i = 0; i < state.max_iterations; i++) {
+ TEST_CHECK(WriteFd(fd, wbuf.data(), wbuf.size()) == size);
+ }
+ });
+
+ for (auto _ : state) {
+ TEST_CHECK(ReadFd(fds[0], rbuf.data(), rbuf.size()) == size);
+ }
+
+ t.Join();
+
+ close(fds[0]);
+ close(fds[1]);
+
+ state.SetBytesProcessed(static_cast<int64_t>(size) *
+ static_cast<int64_t>(state.iterations()));
+}
+
+BENCHMARK(BM_Pipe)->Range(1, 1 << 20)->UseRealTime();
+
+} // namespace
+
+} // namespace testing
+} // namespace gvisor
diff --git a/test/perf/linux/randread_benchmark.cc b/test/perf/linux/randread_benchmark.cc
new file mode 100644
index 000000000..b0eb8c24e
--- /dev/null
+++ b/test/perf/linux/randread_benchmark.cc
@@ -0,0 +1,100 @@
+// Copyright 2020 The gVisor Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+#include <fcntl.h>
+#include <stdlib.h>
+#include <sys/stat.h>
+#include <sys/uio.h>
+#include <unistd.h>
+
+#include "gtest/gtest.h"
+#include "benchmark/benchmark.h"
+#include "test/util/file_descriptor.h"
+#include "test/util/logging.h"
+#include "test/util/temp_path.h"
+#include "test/util/test_util.h"
+
+namespace gvisor {
+namespace testing {
+
+namespace {
+
+// Create a 1GB file that will be read from at random positions. This should
+// invalid any performance gains from caching.
+const uint64_t kFileSize = 1ULL << 30;
+
+// How many bytes to write at once to initialize the file used to read from.
+const uint32_t kWriteSize = 65536;
+
+// Largest benchmarked read unit.
+const uint32_t kMaxRead = 1UL << 26;
+
+TempPath CreateFile(uint64_t file_size) {
+ auto path = TempPath::CreateFile().ValueOrDie();
+ FileDescriptor fd = Open(path.path(), O_WRONLY).ValueOrDie();
+
+ // Try to minimize syscalls by using maximum size writev() requests.
+ std::vector<char> buffer(kWriteSize);
+ RandomizeBuffer(buffer.data(), buffer.size());
+ const std::vector<std::vector<struct iovec>> iovecs_list =
+ GenerateIovecs(file_size, buffer.data(), buffer.size());
+ for (const auto& iovecs : iovecs_list) {
+ TEST_CHECK(writev(fd.get(), iovecs.data(), iovecs.size()) >= 0);
+ }
+
+ return path;
+}
+
+// Global test state, initialized once per process lifetime.
+struct GlobalState {
+ const TempPath tmpfile;
+ explicit GlobalState(TempPath tfile) : tmpfile(std::move(tfile)) {}
+};
+
+GlobalState& GetGlobalState() {
+ // This gets created only once throughout the lifetime of the process.
+ // Use a dynamically allocated object (that is never deleted) to avoid order
+ // of destruction of static storage variables issues.
+ static GlobalState* const state =
+ // The actual file size is the maximum random seek range (kFileSize) + the
+ // maximum read size so we can read that number of bytes at the end of the
+ // file.
+ new GlobalState(CreateFile(kFileSize + kMaxRead));
+ return *state;
+}
+
+void BM_RandRead(benchmark::State& state) {
+ const int size = state.range(0);
+
+ GlobalState& global_state = GetGlobalState();
+ FileDescriptor fd =
+ ASSERT_NO_ERRNO_AND_VALUE(Open(global_state.tmpfile.path(), O_RDONLY));
+ std::vector<char> buf(size);
+
+ unsigned int seed = 1;
+ for (auto _ : state) {
+ TEST_CHECK(PreadFd(fd.get(), buf.data(), buf.size(),
+ rand_r(&seed) % kFileSize) == size);
+ }
+
+ state.SetBytesProcessed(static_cast<int64_t>(size) *
+ static_cast<int64_t>(state.iterations()));
+}
+
+BENCHMARK(BM_RandRead)->Range(1, kMaxRead)->UseRealTime();
+
+} // namespace
+
+} // namespace testing
+} // namespace gvisor
diff --git a/test/perf/linux/read_benchmark.cc b/test/perf/linux/read_benchmark.cc
new file mode 100644
index 000000000..62445867d
--- /dev/null
+++ b/test/perf/linux/read_benchmark.cc
@@ -0,0 +1,53 @@
+// Copyright 2020 The gVisor Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+#include <fcntl.h>
+#include <stdlib.h>
+#include <sys/stat.h>
+#include <unistd.h>
+
+#include "gtest/gtest.h"
+#include "benchmark/benchmark.h"
+#include "test/util/fs_util.h"
+#include "test/util/logging.h"
+#include "test/util/temp_path.h"
+#include "test/util/test_util.h"
+
+namespace gvisor {
+namespace testing {
+
+namespace {
+
+void BM_Read(benchmark::State& state) {
+ const int size = state.range(0);
+ const std::string contents(size, 0);
+ auto path = ASSERT_NO_ERRNO_AND_VALUE(TempPath::CreateFileWith(
+ GetAbsoluteTestTmpdir(), contents, TempPath::kDefaultFileMode));
+ FileDescriptor fd = ASSERT_NO_ERRNO_AND_VALUE(Open(path.path(), O_RDONLY));
+
+ std::vector<char> buf(size);
+ for (auto _ : state) {
+ TEST_CHECK(PreadFd(fd.get(), buf.data(), buf.size(), 0) == size);
+ }
+
+ state.SetBytesProcessed(static_cast<int64_t>(size) *
+ static_cast<int64_t>(state.iterations()));
+}
+
+BENCHMARK(BM_Read)->Range(1, 1 << 26)->UseRealTime();
+
+} // namespace
+
+} // namespace testing
+} // namespace gvisor
diff --git a/test/perf/linux/sched_yield_benchmark.cc b/test/perf/linux/sched_yield_benchmark.cc
new file mode 100644
index 000000000..6756b5575
--- /dev/null
+++ b/test/perf/linux/sched_yield_benchmark.cc
@@ -0,0 +1,37 @@
+// Copyright 2020 The gVisor Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+#include <sched.h>
+
+#include "gtest/gtest.h"
+#include "benchmark/benchmark.h"
+#include "test/util/test_util.h"
+
+namespace gvisor {
+namespace testing {
+
+namespace {
+
+void BM_Sched_yield(benchmark::State& state) {
+ for (auto ignored : state) {
+ TEST_CHECK(sched_yield() == 0);
+ }
+}
+
+BENCHMARK(BM_Sched_yield)->ThreadRange(1, 2000)->UseRealTime();
+
+} // namespace
+
+} // namespace testing
+} // namespace gvisor
diff --git a/test/perf/linux/send_recv_benchmark.cc b/test/perf/linux/send_recv_benchmark.cc
new file mode 100644
index 000000000..d73e49523
--- /dev/null
+++ b/test/perf/linux/send_recv_benchmark.cc
@@ -0,0 +1,372 @@
+// Copyright 2020 The gVisor Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+#include <netinet/in.h>
+#include <netinet/tcp.h>
+#include <poll.h>
+#include <sys/ioctl.h>
+#include <sys/socket.h>
+
+#include <cstring>
+
+#include "gtest/gtest.h"
+#include "absl/synchronization/notification.h"
+#include "benchmark/benchmark.h"
+#include "test/syscalls/linux/socket_test_util.h"
+#include "test/util/file_descriptor.h"
+#include "test/util/logging.h"
+#include "test/util/posix_error.h"
+#include "test/util/test_util.h"
+#include "test/util/thread_util.h"
+
+namespace gvisor {
+namespace testing {
+
+namespace {
+
+constexpr ssize_t kMessageSize = 1024;
+
+class Message {
+ public:
+ explicit Message(int byte = 0) : Message(byte, kMessageSize, 0) {}
+
+ explicit Message(int byte, int sz) : Message(byte, sz, 0) {}
+
+ explicit Message(int byte, int sz, int cmsg_sz)
+ : buffer_(sz, byte), cmsg_buffer_(cmsg_sz, 0) {
+ iov_.iov_base = buffer_.data();
+ iov_.iov_len = sz;
+ hdr_.msg_iov = &iov_;
+ hdr_.msg_iovlen = 1;
+ hdr_.msg_control = cmsg_buffer_.data();
+ hdr_.msg_controllen = cmsg_sz;
+ }
+
+ struct msghdr* header() {
+ return &hdr_;
+ }
+
+ private:
+ std::vector<char> buffer_;
+ std::vector<char> cmsg_buffer_;
+ struct iovec iov_ = {};
+ struct msghdr hdr_ = {};
+};
+
+void BM_Recvmsg(benchmark::State& state) {
+ int sockets[2];
+ TEST_CHECK(socketpair(AF_UNIX, SOCK_STREAM, 0, sockets) == 0);
+ FileDescriptor send_socket(sockets[0]), recv_socket(sockets[1]);
+ absl::Notification notification;
+ Message send_msg('a'), recv_msg;
+
+ ScopedThread t([&send_msg, &send_socket, &notification] {
+ while (!notification.HasBeenNotified()) {
+ sendmsg(send_socket.get(), send_msg.header(), 0);
+ }
+ });
+
+ int64_t bytes_received = 0;
+ for (auto ignored : state) {
+ int n = recvmsg(recv_socket.get(), recv_msg.header(), 0);
+ TEST_CHECK(n > 0);
+ bytes_received += n;
+ }
+
+ notification.Notify();
+ recv_socket.reset();
+
+ state.SetBytesProcessed(bytes_received);
+}
+
+BENCHMARK(BM_Recvmsg)->UseRealTime();
+
+void BM_Sendmsg(benchmark::State& state) {
+ int sockets[2];
+ TEST_CHECK(socketpair(AF_UNIX, SOCK_STREAM, 0, sockets) == 0);
+ FileDescriptor send_socket(sockets[0]), recv_socket(sockets[1]);
+ absl::Notification notification;
+ Message send_msg('a'), recv_msg;
+
+ ScopedThread t([&recv_msg, &recv_socket, &notification] {
+ while (!notification.HasBeenNotified()) {
+ recvmsg(recv_socket.get(), recv_msg.header(), 0);
+ }
+ });
+
+ int64_t bytes_sent = 0;
+ for (auto ignored : state) {
+ int n = sendmsg(send_socket.get(), send_msg.header(), 0);
+ TEST_CHECK(n > 0);
+ bytes_sent += n;
+ }
+
+ notification.Notify();
+ send_socket.reset();
+
+ state.SetBytesProcessed(bytes_sent);
+}
+
+BENCHMARK(BM_Sendmsg)->UseRealTime();
+
+void BM_Recvfrom(benchmark::State& state) {
+ int sockets[2];
+ TEST_CHECK(socketpair(AF_UNIX, SOCK_STREAM, 0, sockets) == 0);
+ FileDescriptor send_socket(sockets[0]), recv_socket(sockets[1]);
+ absl::Notification notification;
+ char send_buffer[kMessageSize], recv_buffer[kMessageSize];
+
+ ScopedThread t([&send_socket, &send_buffer, &notification] {
+ while (!notification.HasBeenNotified()) {
+ sendto(send_socket.get(), send_buffer, kMessageSize, 0, nullptr, 0);
+ }
+ });
+
+ int bytes_received = 0;
+ for (auto ignored : state) {
+ int n = recvfrom(recv_socket.get(), recv_buffer, kMessageSize, 0, nullptr,
+ nullptr);
+ TEST_CHECK(n > 0);
+ bytes_received += n;
+ }
+
+ notification.Notify();
+ recv_socket.reset();
+
+ state.SetBytesProcessed(bytes_received);
+}
+
+BENCHMARK(BM_Recvfrom)->UseRealTime();
+
+void BM_Sendto(benchmark::State& state) {
+ int sockets[2];
+ TEST_CHECK(socketpair(AF_UNIX, SOCK_STREAM, 0, sockets) == 0);
+ FileDescriptor send_socket(sockets[0]), recv_socket(sockets[1]);
+ absl::Notification notification;
+ char send_buffer[kMessageSize], recv_buffer[kMessageSize];
+
+ ScopedThread t([&recv_socket, &recv_buffer, &notification] {
+ while (!notification.HasBeenNotified()) {
+ recvfrom(recv_socket.get(), recv_buffer, kMessageSize, 0, nullptr,
+ nullptr);
+ }
+ });
+
+ int64_t bytes_sent = 0;
+ for (auto ignored : state) {
+ int n = sendto(send_socket.get(), send_buffer, kMessageSize, 0, nullptr, 0);
+ TEST_CHECK(n > 0);
+ bytes_sent += n;
+ }
+
+ notification.Notify();
+ send_socket.reset();
+
+ state.SetBytesProcessed(bytes_sent);
+}
+
+BENCHMARK(BM_Sendto)->UseRealTime();
+
+PosixErrorOr<sockaddr_storage> InetLoopbackAddr(int family) {
+ struct sockaddr_storage addr;
+ memset(&addr, 0, sizeof(addr));
+ addr.ss_family = family;
+ switch (family) {
+ case AF_INET:
+ reinterpret_cast<struct sockaddr_in*>(&addr)->sin_addr.s_addr =
+ htonl(INADDR_LOOPBACK);
+ break;
+ case AF_INET6:
+ reinterpret_cast<struct sockaddr_in6*>(&addr)->sin6_addr =
+ in6addr_loopback;
+ break;
+ default:
+ return PosixError(EINVAL,
+ absl::StrCat("unknown socket family: ", family));
+ }
+ return addr;
+}
+
+// BM_RecvmsgWithControlBuf measures the performance of recvmsg when we allocate
+// space for control messages. Note that we do not expect to receive any.
+void BM_RecvmsgWithControlBuf(benchmark::State& state) {
+ auto listen_socket =
+ ASSERT_NO_ERRNO_AND_VALUE(Socket(AF_INET6, SOCK_STREAM, IPPROTO_TCP));
+
+ // Initialize address to the loopback one.
+ sockaddr_storage addr = ASSERT_NO_ERRNO_AND_VALUE(InetLoopbackAddr(AF_INET6));
+ socklen_t addrlen = sizeof(addr);
+
+ // Bind to some port then start listening.
+ ASSERT_THAT(bind(listen_socket.get(),
+ reinterpret_cast<struct sockaddr*>(&addr), addrlen),
+ SyscallSucceeds());
+
+ ASSERT_THAT(listen(listen_socket.get(), SOMAXCONN), SyscallSucceeds());
+
+ // Get the address we're listening on, then connect to it. We need to do this
+ // because we're allowing the stack to pick a port for us.
+ ASSERT_THAT(getsockname(listen_socket.get(),
+ reinterpret_cast<struct sockaddr*>(&addr), &addrlen),
+ SyscallSucceeds());
+
+ auto send_socket =
+ ASSERT_NO_ERRNO_AND_VALUE(Socket(AF_INET6, SOCK_STREAM, IPPROTO_TCP));
+
+ ASSERT_THAT(
+ RetryEINTR(connect)(send_socket.get(),
+ reinterpret_cast<struct sockaddr*>(&addr), addrlen),
+ SyscallSucceeds());
+
+ // Accept the connection.
+ auto recv_socket =
+ ASSERT_NO_ERRNO_AND_VALUE(Accept(listen_socket.get(), nullptr, nullptr));
+
+ absl::Notification notification;
+ Message send_msg('a');
+ // Create a msghdr with a buffer allocated for control messages.
+ Message recv_msg(0, kMessageSize, /*cmsg_sz=*/24);
+
+ ScopedThread t([&send_msg, &send_socket, &notification] {
+ while (!notification.HasBeenNotified()) {
+ sendmsg(send_socket.get(), send_msg.header(), 0);
+ }
+ });
+
+ int64_t bytes_received = 0;
+ for (auto ignored : state) {
+ int n = recvmsg(recv_socket.get(), recv_msg.header(), 0);
+ TEST_CHECK(n > 0);
+ bytes_received += n;
+ }
+
+ notification.Notify();
+ recv_socket.reset();
+
+ state.SetBytesProcessed(bytes_received);
+}
+
+BENCHMARK(BM_RecvmsgWithControlBuf)->UseRealTime();
+
+// BM_SendmsgTCP measures the sendmsg throughput with varying payload sizes.
+//
+// state.Args[0] indicates whether the underlying socket should be blocking or
+// non-blocking w/ 0 indicating non-blocking and 1 to indicate blocking.
+// state.Args[1] is the size of the payload to be used per sendmsg call.
+void BM_SendmsgTCP(benchmark::State& state) {
+ auto listen_socket =
+ ASSERT_NO_ERRNO_AND_VALUE(Socket(AF_INET, SOCK_STREAM, IPPROTO_TCP));
+
+ // Initialize address to the loopback one.
+ sockaddr_storage addr = ASSERT_NO_ERRNO_AND_VALUE(InetLoopbackAddr(AF_INET));
+ socklen_t addrlen = sizeof(addr);
+
+ // Bind to some port then start listening.
+ ASSERT_THAT(bind(listen_socket.get(),
+ reinterpret_cast<struct sockaddr*>(&addr), addrlen),
+ SyscallSucceeds());
+
+ ASSERT_THAT(listen(listen_socket.get(), SOMAXCONN), SyscallSucceeds());
+
+ // Get the address we're listening on, then connect to it. We need to do this
+ // because we're allowing the stack to pick a port for us.
+ ASSERT_THAT(getsockname(listen_socket.get(),
+ reinterpret_cast<struct sockaddr*>(&addr), &addrlen),
+ SyscallSucceeds());
+
+ auto send_socket =
+ ASSERT_NO_ERRNO_AND_VALUE(Socket(AF_INET, SOCK_STREAM, IPPROTO_TCP));
+
+ ASSERT_THAT(
+ RetryEINTR(connect)(send_socket.get(),
+ reinterpret_cast<struct sockaddr*>(&addr), addrlen),
+ SyscallSucceeds());
+
+ // Accept the connection.
+ auto recv_socket =
+ ASSERT_NO_ERRNO_AND_VALUE(Accept(listen_socket.get(), nullptr, nullptr));
+
+ // Check if we want to run the test w/ a blocking send socket
+ // or non-blocking.
+ const int blocking = state.range(0);
+ if (!blocking) {
+ // Set the send FD to O_NONBLOCK.
+ int opts;
+ ASSERT_THAT(opts = fcntl(send_socket.get(), F_GETFL), SyscallSucceeds());
+ opts |= O_NONBLOCK;
+ ASSERT_THAT(fcntl(send_socket.get(), F_SETFL, opts), SyscallSucceeds());
+ }
+
+ absl::Notification notification;
+
+ // Get the buffer size we should use for this iteration of the test.
+ const int buf_size = state.range(1);
+ Message send_msg('a', buf_size), recv_msg(0, buf_size);
+
+ ScopedThread t([&recv_msg, &recv_socket, &notification] {
+ while (!notification.HasBeenNotified()) {
+ TEST_CHECK(recvmsg(recv_socket.get(), recv_msg.header(), 0) >= 0);
+ }
+ });
+
+ int64_t bytes_sent = 0;
+ int ncalls = 0;
+ for (auto ignored : state) {
+ int sent = 0;
+ while (true) {
+ struct msghdr hdr = {};
+ struct iovec iov = {};
+ struct msghdr* snd_header = send_msg.header();
+ iov.iov_base = static_cast<char*>(snd_header->msg_iov->iov_base) + sent;
+ iov.iov_len = snd_header->msg_iov->iov_len - sent;
+ hdr.msg_iov = &iov;
+ hdr.msg_iovlen = 1;
+ int n = RetryEINTR(sendmsg)(send_socket.get(), &hdr, 0);
+ ncalls++;
+ if (n > 0) {
+ sent += n;
+ if (sent == buf_size) {
+ break;
+ }
+ // n can be > 0 but less than requested size. In which case we don't
+ // poll.
+ continue;
+ }
+ // Poll the fd for it to become writable.
+ struct pollfd poll_fd = {send_socket.get(), POLL_OUT, 0};
+ EXPECT_THAT(RetryEINTR(poll)(&poll_fd, 1, 10),
+ SyscallSucceedsWithValue(0));
+ }
+ bytes_sent += static_cast<int64_t>(sent);
+ }
+
+ notification.Notify();
+ send_socket.reset();
+ state.SetBytesProcessed(bytes_sent);
+}
+
+void Args(benchmark::internal::Benchmark* benchmark) {
+ for (int blocking = 0; blocking < 2; blocking++) {
+ for (int buf_size = 1024; buf_size <= 256 << 20; buf_size *= 2) {
+ benchmark->Args({blocking, buf_size});
+ }
+ }
+}
+
+BENCHMARK(BM_SendmsgTCP)->Apply(&Args)->UseRealTime();
+
+} // namespace
+
+} // namespace testing
+} // namespace gvisor
diff --git a/test/perf/linux/seqwrite_benchmark.cc b/test/perf/linux/seqwrite_benchmark.cc
new file mode 100644
index 000000000..af49e4477
--- /dev/null
+++ b/test/perf/linux/seqwrite_benchmark.cc
@@ -0,0 +1,66 @@
+// Copyright 2020 The gVisor Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+#include <fcntl.h>
+#include <stdlib.h>
+#include <sys/stat.h>
+#include <unistd.h>
+
+#include "gtest/gtest.h"
+#include "benchmark/benchmark.h"
+#include "test/util/logging.h"
+#include "test/util/temp_path.h"
+#include "test/util/test_util.h"
+
+namespace gvisor {
+namespace testing {
+
+namespace {
+
+// The maximum file size of the test file, when writes get beyond this point
+// they wrap around. This should be large enough to blow away caches.
+const uint64_t kMaxFile = 1 << 30;
+
+// Perform writes of various sizes sequentially to one file. Wraps around if it
+// goes above a certain maximum file size.
+void BM_SeqWrite(benchmark::State& state) {
+ auto f = ASSERT_NO_ERRNO_AND_VALUE(TempPath::CreateFile());
+ FileDescriptor fd = ASSERT_NO_ERRNO_AND_VALUE(Open(f.path(), O_WRONLY));
+
+ const int size = state.range(0);
+ std::vector<char> buf(size);
+ RandomizeBuffer(buf.data(), buf.size());
+
+ // Start writes at offset 0.
+ uint64_t offset = 0;
+ for (auto _ : state) {
+ TEST_CHECK(PwriteFd(fd.get(), buf.data(), buf.size(), offset) ==
+ buf.size());
+ offset += buf.size();
+ // Wrap around if going above the maximum file size.
+ if (offset >= kMaxFile) {
+ offset = 0;
+ }
+ }
+
+ state.SetBytesProcessed(static_cast<int64_t>(size) *
+ static_cast<int64_t>(state.iterations()));
+}
+
+BENCHMARK(BM_SeqWrite)->Range(1, 1 << 26)->UseRealTime();
+
+} // namespace
+
+} // namespace testing
+} // namespace gvisor
diff --git a/test/perf/linux/signal_benchmark.cc b/test/perf/linux/signal_benchmark.cc
new file mode 100644
index 000000000..a6928df58
--- /dev/null
+++ b/test/perf/linux/signal_benchmark.cc
@@ -0,0 +1,59 @@
+// Copyright 2020 The gVisor Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+#include <signal.h>
+#include <string.h>
+
+#include "gtest/gtest.h"
+#include "benchmark/benchmark.h"
+#include "test/util/logging.h"
+#include "test/util/test_util.h"
+
+namespace gvisor {
+namespace testing {
+
+namespace {
+
+void FixupHandler(int sig, siginfo_t* si, void* void_ctx) {
+ static unsigned int dataval = 0;
+
+ // Skip the offending instruction.
+ ucontext_t* ctx = reinterpret_cast<ucontext_t*>(void_ctx);
+ ctx->uc_mcontext.gregs[REG_RAX] = reinterpret_cast<greg_t>(&dataval);
+}
+
+void BM_FaultSignalFixup(benchmark::State& state) {
+ // Set up the signal handler.
+ struct sigaction sa = {};
+ sigemptyset(&sa.sa_mask);
+ sa.sa_sigaction = FixupHandler;
+ sa.sa_flags = SA_SIGINFO;
+ TEST_CHECK(sigaction(SIGSEGV, &sa, nullptr) == 0);
+
+ // Fault, fault, fault.
+ for (auto _ : state) {
+ register volatile unsigned int* ptr asm("rax");
+
+ // Trigger the segfault.
+ ptr = nullptr;
+ *ptr = 0;
+ }
+}
+
+BENCHMARK(BM_FaultSignalFixup)->UseRealTime();
+
+} // namespace
+
+} // namespace testing
+} // namespace gvisor
diff --git a/test/perf/linux/sleep_benchmark.cc b/test/perf/linux/sleep_benchmark.cc
new file mode 100644
index 000000000..99ef05117
--- /dev/null
+++ b/test/perf/linux/sleep_benchmark.cc
@@ -0,0 +1,60 @@
+// Copyright 2020 The gVisor Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+#include <errno.h>
+#include <sys/syscall.h>
+#include <time.h>
+#include <unistd.h>
+
+#include "gtest/gtest.h"
+#include "benchmark/benchmark.h"
+#include "test/util/logging.h"
+
+namespace gvisor {
+namespace testing {
+
+namespace {
+
+// Sleep for 'param' nanoseconds.
+void BM_Sleep(benchmark::State& state) {
+ const int nanoseconds = state.range(0);
+
+ for (auto _ : state) {
+ struct timespec ts;
+ ts.tv_sec = 0;
+ ts.tv_nsec = nanoseconds;
+
+ int ret;
+ do {
+ ret = syscall(SYS_nanosleep, &ts, &ts);
+ if (ret < 0) {
+ TEST_CHECK(errno == EINTR);
+ }
+ } while (ret < 0);
+ }
+}
+
+BENCHMARK(BM_Sleep)
+ ->Arg(0)
+ ->Arg(1)
+ ->Arg(1000) // 1us
+ ->Arg(1000 * 1000) // 1ms
+ ->Arg(10 * 1000 * 1000) // 10ms
+ ->Arg(50 * 1000 * 1000) // 50ms
+ ->UseRealTime();
+
+} // namespace
+
+} // namespace testing
+} // namespace gvisor
diff --git a/test/perf/linux/stat_benchmark.cc b/test/perf/linux/stat_benchmark.cc
new file mode 100644
index 000000000..f15424482
--- /dev/null
+++ b/test/perf/linux/stat_benchmark.cc
@@ -0,0 +1,62 @@
+// Copyright 2020 The gVisor Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+#include <sys/stat.h>
+#include <sys/types.h>
+#include <unistd.h>
+
+#include "gtest/gtest.h"
+#include "absl/strings/str_cat.h"
+#include "benchmark/benchmark.h"
+#include "test/util/fs_util.h"
+#include "test/util/temp_path.h"
+#include "test/util/test_util.h"
+
+namespace gvisor {
+namespace testing {
+
+namespace {
+
+// Creates a file in a nested directory hierarchy at least `depth` directories
+// deep, and stats that file multiple times.
+void BM_Stat(benchmark::State& state) {
+ // Create nested directories with given depth.
+ int depth = state.range(0);
+ const TempPath top_dir = ASSERT_NO_ERRNO_AND_VALUE(TempPath::CreateDir());
+ std::string dir_path = top_dir.path();
+
+ while (depth-- > 0) {
+ // Don't use TempPath because it will make paths too long to use.
+ //
+ // The top_dir destructor will clean up this whole tree.
+ dir_path = JoinPath(dir_path, absl::StrCat(depth));
+ ASSERT_NO_ERRNO(Mkdir(dir_path, 0755));
+ }
+
+ // Create the file that will be stat'd.
+ const TempPath file =
+ ASSERT_NO_ERRNO_AND_VALUE(TempPath::CreateFileIn(dir_path));
+
+ struct stat st;
+ for (auto _ : state) {
+ ASSERT_THAT(stat(file.path().c_str(), &st), SyscallSucceeds());
+ }
+}
+
+BENCHMARK(BM_Stat)->Range(1, 100)->UseRealTime();
+
+} // namespace
+
+} // namespace testing
+} // namespace gvisor
diff --git a/test/perf/linux/unlink_benchmark.cc b/test/perf/linux/unlink_benchmark.cc
new file mode 100644
index 000000000..92243a042
--- /dev/null
+++ b/test/perf/linux/unlink_benchmark.cc
@@ -0,0 +1,66 @@
+// Copyright 2020 The gVisor Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+#include <sys/stat.h>
+#include <sys/types.h>
+#include <unistd.h>
+
+#include "gtest/gtest.h"
+#include "benchmark/benchmark.h"
+#include "test/util/fs_util.h"
+#include "test/util/temp_path.h"
+#include "test/util/test_util.h"
+
+namespace gvisor {
+namespace testing {
+
+namespace {
+
+// Creates a directory containing `files` files, and unlinks all the files.
+void BM_Unlink(benchmark::State& state) {
+ // Create directory with given files.
+ const int file_count = state.range(0);
+
+ // We unlink all files on each iteration, but report this as a "batch"
+ // iteration so that reported times are per file.
+ TempPath dir;
+ while (state.KeepRunningBatch(file_count)) {
+ state.PauseTiming();
+ // N.B. dir is declared outside the loop so that destruction of the previous
+ // iteration's directory occurs here, inside of PauseTiming.
+ dir = ASSERT_NO_ERRNO_AND_VALUE(TempPath::CreateDir());
+
+ std::vector<TempPath> files;
+ for (int i = 0; i < file_count; i++) {
+ TempPath file =
+ ASSERT_NO_ERRNO_AND_VALUE(TempPath::CreateFileIn(dir.path()));
+ files.push_back(std::move(file));
+ }
+ state.ResumeTiming();
+
+ while (!files.empty()) {
+ // Destructor unlinks.
+ files.pop_back();
+ }
+ }
+
+ state.SetItemsProcessed(state.iterations());
+}
+
+BENCHMARK(BM_Unlink)->Range(1, 100 * 1000)->UseRealTime();
+
+} // namespace
+
+} // namespace testing
+} // namespace gvisor
diff --git a/test/perf/linux/write_benchmark.cc b/test/perf/linux/write_benchmark.cc
new file mode 100644
index 000000000..7b060c70e
--- /dev/null
+++ b/test/perf/linux/write_benchmark.cc
@@ -0,0 +1,52 @@
+// Copyright 2020 The gVisor Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+#include <fcntl.h>
+#include <stdlib.h>
+#include <sys/stat.h>
+#include <unistd.h>
+
+#include "gtest/gtest.h"
+#include "benchmark/benchmark.h"
+#include "test/util/logging.h"
+#include "test/util/temp_path.h"
+#include "test/util/test_util.h"
+
+namespace gvisor {
+namespace testing {
+
+namespace {
+
+void BM_Write(benchmark::State& state) {
+ auto f = ASSERT_NO_ERRNO_AND_VALUE(TempPath::CreateFile());
+ FileDescriptor fd = ASSERT_NO_ERRNO_AND_VALUE(Open(f.path(), O_WRONLY));
+
+ const int size = state.range(0);
+ std::vector<char> buf(size);
+ RandomizeBuffer(buf.data(), size);
+
+ for (auto _ : state) {
+ TEST_CHECK(PwriteFd(fd.get(), buf.data(), size, 0) == size);
+ }
+
+ state.SetBytesProcessed(static_cast<int64_t>(size) *
+ static_cast<int64_t>(state.iterations()));
+}
+
+BENCHMARK(BM_Write)->Range(1, 1 << 26)->UseRealTime();
+
+} // namespace
+
+} // namespace testing
+} // namespace gvisor
diff --git a/test/runner/BUILD b/test/runner/BUILD
new file mode 100644
index 000000000..9959ef9b0
--- /dev/null
+++ b/test/runner/BUILD
@@ -0,0 +1,22 @@
+load("//tools:defs.bzl", "go_binary")
+
+package(licenses = ["notice"])
+
+go_binary(
+ name = "runner",
+ testonly = 1,
+ srcs = ["runner.go"],
+ data = [
+ "//runsc",
+ ],
+ visibility = ["//:sandbox"],
+ deps = [
+ "//pkg/log",
+ "//runsc/specutils",
+ "//runsc/testutil",
+ "//test/runner/gtest",
+ "//test/uds",
+ "@com_github_opencontainers_runtime-spec//specs-go:go_default_library",
+ "@org_golang_x_sys//unix:go_default_library",
+ ],
+)
diff --git a/test/syscalls/build_defs.bzl b/test/runner/defs.bzl
index cbab85ef7..56743a526 100644
--- a/test/syscalls/build_defs.bzl
+++ b/test/runner/defs.bzl
@@ -1,6 +1,119 @@
"""Defines a rule for syscall test targets."""
-load("//tools:defs.bzl", "loopback")
+load("//tools:defs.bzl", "default_platform", "loopback", "platforms")
+
+def _runner_test_impl(ctx):
+ # Generate a runner binary.
+ runner = ctx.actions.declare_file("%s-runner" % ctx.label.name)
+ runner_content = "\n".join([
+ "#!/bin/bash",
+ "set -euf -x -o pipefail",
+ "if [[ -n \"${TEST_UNDECLARED_OUTPUTS_DIR}\" ]]; then",
+ " mkdir -p \"${TEST_UNDECLARED_OUTPUTS_DIR}\"",
+ " chmod a+rwx \"${TEST_UNDECLARED_OUTPUTS_DIR}\"",
+ "fi",
+ "exec %s %s %s\n" % (
+ ctx.files.runner[0].short_path,
+ " ".join(ctx.attr.runner_args),
+ ctx.files.test[0].short_path,
+ ),
+ ])
+ ctx.actions.write(runner, runner_content, is_executable = True)
+
+ # Return with all transitive files.
+ runfiles = ctx.runfiles(
+ transitive_files = depset(transitive = [
+ depset(target.data_runfiles.files)
+ for target in (ctx.attr.runner, ctx.attr.test)
+ if hasattr(target, "data_runfiles")
+ ]),
+ files = ctx.files.runner + ctx.files.test,
+ collect_default = True,
+ collect_data = True,
+ )
+ return [DefaultInfo(executable = runner, runfiles = runfiles)]
+
+_runner_test = rule(
+ attrs = {
+ "runner": attr.label(
+ default = "//test/runner:runner",
+ ),
+ "test": attr.label(
+ mandatory = True,
+ ),
+ "runner_args": attr.string_list(),
+ "data": attr.label_list(
+ allow_files = True,
+ ),
+ },
+ test = True,
+ implementation = _runner_test_impl,
+)
+
+def _syscall_test(
+ test,
+ shard_count,
+ size,
+ platform,
+ use_tmpfs,
+ tags,
+ network = "none",
+ file_access = "exclusive",
+ overlay = False,
+ add_uds_tree = False):
+ # Prepend "runsc" to non-native platform names.
+ full_platform = platform if platform == "native" else "runsc_" + platform
+
+ # Name the test appropriately.
+ name = test.split(":")[1] + "_" + full_platform
+ if file_access == "shared":
+ name += "_shared"
+ if overlay:
+ name += "_overlay"
+ if network != "none":
+ name += "_" + network + "net"
+
+ # Apply all tags.
+ if tags == None:
+ tags = []
+
+ # Add the full_platform and file access in a tag to make it easier to run
+ # all the tests on a specific flavor. Use --test_tag_filters=ptrace,file_shared.
+ tags += [full_platform, "file_" + file_access]
+
+ # Hash this target into one of 15 buckets. This can be used to
+ # randomly split targets between different workflows.
+ hash15 = hash(native.package_name() + name) % 15
+ tags.append("hash15:" + str(hash15))
+
+ # TODO(b/139838000): Tests using hostinet must be disabled on Guitar until
+ # we figure out how to request ipv4 sockets on Guitar machines.
+ if network == "host":
+ tags.append("noguitar")
+
+ # Disable off-host networking.
+ tags.append("requires-net:loopback")
+
+ runner_args = [
+ # Arguments are passed directly to runner binary.
+ "--platform=" + platform,
+ "--network=" + network,
+ "--use-tmpfs=" + str(use_tmpfs),
+ "--file-access=" + file_access,
+ "--overlay=" + str(overlay),
+ "--add-uds-tree=" + str(add_uds_tree),
+ ]
+
+ # Call the rule above.
+ _runner_test(
+ name = name,
+ test = test,
+ runner_args = runner_args,
+ data = [loopback],
+ size = size,
+ tags = tags,
+ shard_count = shard_count,
+ )
def syscall_test(
test,
@@ -23,6 +136,8 @@ def syscall_test(
add_hostinet: add a hostinet test.
tags: starting test tags.
"""
+ if not tags:
+ tags = []
_syscall_test(
test = test,
@@ -34,35 +149,26 @@ def syscall_test(
tags = tags,
)
- _syscall_test(
- test = test,
- shard_count = shard_count,
- size = size,
- platform = "kvm",
- use_tmpfs = use_tmpfs,
- add_uds_tree = add_uds_tree,
- tags = tags,
- )
-
- _syscall_test(
- test = test,
- shard_count = shard_count,
- size = size,
- platform = "ptrace",
- use_tmpfs = use_tmpfs,
- add_uds_tree = add_uds_tree,
- tags = tags,
- )
+ for (platform, platform_tags) in platforms.items():
+ _syscall_test(
+ test = test,
+ shard_count = shard_count,
+ size = size,
+ platform = platform,
+ use_tmpfs = use_tmpfs,
+ add_uds_tree = add_uds_tree,
+ tags = platform_tags + tags,
+ )
if add_overlay:
_syscall_test(
test = test,
shard_count = shard_count,
size = size,
- platform = "ptrace",
+ platform = default_platform,
use_tmpfs = False, # overlay is adding a writable tmpfs on top of root.
add_uds_tree = add_uds_tree,
- tags = tags,
+ tags = platforms[default_platform] + tags,
overlay = True,
)
@@ -72,10 +178,10 @@ def syscall_test(
test = test,
shard_count = shard_count,
size = size,
- platform = "ptrace",
+ platform = default_platform,
use_tmpfs = use_tmpfs,
add_uds_tree = add_uds_tree,
- tags = tags,
+ tags = platforms[default_platform] + tags,
file_access = "shared",
)
@@ -84,97 +190,9 @@ def syscall_test(
test = test,
shard_count = shard_count,
size = size,
- platform = "ptrace",
+ platform = default_platform,
use_tmpfs = use_tmpfs,
network = "host",
add_uds_tree = add_uds_tree,
- tags = tags,
+ tags = platforms[default_platform] + tags,
)
-
-def _syscall_test(
- test,
- shard_count,
- size,
- platform,
- use_tmpfs,
- tags,
- network = "none",
- file_access = "exclusive",
- overlay = False,
- add_uds_tree = False):
- test_name = test.split(":")[1]
-
- # Prepend "runsc" to non-native platform names.
- full_platform = platform if platform == "native" else "runsc_" + platform
-
- name = test_name + "_" + full_platform
- if file_access == "shared":
- name += "_shared"
- if overlay:
- name += "_overlay"
- if network != "none":
- name += "_" + network + "net"
-
- if tags == None:
- tags = []
-
- # Add the full_platform and file access in a tag to make it easier to run
- # all the tests on a specific flavor. Use --test_tag_filters=ptrace,file_shared.
- tags += [full_platform, "file_" + file_access]
-
- # Hash this target into one of 15 buckets. This can be used to
- # randomly split targets between different workflows.
- hash15 = hash(native.package_name() + name) % 15
- tags.append("hash15:" + str(hash15))
-
- # TODO(b/139838000): Tests using hostinet must be disabled on Guitar until
- # we figure out how to request ipv4 sockets on Guitar machines.
- if network == "host":
- tags.append("noguitar")
-
- # Disable off-host networking.
- tags.append("requires-net:loopback")
-
- # Add tag to prevent the tests from running in a Bazel sandbox.
- # TODO(b/120560048): Make the tests run without this tag.
- tags.append("no-sandbox")
-
- # TODO(b/112165693): KVM tests are tagged "manual" to until the platform is
- # more stable.
- if platform == "kvm":
- tags.append("manual")
- tags.append("requires-kvm")
-
- # TODO(b/112165693): Remove when tests pass reliably.
- tags.append("notap")
-
- args = [
- # Arguments are passed directly to syscall_test_runner binary.
- "--test-name=" + test_name,
- "--platform=" + platform,
- "--network=" + network,
- "--use-tmpfs=" + str(use_tmpfs),
- "--file-access=" + file_access,
- "--overlay=" + str(overlay),
- "--add-uds-tree=" + str(add_uds_tree),
- ]
-
- sh_test(
- srcs = ["syscall_test_runner.sh"],
- name = name,
- data = [
- ":syscall_test_runner",
- loopback,
- test,
- ],
- args = args,
- size = size,
- tags = tags,
- shard_count = shard_count,
- )
-
-def sh_test(**kwargs):
- """Wraps the standard sh_test."""
- native.sh_test(
- **kwargs
- )
diff --git a/test/syscalls/gtest/BUILD b/test/runner/gtest/BUILD
index de4b2727c..de4b2727c 100644
--- a/test/syscalls/gtest/BUILD
+++ b/test/runner/gtest/BUILD
diff --git a/test/runner/gtest/gtest.go b/test/runner/gtest/gtest.go
new file mode 100644
index 000000000..f96e2415e
--- /dev/null
+++ b/test/runner/gtest/gtest.go
@@ -0,0 +1,167 @@
+// Copyright 2018 The gVisor Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+// Package gtest contains helpers for running google-test tests from Go.
+package gtest
+
+import (
+ "fmt"
+ "os/exec"
+ "strings"
+)
+
+var (
+ // listTestFlag is the flag that will list tests in gtest binaries.
+ listTestFlag = "--gtest_list_tests"
+
+ // filterTestFlag is the flag that will filter tests in gtest binaries.
+ filterTestFlag = "--gtest_filter"
+
+ // listBechmarkFlag is the flag that will list benchmarks in gtest binaries.
+ listBenchmarkFlag = "--benchmark_list_tests"
+
+ // filterBenchmarkFlag is the flag that will run specified benchmarks.
+ filterBenchmarkFlag = "--benchmark_filter"
+)
+
+// TestCase is a single gtest test case.
+type TestCase struct {
+ // Suite is the suite for this test.
+ Suite string
+
+ // Name is the name of this individual test.
+ Name string
+
+ // all indicates that this will run without flags. This takes
+ // precendence over benchmark below.
+ all bool
+
+ // benchmark indicates that this is a benchmark. In this case, the
+ // suite will be empty, and we will use the appropriate test and
+ // benchmark flags.
+ benchmark bool
+}
+
+// FullName returns the name of the test including the suite. It is suitable to
+// pass to "-gtest_filter".
+func (tc TestCase) FullName() string {
+ return fmt.Sprintf("%s.%s", tc.Suite, tc.Name)
+}
+
+// Args returns arguments to be passed when invoking the test.
+func (tc TestCase) Args() []string {
+ if tc.all {
+ return []string{} // No arguments.
+ }
+ if tc.benchmark {
+ return []string{
+ fmt.Sprintf("%s=^$", filterTestFlag),
+ fmt.Sprintf("%s=^%s$", filterBenchmarkFlag, tc.Name),
+ }
+ }
+ return []string{
+ fmt.Sprintf("%s=^%s$", filterTestFlag, tc.FullName()),
+ fmt.Sprintf("%s=^$", filterBenchmarkFlag),
+ }
+}
+
+// ParseTestCases calls a gtest test binary to list its test and returns a
+// slice with the name and suite of each test.
+//
+// If benchmarks is true, then benchmarks will be included in the list of test
+// cases provided. Note that this requires the binary to support the
+// benchmarks_list_tests flag.
+func ParseTestCases(testBin string, benchmarks bool, extraArgs ...string) ([]TestCase, error) {
+ // Run to extract test cases.
+ args := append([]string{listTestFlag}, extraArgs...)
+ cmd := exec.Command(testBin, args...)
+ out, err := cmd.Output()
+ if err != nil {
+ // We failed to list tests with the given flags. Just
+ // return something that will run the binary with no
+ // flags, which should execute all tests.
+ return []TestCase{
+ TestCase{
+ Suite: "Default",
+ Name: "All",
+ all: true,
+ },
+ }, nil
+ }
+
+ // Parse test output.
+ var t []TestCase
+ var suite string
+ for _, line := range strings.Split(string(out), "\n") {
+ // Strip comments.
+ line = strings.Split(line, "#")[0]
+
+ // New suite?
+ if !strings.HasPrefix(line, " ") {
+ suite = strings.TrimSuffix(strings.TrimSpace(line), ".")
+ continue
+ }
+
+ // Individual test.
+ name := strings.TrimSpace(line)
+
+ // Do we have a suite yet?
+ if suite == "" {
+ return nil, fmt.Errorf("test without a suite: %v", name)
+ }
+
+ // Add this individual test.
+ t = append(t, TestCase{
+ Suite: suite,
+ Name: name,
+ })
+ }
+
+ // Finished?
+ if !benchmarks {
+ return t, nil
+ }
+
+ // Run again to extract benchmarks.
+ args = append([]string{listBenchmarkFlag}, extraArgs...)
+ cmd = exec.Command(testBin, args...)
+ out, err = cmd.Output()
+ if err != nil {
+ // We were able to enumerate tests above, but not benchmarks?
+ // We requested them, so we return an error in this case.
+ exitErr, ok := err.(*exec.ExitError)
+ if !ok {
+ return nil, fmt.Errorf("could not enumerate gtest benchmarks: %v", err)
+ }
+ return nil, fmt.Errorf("could not enumerate gtest benchmarks: %v\nstderr\n%s", err, exitErr.Stderr)
+ }
+
+ // Parse benchmark output.
+ for _, line := range strings.Split(string(out), "\n") {
+ // Strip comments.
+ line = strings.Split(line, "#")[0]
+
+ // Single benchmark.
+ name := strings.TrimSpace(line)
+
+ // Add the single benchmark.
+ t = append(t, TestCase{
+ Suite: "Benchmarks",
+ Name: name,
+ benchmark: true,
+ })
+ }
+
+ return t, nil
+}
diff --git a/test/syscalls/syscall_test_runner.go b/test/runner/runner.go
index ae342b68c..a78ef38e0 100644
--- a/test/syscalls/syscall_test_runner.go
+++ b/test/runner/runner.go
@@ -34,15 +34,11 @@ import (
"gvisor.dev/gvisor/pkg/log"
"gvisor.dev/gvisor/runsc/specutils"
"gvisor.dev/gvisor/runsc/testutil"
- "gvisor.dev/gvisor/test/syscalls/gtest"
+ "gvisor.dev/gvisor/test/runner/gtest"
"gvisor.dev/gvisor/test/uds"
)
-// Location of syscall tests, relative to the repo root.
-const testDir = "test/syscalls/linux"
-
var (
- testName = flag.String("test-name", "", "name of test binary to run")
debug = flag.Bool("debug", false, "enable debug logs")
strace = flag.Bool("strace", false, "enable strace logs")
platform = flag.String("platform", "ptrace", "platform to run on")
@@ -103,7 +99,7 @@ func runTestCaseNative(testBin string, tc gtest.TestCase, t *testing.T) {
env = append(env, "TEST_UDS_ATTACH_TREE="+socketDir)
}
- cmd := exec.Command(testBin, gtest.FilterTestFlag+"="+tc.FullName())
+ cmd := exec.Command(testBin, tc.Args()...)
cmd.Env = env
cmd.Stdout = os.Stdout
cmd.Stderr = os.Stderr
@@ -296,7 +292,7 @@ func setupUDSTree(spec *specs.Spec) (cleanup func(), err error) {
func runTestCaseRunsc(testBin string, tc gtest.TestCase, t *testing.T) {
// Run a new container with the test executable and filter for the
// given test suite and name.
- spec := testutil.NewSpecWithArgs(testBin, gtest.FilterTestFlag+"="+tc.FullName())
+ spec := testutil.NewSpecWithArgs(append([]string{testBin}, tc.Args()...)...)
// Mark the root as writeable, as some tests attempt to
// write to the rootfs, and expect EACCES, not EROFS.
@@ -404,9 +400,10 @@ func matchString(a, b string) (bool, error) {
func main() {
flag.Parse()
- if *testName == "" {
- fatalf("test-name flag must be provided")
+ if flag.NArg() != 1 {
+ fatalf("test must be provided")
}
+ testBin := flag.Args()[0] // Only argument.
log.SetLevel(log.Info)
if *debug {
@@ -436,15 +433,8 @@ func main() {
}
}
- // Get path to test binary.
- fullTestName := filepath.Join(testDir, *testName)
- testBin, err := testutil.FindFile(fullTestName)
- if err != nil {
- fatalf("FindFile(%q) failed: %v", fullTestName, err)
- }
-
// Get all test cases in each binary.
- testCases, err := gtest.ParseTestCases(testBin)
+ testCases, err := gtest.ParseTestCases(testBin, true)
if err != nil {
fatalf("ParseTestCases(%q) failed: %v", testBin, err)
}
@@ -455,14 +445,19 @@ func main() {
fatalf("TestsForShard() failed: %v", err)
}
+ // Resolve the absolute path for the binary.
+ testBin, err = filepath.Abs(testBin)
+ if err != nil {
+ fatalf("Abs() failed: %v", err)
+ }
+
// Run the tests.
var tests []testing.InternalTest
for _, tci := range indices {
// Capture tc.
tc := testCases[tci]
- testName := fmt.Sprintf("%s_%s", tc.Suite, tc.Name)
tests = append(tests, testing.InternalTest{
- Name: testName,
+ Name: fmt.Sprintf("%s_%s", tc.Suite, tc.Name),
F: func(t *testing.T) {
if *parallel {
t.Parallel()
diff --git a/test/syscalls/BUILD b/test/syscalls/BUILD
index 31d239c0e..3518e862d 100644
--- a/test/syscalls/BUILD
+++ b/test/syscalls/BUILD
@@ -1,5 +1,4 @@
-load("//tools:defs.bzl", "go_binary")
-load("//test/syscalls:build_defs.bzl", "syscall_test")
+load("//test/runner:defs.bzl", "syscall_test")
package(licenses = ["notice"])
@@ -259,6 +258,8 @@ syscall_test(
syscall_test(test = "//test/syscalls/linux:munmap_test")
+syscall_test(test = "//test/syscalls/linux:network_namespace_test")
+
syscall_test(
add_overlay = True,
test = "//test/syscalls/linux:open_create_test",
@@ -677,6 +678,8 @@ syscall_test(
test = "//test/syscalls/linux:truncate_test",
)
+syscall_test(test = "//test/syscalls/linux:tuntap_test")
+
syscall_test(test = "//test/syscalls/linux:udp_bind_test")
syscall_test(
@@ -726,21 +729,3 @@ syscall_test(test = "//test/syscalls/linux:proc_net_unix_test")
syscall_test(test = "//test/syscalls/linux:proc_net_tcp_test")
syscall_test(test = "//test/syscalls/linux:proc_net_udp_test")
-
-go_binary(
- name = "syscall_test_runner",
- testonly = 1,
- srcs = ["syscall_test_runner.go"],
- data = [
- "//runsc",
- ],
- deps = [
- "//pkg/log",
- "//runsc/specutils",
- "//runsc/testutil",
- "//test/syscalls/gtest",
- "//test/uds",
- "@com_github_opencontainers_runtime-spec//specs-go:go_default_library",
- "@org_golang_x_sys//unix:go_default_library",
- ],
-)
diff --git a/test/syscalls/gtest/gtest.go b/test/syscalls/gtest/gtest.go
deleted file mode 100644
index bdec8eb07..000000000
--- a/test/syscalls/gtest/gtest.go
+++ /dev/null
@@ -1,93 +0,0 @@
-// Copyright 2018 The gVisor Authors.
-//
-// Licensed under the Apache License, Version 2.0 (the "License");
-// you may not use this file except in compliance with the License.
-// You may obtain a copy of the License at
-//
-// http://www.apache.org/licenses/LICENSE-2.0
-//
-// Unless required by applicable law or agreed to in writing, software
-// distributed under the License is distributed on an "AS IS" BASIS,
-// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-// See the License for the specific language governing permissions and
-// limitations under the License.
-
-// Package gtest contains helpers for running google-test tests from Go.
-package gtest
-
-import (
- "fmt"
- "os/exec"
- "strings"
-)
-
-var (
- // ListTestFlag is the flag that will list tests in gtest binaries.
- ListTestFlag = "--gtest_list_tests"
-
- // FilterTestFlag is the flag that will filter tests in gtest binaries.
- FilterTestFlag = "--gtest_filter"
-)
-
-// TestCase is a single gtest test case.
-type TestCase struct {
- // Suite is the suite for this test.
- Suite string
-
- // Name is the name of this individual test.
- Name string
-}
-
-// FullName returns the name of the test including the suite. It is suitable to
-// pass to "-gtest_filter".
-func (tc TestCase) FullName() string {
- return fmt.Sprintf("%s.%s", tc.Suite, tc.Name)
-}
-
-// ParseTestCases calls a gtest test binary to list its test and returns a
-// slice with the name and suite of each test.
-func ParseTestCases(testBin string, extraArgs ...string) ([]TestCase, error) {
- args := append([]string{ListTestFlag}, extraArgs...)
- cmd := exec.Command(testBin, args...)
- out, err := cmd.Output()
- if err != nil {
- exitErr, ok := err.(*exec.ExitError)
- if !ok {
- return nil, fmt.Errorf("could not enumerate gtest tests: %v", err)
- }
- return nil, fmt.Errorf("could not enumerate gtest tests: %v\nstderr:\n%s", err, exitErr.Stderr)
- }
-
- var t []TestCase
- var suite string
- for _, line := range strings.Split(string(out), "\n") {
- // Strip comments.
- line = strings.Split(line, "#")[0]
-
- // New suite?
- if !strings.HasPrefix(line, " ") {
- suite = strings.TrimSuffix(strings.TrimSpace(line), ".")
- continue
- }
-
- // Individual test.
- name := strings.TrimSpace(line)
-
- // Do we have a suite yet?
- if suite == "" {
- return nil, fmt.Errorf("test without a suite: %v", name)
- }
-
- // Add this individual test.
- t = append(t, TestCase{
- Suite: suite,
- Name: name,
- })
-
- }
-
- if len(t) == 0 {
- return nil, fmt.Errorf("no tests parsed from %v", testBin)
- }
- return t, nil
-}
diff --git a/test/syscalls/linux/32bit.cc b/test/syscalls/linux/32bit.cc
index c47a05181..3c825477c 100644
--- a/test/syscalls/linux/32bit.cc
+++ b/test/syscalls/linux/32bit.cc
@@ -74,7 +74,7 @@ void ExitGroup32(const char instruction[2], int code) {
"int $3\n"
:
: [ code ] "m"(code), [ ip ] "d"(m.ptr())
- : "rax", "rbx", "rsp");
+ : "rax", "rbx");
}
constexpr int kExitCode = 42;
diff --git a/test/syscalls/linux/BUILD b/test/syscalls/linux/BUILD
index e7c82adfc..704bae17b 100644
--- a/test/syscalls/linux/BUILD
+++ b/test/syscalls/linux/BUILD
@@ -12,8 +12,12 @@ exports_files(
"socket_ip_loopback_blocking.cc",
"socket_ip_tcp_generic_loopback.cc",
"socket_ip_tcp_loopback.cc",
+ "socket_ip_tcp_loopback_blocking.cc",
+ "socket_ip_tcp_loopback_nonblock.cc",
"socket_ip_tcp_udp_generic.cc",
"socket_ip_udp_loopback.cc",
+ "socket_ip_udp_loopback_blocking.cc",
+ "socket_ip_udp_loopback_nonblock.cc",
"socket_ip_unbound.cc",
"socket_ipv4_tcp_unbound_external_networking_test.cc",
"socket_ipv4_udp_unbound_external_networking_test.cc",
@@ -128,6 +132,17 @@ cc_library(
)
cc_library(
+ name = "socket_netlink_route_util",
+ testonly = 1,
+ srcs = ["socket_netlink_route_util.cc"],
+ hdrs = ["socket_netlink_route_util.h"],
+ deps = [
+ ":socket_netlink_util",
+ "@com_google_absl//absl/types:optional",
+ ],
+)
+
+cc_library(
name = "socket_test_util",
testonly = 1,
srcs = [
@@ -3426,6 +3441,25 @@ cc_binary(
],
)
+cc_binary(
+ name = "tuntap_test",
+ testonly = 1,
+ srcs = ["tuntap.cc"],
+ linkstatic = 1,
+ deps = [
+ ":socket_test_util",
+ gtest,
+ "//test/syscalls/linux:socket_netlink_route_util",
+ "//test/util:capability_util",
+ "//test/util:file_descriptor",
+ "//test/util:fs_util",
+ "//test/util:posix_error",
+ "//test/util:test_main",
+ "//test/util:test_util",
+ "@com_google_absl//absl/strings",
+ ],
+)
+
cc_library(
name = "udp_socket_test_cases",
testonly = 1,
@@ -3636,6 +3670,23 @@ cc_binary(
)
cc_binary(
+ name = "network_namespace_test",
+ testonly = 1,
+ srcs = ["network_namespace.cc"],
+ linkstatic = 1,
+ deps = [
+ ":socket_test_util",
+ gtest,
+ "//test/util:capability_util",
+ "//test/util:memory_util",
+ "//test/util:test_main",
+ "//test/util:test_util",
+ "//test/util:thread_util",
+ "@com_google_absl//absl/synchronization",
+ ],
+)
+
+cc_binary(
name = "semaphore_test",
testonly = 1,
srcs = ["semaphore.cc"],
diff --git a/test/syscalls/linux/alarm.cc b/test/syscalls/linux/alarm.cc
index d89269985..940c97285 100644
--- a/test/syscalls/linux/alarm.cc
+++ b/test/syscalls/linux/alarm.cc
@@ -188,6 +188,5 @@ int main(int argc, char** argv) {
TEST_PCHECK(sigprocmask(SIG_BLOCK, &set, nullptr) == 0);
gvisor::testing::TestInit(&argc, &argv);
-
- return RUN_ALL_TESTS();
+ return gvisor::testing::RunAllTests();
}
diff --git a/test/syscalls/linux/dev.cc b/test/syscalls/linux/dev.cc
index 4dd302eed..4e473268c 100644
--- a/test/syscalls/linux/dev.cc
+++ b/test/syscalls/linux/dev.cc
@@ -153,6 +153,13 @@ TEST(DevTest, TTYExists) {
EXPECT_EQ(statbuf.st_mode, S_IFCHR | 0666);
}
+TEST(DevTest, NetTunExists) {
+ struct stat statbuf = {};
+ ASSERT_THAT(stat("/dev/net/tun", &statbuf), SyscallSucceeds());
+ // Check that it's a character device with rw-rw-rw- permissions.
+ EXPECT_EQ(statbuf.st_mode, S_IFCHR | 0666);
+}
+
} // namespace
} // namespace testing
diff --git a/test/syscalls/linux/exec.cc b/test/syscalls/linux/exec.cc
index b5e0a512b..07bd527e6 100644
--- a/test/syscalls/linux/exec.cc
+++ b/test/syscalls/linux/exec.cc
@@ -868,6 +868,5 @@ int main(int argc, char** argv) {
}
gvisor::testing::TestInit(&argc, &argv);
-
- return RUN_ALL_TESTS();
+ return gvisor::testing::RunAllTests();
}
diff --git a/test/syscalls/linux/fallocate.cc b/test/syscalls/linux/fallocate.cc
index 1c3d00287..7819f4ac3 100644
--- a/test/syscalls/linux/fallocate.cc
+++ b/test/syscalls/linux/fallocate.cc
@@ -33,7 +33,7 @@ namespace testing {
namespace {
int fallocate(int fd, int mode, off_t offset, off_t len) {
- return syscall(__NR_fallocate, fd, mode, offset, len);
+ return RetryEINTR(syscall)(__NR_fallocate, fd, mode, offset, len);
}
class AllocateTest : public FileTest {
@@ -47,27 +47,27 @@ TEST_F(AllocateTest, Fallocate) {
EXPECT_EQ(buf.st_size, 0);
// Grow to ten bytes.
- EXPECT_THAT(fallocate(test_file_fd_.get(), 0, 0, 10), SyscallSucceeds());
+ ASSERT_THAT(fallocate(test_file_fd_.get(), 0, 0, 10), SyscallSucceeds());
ASSERT_THAT(fstat(test_file_fd_.get(), &buf), SyscallSucceeds());
EXPECT_EQ(buf.st_size, 10);
// Allocate to a smaller size should be noop.
- EXPECT_THAT(fallocate(test_file_fd_.get(), 0, 0, 5), SyscallSucceeds());
+ ASSERT_THAT(fallocate(test_file_fd_.get(), 0, 0, 5), SyscallSucceeds());
ASSERT_THAT(fstat(test_file_fd_.get(), &buf), SyscallSucceeds());
EXPECT_EQ(buf.st_size, 10);
// Grow again.
- EXPECT_THAT(fallocate(test_file_fd_.get(), 0, 0, 20), SyscallSucceeds());
+ ASSERT_THAT(fallocate(test_file_fd_.get(), 0, 0, 20), SyscallSucceeds());
ASSERT_THAT(fstat(test_file_fd_.get(), &buf), SyscallSucceeds());
EXPECT_EQ(buf.st_size, 20);
// Grow with offset.
- EXPECT_THAT(fallocate(test_file_fd_.get(), 0, 10, 20), SyscallSucceeds());
+ ASSERT_THAT(fallocate(test_file_fd_.get(), 0, 10, 20), SyscallSucceeds());
ASSERT_THAT(fstat(test_file_fd_.get(), &buf), SyscallSucceeds());
EXPECT_EQ(buf.st_size, 30);
// Grow with offset beyond EOF.
- EXPECT_THAT(fallocate(test_file_fd_.get(), 0, 39, 1), SyscallSucceeds());
+ ASSERT_THAT(fallocate(test_file_fd_.get(), 0, 39, 1), SyscallSucceeds());
ASSERT_THAT(fstat(test_file_fd_.get(), &buf), SyscallSucceeds());
EXPECT_EQ(buf.st_size, 40);
}
diff --git a/test/syscalls/linux/fcntl.cc b/test/syscalls/linux/fcntl.cc
index 421c15b87..c7cc5816e 100644
--- a/test/syscalls/linux/fcntl.cc
+++ b/test/syscalls/linux/fcntl.cc
@@ -1128,5 +1128,5 @@ int main(int argc, char** argv) {
exit(err);
}
- return RUN_ALL_TESTS();
+ return gvisor::testing::RunAllTests();
}
diff --git a/test/syscalls/linux/ip_socket_test_util.h b/test/syscalls/linux/ip_socket_test_util.h
index 083ebbcf0..39fd6709d 100644
--- a/test/syscalls/linux/ip_socket_test_util.h
+++ b/test/syscalls/linux/ip_socket_test_util.h
@@ -84,20 +84,20 @@ SocketPairKind DualStackUDPBidirectionalBindSocketPair(int type);
// SocketPairs created with AF_INET and the given type.
SocketPairKind IPv4UDPUnboundSocketPair(int type);
-// IPv4UDPUnboundSocketPair returns a SocketKind that represents
-// a SimpleSocket created with AF_INET, SOCK_DGRAM, and the given type.
+// IPv4UDPUnboundSocket returns a SocketKind that represents a SimpleSocket
+// created with AF_INET, SOCK_DGRAM, and the given type.
SocketKind IPv4UDPUnboundSocket(int type);
-// IPv6UDPUnboundSocketPair returns a SocketKind that represents
-// a SimpleSocket created with AF_INET6, SOCK_DGRAM, and the given type.
+// IPv6UDPUnboundSocket returns a SocketKind that represents a SimpleSocket
+// created with AF_INET6, SOCK_DGRAM, and the given type.
SocketKind IPv6UDPUnboundSocket(int type);
-// IPv4TCPUnboundSocketPair returns a SocketKind that represents
-// a SimpleSocket created with AF_INET, SOCK_STREAM and the given type.
+// IPv4TCPUnboundSocket returns a SocketKind that represents a SimpleSocket
+// created with AF_INET, SOCK_STREAM and the given type.
SocketKind IPv4TCPUnboundSocket(int type);
-// IPv6TCPUnboundSocketPair returns a SocketKind that represents
-// a SimpleSocket created with AF_INET6, SOCK_STREAM and the given type.
+// IPv6TCPUnboundSocket returns a SocketKind that represents a SimpleSocket
+// created with AF_INET6, SOCK_STREAM and the given type.
SocketKind IPv6TCPUnboundSocket(int type);
// IfAddrHelper is a helper class that determines the local interfaces present
diff --git a/test/syscalls/linux/itimer.cc b/test/syscalls/linux/itimer.cc
index b77e4cbd1..8b48f0804 100644
--- a/test/syscalls/linux/itimer.cc
+++ b/test/syscalls/linux/itimer.cc
@@ -349,6 +349,5 @@ int main(int argc, char** argv) {
}
gvisor::testing::TestInit(&argc, &argv);
-
- return RUN_ALL_TESTS();
+ return gvisor::testing::RunAllTests();
}
diff --git a/test/syscalls/linux/network_namespace.cc b/test/syscalls/linux/network_namespace.cc
new file mode 100644
index 000000000..6ea48c263
--- /dev/null
+++ b/test/syscalls/linux/network_namespace.cc
@@ -0,0 +1,121 @@
+// Copyright 2020 The gVisor Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+#include <net/if.h>
+#include <sched.h>
+#include <sys/ioctl.h>
+#include <sys/socket.h>
+#include <sys/types.h>
+
+#include "gmock/gmock.h"
+#include "gtest/gtest.h"
+#include "absl/synchronization/notification.h"
+#include "test/syscalls/linux/socket_test_util.h"
+#include "test/util/capability_util.h"
+#include "test/util/memory_util.h"
+#include "test/util/test_util.h"
+#include "test/util/thread_util.h"
+
+namespace gvisor {
+namespace testing {
+
+namespace {
+
+using TestFunc = std::function<PosixError()>;
+using RunFunc = std::function<PosixError(TestFunc)>;
+
+struct NamespaceStrategy {
+ RunFunc run;
+
+ static NamespaceStrategy Of(RunFunc run) {
+ NamespaceStrategy s;
+ s.run = run;
+ return s;
+ }
+};
+
+PosixError RunWithUnshare(TestFunc fn) {
+ PosixError err = PosixError(-1, "function did not return a value");
+ ScopedThread t([&] {
+ if (unshare(CLONE_NEWNET) != 0) {
+ err = PosixError(errno);
+ return;
+ }
+ err = fn();
+ });
+ t.Join();
+ return err;
+}
+
+PosixError RunWithClone(TestFunc fn) {
+ struct Args {
+ absl::Notification n;
+ TestFunc fn;
+ PosixError err;
+ };
+ Args args;
+ args.fn = fn;
+ args.err = PosixError(-1, "function did not return a value");
+
+ ASSIGN_OR_RETURN_ERRNO(
+ Mapping child_stack,
+ MmapAnon(kPageSize, PROT_READ | PROT_WRITE, MAP_PRIVATE));
+ pid_t child = clone(
+ +[](void *arg) {
+ Args *args = reinterpret_cast<Args *>(arg);
+ args->err = args->fn();
+ args->n.Notify();
+ syscall(SYS_exit, 0); // Exit manually. No return address on stack.
+ return 0;
+ },
+ reinterpret_cast<void *>(child_stack.addr() + kPageSize),
+ CLONE_NEWNET | CLONE_THREAD | CLONE_SIGHAND | CLONE_VM, &args);
+ if (child < 0) {
+ return PosixError(errno, "clone() failed");
+ }
+ args.n.WaitForNotification();
+ return args.err;
+}
+
+class NetworkNamespaceTest
+ : public ::testing::TestWithParam<NamespaceStrategy> {};
+
+TEST_P(NetworkNamespaceTest, LoopbackExists) {
+ SKIP_IF(!ASSERT_NO_ERRNO_AND_VALUE(HaveCapability(CAP_NET_ADMIN)));
+
+ EXPECT_NO_ERRNO(GetParam().run([]() {
+ // TODO(gvisor.dev/issue/1833): Update this to test that only "lo" exists.
+ // Check loopback device exists.
+ int sock = socket(AF_INET, SOCK_DGRAM, 0);
+ if (sock < 0) {
+ return PosixError(errno, "socket() failed");
+ }
+ struct ifreq ifr;
+ snprintf(ifr.ifr_name, IFNAMSIZ, "lo");
+ if (ioctl(sock, SIOCGIFINDEX, &ifr) < 0) {
+ return PosixError(errno, "ioctl() failed, lo cannot be found");
+ }
+ return NoError();
+ }));
+}
+
+INSTANTIATE_TEST_SUITE_P(
+ AllNetworkNamespaceTest, NetworkNamespaceTest,
+ ::testing::Values(NamespaceStrategy::Of(RunWithUnshare),
+ NamespaceStrategy::Of(RunWithClone)));
+
+} // namespace
+
+} // namespace testing
+} // namespace gvisor
diff --git a/test/syscalls/linux/prctl.cc b/test/syscalls/linux/prctl.cc
index d07571a5f..04c5161f5 100644
--- a/test/syscalls/linux/prctl.cc
+++ b/test/syscalls/linux/prctl.cc
@@ -226,5 +226,5 @@ int main(int argc, char** argv) {
prctl(PR_GET_NO_NEW_PRIVS, 0, 0, 0, 0));
}
- return RUN_ALL_TESTS();
+ return gvisor::testing::RunAllTests();
}
diff --git a/test/syscalls/linux/prctl_setuid.cc b/test/syscalls/linux/prctl_setuid.cc
index 30f0d75b3..c4e9cf528 100644
--- a/test/syscalls/linux/prctl_setuid.cc
+++ b/test/syscalls/linux/prctl_setuid.cc
@@ -264,5 +264,5 @@ int main(int argc, char** argv) {
prctl(PR_GET_KEEPCAPS, 0, 0, 0, 0);
}
- return RUN_ALL_TESTS();
+ return gvisor::testing::RunAllTests();
}
diff --git a/test/syscalls/linux/proc.cc b/test/syscalls/linux/proc.cc
index a23fdb58d..f91187e75 100644
--- a/test/syscalls/linux/proc.cc
+++ b/test/syscalls/linux/proc.cc
@@ -2076,5 +2076,5 @@ int main(int argc, char** argv) {
}
gvisor::testing::TestInit(&argc, &argv);
- return RUN_ALL_TESTS();
+ return gvisor::testing::RunAllTests();
}
diff --git a/test/syscalls/linux/ptrace.cc b/test/syscalls/linux/ptrace.cc
index 4dd5cf27b..bfe3e2603 100644
--- a/test/syscalls/linux/ptrace.cc
+++ b/test/syscalls/linux/ptrace.cc
@@ -1208,5 +1208,5 @@ int main(int argc, char** argv) {
gvisor::testing::RunExecveChild();
}
- return RUN_ALL_TESTS();
+ return gvisor::testing::RunAllTests();
}
diff --git a/test/syscalls/linux/rseq/uapi.h b/test/syscalls/linux/rseq/uapi.h
index e3ff0579a..ca1d67691 100644
--- a/test/syscalls/linux/rseq/uapi.h
+++ b/test/syscalls/linux/rseq/uapi.h
@@ -15,14 +15,9 @@
#ifndef GVISOR_TEST_SYSCALLS_LINUX_RSEQ_UAPI_H_
#define GVISOR_TEST_SYSCALLS_LINUX_RSEQ_UAPI_H_
-// User-kernel ABI for restartable sequences.
+#include <stdint.h>
-// Standard types.
-//
-// N.B. This header will be included in targets that do have the standard
-// library, so we can't shadow the standard type names.
-using __u32 = __UINT32_TYPE__;
-using __u64 = __UINT64_TYPE__;
+// User-kernel ABI for restartable sequences.
#ifdef __x86_64__
// Syscall numbers.
@@ -32,20 +27,20 @@ constexpr int kRseqSyscall = 334;
#endif // __x86_64__
struct rseq_cs {
- __u32 version;
- __u32 flags;
- __u64 start_ip;
- __u64 post_commit_offset;
- __u64 abort_ip;
-} __attribute__((aligned(4 * sizeof(__u64))));
+ uint32_t version;
+ uint32_t flags;
+ uint64_t start_ip;
+ uint64_t post_commit_offset;
+ uint64_t abort_ip;
+} __attribute__((aligned(4 * sizeof(uint64_t))));
// N.B. alignment is enforced by the kernel.
struct rseq {
- __u32 cpu_id_start;
- __u32 cpu_id;
+ uint32_t cpu_id_start;
+ uint32_t cpu_id;
struct rseq_cs* rseq_cs;
- __u32 flags;
-} __attribute__((aligned(4 * sizeof(__u64))));
+ uint32_t flags;
+} __attribute__((aligned(4 * sizeof(uint64_t))));
constexpr int kRseqFlagUnregister = 1 << 0;
diff --git a/test/syscalls/linux/rtsignal.cc b/test/syscalls/linux/rtsignal.cc
index 81d193ffd..ed27e2566 100644
--- a/test/syscalls/linux/rtsignal.cc
+++ b/test/syscalls/linux/rtsignal.cc
@@ -167,6 +167,5 @@ int main(int argc, char** argv) {
TEST_PCHECK(sigprocmask(SIG_BLOCK, &set, nullptr) == 0);
gvisor::testing::TestInit(&argc, &argv);
-
- return RUN_ALL_TESTS();
+ return gvisor::testing::RunAllTests();
}
diff --git a/test/syscalls/linux/seccomp.cc b/test/syscalls/linux/seccomp.cc
index 2c947feb7..cf6499f8b 100644
--- a/test/syscalls/linux/seccomp.cc
+++ b/test/syscalls/linux/seccomp.cc
@@ -411,5 +411,5 @@ int main(int argc, char** argv) {
}
gvisor::testing::TestInit(&argc, &argv);
- return RUN_ALL_TESTS();
+ return gvisor::testing::RunAllTests();
}
diff --git a/test/syscalls/linux/sigiret.cc b/test/syscalls/linux/sigiret.cc
index 4deb1ae95..6227774a4 100644
--- a/test/syscalls/linux/sigiret.cc
+++ b/test/syscalls/linux/sigiret.cc
@@ -132,6 +132,5 @@ int main(int argc, char** argv) {
TEST_PCHECK(sigprocmask(SIG_BLOCK, &set, nullptr) == 0);
gvisor::testing::TestInit(&argc, &argv);
-
- return RUN_ALL_TESTS();
+ return gvisor::testing::RunAllTests();
}
diff --git a/test/syscalls/linux/signalfd.cc b/test/syscalls/linux/signalfd.cc
index 95be4b66c..389e5fca2 100644
--- a/test/syscalls/linux/signalfd.cc
+++ b/test/syscalls/linux/signalfd.cc
@@ -369,5 +369,5 @@ int main(int argc, char** argv) {
gvisor::testing::TestInit(&argc, &argv);
- return RUN_ALL_TESTS();
+ return gvisor::testing::RunAllTests();
}
diff --git a/test/syscalls/linux/sigstop.cc b/test/syscalls/linux/sigstop.cc
index 7db57d968..b2fcedd62 100644
--- a/test/syscalls/linux/sigstop.cc
+++ b/test/syscalls/linux/sigstop.cc
@@ -147,5 +147,5 @@ int main(int argc, char** argv) {
return 1;
}
- return RUN_ALL_TESTS();
+ return gvisor::testing::RunAllTests();
}
diff --git a/test/syscalls/linux/sigtimedwait.cc b/test/syscalls/linux/sigtimedwait.cc
index 1e5bf5942..4f8afff15 100644
--- a/test/syscalls/linux/sigtimedwait.cc
+++ b/test/syscalls/linux/sigtimedwait.cc
@@ -319,6 +319,5 @@ int main(int argc, char** argv) {
TEST_PCHECK(sigprocmask(SIG_BLOCK, &set, nullptr) == 0);
gvisor::testing::TestInit(&argc, &argv);
-
- return RUN_ALL_TESTS();
+ return gvisor::testing::RunAllTests();
}
diff --git a/test/syscalls/linux/socket_ip_udp_generic.cc b/test/syscalls/linux/socket_ip_udp_generic.cc
index db5663ecd..1c533fdf2 100644
--- a/test/syscalls/linux/socket_ip_udp_generic.cc
+++ b/test/syscalls/linux/socket_ip_udp_generic.cc
@@ -14,6 +14,7 @@
#include "test/syscalls/linux/socket_ip_udp_generic.h"
+#include <errno.h>
#include <netinet/in.h>
#include <netinet/tcp.h>
#include <poll.h>
@@ -209,46 +210,6 @@ TEST_P(UDPSocketPairTest, SetMulticastLoopChar) {
EXPECT_EQ(get, kSockOptOn);
}
-// Ensure that Receiving TOS is off by default.
-TEST_P(UDPSocketPairTest, RecvTosDefault) {
- auto sockets = ASSERT_NO_ERRNO_AND_VALUE(NewSocketPair());
-
- int get = -1;
- socklen_t get_len = sizeof(get);
- ASSERT_THAT(
- getsockopt(sockets->first_fd(), IPPROTO_IP, IP_RECVTOS, &get, &get_len),
- SyscallSucceedsWithValue(0));
- EXPECT_EQ(get_len, sizeof(get));
- EXPECT_EQ(get, kSockOptOff);
-}
-
-// Test that setting and getting IP_RECVTOS works as expected.
-TEST_P(UDPSocketPairTest, SetRecvTos) {
- auto sockets = ASSERT_NO_ERRNO_AND_VALUE(NewSocketPair());
-
- ASSERT_THAT(setsockopt(sockets->first_fd(), IPPROTO_IP, IP_RECVTOS,
- &kSockOptOff, sizeof(kSockOptOff)),
- SyscallSucceeds());
-
- int get = -1;
- socklen_t get_len = sizeof(get);
- ASSERT_THAT(
- getsockopt(sockets->first_fd(), IPPROTO_IP, IP_RECVTOS, &get, &get_len),
- SyscallSucceedsWithValue(0));
- EXPECT_EQ(get_len, sizeof(get));
- EXPECT_EQ(get, kSockOptOff);
-
- ASSERT_THAT(setsockopt(sockets->first_fd(), IPPROTO_IP, IP_RECVTOS,
- &kSockOptOn, sizeof(kSockOptOn)),
- SyscallSucceeds());
-
- ASSERT_THAT(
- getsockopt(sockets->first_fd(), IPPROTO_IP, IP_RECVTOS, &get, &get_len),
- SyscallSucceedsWithValue(0));
- EXPECT_EQ(get_len, sizeof(get));
- EXPECT_EQ(get, kSockOptOn);
-}
-
TEST_P(UDPSocketPairTest, ReuseAddrDefault) {
auto sockets = ASSERT_NO_ERRNO_AND_VALUE(NewSocketPair());
@@ -401,5 +362,97 @@ TEST_P(UDPSocketPairTest, SetAndGetIPPKTINFO) {
EXPECT_EQ(get_len, sizeof(get));
}
+// Holds TOS or TClass information for IPv4 or IPv6 respectively.
+struct RecvTosOption {
+ int level;
+ int option;
+};
+
+RecvTosOption GetRecvTosOption(int domain) {
+ TEST_CHECK(domain == AF_INET || domain == AF_INET6);
+ RecvTosOption opt;
+ switch (domain) {
+ case AF_INET:
+ opt.level = IPPROTO_IP;
+ opt.option = IP_RECVTOS;
+ break;
+ case AF_INET6:
+ opt.level = IPPROTO_IPV6;
+ opt.option = IPV6_RECVTCLASS;
+ break;
+ }
+ return opt;
+}
+
+// Ensure that Receiving TOS or TCLASS is off by default.
+TEST_P(UDPSocketPairTest, RecvTosDefault) {
+ auto sockets = ASSERT_NO_ERRNO_AND_VALUE(NewSocketPair());
+ RecvTosOption t = GetRecvTosOption(GetParam().domain);
+ int get = -1;
+ socklen_t get_len = sizeof(get);
+ ASSERT_THAT(
+ getsockopt(sockets->first_fd(), t.level, t.option, &get, &get_len),
+ SyscallSucceedsWithValue(0));
+ EXPECT_EQ(get_len, sizeof(get));
+ EXPECT_EQ(get, kSockOptOff);
+}
+
+// Test that setting and getting IP_RECVTOS or IPV6_RECVTCLASS works as
+// expected.
+TEST_P(UDPSocketPairTest, SetRecvTos) {
+ auto sockets = ASSERT_NO_ERRNO_AND_VALUE(NewSocketPair());
+ RecvTosOption t = GetRecvTosOption(GetParam().domain);
+
+ ASSERT_THAT(setsockopt(sockets->first_fd(), t.level, t.option, &kSockOptOff,
+ sizeof(kSockOptOff)),
+ SyscallSucceeds());
+
+ int get = -1;
+ socklen_t get_len = sizeof(get);
+ ASSERT_THAT(
+ getsockopt(sockets->first_fd(), t.level, t.option, &get, &get_len),
+ SyscallSucceedsWithValue(0));
+ EXPECT_EQ(get_len, sizeof(get));
+ EXPECT_EQ(get, kSockOptOff);
+
+ ASSERT_THAT(setsockopt(sockets->first_fd(), t.level, t.option, &kSockOptOn,
+ sizeof(kSockOptOn)),
+ SyscallSucceeds());
+
+ ASSERT_THAT(
+ getsockopt(sockets->first_fd(), t.level, t.option, &get, &get_len),
+ SyscallSucceedsWithValue(0));
+ EXPECT_EQ(get_len, sizeof(get));
+ EXPECT_EQ(get, kSockOptOn);
+}
+
+// Test that any socket (including IPv6 only) accepts the IPv4 TOS option: this
+// mirrors behavior in linux.
+TEST_P(UDPSocketPairTest, TOSRecvMismatch) {
+ auto sockets = ASSERT_NO_ERRNO_AND_VALUE(NewSocketPair());
+ RecvTosOption t = GetRecvTosOption(AF_INET);
+ int get = -1;
+ socklen_t get_len = sizeof(get);
+
+ ASSERT_THAT(
+ getsockopt(sockets->first_fd(), t.level, t.option, &get, &get_len),
+ SyscallSucceedsWithValue(0));
+}
+
+// Test that an IPv4 socket does not support the IPv6 TClass option.
+TEST_P(UDPSocketPairTest, TClassRecvMismatch) {
+ // This should only test AF_INET sockets for the mismatch behavior.
+ SKIP_IF(GetParam().domain != AF_INET);
+
+ auto sockets = ASSERT_NO_ERRNO_AND_VALUE(NewSocketPair());
+
+ int get = -1;
+ socklen_t get_len = sizeof(get);
+
+ ASSERT_THAT(getsockopt(sockets->first_fd(), IPPROTO_IPV6, IPV6_RECVTCLASS,
+ &get, &get_len),
+ SyscallFailsWithErrno(EOPNOTSUPP));
+}
+
} // namespace testing
} // namespace gvisor
diff --git a/test/syscalls/linux/socket_netlink_route_util.cc b/test/syscalls/linux/socket_netlink_route_util.cc
new file mode 100644
index 000000000..53eb3b6b2
--- /dev/null
+++ b/test/syscalls/linux/socket_netlink_route_util.cc
@@ -0,0 +1,163 @@
+// Copyright 2020 The gVisor Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+#include "test/syscalls/linux/socket_netlink_route_util.h"
+
+#include <linux/if.h>
+#include <linux/netlink.h>
+#include <linux/rtnetlink.h>
+
+#include "absl/types/optional.h"
+#include "test/syscalls/linux/socket_netlink_util.h"
+
+namespace gvisor {
+namespace testing {
+namespace {
+
+constexpr uint32_t kSeq = 12345;
+
+} // namespace
+
+PosixError DumpLinks(
+ const FileDescriptor& fd, uint32_t seq,
+ const std::function<void(const struct nlmsghdr* hdr)>& fn) {
+ struct request {
+ struct nlmsghdr hdr;
+ struct ifinfomsg ifm;
+ };
+
+ struct request req = {};
+ req.hdr.nlmsg_len = sizeof(req);
+ req.hdr.nlmsg_type = RTM_GETLINK;
+ req.hdr.nlmsg_flags = NLM_F_REQUEST | NLM_F_DUMP;
+ req.hdr.nlmsg_seq = seq;
+ req.ifm.ifi_family = AF_UNSPEC;
+
+ return NetlinkRequestResponse(fd, &req, sizeof(req), fn, false);
+}
+
+PosixErrorOr<std::vector<Link>> DumpLinks() {
+ ASSIGN_OR_RETURN_ERRNO(FileDescriptor fd, NetlinkBoundSocket(NETLINK_ROUTE));
+
+ std::vector<Link> links;
+ RETURN_IF_ERRNO(DumpLinks(fd, kSeq, [&](const struct nlmsghdr* hdr) {
+ if (hdr->nlmsg_type != RTM_NEWLINK ||
+ hdr->nlmsg_len < NLMSG_SPACE(sizeof(struct ifinfomsg))) {
+ return;
+ }
+ const struct ifinfomsg* msg =
+ reinterpret_cast<const struct ifinfomsg*>(NLMSG_DATA(hdr));
+ const auto* rta = FindRtAttr(hdr, msg, IFLA_IFNAME);
+ if (rta == nullptr) {
+ // Ignore links that do not have a name.
+ return;
+ }
+
+ links.emplace_back();
+ links.back().index = msg->ifi_index;
+ links.back().type = msg->ifi_type;
+ links.back().name =
+ std::string(reinterpret_cast<const char*>(RTA_DATA(rta)));
+ }));
+ return links;
+}
+
+PosixErrorOr<absl::optional<Link>> FindLoopbackLink() {
+ ASSIGN_OR_RETURN_ERRNO(auto links, DumpLinks());
+ for (const auto& link : links) {
+ if (link.type == ARPHRD_LOOPBACK) {
+ return absl::optional<Link>(link);
+ }
+ }
+ return absl::optional<Link>();
+}
+
+PosixError LinkAddLocalAddr(int index, int family, int prefixlen,
+ const void* addr, int addrlen) {
+ ASSIGN_OR_RETURN_ERRNO(FileDescriptor fd, NetlinkBoundSocket(NETLINK_ROUTE));
+
+ struct request {
+ struct nlmsghdr hdr;
+ struct ifaddrmsg ifaddr;
+ char attrbuf[512];
+ };
+
+ struct request req = {};
+ req.hdr.nlmsg_len = NLMSG_LENGTH(sizeof(req.ifaddr));
+ req.hdr.nlmsg_type = RTM_NEWADDR;
+ req.hdr.nlmsg_flags = NLM_F_REQUEST | NLM_F_ACK;
+ req.hdr.nlmsg_seq = kSeq;
+ req.ifaddr.ifa_index = index;
+ req.ifaddr.ifa_family = family;
+ req.ifaddr.ifa_prefixlen = prefixlen;
+
+ struct rtattr* rta = reinterpret_cast<struct rtattr*>(
+ reinterpret_cast<int8_t*>(&req) + NLMSG_ALIGN(req.hdr.nlmsg_len));
+ rta->rta_type = IFA_LOCAL;
+ rta->rta_len = RTA_LENGTH(addrlen);
+ req.hdr.nlmsg_len = NLMSG_ALIGN(req.hdr.nlmsg_len) + RTA_LENGTH(addrlen);
+ memcpy(RTA_DATA(rta), addr, addrlen);
+
+ return NetlinkRequestAckOrError(fd, kSeq, &req, req.hdr.nlmsg_len);
+}
+
+PosixError LinkChangeFlags(int index, unsigned int flags, unsigned int change) {
+ ASSIGN_OR_RETURN_ERRNO(FileDescriptor fd, NetlinkBoundSocket(NETLINK_ROUTE));
+
+ struct request {
+ struct nlmsghdr hdr;
+ struct ifinfomsg ifinfo;
+ char pad[NLMSG_ALIGNTO];
+ };
+
+ struct request req = {};
+ req.hdr.nlmsg_len = NLMSG_LENGTH(sizeof(req.ifinfo));
+ req.hdr.nlmsg_type = RTM_NEWLINK;
+ req.hdr.nlmsg_flags = NLM_F_REQUEST | NLM_F_ACK;
+ req.hdr.nlmsg_seq = kSeq;
+ req.ifinfo.ifi_index = index;
+ req.ifinfo.ifi_flags = flags;
+ req.ifinfo.ifi_change = change;
+
+ return NetlinkRequestAckOrError(fd, kSeq, &req, req.hdr.nlmsg_len);
+}
+
+PosixError LinkSetMacAddr(int index, const void* addr, int addrlen) {
+ ASSIGN_OR_RETURN_ERRNO(FileDescriptor fd, NetlinkBoundSocket(NETLINK_ROUTE));
+
+ struct request {
+ struct nlmsghdr hdr;
+ struct ifinfomsg ifinfo;
+ char attrbuf[512];
+ };
+
+ struct request req = {};
+ req.hdr.nlmsg_len = NLMSG_LENGTH(sizeof(req.ifinfo));
+ req.hdr.nlmsg_type = RTM_NEWLINK;
+ req.hdr.nlmsg_flags = NLM_F_REQUEST | NLM_F_ACK;
+ req.hdr.nlmsg_seq = kSeq;
+ req.ifinfo.ifi_index = index;
+
+ struct rtattr* rta = reinterpret_cast<struct rtattr*>(
+ reinterpret_cast<int8_t*>(&req) + NLMSG_ALIGN(req.hdr.nlmsg_len));
+ rta->rta_type = IFLA_ADDRESS;
+ rta->rta_len = RTA_LENGTH(addrlen);
+ req.hdr.nlmsg_len = NLMSG_ALIGN(req.hdr.nlmsg_len) + RTA_LENGTH(addrlen);
+ memcpy(RTA_DATA(rta), addr, addrlen);
+
+ return NetlinkRequestAckOrError(fd, kSeq, &req, req.hdr.nlmsg_len);
+}
+
+} // namespace testing
+} // namespace gvisor
diff --git a/test/syscalls/linux/socket_netlink_route_util.h b/test/syscalls/linux/socket_netlink_route_util.h
new file mode 100644
index 000000000..2c018e487
--- /dev/null
+++ b/test/syscalls/linux/socket_netlink_route_util.h
@@ -0,0 +1,55 @@
+// Copyright 2020 The gVisor Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+#ifndef GVISOR_TEST_SYSCALLS_LINUX_SOCKET_NETLINK_ROUTE_UTIL_H_
+#define GVISOR_TEST_SYSCALLS_LINUX_SOCKET_NETLINK_ROUTE_UTIL_H_
+
+#include <linux/netlink.h>
+#include <linux/rtnetlink.h>
+
+#include <vector>
+
+#include "absl/types/optional.h"
+#include "test/syscalls/linux/socket_netlink_util.h"
+
+namespace gvisor {
+namespace testing {
+
+struct Link {
+ int index;
+ int16_t type;
+ std::string name;
+};
+
+PosixError DumpLinks(const FileDescriptor& fd, uint32_t seq,
+ const std::function<void(const struct nlmsghdr* hdr)>& fn);
+
+PosixErrorOr<std::vector<Link>> DumpLinks();
+
+PosixErrorOr<absl::optional<Link>> FindLoopbackLink();
+
+// LinkAddLocalAddr sets IFA_LOCAL attribute on the interface.
+PosixError LinkAddLocalAddr(int index, int family, int prefixlen,
+ const void* addr, int addrlen);
+
+// LinkChangeFlags changes interface flags. E.g. IFF_UP.
+PosixError LinkChangeFlags(int index, unsigned int flags, unsigned int change);
+
+// LinkSetMacAddr sets IFLA_ADDRESS attribute of the interface.
+PosixError LinkSetMacAddr(int index, const void* addr, int addrlen);
+
+} // namespace testing
+} // namespace gvisor
+
+#endif // GVISOR_TEST_SYSCALLS_LINUX_SOCKET_NETLINK_ROUTE_UTIL_H_
diff --git a/test/syscalls/linux/timers.cc b/test/syscalls/linux/timers.cc
index 2f92c27da..4b3c44527 100644
--- a/test/syscalls/linux/timers.cc
+++ b/test/syscalls/linux/timers.cc
@@ -658,5 +658,5 @@ int main(int argc, char** argv) {
}
}
- return RUN_ALL_TESTS();
+ return gvisor::testing::RunAllTests();
}
diff --git a/test/syscalls/linux/tuntap.cc b/test/syscalls/linux/tuntap.cc
new file mode 100644
index 000000000..f6ac9d7b8
--- /dev/null
+++ b/test/syscalls/linux/tuntap.cc
@@ -0,0 +1,346 @@
+// Copyright 2019 The gVisor Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+#include <arpa/inet.h>
+#include <linux/capability.h>
+#include <linux/if_arp.h>
+#include <linux/if_ether.h>
+#include <linux/if_tun.h>
+#include <netinet/ip.h>
+#include <netinet/ip_icmp.h>
+#include <sys/ioctl.h>
+#include <sys/socket.h>
+#include <sys/types.h>
+
+#include "gmock/gmock.h"
+#include "gtest/gtest.h"
+#include "absl/strings/ascii.h"
+#include "absl/strings/str_split.h"
+#include "test/syscalls/linux/socket_netlink_route_util.h"
+#include "test/syscalls/linux/socket_test_util.h"
+#include "test/util/capability_util.h"
+#include "test/util/file_descriptor.h"
+#include "test/util/fs_util.h"
+#include "test/util/posix_error.h"
+#include "test/util/test_util.h"
+
+namespace gvisor {
+namespace testing {
+namespace {
+
+constexpr int kIPLen = 4;
+
+constexpr const char kDevNetTun[] = "/dev/net/tun";
+constexpr const char kTapName[] = "tap0";
+
+constexpr const uint8_t kMacA[ETH_ALEN] = {0xAA, 0xAA, 0xAA, 0xAA, 0xAA, 0xAA};
+constexpr const uint8_t kMacB[ETH_ALEN] = {0xBB, 0xBB, 0xBB, 0xBB, 0xBB, 0xBB};
+
+PosixErrorOr<std::set<std::string>> DumpLinkNames() {
+ ASSIGN_OR_RETURN_ERRNO(auto links, DumpLinks());
+ std::set<std::string> names;
+ for (const auto& link : links) {
+ names.emplace(link.name);
+ }
+ return names;
+}
+
+PosixErrorOr<absl::optional<Link>> GetLinkByName(const std::string& name) {
+ ASSIGN_OR_RETURN_ERRNO(auto links, DumpLinks());
+ for (const auto& link : links) {
+ if (link.name == name) {
+ return absl::optional<Link>(link);
+ }
+ }
+ return absl::optional<Link>();
+}
+
+struct pihdr {
+ uint16_t pi_flags;
+ uint16_t pi_protocol;
+} __attribute__((packed));
+
+struct ping_pkt {
+ pihdr pi;
+ struct ethhdr eth;
+ struct iphdr ip;
+ struct icmphdr icmp;
+ char payload[64];
+} __attribute__((packed));
+
+ping_pkt CreatePingPacket(const uint8_t srcmac[ETH_ALEN], const char* srcip,
+ const uint8_t dstmac[ETH_ALEN], const char* dstip) {
+ ping_pkt pkt = {};
+
+ pkt.pi.pi_protocol = htons(ETH_P_IP);
+
+ memcpy(pkt.eth.h_dest, dstmac, sizeof(pkt.eth.h_dest));
+ memcpy(pkt.eth.h_source, srcmac, sizeof(pkt.eth.h_source));
+ pkt.eth.h_proto = htons(ETH_P_IP);
+
+ pkt.ip.ihl = 5;
+ pkt.ip.version = 4;
+ pkt.ip.tos = 0;
+ pkt.ip.tot_len = htons(sizeof(struct iphdr) + sizeof(struct icmphdr) +
+ sizeof(pkt.payload));
+ pkt.ip.id = 1;
+ pkt.ip.frag_off = 1 << 6; // Do not fragment
+ pkt.ip.ttl = 64;
+ pkt.ip.protocol = IPPROTO_ICMP;
+ inet_pton(AF_INET, dstip, &pkt.ip.daddr);
+ inet_pton(AF_INET, srcip, &pkt.ip.saddr);
+ pkt.ip.check = IPChecksum(pkt.ip);
+
+ pkt.icmp.type = ICMP_ECHO;
+ pkt.icmp.code = 0;
+ pkt.icmp.checksum = 0;
+ pkt.icmp.un.echo.sequence = 1;
+ pkt.icmp.un.echo.id = 1;
+
+ strncpy(pkt.payload, "abcd", sizeof(pkt.payload));
+ pkt.icmp.checksum = ICMPChecksum(pkt.icmp, pkt.payload, sizeof(pkt.payload));
+
+ return pkt;
+}
+
+struct arp_pkt {
+ pihdr pi;
+ struct ethhdr eth;
+ struct arphdr arp;
+ uint8_t arp_sha[ETH_ALEN];
+ uint8_t arp_spa[kIPLen];
+ uint8_t arp_tha[ETH_ALEN];
+ uint8_t arp_tpa[kIPLen];
+} __attribute__((packed));
+
+std::string CreateArpPacket(const uint8_t srcmac[ETH_ALEN], const char* srcip,
+ const uint8_t dstmac[ETH_ALEN], const char* dstip) {
+ std::string buffer;
+ buffer.resize(sizeof(arp_pkt));
+
+ arp_pkt* pkt = reinterpret_cast<arp_pkt*>(&buffer[0]);
+ {
+ pkt->pi.pi_protocol = htons(ETH_P_ARP);
+
+ memcpy(pkt->eth.h_dest, kMacA, sizeof(pkt->eth.h_dest));
+ memcpy(pkt->eth.h_source, kMacB, sizeof(pkt->eth.h_source));
+ pkt->eth.h_proto = htons(ETH_P_ARP);
+
+ pkt->arp.ar_hrd = htons(ARPHRD_ETHER);
+ pkt->arp.ar_pro = htons(ETH_P_IP);
+ pkt->arp.ar_hln = ETH_ALEN;
+ pkt->arp.ar_pln = kIPLen;
+ pkt->arp.ar_op = htons(ARPOP_REPLY);
+
+ memcpy(pkt->arp_sha, srcmac, sizeof(pkt->arp_sha));
+ inet_pton(AF_INET, srcip, pkt->arp_spa);
+ memcpy(pkt->arp_tha, dstmac, sizeof(pkt->arp_tha));
+ inet_pton(AF_INET, dstip, pkt->arp_tpa);
+ }
+ return buffer;
+}
+
+} // namespace
+
+class TuntapTest : public ::testing::Test {
+ protected:
+ void TearDown() override {
+ if (ASSERT_NO_ERRNO_AND_VALUE(HaveCapability(CAP_NET_ADMIN))) {
+ // Bring back capability if we had dropped it in test case.
+ ASSERT_NO_ERRNO(SetCapability(CAP_NET_ADMIN, true));
+ }
+ }
+};
+
+TEST_F(TuntapTest, CreateInterfaceNoCap) {
+ SKIP_IF(!ASSERT_NO_ERRNO_AND_VALUE(HaveCapability(CAP_NET_ADMIN)));
+
+ ASSERT_NO_ERRNO(SetCapability(CAP_NET_ADMIN, false));
+
+ FileDescriptor fd = ASSERT_NO_ERRNO_AND_VALUE(Open(kDevNetTun, O_RDWR));
+
+ struct ifreq ifr = {};
+ ifr.ifr_flags = IFF_TAP;
+ strncpy(ifr.ifr_name, kTapName, IFNAMSIZ);
+
+ EXPECT_THAT(ioctl(fd.get(), TUNSETIFF, &ifr), SyscallFailsWithErrno(EPERM));
+}
+
+TEST_F(TuntapTest, CreateFixedNameInterface) {
+ SKIP_IF(!ASSERT_NO_ERRNO_AND_VALUE(HaveCapability(CAP_NET_ADMIN)));
+
+ FileDescriptor fd = ASSERT_NO_ERRNO_AND_VALUE(Open(kDevNetTun, O_RDWR));
+
+ struct ifreq ifr_set = {};
+ ifr_set.ifr_flags = IFF_TAP;
+ strncpy(ifr_set.ifr_name, kTapName, IFNAMSIZ);
+ EXPECT_THAT(ioctl(fd.get(), TUNSETIFF, &ifr_set),
+ SyscallSucceedsWithValue(0));
+
+ struct ifreq ifr_get = {};
+ EXPECT_THAT(ioctl(fd.get(), TUNGETIFF, &ifr_get),
+ SyscallSucceedsWithValue(0));
+
+ struct ifreq ifr_expect = ifr_set;
+ // See __tun_chr_ioctl() in net/drivers/tun.c.
+ ifr_expect.ifr_flags |= IFF_NOFILTER;
+
+ EXPECT_THAT(DumpLinkNames(),
+ IsPosixErrorOkAndHolds(::testing::Contains(kTapName)));
+ EXPECT_THAT(memcmp(&ifr_expect, &ifr_get, sizeof(ifr_get)), ::testing::Eq(0));
+}
+
+TEST_F(TuntapTest, CreateInterface) {
+ SKIP_IF(!ASSERT_NO_ERRNO_AND_VALUE(HaveCapability(CAP_NET_ADMIN)));
+
+ FileDescriptor fd = ASSERT_NO_ERRNO_AND_VALUE(Open(kDevNetTun, O_RDWR));
+
+ struct ifreq ifr = {};
+ ifr.ifr_flags = IFF_TAP;
+ // Empty ifr.ifr_name. Let kernel assign.
+
+ EXPECT_THAT(ioctl(fd.get(), TUNSETIFF, &ifr), SyscallSucceedsWithValue(0));
+
+ struct ifreq ifr_get = {};
+ EXPECT_THAT(ioctl(fd.get(), TUNGETIFF, &ifr_get),
+ SyscallSucceedsWithValue(0));
+
+ std::string ifname = ifr_get.ifr_name;
+ EXPECT_THAT(ifname, ::testing::StartsWith("tap"));
+ EXPECT_THAT(DumpLinkNames(),
+ IsPosixErrorOkAndHolds(::testing::Contains(ifname)));
+}
+
+TEST_F(TuntapTest, InvalidReadWrite) {
+ SKIP_IF(!ASSERT_NO_ERRNO_AND_VALUE(HaveCapability(CAP_NET_ADMIN)));
+
+ FileDescriptor fd = ASSERT_NO_ERRNO_AND_VALUE(Open(kDevNetTun, O_RDWR));
+
+ char buf[128] = {};
+ EXPECT_THAT(read(fd.get(), buf, sizeof(buf)), SyscallFailsWithErrno(EBADFD));
+ EXPECT_THAT(write(fd.get(), buf, sizeof(buf)), SyscallFailsWithErrno(EBADFD));
+}
+
+TEST_F(TuntapTest, WriteToDownDevice) {
+ SKIP_IF(!ASSERT_NO_ERRNO_AND_VALUE(HaveCapability(CAP_NET_ADMIN)));
+
+ // FIXME: gVisor always creates enabled/up'd interfaces.
+ SKIP_IF(IsRunningOnGvisor());
+
+ FileDescriptor fd = ASSERT_NO_ERRNO_AND_VALUE(Open(kDevNetTun, O_RDWR));
+
+ // Device created should be down by default.
+ struct ifreq ifr = {};
+ ifr.ifr_flags = IFF_TAP;
+ EXPECT_THAT(ioctl(fd.get(), TUNSETIFF, &ifr), SyscallSucceedsWithValue(0));
+
+ char buf[128] = {};
+ EXPECT_THAT(write(fd.get(), buf, sizeof(buf)), SyscallFailsWithErrno(EIO));
+}
+
+// This test sets up a TAP device and pings kernel by sending ICMP echo request.
+//
+// It works as the following:
+// * Open /dev/net/tun, and create kTapName interface.
+// * Use rtnetlink to do initial setup of the interface:
+// * Assign IP address 10.0.0.1/24 to kernel.
+// * MAC address: kMacA
+// * Bring up the interface.
+// * Send an ICMP echo reqest (ping) packet from 10.0.0.2 (kMacB) to kernel.
+// * Loop to receive packets from TAP device/fd:
+// * If packet is an ICMP echo reply, it stops and passes the test.
+// * If packet is an ARP request, it responds with canned reply and resends
+// the
+// ICMP request packet.
+TEST_F(TuntapTest, PingKernel) {
+ SKIP_IF(!ASSERT_NO_ERRNO_AND_VALUE(HaveCapability(CAP_NET_ADMIN)));
+
+ // Interface creation.
+ FileDescriptor fd = ASSERT_NO_ERRNO_AND_VALUE(Open(kDevNetTun, O_RDWR));
+
+ struct ifreq ifr_set = {};
+ ifr_set.ifr_flags = IFF_TAP;
+ strncpy(ifr_set.ifr_name, kTapName, IFNAMSIZ);
+ EXPECT_THAT(ioctl(fd.get(), TUNSETIFF, &ifr_set),
+ SyscallSucceedsWithValue(0));
+
+ absl::optional<Link> link =
+ ASSERT_NO_ERRNO_AND_VALUE(GetLinkByName(kTapName));
+ ASSERT_TRUE(link.has_value());
+
+ // Interface setup.
+ struct in_addr addr;
+ inet_pton(AF_INET, "10.0.0.1", &addr);
+ EXPECT_NO_ERRNO(LinkAddLocalAddr(link->index, AF_INET, /*prefixlen=*/24,
+ &addr, sizeof(addr)));
+
+ if (!IsRunningOnGvisor()) {
+ // FIXME: gVisor doesn't support setting MAC address on interfaces yet.
+ EXPECT_NO_ERRNO(LinkSetMacAddr(link->index, kMacA, sizeof(kMacA)));
+
+ // FIXME: gVisor always creates enabled/up'd interfaces.
+ EXPECT_NO_ERRNO(LinkChangeFlags(link->index, IFF_UP, IFF_UP));
+ }
+
+ ping_pkt ping_req = CreatePingPacket(kMacB, "10.0.0.2", kMacA, "10.0.0.1");
+ std::string arp_rep = CreateArpPacket(kMacB, "10.0.0.2", kMacA, "10.0.0.1");
+
+ // Send ping, this would trigger an ARP request on Linux.
+ EXPECT_THAT(write(fd.get(), &ping_req, sizeof(ping_req)),
+ SyscallSucceedsWithValue(sizeof(ping_req)));
+
+ // Receive loop to process inbound packets.
+ struct inpkt {
+ union {
+ pihdr pi;
+ ping_pkt ping;
+ arp_pkt arp;
+ };
+ };
+ while (1) {
+ inpkt r = {};
+ int n = read(fd.get(), &r, sizeof(r));
+ EXPECT_THAT(n, SyscallSucceeds());
+
+ if (n < sizeof(pihdr)) {
+ std::cerr << "Ignored packet, protocol: " << r.pi.pi_protocol
+ << " len: " << n << std::endl;
+ continue;
+ }
+
+ // Process ARP packet.
+ if (n >= sizeof(arp_pkt) && r.pi.pi_protocol == htons(ETH_P_ARP)) {
+ // Respond with canned ARP reply.
+ EXPECT_THAT(write(fd.get(), arp_rep.data(), arp_rep.size()),
+ SyscallSucceedsWithValue(arp_rep.size()));
+ // First ping request might have been dropped due to mac address not in
+ // ARP cache. Send it again.
+ EXPECT_THAT(write(fd.get(), &ping_req, sizeof(ping_req)),
+ SyscallSucceedsWithValue(sizeof(ping_req)));
+ }
+
+ // Process ping response packet.
+ if (n >= sizeof(ping_pkt) && r.pi.pi_protocol == ping_req.pi.pi_protocol &&
+ r.ping.ip.protocol == ping_req.ip.protocol &&
+ !memcmp(&r.ping.ip.saddr, &ping_req.ip.daddr, kIPLen) &&
+ !memcmp(&r.ping.ip.daddr, &ping_req.ip.saddr, kIPLen) &&
+ r.ping.icmp.type == 0 && r.ping.icmp.code == 0) {
+ // Ends and passes the test.
+ break;
+ }
+ }
+}
+
+} // namespace testing
+} // namespace gvisor
diff --git a/test/syscalls/linux/udp_socket_test_cases.cc b/test/syscalls/linux/udp_socket_test_cases.cc
index 9f8de6b48..740c7986d 100644
--- a/test/syscalls/linux/udp_socket_test_cases.cc
+++ b/test/syscalls/linux/udp_socket_test_cases.cc
@@ -21,6 +21,10 @@
#include <sys/socket.h>
#include <sys/types.h>
+#ifndef SIOCGSTAMP
+#include <linux/sockios.h>
+#endif
+
#include "gtest/gtest.h"
#include "absl/base/macros.h"
#include "absl/time/clock.h"
@@ -1349,9 +1353,6 @@ TEST_P(UdpSocketTest, TimestampIoctlPersistence) {
// outgoing packets, and that a receiving socket with IP_RECVTOS or
// IPV6_RECVTCLASS will create the corresponding control message.
TEST_P(UdpSocketTest, SetAndReceiveTOS) {
- // TODO(b/144868438): IPV6_RECVTCLASS not supported for netstack.
- SKIP_IF((GetParam() != AddressFamily::kIpv4) && IsRunningOnGvisor() &&
- !IsRunningWithHostinet());
ASSERT_THAT(bind(s_, addr_[0], addrlen_), SyscallSucceeds());
ASSERT_THAT(connect(t_, addr_[0], addrlen_), SyscallSucceeds());
@@ -1422,7 +1423,6 @@ TEST_P(UdpSocketTest, SetAndReceiveTOS) {
// TOS byte on outgoing packets, and that a receiving socket with IP_RECVTOS or
// IPV6_RECVTCLASS will create the corresponding control message.
TEST_P(UdpSocketTest, SendAndReceiveTOS) {
- // TODO(b/144868438): IPV6_RECVTCLASS not supported for netstack.
// TODO(b/146661005): Setting TOS via cmsg not supported for netstack.
SKIP_IF(IsRunningOnGvisor() && !IsRunningWithHostinet());
ASSERT_THAT(bind(s_, addr_[0], addrlen_), SyscallSucceeds());
diff --git a/test/syscalls/linux/vfork.cc b/test/syscalls/linux/vfork.cc
index 0aaba482d..19d05998e 100644
--- a/test/syscalls/linux/vfork.cc
+++ b/test/syscalls/linux/vfork.cc
@@ -191,5 +191,5 @@ int main(int argc, char** argv) {
return gvisor::testing::RunChild();
}
- return RUN_ALL_TESTS();
+ return gvisor::testing::RunAllTests();
}
diff --git a/test/syscalls/syscall_test_runner.sh b/test/syscalls/syscall_test_runner.sh
deleted file mode 100755
index 864bb2de4..000000000
--- a/test/syscalls/syscall_test_runner.sh
+++ /dev/null
@@ -1,34 +0,0 @@
-#!/bin/bash
-
-# Copyright 2018 The gVisor Authors.
-#
-# Licensed under the Apache License, Version 2.0 (the "License");
-# you may not use this file except in compliance with the License.
-# You may obtain a copy of the License at
-#
-# http://www.apache.org/licenses/LICENSE-2.0
-#
-# Unless required by applicable law or agreed to in writing, software
-# distributed under the License is distributed on an "AS IS" BASIS,
-# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-# See the License for the specific language governing permissions and
-# limitations under the License.
-
-# syscall_test_runner.sh is a simple wrapper around the go syscall test runner.
-# It exists so that we can build the syscall test runner once, and use it for
-# all syscall tests, rather than build it for each test run.
-
-set -euf -x -o pipefail
-
-echo -- "$@"
-
-if [[ -n "${TEST_UNDECLARED_OUTPUTS_DIR}" ]]; then
- mkdir -p "${TEST_UNDECLARED_OUTPUTS_DIR}"
- chmod a+rwx "${TEST_UNDECLARED_OUTPUTS_DIR}"
-fi
-
-# Get location of syscall_test_runner binary.
-readonly runner=$(find "${TEST_SRCDIR}" -name syscall_test_runner)
-
-# Pass the arguments of this script directly to the runner.
-exec "${runner}" "$@"
diff --git a/test/util/BUILD b/test/util/BUILD
index 1f22ebe29..8b5a0f25c 100644
--- a/test/util/BUILD
+++ b/test/util/BUILD
@@ -1,4 +1,4 @@
-load("//tools:defs.bzl", "cc_library", "cc_test", "gtest", "select_system")
+load("//tools:defs.bzl", "cc_library", "cc_test", "gbenchmark", "gtest", "select_system")
package(
default_visibility = ["//:sandbox"],
@@ -260,6 +260,7 @@ cc_library(
"@com_google_absl//absl/strings:str_format",
"@com_google_absl//absl/time",
gtest,
+ gbenchmark,
],
)
diff --git a/test/util/test_main.cc b/test/util/test_main.cc
index 5c7ee0064..1f389e58f 100644
--- a/test/util/test_main.cc
+++ b/test/util/test_main.cc
@@ -16,5 +16,5 @@
int main(int argc, char** argv) {
gvisor::testing::TestInit(&argc, &argv);
- return RUN_ALL_TESTS();
+ return gvisor::testing::RunAllTests();
}
diff --git a/test/util/test_util.h b/test/util/test_util.h
index 2d22b0eb8..c5cb9d6d6 100644
--- a/test/util/test_util.h
+++ b/test/util/test_util.h
@@ -771,6 +771,7 @@ std::string RunfilePath(std::string path);
#endif
void TestInit(int* argc, char*** argv);
+int RunAllTests(void);
} // namespace testing
} // namespace gvisor
diff --git a/test/util/test_util_impl.cc b/test/util/test_util_impl.cc
index ba7c0a85b..7e1ad9e66 100644
--- a/test/util/test_util_impl.cc
+++ b/test/util/test_util_impl.cc
@@ -17,8 +17,12 @@
#include "gtest/gtest.h"
#include "absl/flags/flag.h"
#include "absl/flags/parse.h"
+#include "benchmark/benchmark.h"
#include "test/util/logging.h"
+extern bool FLAGS_benchmark_list_tests;
+extern std::string FLAGS_benchmark_filter;
+
namespace gvisor {
namespace testing {
@@ -26,6 +30,7 @@ void SetupGvisorDeathTest() {}
void TestInit(int* argc, char*** argv) {
::testing::InitGoogleTest(argc, *argv);
+ benchmark::Initialize(argc, *argv);
::absl::ParseCommandLine(*argc, *argv);
// Always mask SIGPIPE as it's common and tests aren't expected to handle it.
@@ -34,5 +39,14 @@ void TestInit(int* argc, char*** argv) {
TEST_CHECK(sigaction(SIGPIPE, &sa, nullptr) == 0);
}
+int RunAllTests() {
+ if (FLAGS_benchmark_list_tests || FLAGS_benchmark_filter != ".") {
+ benchmark::RunSpecifiedBenchmarks();
+ return 0;
+ } else {
+ return RUN_ALL_TESTS();
+ }
+}
+
} // namespace testing
} // namespace gvisor
diff --git a/tools/bazeldefs/defs.bzl b/tools/bazeldefs/defs.bzl
index 6798362dc..905b16d41 100644
--- a/tools/bazeldefs/defs.bzl
+++ b/tools/bazeldefs/defs.bzl
@@ -8,7 +8,6 @@ load("@rules_pkg//:pkg.bzl", _pkg_deb = "pkg_deb", _pkg_tar = "pkg_tar")
load("@io_bazel_rules_docker//go:image.bzl", _go_image = "go_image")
load("@io_bazel_rules_docker//container:container.bzl", _container_image = "container_image")
load("@pydeps//:requirements.bzl", _py_requirement = "requirement")
-load("//tools/bazeldefs:tags.bzl", _go_suffixes = "go_suffixes")
container_image = _container_image
cc_binary = _cc_binary
@@ -19,8 +18,8 @@ cc_test = _cc_test
cc_toolchain = "@bazel_tools//tools/cpp:current_cc_toolchain"
go_image = _go_image
go_embed_data = _go_embed_data
-go_suffixes = _go_suffixes
gtest = "@com_google_googletest//:gtest"
+gbenchmark = "@com_google_benchmark//:benchmark"
loopback = "//tools/bazeldefs:loopback"
proto_library = native.proto_library
pkg_deb = _pkg_deb
diff --git a/tools/bazeldefs/platforms.bzl b/tools/bazeldefs/platforms.bzl
new file mode 100644
index 000000000..92b0b5fc0
--- /dev/null
+++ b/tools/bazeldefs/platforms.bzl
@@ -0,0 +1,17 @@
+"""List of platforms."""
+
+# Platform to associated tags.
+platforms = {
+ "ptrace": [
+ # TODO(b/120560048): Make the tests run without this tag.
+ "no-sandbox",
+ ],
+ "kvm": [
+ "manual",
+ "local",
+ # TODO(b/120560048): Make the tests run without this tag.
+ "no-sandbox",
+ ],
+}
+
+default_platform = "ptrace"
diff --git a/tools/defs.bzl b/tools/defs.bzl
index 39f035f12..15a310403 100644
--- a/tools/defs.bzl
+++ b/tools/defs.bzl
@@ -7,7 +7,9 @@ change for Google-internal and bazel-compatible rules.
load("//tools/go_stateify:defs.bzl", "go_stateify")
load("//tools/go_marshal:defs.bzl", "go_marshal", "marshal_deps", "marshal_test_deps")
-load("//tools/bazeldefs:defs.bzl", "go_suffixes", _cc_binary = "cc_binary", _cc_flags_supplier = "cc_flags_supplier", _cc_library = "cc_library", _cc_proto_library = "cc_proto_library", _cc_test = "cc_test", _cc_toolchain = "cc_toolchain", _container_image = "container_image", _default_installer = "default_installer", _default_net_util = "default_net_util", _go_binary = "go_binary", _go_embed_data = "go_embed_data", _go_image = "go_image", _go_library = "go_library", _go_proto_library = "go_proto_library", _go_test = "go_test", _go_tool_library = "go_tool_library", _gtest = "gtest", _loopback = "loopback", _pkg_deb = "pkg_deb", _pkg_tar = "pkg_tar", _proto_library = "proto_library", _py_binary = "py_binary", _py_library = "py_library", _py_requirement = "py_requirement", _py_test = "py_test", _select_arch = "select_arch", _select_system = "select_system")
+load("//tools/bazeldefs:defs.bzl", _cc_binary = "cc_binary", _cc_flags_supplier = "cc_flags_supplier", _cc_library = "cc_library", _cc_proto_library = "cc_proto_library", _cc_test = "cc_test", _cc_toolchain = "cc_toolchain", _container_image = "container_image", _default_installer = "default_installer", _default_net_util = "default_net_util", _gbenchmark = "gbenchmark", _go_binary = "go_binary", _go_embed_data = "go_embed_data", _go_image = "go_image", _go_library = "go_library", _go_proto_library = "go_proto_library", _go_test = "go_test", _go_tool_library = "go_tool_library", _gtest = "gtest", _loopback = "loopback", _pkg_deb = "pkg_deb", _pkg_tar = "pkg_tar", _proto_library = "proto_library", _py_binary = "py_binary", _py_library = "py_library", _py_requirement = "py_requirement", _py_test = "py_test", _select_arch = "select_arch", _select_system = "select_system")
+load("//tools/bazeldefs:platforms.bzl", _default_platform = "default_platform", _platforms = "platforms")
+load("//tools/bazeldefs:tags.bzl", "go_suffixes")
# Delegate directly.
cc_binary = _cc_binary
@@ -21,6 +23,7 @@ go_image = _go_image
go_test = _go_test
go_tool_library = _go_tool_library
gtest = _gtest
+gbenchmark = _gbenchmark
pkg_deb = _pkg_deb
pkg_tar = _pkg_tar
py_library = _py_library
@@ -32,6 +35,8 @@ select_system = _select_system
loopback = _loopback
default_installer = _default_installer
default_net_util = _default_net_util
+platforms = _platforms
+default_platform = _default_platform
def go_binary(name, **kwargs):
"""Wraps the standard go_binary.
@@ -83,7 +88,7 @@ def go_imports(name, src, out):
cmd = ("$(location @org_golang_x_tools//cmd/goimports:goimports) $(SRCS) > $@"),
)
-def go_library(name, srcs, deps = [], imports = [], stateify = True, marshal = False, **kwargs):
+def go_library(name, srcs, deps = [], imports = [], stateify = True, marshal = False, marshal_debug = False, **kwargs):
"""Wraps the standard go_library and does stateification and marshalling.
The recommended way is to use this rule with mostly identical configuration as the native
@@ -106,6 +111,7 @@ def go_library(name, srcs, deps = [], imports = [], stateify = True, marshal = F
imports: imports required for stateify.
stateify: whether statify is enabled (default: true).
marshal: whether marshal is enabled (default: false).
+ marshal_debug: whether the gomarshal tools emits debugging output (default: false).
**kwargs: standard go_library arguments.
"""
all_srcs = srcs
@@ -144,7 +150,10 @@ def go_library(name, srcs, deps = [], imports = [], stateify = True, marshal = F
go_marshal(
name = name + suffix + "_abi_autogen",
srcs = src_subset,
- debug = False,
+ debug = select({
+ "//tools/go_marshal:marshal_config_verbose": True,
+ "//conditions:default": marshal_debug,
+ }),
imports = imports,
package = name,
)
diff --git a/tools/go_marshal/BUILD b/tools/go_marshal/BUILD
index 80d9c0504..be49cf9c8 100644
--- a/tools/go_marshal/BUILD
+++ b/tools/go_marshal/BUILD
@@ -12,3 +12,8 @@ go_binary(
"//tools/go_marshal/gomarshal",
],
)
+
+config_setting(
+ name = "marshal_config_verbose",
+ values = {"define": "gomarshal=verbose"},
+)
diff --git a/tools/go_marshal/gomarshal/generator.go b/tools/go_marshal/gomarshal/generator.go
index 0294ba5ba..d365a1f3c 100644
--- a/tools/go_marshal/gomarshal/generator.go
+++ b/tools/go_marshal/gomarshal/generator.go
@@ -44,7 +44,8 @@ const (
// All recievers are single letters, so we don't allow import aliases to be a
// single letter.
var badIdents = []string{
- "addr", "blk", "buf", "dst", "dsts", "err", "hdr", "len", "ptr", "src", "srcs", "task", "val",
+ "addr", "blk", "buf", "dst", "dsts", "err", "hdr", "idx", "inner", "len",
+ "ptr", "src", "srcs", "task", "val",
// All single-letter identifiers.
}
@@ -193,9 +194,9 @@ func (g *Generator) parse() ([]*ast.File, []*token.FileSet, error) {
return files, fsets, nil
}
-// collectMarshallabeTypes walks the parsed AST and collects a list of type
+// collectMarshallableTypes walks the parsed AST and collects a list of type
// declarations for which we need to generate the Marshallable interface.
-func (g *Generator) collectMarshallabeTypes(a *ast.File, f *token.FileSet) []*ast.TypeSpec {
+func (g *Generator) collectMarshallableTypes(a *ast.File, f *token.FileSet) []*ast.TypeSpec {
var types []*ast.TypeSpec
for _, decl := range a.Decls {
gdecl, ok := decl.(*ast.GenDecl)
@@ -222,14 +223,22 @@ func (g *Generator) collectMarshallabeTypes(a *ast.File, f *token.FileSet) []*as
continue
}
for _, spec := range gdecl.Specs {
- // We already confirmed we're in a type declaration earlier.
+ // We already confirmed we're in a type declaration earlier, so this
+ // cast will succeed.
t := spec.(*ast.TypeSpec)
- if _, ok := t.Type.(*ast.StructType); ok {
- debugfAt(f.Position(t.Pos()), "Collected marshallable type %s.\n", t.Name.Name)
+ switch t.Type.(type) {
+ case *ast.StructType:
+ debugfAt(f.Position(t.Pos()), "Collected marshallable struct %s.\n", t.Name.Name)
+ types = append(types, t)
+ continue
+ case *ast.Ident: // Newtype on primitive.
+ debugfAt(f.Position(t.Pos()), "Collected marshallable newtype on primitive %s.\n", t.Name.Name)
types = append(types, t)
continue
}
- debugf("Skipping declaration %v since it's not a struct declaration.\n", gdecl)
+ // A user specifically requested marshalling on this type, but we
+ // don't support it.
+ abortAt(f.Position(t.Pos()), fmt.Sprintf("Marshalling codegen was requested on type '%s', but go-marshal doesn't support this kind of declaration.\n", t.Name))
}
}
return types
@@ -269,12 +278,20 @@ func (g *Generator) collectImports(a *ast.File, f *token.FileSet) map[string]imp
}
func (g *Generator) generateOne(t *ast.TypeSpec, fset *token.FileSet) *interfaceGenerator {
- // We're guaranteed to have only struct type specs by now. See
- // Generator.collectMarshallabeTypes.
i := newInterfaceGenerator(t, fset)
- i.validate()
- i.emitMarshallable()
- return i
+ switch ty := t.Type.(type) {
+ case *ast.StructType:
+ i.validateStruct()
+ i.emitMarshallableForStruct()
+ return i
+ case *ast.Ident:
+ i.validatePrimitiveNewtype(ty)
+ i.emitMarshallableForPrimitiveNewtype()
+ return i
+ default:
+ // This should've been filtered out by collectMarshallabeTypes.
+ panic(fmt.Sprintf("Unexpected type %+v", ty))
+ }
}
// generateOneTestSuite generates a test suite for the automatically generated
@@ -320,7 +337,7 @@ func (g *Generator) Run() error {
for i, a := range asts {
// Collect type declarations marked for code generation and generate
// Marshallable interfaces.
- for _, t := range g.collectMarshallabeTypes(a, fsets[i]) {
+ for _, t := range g.collectMarshallableTypes(a, fsets[i]) {
impl := g.generateOne(t, fsets[i])
// Collect Marshallable types referenced by the generated code.
for ref, _ := range impl.ms {
@@ -338,17 +355,6 @@ func (g *Generator) Run() error {
}
}
- // Tool was invoked with input files with no data structures marked for code
- // generation. This is probably not what the user intended.
- if len(impls) == 0 {
- var buf bytes.Buffer
- fmt.Fprintf(&buf, "go_marshal invoked on these files, but they don't contain any types requiring code generation. Perhaps mark some with \"// +marshal\"?:\n")
- for _, i := range g.inputs {
- fmt.Fprintf(&buf, " %s\n", i)
- }
- abort(buf.String())
- }
-
// Write output file header. These include things like package name and
// import statements.
if err := g.writeHeader(); err != nil {
@@ -391,6 +397,26 @@ func (g *Generator) writeTests(ts []*testGenerator) error {
}
// Write test functions.
+
+ // If we didn't generate any Marshallable implementations, we can't just
+ // emit an empty test file, since that causes the build to fail with "no
+ // tests/benchmarks/examples found". Unfortunately we can't signal bazel to
+ // omit the entire package since the outputs are already defined before
+ // go-marshal is called. If we'd otherwise emit an empty test suite, emit an
+ // empty example instead.
+ if len(ts) == 0 {
+ b.reset()
+ b.emit("func ExampleEmptyTestSuite() {\n")
+ b.inIndent(func() {
+ b.emit("// This example is intentionally empty to ensure this file contains at least\n")
+ b.emit("// one testable entity. go-marshal is forced to emit a test file if a package\n")
+ b.emit("// is marked marshallable, but emitting a test file with no entities results\n")
+ b.emit("// in a build failure.\n")
+ })
+ b.emit("}\n")
+ return b.write(g.outputTest)
+ }
+
for _, t := range ts {
if err := t.write(g.outputTest); err != nil {
return err
diff --git a/tools/go_marshal/gomarshal/generator_interfaces.go b/tools/go_marshal/gomarshal/generator_interfaces.go
index 3aa299ccd..ea1af998e 100644
--- a/tools/go_marshal/gomarshal/generator_interfaces.go
+++ b/tools/go_marshal/gomarshal/generator_interfaces.go
@@ -55,9 +55,6 @@ func (g *interfaceGenerator) typeName() string {
// newinterfaceGenerator creates a new interface generator.
func newInterfaceGenerator(t *ast.TypeSpec, fset *token.FileSet) *interfaceGenerator {
- if _, ok := t.Type.(*ast.StructType); !ok {
- panic(fmt.Sprintf("Attempting to generate code for a not struct type %v", t))
- }
g := &interfaceGenerator{
t: t,
r: receiverName(t),
@@ -103,9 +100,31 @@ func (g *interfaceGenerator) abortAt(p token.Pos, msg string) {
abortAt(g.f.Position(p), msg)
}
-// validate ensures the type we're working with can be marshalled. These checks
-// are done ahead of time and in one place so we can make assumptions later.
-func (g *interfaceGenerator) validate() {
+func (g *interfaceGenerator) validatePrimitiveNewtype(t *ast.Ident) {
+ switch t.Name {
+ case "int8", "uint8", "byte", "int16", "uint16", "int32", "uint32", "int64", "uint64":
+ // These are the only primitive types we're allow. Below, we provide
+ // suggestions for some disallowed types and reject them, then attempt
+ // to marshal any remaining types by invoking the marshal.Marshallable
+ // interface on them. If these types don't actually implement
+ // marshal.Marshallable, compilation of the generated code will fail
+ // with an appropriate error message.
+ return
+ case "int":
+ g.abortAt(t.Pos(), "Type 'int' has ambiguous width, use int32 or int64")
+ case "uint":
+ g.abortAt(t.Pos(), "Type 'uint' has ambiguous width, use uint32 or uint64")
+ case "string":
+ g.abortAt(t.Pos(), "Type 'string' is dynamically-sized and cannot be marshalled, use a fixed size byte array '[...]byte' instead")
+ default:
+ debugfAt(g.f.Position(t.Pos()), fmt.Sprintf("Found derived type '%s', will attempt dispatch via marshal.Marshallable.\n", t.Name))
+ }
+}
+
+// validateStruct ensures the type we're working with can be marshalled. These
+// checks are done ahead of time and in one place so we can make assumptions
+// later.
+func (g *interfaceGenerator) validateStruct() {
g.forEachField(func(f *ast.Field) {
if len(f.Names) == 0 {
g.abortAt(f.Pos(), "Cannot marshal structs with embedded fields, give the field a name; use '_' for anonymous fields such as padding fields")
@@ -115,25 +134,7 @@ func (g *interfaceGenerator) validate() {
g.forEachField(func(f *ast.Field) {
fieldDispatcher{
primitive: func(_, t *ast.Ident) {
- switch t.Name {
- case "int8", "uint8", "byte", "int16", "uint16", "int32", "uint32", "int64", "uint64":
- // These are the only primitive types we're allow. Below, we
- // provide suggestions for some disallowed types and reject
- // them, then attempt to marshal any remaining types by
- // invoking the marshal.Marshallable interface on them. If
- // these types don't actually implement
- // marshal.Marshallable, compilation of the generated code
- // will fail with an appropriate error message.
- return
- case "int":
- g.abortAt(f.Pos(), "Type 'int' has ambiguous width, use int32 or int64")
- case "uint":
- g.abortAt(f.Pos(), "Type 'uint' has ambiguous width, use uint32 or uint64")
- case "string":
- g.abortAt(f.Pos(), "Type 'string' is dynamically-sized and cannot be marshalled, use a fixed size byte array '[...]byte' instead")
- default:
- debugfAt(g.f.Position(f.Pos()), fmt.Sprintf("Found derived type '%s', will attempt dispatch via marshal.Marshallable.\n", t.Name))
- }
+ g.validatePrimitiveNewtype(t)
},
selector: func(_, _, _ *ast.Ident) {
// No validation to perform on selector fields. However this
@@ -190,7 +191,8 @@ func (g *interfaceGenerator) shiftDynamic(bufVar, name string) {
g.emit("%s = %s[%s.SizeBytes():]\n", bufVar, bufVar, name)
}
-func (g *interfaceGenerator) marshalScalar(accessor, typ string, bufVar string) {
+// marshalStructFieldScalar writes a single scalar field from a struct to a byte slice.
+func (g *interfaceGenerator) marshalStructFieldScalar(accessor, typ, bufVar string) {
switch typ {
case "int8", "uint8", "byte":
g.emit("%s[0] = byte(%s)\n", bufVar, accessor)
@@ -213,43 +215,27 @@ func (g *interfaceGenerator) marshalScalar(accessor, typ string, bufVar string)
}
}
-func (g *interfaceGenerator) unmarshalScalar(accessor, typ string, bufVar string) {
+// unmarshalStructFieldScalar reads a single scalar field from a struct, from a
+// byte slice.
+func (g *interfaceGenerator) unmarshalStructFieldScalar(accessor, typ, bufVar string) {
switch typ {
- case "int8":
- g.emit("%s = int8(%s[0])\n", accessor, bufVar)
- g.shift(bufVar, 1)
- case "uint8":
- g.emit("%s = uint8(%s[0])\n", accessor, bufVar)
- g.shift(bufVar, 1)
case "byte":
g.emit("%s = %s[0]\n", accessor, bufVar)
g.shift(bufVar, 1)
-
- case "int16":
- g.recordUsedImport("usermem")
- g.emit("%s = int16(usermem.ByteOrder.Uint16(%s[:2]))\n", accessor, bufVar)
- g.shift(bufVar, 2)
- case "uint16":
+ case "int8", "uint8":
+ g.emit("%s = %s(%s[0])\n", accessor, typ, bufVar)
+ g.shift(bufVar, 1)
+ case "int16", "uint16":
g.recordUsedImport("usermem")
- g.emit("%s = usermem.ByteOrder.Uint16(%s[:2])\n", accessor, bufVar)
+ g.emit("%s = %s(usermem.ByteOrder.Uint16(%s[:2]))\n", accessor, typ, bufVar)
g.shift(bufVar, 2)
-
- case "int32":
- g.recordUsedImport("usermem")
- g.emit("%s = int32(usermem.ByteOrder.Uint32(%s[:4]))\n", accessor, bufVar)
- g.shift(bufVar, 4)
- case "uint32":
+ case "int32", "uint32":
g.recordUsedImport("usermem")
- g.emit("%s = usermem.ByteOrder.Uint32(%s[:4])\n", accessor, bufVar)
+ g.emit("%s = %s(usermem.ByteOrder.Uint32(%s[:4]))\n", accessor, typ, bufVar)
g.shift(bufVar, 4)
-
- case "int64":
- g.recordUsedImport("usermem")
- g.emit("%s = int64(usermem.ByteOrder.Uint64(%s[:8]))\n", accessor, bufVar)
- g.shift(bufVar, 8)
- case "uint64":
+ case "int64", "uint64":
g.recordUsedImport("usermem")
- g.emit("%s = usermem.ByteOrder.Uint64(%s[:8])\n", accessor, bufVar)
+ g.emit("%s = %s(usermem.ByteOrder.Uint64(%s[:8]))\n", accessor, typ, bufVar)
g.shift(bufVar, 8)
default:
g.emit("%s.UnmarshalBytes(%s[:%s.SizeBytes()])\n", accessor, bufVar, accessor)
@@ -258,6 +244,49 @@ func (g *interfaceGenerator) unmarshalScalar(accessor, typ string, bufVar string
}
}
+// marshalPrimitiveScalar writes a single primitive variable to a byte slice.
+func (g *interfaceGenerator) marshalPrimitiveScalar(accessor, typ, bufVar string) {
+ switch typ {
+ case "int8", "uint8", "byte":
+ g.emit("%s[0] = byte(*%s)\n", bufVar, accessor)
+ case "int16", "uint16":
+ g.recordUsedImport("usermem")
+ g.emit("usermem.ByteOrder.PutUint16(%s[:2], uint16(*%s))\n", bufVar, accessor)
+ case "int32", "uint32":
+ g.recordUsedImport("usermem")
+ g.emit("usermem.ByteOrder.PutUint32(%s[:4], uint32(*%s))\n", bufVar, accessor)
+ case "int64", "uint64":
+ g.recordUsedImport("usermem")
+ g.emit("usermem.ByteOrder.PutUint64(%s[:8], uint64(*%s))\n", bufVar, accessor)
+ default:
+ g.emit("inner := (*%s)(%s)\n", typ, accessor)
+ g.emit("inner.MarshalBytes(%s[:%s.SizeBytes()])\n", bufVar, accessor)
+ }
+}
+
+// unmarshalPrimitiveScalar read a single primitive variable from a byte slice.
+func (g *interfaceGenerator) unmarshalPrimitiveScalar(accessor, typ, bufVar, typeCast string) {
+ switch typ {
+ case "byte":
+ g.emit("*%s = %s(%s[0])\n", accessor, typeCast, bufVar)
+ case "int8", "uint8":
+ g.emit("*%s = %s(%s(%s[0]))\n", accessor, typeCast, typ, bufVar)
+ case "int16", "uint16":
+ g.recordUsedImport("usermem")
+ g.emit("*%s = %s(%s(usermem.ByteOrder.Uint16(%s[:2])))\n", accessor, typeCast, typ, bufVar)
+ case "int32", "uint32":
+ g.recordUsedImport("usermem")
+ g.emit("*%s = %s(%s(usermem.ByteOrder.Uint32(%s[:4])))\n", accessor, typeCast, typ, bufVar)
+
+ case "int64", "uint64":
+ g.recordUsedImport("usermem")
+ g.emit("*%s = %s(%s(usermem.ByteOrder.Uint64(%s[:8])))\n", accessor, typeCast, typ, bufVar)
+ default:
+ g.emit("inner := (*%s)(%s)\n", typ, accessor)
+ g.emit("inner.UnmarshalBytes(%s[:%s.SizeBytes()])\n", bufVar, accessor)
+ }
+}
+
// areFieldsPackedExpression returns a go expression checking whether g.t's fields are
// packed. Returns "", false if g.t has no fields that may be potentially
// packed, otherwise returns <clause>, true, where <clause> is an expression
@@ -274,7 +303,7 @@ func (g *interfaceGenerator) areFieldsPackedExpression() (string, bool) {
return strings.Join(cs, " && "), true
}
-func (g *interfaceGenerator) emitMarshallable() {
+func (g *interfaceGenerator) emitMarshallableForStruct() {
// Is g.t a packed struct without consideing field types?
thisPacked := true
g.forEachField(func(f *ast.Field) {
@@ -357,10 +386,10 @@ func (g *interfaceGenerator) emitMarshallable() {
}
return
}
- g.marshalScalar(g.fieldAccessor(n), t.Name, "dst")
+ g.marshalStructFieldScalar(g.fieldAccessor(n), t.Name, "dst")
},
selector: func(n, tX, tSel *ast.Ident) {
- g.marshalScalar(g.fieldAccessor(n), fmt.Sprintf("%s.%s", tX.Name, tSel.Name), "dst")
+ g.marshalStructFieldScalar(g.fieldAccessor(n), fmt.Sprintf("%s.%s", tX.Name, tSel.Name), "dst")
},
array: func(n, t *ast.Ident, size int) {
if n.Name == "_" {
@@ -377,9 +406,9 @@ func (g *interfaceGenerator) emitMarshallable() {
return
}
- g.emit("for i := 0; i < %d; i++ {\n", size)
+ g.emit("for idx := 0; idx < %d; idx++ {\n", size)
g.inIndent(func() {
- g.marshalScalar(fmt.Sprintf("%s[i]", g.fieldAccessor(n)), t.Name, "dst")
+ g.marshalStructFieldScalar(fmt.Sprintf("%s[idx]", g.fieldAccessor(n)), t.Name, "dst")
})
g.emit("}\n")
},
@@ -406,10 +435,10 @@ func (g *interfaceGenerator) emitMarshallable() {
}
return
}
- g.unmarshalScalar(g.fieldAccessor(n), t.Name, "src")
+ g.unmarshalStructFieldScalar(g.fieldAccessor(n), t.Name, "src")
},
selector: func(n, tX, tSel *ast.Ident) {
- g.unmarshalScalar(g.fieldAccessor(n), fmt.Sprintf("%s.%s", tX.Name, tSel.Name), "src")
+ g.unmarshalStructFieldScalar(g.fieldAccessor(n), fmt.Sprintf("%s.%s", tX.Name, tSel.Name), "src")
},
array: func(n, t *ast.Ident, size int) {
if n.Name == "_" {
@@ -426,9 +455,9 @@ func (g *interfaceGenerator) emitMarshallable() {
return
}
- g.emit("for i := 0; i < %d; i++ {\n", size)
+ g.emit("for idx := 0; idx < %d; idx++ {\n", size)
g.inIndent(func() {
- g.unmarshalScalar(fmt.Sprintf("%s[i]", g.fieldAccessor(n)), t.Name, "src")
+ g.unmarshalStructFieldScalar(fmt.Sprintf("%s[idx]", g.fieldAccessor(n)), t.Name, "src")
})
g.emit("}\n")
},
@@ -507,13 +536,14 @@ func (g *interfaceGenerator) emitMarshallable() {
g.emit("// CopyOut implements marshal.Marshallable.CopyOut.\n")
g.recordUsedImport("marshal")
g.recordUsedImport("usermem")
- g.emit("func (%s *%s) CopyOut(task marshal.Task, addr usermem.Addr) (int, error) {\n", g.r, g.typeName())
+ g.emit("func (%s *%s) CopyOut(task marshal.Task, addr usermem.Addr) error {\n", g.r, g.typeName())
g.inIndent(func() {
fallback := func() {
g.emit("// Type %s doesn't have a packed layout in memory, fall back to MarshalBytes.\n", g.typeName())
g.emit("buf := task.CopyScratchBuffer(%s.SizeBytes())\n", g.r)
g.emit("%s.MarshalBytes(buf)\n", g.r)
- g.emit("return task.CopyOutBytes(addr, buf)\n")
+ g.emit("_, err := task.CopyOutBytes(addr, buf)\n")
+ g.emit("return err\n")
}
if thisPacked {
g.recordUsedImport("reflect")
@@ -539,11 +569,11 @@ func (g *interfaceGenerator) emitMarshallable() {
g.emit("hdr.Len = %s.SizeBytes()\n", g.r)
g.emit("hdr.Cap = %s.SizeBytes()\n\n", g.r)
- g.emit("len, err := task.CopyOutBytes(addr, buf)\n")
+ g.emit("_, err := task.CopyOutBytes(addr, buf)\n")
g.emit("// Since we bypassed the compiler's escape analysis, indicate that %s\n", g.r)
g.emit("// must live until after the CopyOutBytes.\n")
g.emit("runtime.KeepAlive(%s)\n", g.r)
- g.emit("return len, err\n")
+ g.emit("return err\n")
} else {
fallback()
}
@@ -553,20 +583,20 @@ func (g *interfaceGenerator) emitMarshallable() {
g.emit("// CopyIn implements marshal.Marshallable.CopyIn.\n")
g.recordUsedImport("marshal")
g.recordUsedImport("usermem")
- g.emit("func (%s *%s) CopyIn(task marshal.Task, addr usermem.Addr) (int, error) {\n", g.r, g.typeName())
+ g.emit("func (%s *%s) CopyIn(task marshal.Task, addr usermem.Addr) error {\n", g.r, g.typeName())
g.inIndent(func() {
fallback := func() {
g.emit("// Type %s doesn't have a packed layout in memory, fall back to UnmarshalBytes.\n", g.typeName())
g.emit("buf := task.CopyScratchBuffer(%s.SizeBytes())\n", g.r)
- g.emit("n, err := task.CopyInBytes(addr, buf)\n")
+ g.emit("_, err := task.CopyInBytes(addr, buf)\n")
g.emit("if err != nil {\n")
g.inIndent(func() {
- g.emit("return n, err\n")
+ g.emit("return err\n")
})
g.emit("}\n")
g.emit("%s.UnmarshalBytes(buf)\n", g.r)
- g.emit("return n, nil\n")
+ g.emit("return nil\n")
}
if thisPacked {
g.recordUsedImport("reflect")
@@ -592,11 +622,11 @@ func (g *interfaceGenerator) emitMarshallable() {
g.emit("hdr.Len = %s.SizeBytes()\n", g.r)
g.emit("hdr.Cap = %s.SizeBytes()\n\n", g.r)
- g.emit("len, err := task.CopyInBytes(addr, buf)\n")
+ g.emit("_, err := task.CopyInBytes(addr, buf)\n")
g.emit("// Since we bypassed the compiler's escape analysis, indicate that %s\n", g.r)
g.emit("// must live until after the CopyInBytes.\n")
g.emit("runtime.KeepAlive(%s)\n", g.r)
- g.emit("return len, err\n")
+ g.emit("return err\n")
} else {
fallback()
}
@@ -649,3 +679,144 @@ func (g *interfaceGenerator) emitMarshallable() {
})
g.emit("}\n\n")
}
+
+// emitMarshallableForPrimitiveNewtype outputs code to implement the
+// marshal.Marshallable interface for a newtype on a primitive. Primitive
+// newtypes are always packed, so we can omit the various fallbacks required for
+// non-packed structs.
+func (g *interfaceGenerator) emitMarshallableForPrimitiveNewtype() {
+ g.recordUsedImport("io")
+ g.recordUsedImport("marshal")
+ g.recordUsedImport("reflect")
+ g.recordUsedImport("runtime")
+ g.recordUsedImport("safecopy")
+ g.recordUsedImport("unsafe")
+ g.recordUsedImport("usermem")
+
+ nt := g.t.Type.(*ast.Ident)
+
+ g.emit("// SizeBytes implements marshal.Marshallable.SizeBytes.\n")
+ g.emit("func (%s *%s) SizeBytes() int {\n", g.r, g.typeName())
+ g.inIndent(func() {
+ if size, dynamic := g.scalarSize(nt); !dynamic {
+ g.emit("return %d\n", size)
+ } else {
+ g.emit("return (*%s)(nil).SizeBytes()\n", nt.Name)
+ }
+ })
+ g.emit("}\n\n")
+
+ g.emit("// MarshalBytes implements marshal.Marshallable.MarshalBytes.\n")
+ g.emit("func (%s *%s) MarshalBytes(dst []byte) {\n", g.r, g.typeName())
+ g.inIndent(func() {
+ g.marshalPrimitiveScalar(g.r, nt.Name, "dst")
+ })
+ g.emit("}\n\n")
+
+ g.emit("// UnmarshalBytes implements marshal.Marshallable.UnmarshalBytes.\n")
+ g.emit("func (%s *%s) UnmarshalBytes(src []byte) {\n", g.r, g.typeName())
+ g.inIndent(func() {
+ g.unmarshalPrimitiveScalar(g.r, nt.Name, "src", g.typeName())
+ })
+ g.emit("}\n\n")
+
+ g.emit("// Packed implements marshal.Marshallable.Packed.\n")
+ g.emit("func (%s *%s) Packed() bool {\n", g.r, g.typeName())
+ g.inIndent(func() {
+ g.emit("// Scalar newtypes are always packed.\n")
+ g.emit("return true\n")
+ })
+ g.emit("}\n\n")
+
+ g.emit("// MarshalUnsafe implements marshal.Marshallable.MarshalUnsafe.\n")
+ g.emit("func (%s *%s) MarshalUnsafe(dst []byte) {\n", g.r, g.typeName())
+ g.inIndent(func() {
+ g.emit("safecopy.CopyIn(dst, unsafe.Pointer(%s))\n", g.r)
+ })
+ g.emit("}\n\n")
+
+ g.emit("// UnmarshalUnsafe implements marshal.Marshallable.UnmarshalUnsafe.\n")
+ g.emit("func (%s *%s) UnmarshalUnsafe(src []byte) {\n", g.r, g.typeName())
+ g.inIndent(func() {
+ g.emit("safecopy.CopyOut(unsafe.Pointer(%s), src)\n", g.r)
+ })
+ g.emit("}\n\n")
+
+ g.emit("// CopyOut implements marshal.Marshallable.CopyOut.\n")
+ g.emit("func (%s *%s) CopyOut(task marshal.Task, addr usermem.Addr) error {\n", g.r, g.typeName())
+ g.inIndent(func() {
+ // Fast serialization.
+ g.emit("// Bypass escape analysis on %s. The no-op arithmetic operation on the\n", g.r)
+ g.emit("// pointer makes the compiler think val doesn't depend on %s.\n", g.r)
+ g.emit("// See src/runtime/stubs.go:noescape() in the golang toolchain.\n")
+ g.emit("ptr := unsafe.Pointer(%s)\n", g.r)
+ g.emit("val := uintptr(ptr)\n")
+ g.emit("val = val^0\n\n")
+
+ g.emit("// Construct a slice backed by %s's underlying memory.\n", g.r)
+ g.emit("var buf []byte\n")
+ g.emit("hdr := (*reflect.SliceHeader)(unsafe.Pointer(&buf))\n")
+ g.emit("hdr.Data = val\n")
+ g.emit("hdr.Len = %s.SizeBytes()\n", g.r)
+ g.emit("hdr.Cap = %s.SizeBytes()\n\n", g.r)
+
+ g.emit("_, err := task.CopyOutBytes(addr, buf)\n")
+ g.emit("// Since we bypassed the compiler's escape analysis, indicate that %s\n", g.r)
+ g.emit("// must live until after the CopyOutBytes.\n")
+ g.emit("runtime.KeepAlive(%s)\n", g.r)
+ g.emit("return err\n")
+ })
+ g.emit("}\n\n")
+
+ g.emit("// CopyIn implements marshal.Marshallable.CopyIn.\n")
+ g.emit("func (%s *%s) CopyIn(task marshal.Task, addr usermem.Addr) error {\n", g.r, g.typeName())
+ g.inIndent(func() {
+ g.emit("// Bypass escape analysis on %s. The no-op arithmetic operation on the\n", g.r)
+ g.emit("// pointer makes the compiler think val doesn't depend on %s.\n", g.r)
+ g.emit("// See src/runtime/stubs.go:noescape() in the golang toolchain.\n")
+ g.emit("ptr := unsafe.Pointer(%s)\n", g.r)
+ g.emit("val := uintptr(ptr)\n")
+ g.emit("val = val^0\n\n")
+
+ g.emit("// Construct a slice backed by %s's underlying memory.\n", g.r)
+ g.emit("var buf []byte\n")
+ g.emit("hdr := (*reflect.SliceHeader)(unsafe.Pointer(&buf))\n")
+ g.emit("hdr.Data = val\n")
+ g.emit("hdr.Len = %s.SizeBytes()\n", g.r)
+ g.emit("hdr.Cap = %s.SizeBytes()\n\n", g.r)
+
+ g.emit("_, err := task.CopyInBytes(addr, buf)\n")
+ g.emit("// Since we bypassed the compiler's escape analysis, indicate that %s\n", g.r)
+ g.emit("// must live until after the CopyInBytes.\n")
+ g.emit("runtime.KeepAlive(%s)\n", g.r)
+ g.emit("return err\n")
+ })
+ g.emit("}\n\n")
+
+ g.emit("// WriteTo implements io.WriterTo.WriteTo.\n")
+ g.emit("func (%s *%s) WriteTo(w io.Writer) (int64, error) {\n", g.r, g.typeName())
+ g.inIndent(func() {
+ g.emit("// Bypass escape analysis on %s. The no-op arithmetic operation on the\n", g.r)
+ g.emit("// pointer makes the compiler think val doesn't depend on %s.\n", g.r)
+ g.emit("// See src/runtime/stubs.go:noescape() in the golang toolchain.\n")
+ g.emit("ptr := unsafe.Pointer(%s)\n", g.r)
+ g.emit("val := uintptr(ptr)\n")
+ g.emit("val = val^0\n\n")
+
+ g.emit("// Construct a slice backed by %s's underlying memory.\n", g.r)
+ g.emit("var buf []byte\n")
+ g.emit("hdr := (*reflect.SliceHeader)(unsafe.Pointer(&buf))\n")
+ g.emit("hdr.Data = val\n")
+ g.emit("hdr.Len = %s.SizeBytes()\n", g.r)
+ g.emit("hdr.Cap = %s.SizeBytes()\n\n", g.r)
+
+ g.emit("len, err := w.Write(buf)\n")
+ g.emit("// Since we bypassed the compiler's escape analysis, indicate that %s\n", g.r)
+ g.emit("// must live until after the Write.\n")
+ g.emit("runtime.KeepAlive(%s)\n", g.r)
+ g.emit("return int64(len), err\n")
+
+ })
+ g.emit("}\n\n")
+
+}
diff --git a/tools/go_marshal/gomarshal/generator_tests.go b/tools/go_marshal/gomarshal/generator_tests.go
index 8c28b00d0..8ba47eb67 100644
--- a/tools/go_marshal/gomarshal/generator_tests.go
+++ b/tools/go_marshal/gomarshal/generator_tests.go
@@ -49,9 +49,6 @@ type testGenerator struct {
}
func newTestGenerator(t *ast.TypeSpec) *testGenerator {
- if _, ok := t.Type.(*ast.StructType); !ok {
- panic(fmt.Sprintf("Attempting to generate code for a not struct type %v", t))
- }
g := &testGenerator{
t: t,
r: receiverName(t),
@@ -69,14 +66,6 @@ func (g *testGenerator) typeName() string {
return g.t.Name.Name
}
-func (g *testGenerator) forEachField(fn func(f *ast.Field)) {
- // This is guaranteed to succeed because g.t is always a struct.
- st := g.t.Type.(*ast.StructType)
- for _, field := range st.Fields.List {
- fn(field)
- }
-}
-
func (g *testGenerator) testFuncName(base string) string {
return fmt.Sprintf("%s%s", base, strings.Title(g.t.Name.Name))
}
@@ -89,10 +78,10 @@ func (g *testGenerator) inTestFunction(name string, body func()) {
func (g *testGenerator) emitTestNonZeroSize() {
g.inTestFunction("TestSizeNonZero", func() {
- g.emit("x := &%s{}\n", g.typeName())
+ g.emit("var x %v\n", g.typeName())
g.emit("if x.SizeBytes() == 0 {\n")
g.inIndent(func() {
- g.emit("t.Fatal(\"Marshallable.Size() should not return zero\")\n")
+ g.emit("t.Fatal(\"Marshallable.SizeBytes() should not return zero\")\n")
})
g.emit("}\n")
})
@@ -100,7 +89,7 @@ func (g *testGenerator) emitTestNonZeroSize() {
func (g *testGenerator) emitTestSuspectAlignment() {
g.inTestFunction("TestSuspectAlignment", func() {
- g.emit("x := %s{}\n", g.typeName())
+ g.emit("var x %v\n", g.typeName())
g.emit("analysis.AlignmentCheck(t, reflect.TypeOf(x))\n")
})
}
diff --git a/tools/go_marshal/marshal/marshal.go b/tools/go_marshal/marshal/marshal.go
index 20353850d..f129788e0 100644
--- a/tools/go_marshal/marshal/marshal.go
+++ b/tools/go_marshal/marshal/marshal.go
@@ -91,12 +91,12 @@ type Marshallable interface {
// marshalled does not escape. The implementation should avoid creating
// extra copies in memory by directly deserializing to the object's
// underlying memory.
- CopyIn(task Task, addr usermem.Addr) (int, error)
+ CopyIn(task Task, addr usermem.Addr) error
// CopyOut serializes a Marshallable type to a task's memory. This may only
// be called from a task goroutine. This is more efficient than calling
// MarshalUnsafe on Marshallable.Packed types, as the type being serialized
// does not escape. The implementation should avoid creating extra copies in
// memory by directly serializing from the object's underlying memory.
- CopyOut(task Task, addr usermem.Addr) (int, error)
+ CopyOut(task Task, addr usermem.Addr) error
}
diff --git a/tools/go_marshal/test/test.go b/tools/go_marshal/test/test.go
index 8de02d707..93229dedb 100644
--- a/tools/go_marshal/test/test.go
+++ b/tools/go_marshal/test/test.go
@@ -103,3 +103,13 @@ type Stat struct {
CTime Timespec
_ [3]int64
}
+
+// SignalSet is an example marshallable newtype on a primitive.
+//
+// +marshal
+type SignalSet uint64
+
+// SignalSetAlias is an example newtype on another marshallable type.
+//
+// +marshal
+type SignalSetAlias SignalSet
diff --git a/tools/installers/master.sh b/tools/installers/master.sh
index 7b1956454..52f9734a6 100755
--- a/tools/installers/master.sh
+++ b/tools/installers/master.sh
@@ -15,6 +15,21 @@
# limitations under the License.
# Install runsc from the master branch.
+set -e
+
curl -fsSL https://gvisor.dev/archive.key | sudo apt-key add -
add-apt-repository "deb https://storage.googleapis.com/gvisor/releases release main"
-apt-get update && apt-get install -y runsc
+while true; do
+ if apt-get update; then
+ apt-get install -y runsc
+ break
+ fi
+ result=$?
+ # Check if apt update failed to aquire the file lock.
+ if [[ $result -ne 100 ]]; then
+ exit $result
+ fi
+done
+runsc install
+service docker restart
+