summaryrefslogtreecommitdiffhomepage
diff options
context:
space:
mode:
-rw-r--r--.travis.yml19
-rw-r--r--Dockerfile11
-rw-r--r--Makefile5
-rw-r--r--WORKSPACE21
-rw-r--r--benchmarks/harness/machine.py9
-rw-r--r--benchmarks/harness/ssh_connection.py9
-rw-r--r--kokoro/runtime_tests/go1.12.cfg10
-rw-r--r--kokoro/runtime_tests/java11.cfg10
-rw-r--r--kokoro/runtime_tests/nodejs12.4.0.cfg10
-rw-r--r--kokoro/runtime_tests/php7.3.6.cfg10
-rw-r--r--kokoro/runtime_tests/python3.7.3.cfg10
-rw-r--r--pkg/abi/linux/BUILD1
-rw-r--r--pkg/abi/linux/epoll.go7
-rw-r--r--pkg/abi/linux/epoll_amd64.go2
-rw-r--r--pkg/abi/linux/epoll_arm64.go2
-rw-r--r--pkg/abi/linux/file.go2
-rw-r--r--pkg/abi/linux/fs.go2
-rw-r--r--pkg/abi/linux/ioctl.go26
-rw-r--r--pkg/abi/linux/ioctl_tun.go (renamed from pkg/fspath/builder_unsafe.go)24
-rw-r--r--pkg/abi/linux/signal.go2
-rw-r--r--pkg/abi/linux/time.go6
-rw-r--r--pkg/abi/linux/xattr.go1
-rw-r--r--pkg/atomicbitops/BUILD1
-rw-r--r--pkg/fspath/BUILD4
-rw-r--r--pkg/fspath/builder.go8
-rw-r--r--pkg/fspath/fspath.go3
-rw-r--r--pkg/gohacks/BUILD11
-rw-r--r--pkg/gohacks/gohacks_unsafe.go57
-rw-r--r--pkg/log/glog.go6
-rw-r--r--pkg/log/json.go2
-rw-r--r--pkg/log/json_k8s.go2
-rw-r--r--pkg/log/log.go60
-rw-r--r--pkg/sentry/arch/arch_aarch64.go3
-rw-r--r--pkg/sentry/fs/dev/BUILD5
-rw-r--r--pkg/sentry/fs/dev/dev.go10
-rw-r--r--pkg/sentry/fs/dev/net_tun.go170
-rw-r--r--pkg/sentry/fs/dirent.go25
-rw-r--r--pkg/sentry/fs/mount_test.go11
-rw-r--r--pkg/sentry/fs/mounts.go10
-rw-r--r--pkg/sentry/fs/proc/mounts.go16
-rw-r--r--pkg/sentry/fs/proc/net.go5
-rw-r--r--pkg/sentry/fs/proc/sys_net.go4
-rw-r--r--pkg/sentry/fsbridge/vfs.go10
-rw-r--r--pkg/sentry/fsimpl/proc/tasks.go4
-rw-r--r--pkg/sentry/fsimpl/proc/tasks_net.go5
-rw-r--r--pkg/sentry/fsimpl/proc/tasks_sys.go4
-rw-r--r--pkg/sentry/fsimpl/testutil/kernel.go1
-rw-r--r--pkg/sentry/fsimpl/tmpfs/filesystem.go8
-rw-r--r--pkg/sentry/fsimpl/tmpfs/regular_file.go233
-rw-r--r--pkg/sentry/fsimpl/tmpfs/tmpfs.go16
-rw-r--r--pkg/sentry/inet/BUILD1
-rw-r--r--pkg/sentry/inet/namespace.go99
-rw-r--r--pkg/sentry/kernel/fd_table.go49
-rw-r--r--pkg/sentry/kernel/fs_context.go22
-rw-r--r--pkg/sentry/kernel/kernel.go30
-rw-r--r--pkg/sentry/kernel/rseq.go2
-rw-r--r--pkg/sentry/kernel/task.go27
-rw-r--r--pkg/sentry/kernel/task_clone.go16
-rw-r--r--pkg/sentry/kernel/task_context.go2
-rw-r--r--pkg/sentry/kernel/task_exec.go2
-rw-r--r--pkg/sentry/kernel/task_log.go6
-rw-r--r--pkg/sentry/kernel/task_net.go19
-rw-r--r--pkg/sentry/kernel/task_start.go8
-rw-r--r--pkg/sentry/kernel/task_usermem.go2
-rw-r--r--pkg/sentry/mm/address_space.go23
-rw-r--r--pkg/sentry/mm/lifecycle.go26
-rw-r--r--pkg/sentry/mm/mm.go5
-rw-r--r--pkg/sentry/mm/mm_test.go2
-rw-r--r--pkg/sentry/platform/kvm/machine.go18
-rw-r--r--pkg/sentry/platform/ring0/entry_arm64.s73
-rw-r--r--pkg/sentry/platform/ring0/kernel_arm64.go8
-rw-r--r--pkg/sentry/strace/BUILD1
-rw-r--r--pkg/sentry/strace/epoll.go89
-rw-r--r--pkg/sentry/strace/linux64_amd64.go6
-rw-r--r--pkg/sentry/strace/linux64_arm64.go4
-rw-r--r--pkg/sentry/strace/socket.go9
-rw-r--r--pkg/sentry/strace/strace.go8
-rw-r--r--pkg/sentry/strace/syscalls.go10
-rw-r--r--pkg/sentry/syscalls/linux/sys_epoll.go4
-rw-r--r--pkg/sentry/syscalls/linux/sys_file.go44
-rw-r--r--pkg/sentry/syscalls/linux/sys_getdents.go4
-rw-r--r--pkg/sentry/syscalls/linux/sys_lseek.go4
-rw-r--r--pkg/sentry/syscalls/linux/sys_mmap.go4
-rw-r--r--pkg/sentry/syscalls/linux/sys_read.go4
-rw-r--r--pkg/sentry/syscalls/linux/sys_stat.go10
-rw-r--r--pkg/sentry/syscalls/linux/sys_sync.go4
-rw-r--r--pkg/sentry/syscalls/linux/sys_write.go4
-rw-r--r--pkg/sentry/syscalls/linux/sys_xattr.go4
-rw-r--r--pkg/sentry/syscalls/linux/vfs2/BUILD28
-rw-r--r--pkg/sentry/syscalls/linux/vfs2/epoll.go225
-rw-r--r--pkg/sentry/syscalls/linux/vfs2/epoll_unsafe.go44
-rw-r--r--pkg/sentry/syscalls/linux/vfs2/execve.go137
-rw-r--r--pkg/sentry/syscalls/linux/vfs2/fd.go147
-rw-r--r--pkg/sentry/syscalls/linux/vfs2/filesystem.go326
-rw-r--r--pkg/sentry/syscalls/linux/vfs2/fscontext.go131
-rw-r--r--pkg/sentry/syscalls/linux/vfs2/getdents.go149
-rw-r--r--pkg/sentry/syscalls/linux/vfs2/ioctl.go (renamed from pkg/usermem/usermem_unsafe.go)26
-rw-r--r--pkg/sentry/syscalls/linux/vfs2/linux64_override_amd64.go216
-rw-r--r--pkg/sentry/syscalls/linux/vfs2/linux64_override_arm64.go2
-rw-r--r--pkg/sentry/syscalls/linux/vfs2/mmap.go92
-rw-r--r--pkg/sentry/syscalls/linux/vfs2/path.go94
-rw-r--r--pkg/sentry/syscalls/linux/vfs2/poll.go584
-rw-r--r--pkg/sentry/syscalls/linux/vfs2/read_write.go511
-rw-r--r--pkg/sentry/syscalls/linux/vfs2/setstat.go380
-rw-r--r--pkg/sentry/syscalls/linux/vfs2/stat.go346
-rw-r--r--pkg/sentry/syscalls/linux/vfs2/sync.go87
-rw-r--r--pkg/sentry/syscalls/linux/vfs2/sys_read.go95
-rw-r--r--pkg/sentry/syscalls/linux/vfs2/xattr.go353
-rw-r--r--pkg/sentry/vfs/BUILD1
-rw-r--r--pkg/sentry/vfs/epoll.go3
-rw-r--r--pkg/sentry/vfs/mount.go17
-rw-r--r--pkg/sentry/vfs/mount_unsafe.go12
-rw-r--r--pkg/sentry/vfs/resolving_path.go2
-rw-r--r--pkg/sentry/vfs/vfs.go26
-rw-r--r--pkg/syserror/syserror.go1
-rw-r--r--pkg/tcpip/adapters/gonet/gonet_test.go63
-rw-r--r--pkg/tcpip/buffer/view.go6
-rw-r--r--pkg/tcpip/header/ipv6.go22
-rw-r--r--pkg/tcpip/header/ipv6_test.go98
-rw-r--r--pkg/tcpip/link/channel/BUILD1
-rw-r--r--pkg/tcpip/link/channel/channel.go180
-rw-r--r--pkg/tcpip/link/tun/BUILD18
-rw-r--r--pkg/tcpip/link/tun/device.go352
-rw-r--r--pkg/tcpip/link/tun/protocol.go56
-rw-r--r--pkg/tcpip/network/arp/arp.go20
-rw-r--r--pkg/tcpip/network/ipv4/ipv4.go6
-rw-r--r--pkg/tcpip/network/ipv6/ipv6.go6
-rw-r--r--pkg/tcpip/stack/ndp.go25
-rw-r--r--pkg/tcpip/stack/ndp_test.go900
-rw-r--r--pkg/tcpip/stack/nic.go149
-rw-r--r--pkg/tcpip/stack/registration.go23
-rw-r--r--pkg/tcpip/stack/stack.go103
-rw-r--r--pkg/tcpip/stack/stack_test.go466
-rw-r--r--pkg/tcpip/stack/transport_demuxer.go20
-rw-r--r--pkg/tcpip/stack/transport_test.go15
-rw-r--r--pkg/tcpip/tcpip.go8
-rw-r--r--pkg/tcpip/time_unsafe.go2
-rw-r--r--pkg/tcpip/transport/icmp/endpoint.go5
-rw-r--r--pkg/tcpip/transport/icmp/protocol.go16
-rw-r--r--pkg/tcpip/transport/packet/endpoint.go5
-rw-r--r--pkg/tcpip/transport/raw/endpoint.go5
-rw-r--r--pkg/tcpip/transport/tcp/accept.go9
-rw-r--r--pkg/tcpip/transport/tcp/connect.go4
-rw-r--r--pkg/tcpip/transport/tcp/dispatcher.go31
-rw-r--r--pkg/tcpip/transport/tcp/endpoint.go33
-rw-r--r--pkg/tcpip/transport/tcp/protocol.go14
-rw-r--r--pkg/tcpip/transport/tcp/tcp_test.go5
-rw-r--r--pkg/tcpip/transport/udp/endpoint.go5
-rw-r--r--pkg/tcpip/transport/udp/protocol.go14
-rw-r--r--pkg/usermem/BUILD2
-rw-r--r--pkg/usermem/usermem.go9
-rw-r--r--runsc/boot/BUILD2
-rw-r--r--runsc/boot/controller.go11
-rw-r--r--runsc/boot/filter/config.go2
-rw-r--r--runsc/boot/loader.go121
-rw-r--r--runsc/boot/network.go27
-rw-r--r--runsc/boot/pprof/BUILD11
-rw-r--r--runsc/boot/pprof/pprof.go (renamed from runsc/boot/pprof.go)6
-rw-r--r--runsc/container/container_test.go1
-rw-r--r--runsc/sandbox/network.go25
-rwxr-xr-xscripts/common_build.sh12
-rw-r--r--test/perf/BUILD8
-rw-r--r--test/perf/linux/getdents_benchmark.cc2
-rw-r--r--test/runner/defs.bzl58
-rw-r--r--test/runner/gtest/gtest.go25
-rw-r--r--test/syscalls/BUILD4
-rw-r--r--test/syscalls/linux/32bit.cc2
-rw-r--r--test/syscalls/linux/BUILD47
-rw-r--r--test/syscalls/linux/dev.cc7
-rw-r--r--test/syscalls/linux/network_namespace.cc121
-rw-r--r--test/syscalls/linux/rseq/uapi.h29
-rw-r--r--test/syscalls/linux/socket_generic.cc96
-rw-r--r--test/syscalls/linux/socket_netlink_route_util.cc163
-rw-r--r--test/syscalls/linux/socket_netlink_route_util.h55
-rw-r--r--test/syscalls/linux/tuntap.cc346
-rw-r--r--test/syscalls/linux/udp_socket_test_cases.cc4
-rw-r--r--tools/bazeldefs/defs.bzl2
-rw-r--r--tools/bazeldefs/platforms.bzl17
-rw-r--r--tools/defs.bzl14
-rw-r--r--tools/go_marshal/BUILD5
-rw-r--r--tools/go_marshal/gomarshal/generator.go43
-rw-r--r--tools/go_marshal/gomarshal/generator_interfaces.go317
-rw-r--r--tools/go_marshal/gomarshal/generator_tests.go17
-rw-r--r--tools/go_marshal/marshal/marshal.go4
-rw-r--r--tools/go_marshal/test/test.go10
-rwxr-xr-xtools/installers/master.sh17
186 files changed, 8692 insertions, 1216 deletions
diff --git a/.travis.yml b/.travis.yml
new file mode 100644
index 000000000..a2a260538
--- /dev/null
+++ b/.travis.yml
@@ -0,0 +1,19 @@
+language: minimal
+sudo: required
+dist: xenial
+cache:
+ directories:
+ - /home/travis/.cache/bazel/
+services:
+ - docker
+matrix:
+ include:
+ - os: linux
+ arch: amd64
+ env: RUNSC_PATH=./bazel-bin/runsc/linux_amd64_pure_stripped/runsc
+ - os: linux
+ arch: arm64
+ env: RUNSC_PATH=./bazel-bin/runsc/linux_arm64_pure_stripped/runsc
+script:
+ - uname -a
+ - make DOCKER_RUN_OPTIONS="" BAZEL_OPTIONS="build runsc:runsc" bazel && $RUNSC_PATH --alsologtostderr --network none --debug --TESTONLY-unsafe-nonroot=true --rootless do ls
diff --git a/Dockerfile b/Dockerfile
index 738623023..2bfdfec6c 100644
--- a/Dockerfile
+++ b/Dockerfile
@@ -1,8 +1,9 @@
-FROM ubuntu:bionic
+FROM fedora:31
-RUN apt-get update && apt-get install -y curl gnupg2 git python python3 python3-distutils python3-pip
-RUN echo "deb [arch=amd64] http://storage.googleapis.com/bazel-apt stable jdk1.8" | tee /etc/apt/sources.list.d/bazel.list && \
- curl https://bazel.build/bazel-release.pub.gpg | apt-key add -
-RUN apt-get update && apt-get install -y bazel && apt-get clean
+RUN dnf install -y dnf-plugins-core && dnf copr enable -y vbatts/bazel
+
+RUN dnf install -y bazel2 git gcc make golang gcc-c++ glibc-devel python3 which python3-pip python3-devel libffi-devel openssl-devel pkg-config glibc-static
+
+RUN pip install pycparser
WORKDIR /gvisor
diff --git a/Makefile b/Makefile
index a73bc0c36..d9531fbd5 100644
--- a/Makefile
+++ b/Makefile
@@ -2,6 +2,9 @@ UID := $(shell id -u ${USER})
GID := $(shell id -g ${USER})
GVISOR_BAZEL_CACHE := $(shell readlink -f ~/.cache/bazel/)
+# The --privileged is required to run tests.
+DOCKER_RUN_OPTIONS ?= --privileged
+
all: runsc
docker-build:
@@ -19,7 +22,7 @@ bazel-server-start: docker-build
-v "$(CURDIR):$(CURDIR)" \
--workdir "$(CURDIR)" \
--tmpfs /tmp:rw,exec \
- --privileged \
+ $(DOCKER_RUN_OPTIONS) \
gvisor-bazel \
sh -c "while :; do sleep 100; done" && \
docker exec --user 0:0 -i gvisor-bazel sh -c "groupadd --gid $(GID) --non-unique gvisor && useradd --uid $(UID) --non-unique --gid $(GID) -d $(HOME) gvisor"
diff --git a/WORKSPACE b/WORKSPACE
index ff0196dc6..a15238a2e 100644
--- a/WORKSPACE
+++ b/WORKSPACE
@@ -33,6 +33,20 @@ load("@bazel_gazelle//:deps.bzl", "gazelle_dependencies", "go_repository")
gazelle_dependencies()
+# TODO(gvisor.dev/issue/1876): Move the statement to "External repositories"
+# block below once 1876 is fixed.
+#
+# The com_google_protobuf repository below would trigger downloading a older
+# version of org_golang_x_sys. If putting this repository statment in a place
+# after that of the com_google_protobuf, this statement will not work as
+# expectd to download a new version of org_golang_x_sys.
+go_repository(
+ name = "org_golang_x_sys",
+ importpath = "golang.org/x/sys",
+ sum = "h1:72l8qCJ1nGxMGH26QVBVIxKd/D34cfGt0OvrPtpemyY=",
+ version = "v0.0.0-20191220220014-0732a990476f",
+)
+
# Load C++ rules.
http_archive(
name = "rules_cc",
@@ -257,13 +271,6 @@ go_repository(
)
go_repository(
- name = "org_golang_x_sys",
- importpath = "golang.org/x/sys",
- sum = "h1:1BGLXjeY4akVXGgbC9HugT3Jv3hCI0z56oJR5vAMgBU=",
- version = "v0.0.0-20190215142949-d0b11bdaac8a",
-)
-
-go_repository(
name = "org_golang_x_time",
commit = "c4c64cad1fd0a1a8dab2523e04e61d35308e131e",
importpath = "golang.org/x/time",
diff --git a/benchmarks/harness/machine.py b/benchmarks/harness/machine.py
index 3d32d3dda..5bdc4aa85 100644
--- a/benchmarks/harness/machine.py
+++ b/benchmarks/harness/machine.py
@@ -43,6 +43,8 @@ from benchmarks.harness import machine_mocks
from benchmarks.harness import ssh_connection
from benchmarks.harness import tunnel_dispatcher
+log = logging.getLogger(__name__)
+
class Machine(object):
"""The machine object is the primary object for benchmarks.
@@ -236,9 +238,10 @@ class RemoteMachine(Machine):
archive=archive, dir=harness.REMOTE_INSTALLERS_PATH))
self._has_installers = True
- # Execute the remote installer.
- self.run("sudo {dir}/{file}".format(
- dir=harness.REMOTE_INSTALLERS_PATH, file=installer))
+ # Execute the remote installer.
+ self.run("sudo {dir}/{file}".format(
+ dir=harness.REMOTE_INSTALLERS_PATH, file=installer))
+
if results:
results[index] = True
diff --git a/benchmarks/harness/ssh_connection.py b/benchmarks/harness/ssh_connection.py
index a50e34293..b8c8e42d4 100644
--- a/benchmarks/harness/ssh_connection.py
+++ b/benchmarks/harness/ssh_connection.py
@@ -13,7 +13,7 @@
# limitations under the License.
"""SSHConnection handles the details of SSH connections."""
-
+import logging
import os
import warnings
@@ -24,6 +24,8 @@ from benchmarks import harness
# Get rid of paramiko Cryptography Warnings.
warnings.filterwarnings(action="ignore", module=".*paramiko.*")
+log = logging.getLogger(__name__)
+
def send_one_file(client: paramiko.SSHClient, path: str,
remote_dir: str) -> str:
@@ -94,10 +96,13 @@ class SSHConnection:
The contents of stdout and stderr.
"""
with self._client() as client:
+ log.info("running command: %s", cmd)
_, stdout, stderr = client.exec_command(command=cmd)
- stdout.channel.recv_exit_status()
+ log.info("returned status: %d", stdout.channel.recv_exit_status())
stdout = stdout.read().decode("utf-8")
stderr = stderr.read().decode("utf-8")
+ log.info("stdout: %s", stdout)
+ log.info("stderr: %s", stderr)
return stdout, stderr
def send_workload(self, name: str) -> str:
diff --git a/kokoro/runtime_tests/go1.12.cfg b/kokoro/runtime_tests/go1.12.cfg
index 164ddc18f..fd4911e88 100644
--- a/kokoro/runtime_tests/go1.12.cfg
+++ b/kokoro/runtime_tests/go1.12.cfg
@@ -4,3 +4,13 @@ env_vars {
key: "RUNTIME_TEST_NAME"
value: "go1.12"
}
+
+action {
+ define_artifacts {
+ regex: "**/sponge_log.xml"
+ regex: "**/sponge_log.log"
+ regex: "**/outputs.zip"
+ regex: "**/runsc"
+ regex: "**/runsc.*"
+ }
+} \ No newline at end of file
diff --git a/kokoro/runtime_tests/java11.cfg b/kokoro/runtime_tests/java11.cfg
index 4957d4794..7f8611a08 100644
--- a/kokoro/runtime_tests/java11.cfg
+++ b/kokoro/runtime_tests/java11.cfg
@@ -4,3 +4,13 @@ env_vars {
key: "RUNTIME_TEST_NAME"
value: "java11"
}
+
+action {
+ define_artifacts {
+ regex: "**/sponge_log.xml"
+ regex: "**/sponge_log.log"
+ regex: "**/outputs.zip"
+ regex: "**/runsc"
+ regex: "**/runsc.*"
+ }
+} \ No newline at end of file
diff --git a/kokoro/runtime_tests/nodejs12.4.0.cfg b/kokoro/runtime_tests/nodejs12.4.0.cfg
index 1df343f95..c67ad5567 100644
--- a/kokoro/runtime_tests/nodejs12.4.0.cfg
+++ b/kokoro/runtime_tests/nodejs12.4.0.cfg
@@ -4,3 +4,13 @@ env_vars {
key: "RUNTIME_TEST_NAME"
value: "nodejs12.4.0"
}
+
+action {
+ define_artifacts {
+ regex: "**/sponge_log.xml"
+ regex: "**/sponge_log.log"
+ regex: "**/outputs.zip"
+ regex: "**/runsc"
+ regex: "**/runsc.*"
+ }
+} \ No newline at end of file
diff --git a/kokoro/runtime_tests/php7.3.6.cfg b/kokoro/runtime_tests/php7.3.6.cfg
index 8e3667125..f266c5e26 100644
--- a/kokoro/runtime_tests/php7.3.6.cfg
+++ b/kokoro/runtime_tests/php7.3.6.cfg
@@ -4,3 +4,13 @@ env_vars {
key: "RUNTIME_TEST_NAME"
value: "php7.3.6"
}
+
+action {
+ define_artifacts {
+ regex: "**/sponge_log.xml"
+ regex: "**/sponge_log.log"
+ regex: "**/outputs.zip"
+ regex: "**/runsc"
+ regex: "**/runsc.*"
+ }
+} \ No newline at end of file
diff --git a/kokoro/runtime_tests/python3.7.3.cfg b/kokoro/runtime_tests/python3.7.3.cfg
index 0ca70d5bb..574add152 100644
--- a/kokoro/runtime_tests/python3.7.3.cfg
+++ b/kokoro/runtime_tests/python3.7.3.cfg
@@ -4,3 +4,13 @@ env_vars {
key: "RUNTIME_TEST_NAME"
value: "python3.7.3"
}
+
+action {
+ define_artifacts {
+ regex: "**/sponge_log.xml"
+ regex: "**/sponge_log.log"
+ regex: "**/outputs.zip"
+ regex: "**/runsc"
+ regex: "**/runsc.*"
+ }
+} \ No newline at end of file
diff --git a/pkg/abi/linux/BUILD b/pkg/abi/linux/BUILD
index a89f34d4b..322d1ccc4 100644
--- a/pkg/abi/linux/BUILD
+++ b/pkg/abi/linux/BUILD
@@ -30,6 +30,7 @@ go_library(
"futex.go",
"inotify.go",
"ioctl.go",
+ "ioctl_tun.go",
"ip.go",
"ipc.go",
"limits.go",
diff --git a/pkg/abi/linux/epoll.go b/pkg/abi/linux/epoll.go
index 6e4de69da..1121a1a92 100644
--- a/pkg/abi/linux/epoll.go
+++ b/pkg/abi/linux/epoll.go
@@ -14,6 +14,10 @@
package linux
+import (
+ "gvisor.dev/gvisor/pkg/binary"
+)
+
// Event masks.
const (
EPOLLIN = 0x1
@@ -53,3 +57,6 @@ const (
EPOLL_CTL_DEL = 0x2
EPOLL_CTL_MOD = 0x3
)
+
+// SizeOfEpollEvent is the size of EpollEvent struct.
+var SizeOfEpollEvent = int(binary.Size(EpollEvent{}))
diff --git a/pkg/abi/linux/epoll_amd64.go b/pkg/abi/linux/epoll_amd64.go
index 57041491c..34ff18009 100644
--- a/pkg/abi/linux/epoll_amd64.go
+++ b/pkg/abi/linux/epoll_amd64.go
@@ -15,6 +15,8 @@
package linux
// EpollEvent is equivalent to struct epoll_event from epoll(2).
+//
+// +marshal
type EpollEvent struct {
Events uint32
// Linux makes struct epoll_event::data a __u64. We represent it as
diff --git a/pkg/abi/linux/epoll_arm64.go b/pkg/abi/linux/epoll_arm64.go
index 62ef5821e..f86c35329 100644
--- a/pkg/abi/linux/epoll_arm64.go
+++ b/pkg/abi/linux/epoll_arm64.go
@@ -15,6 +15,8 @@
package linux
// EpollEvent is equivalent to struct epoll_event from epoll(2).
+//
+// +marshal
type EpollEvent struct {
Events uint32
// Linux makes struct epoll_event a __u64, necessitating 4 bytes of padding
diff --git a/pkg/abi/linux/file.go b/pkg/abi/linux/file.go
index c3ab15a4f..e229ac21c 100644
--- a/pkg/abi/linux/file.go
+++ b/pkg/abi/linux/file.go
@@ -241,6 +241,8 @@ const (
)
// Statx represents struct statx.
+//
+// +marshal
type Statx struct {
Mask uint32
Blksize uint32
diff --git a/pkg/abi/linux/fs.go b/pkg/abi/linux/fs.go
index 2c652baa2..158d2db5b 100644
--- a/pkg/abi/linux/fs.go
+++ b/pkg/abi/linux/fs.go
@@ -38,6 +38,8 @@ const (
)
// Statfs is struct statfs, from uapi/asm-generic/statfs.h.
+//
+// +marshal
type Statfs struct {
// Type is one of the filesystem magic values, defined above.
Type uint64
diff --git a/pkg/abi/linux/ioctl.go b/pkg/abi/linux/ioctl.go
index 0e18db9ef..2062e6a4b 100644
--- a/pkg/abi/linux/ioctl.go
+++ b/pkg/abi/linux/ioctl.go
@@ -72,3 +72,29 @@ const (
SIOCGMIIPHY = 0x8947
SIOCGMIIREG = 0x8948
)
+
+// ioctl(2) directions. Used to calculate requests number.
+// Constants from asm-generic/ioctl.h.
+const (
+ _IOC_NONE = 0
+ _IOC_WRITE = 1
+ _IOC_READ = 2
+)
+
+// Constants from asm-generic/ioctl.h.
+const (
+ _IOC_NRBITS = 8
+ _IOC_TYPEBITS = 8
+ _IOC_SIZEBITS = 14
+ _IOC_DIRBITS = 2
+
+ _IOC_NRSHIFT = 0
+ _IOC_TYPESHIFT = _IOC_NRSHIFT + _IOC_NRBITS
+ _IOC_SIZESHIFT = _IOC_TYPESHIFT + _IOC_TYPEBITS
+ _IOC_DIRSHIFT = _IOC_SIZESHIFT + _IOC_SIZEBITS
+)
+
+// IOC outputs the result of _IOC macro in asm-generic/ioctl.h.
+func IOC(dir, typ, nr, size uint32) uint32 {
+ return uint32(dir)<<_IOC_DIRSHIFT | typ<<_IOC_TYPESHIFT | nr<<_IOC_NRSHIFT | size<<_IOC_SIZESHIFT
+}
diff --git a/pkg/fspath/builder_unsafe.go b/pkg/abi/linux/ioctl_tun.go
index 75606808d..c59c9c136 100644
--- a/pkg/fspath/builder_unsafe.go
+++ b/pkg/abi/linux/ioctl_tun.go
@@ -1,4 +1,4 @@
-// Copyright 2019 The gVisor Authors.
+// Copyright 2020 The gVisor Authors.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
@@ -12,16 +12,18 @@
// See the License for the specific language governing permissions and
// limitations under the License.
-package fspath
+package linux
-import (
- "unsafe"
+// ioctl(2) request numbers from linux/if_tun.h
+var (
+ TUNSETIFF = IOC(_IOC_WRITE, 'T', 202, 4)
+ TUNGETIFF = IOC(_IOC_READ, 'T', 210, 4)
)
-// String returns the accumulated string. No other methods should be called
-// after String.
-func (b *Builder) String() string {
- bs := b.buf[b.start:]
- // Compare strings.Builder.String().
- return *(*string)(unsafe.Pointer(&bs))
-}
+// Flags from net/if_tun.h
+const (
+ IFF_TUN = 0x0001
+ IFF_TAP = 0x0002
+ IFF_NO_PI = 0x1000
+ IFF_NOFILTER = 0x1000
+)
diff --git a/pkg/abi/linux/signal.go b/pkg/abi/linux/signal.go
index c69b04ea9..1c330e763 100644
--- a/pkg/abi/linux/signal.go
+++ b/pkg/abi/linux/signal.go
@@ -115,6 +115,8 @@ const (
)
// SignalSet is a signal mask with a bit corresponding to each signal.
+//
+// +marshal
type SignalSet uint64
// SignalSetSize is the size in bytes of a SignalSet.
diff --git a/pkg/abi/linux/time.go b/pkg/abi/linux/time.go
index e562b46d9..e6860ed49 100644
--- a/pkg/abi/linux/time.go
+++ b/pkg/abi/linux/time.go
@@ -157,6 +157,8 @@ func DurationToTimespec(dur time.Duration) Timespec {
const SizeOfTimeval = 16
// Timeval represents struct timeval in <time.h>.
+//
+// +marshal
type Timeval struct {
Sec int64
Usec int64
@@ -230,6 +232,8 @@ type Tms struct {
type TimerID int32
// StatxTimestamp represents struct statx_timestamp.
+//
+// +marshal
type StatxTimestamp struct {
Sec int64
Nsec uint32
@@ -258,6 +262,8 @@ func NsecToStatxTimestamp(nsec int64) (ts StatxTimestamp) {
}
// Utime represents struct utimbuf used by utimes(2).
+//
+// +marshal
type Utime struct {
Actime int64
Modtime int64
diff --git a/pkg/abi/linux/xattr.go b/pkg/abi/linux/xattr.go
index a3b6406fa..99180b208 100644
--- a/pkg/abi/linux/xattr.go
+++ b/pkg/abi/linux/xattr.go
@@ -18,6 +18,7 @@ package linux
const (
XATTR_NAME_MAX = 255
XATTR_SIZE_MAX = 65536
+ XATTR_LIST_MAX = 65536
XATTR_CREATE = 1
XATTR_REPLACE = 2
diff --git a/pkg/atomicbitops/BUILD b/pkg/atomicbitops/BUILD
index ba8b06071..1a30f6967 100644
--- a/pkg/atomicbitops/BUILD
+++ b/pkg/atomicbitops/BUILD
@@ -7,6 +7,7 @@ go_library(
srcs = [
"atomicbitops.go",
"atomicbitops_amd64.s",
+ "atomicbitops_arm64.s",
"atomicbitops_noasm.go",
],
visibility = ["//:sandbox"],
diff --git a/pkg/fspath/BUILD b/pkg/fspath/BUILD
index ee84471b2..67dd1e225 100644
--- a/pkg/fspath/BUILD
+++ b/pkg/fspath/BUILD
@@ -8,9 +8,11 @@ go_library(
name = "fspath",
srcs = [
"builder.go",
- "builder_unsafe.go",
"fspath.go",
],
+ deps = [
+ "//pkg/gohacks",
+ ],
)
go_test(
diff --git a/pkg/fspath/builder.go b/pkg/fspath/builder.go
index 7ddb36826..6318d3874 100644
--- a/pkg/fspath/builder.go
+++ b/pkg/fspath/builder.go
@@ -16,6 +16,8 @@ package fspath
import (
"fmt"
+
+ "gvisor.dev/gvisor/pkg/gohacks"
)
// Builder is similar to strings.Builder, but is used to produce pathnames
@@ -102,3 +104,9 @@ func (b *Builder) AppendString(str string) {
copy(b.buf[b.start:], b.buf[oldStart:])
copy(b.buf[len(b.buf)-len(str):], str)
}
+
+// String returns the accumulated string. No other methods should be called
+// after String.
+func (b *Builder) String() string {
+ return gohacks.StringFromImmutableBytes(b.buf[b.start:])
+}
diff --git a/pkg/fspath/fspath.go b/pkg/fspath/fspath.go
index 9fb3fee24..4c983d5fd 100644
--- a/pkg/fspath/fspath.go
+++ b/pkg/fspath/fspath.go
@@ -67,7 +67,8 @@ func Parse(pathname string) Path {
// Path contains the information contained in a pathname string.
//
-// Path is copyable by value.
+// Path is copyable by value. The zero value for Path is equivalent to
+// fspath.Parse(""), i.e. the empty path.
type Path struct {
// Begin is an iterator to the first path component in the relative part of
// the path.
diff --git a/pkg/gohacks/BUILD b/pkg/gohacks/BUILD
new file mode 100644
index 000000000..798a65eca
--- /dev/null
+++ b/pkg/gohacks/BUILD
@@ -0,0 +1,11 @@
+load("//tools:defs.bzl", "go_library")
+
+package(licenses = ["notice"])
+
+go_library(
+ name = "gohacks",
+ srcs = [
+ "gohacks_unsafe.go",
+ ],
+ visibility = ["//:sandbox"],
+)
diff --git a/pkg/gohacks/gohacks_unsafe.go b/pkg/gohacks/gohacks_unsafe.go
new file mode 100644
index 000000000..aad675172
--- /dev/null
+++ b/pkg/gohacks/gohacks_unsafe.go
@@ -0,0 +1,57 @@
+// Copyright 2020 The gVisor Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+// Package gohacks contains utilities for subverting the Go compiler.
+package gohacks
+
+import (
+ "reflect"
+ "unsafe"
+)
+
+// Noescape hides a pointer from escape analysis. Noescape is the identity
+// function but escape analysis doesn't think the output depends on the input.
+// Noescape is inlined and currently compiles down to zero instructions.
+// USE CAREFULLY!
+//
+// (Noescape is copy/pasted from Go's runtime/stubs.go:noescape().)
+//
+//go:nosplit
+func Noescape(p unsafe.Pointer) unsafe.Pointer {
+ x := uintptr(p)
+ return unsafe.Pointer(x ^ 0)
+}
+
+// ImmutableBytesFromString is equivalent to []byte(s), except that it uses the
+// same memory backing s instead of making a heap-allocated copy. This is only
+// valid if the returned slice is never mutated.
+func ImmutableBytesFromString(s string) []byte {
+ shdr := (*reflect.StringHeader)(unsafe.Pointer(&s))
+ var bs []byte
+ bshdr := (*reflect.SliceHeader)(unsafe.Pointer(&bs))
+ bshdr.Data = shdr.Data
+ bshdr.Len = shdr.Len
+ bshdr.Cap = shdr.Len
+ return bs
+}
+
+// StringFromImmutableBytes is equivalent to string(bs), except that it uses
+// the same memory backing bs instead of making a heap-allocated copy. This is
+// only valid if bs is never mutated after StringFromImmutableBytes returns.
+func StringFromImmutableBytes(bs []byte) string {
+ // This is cheaper than messing with reflect.StringHeader and
+ // reflect.SliceHeader, which as of this writing produces many dead stores
+ // of zeroes. Compare strings.Builder.String().
+ return *(*string)(unsafe.Pointer(&bs))
+}
diff --git a/pkg/log/glog.go b/pkg/log/glog.go
index cab5fae55..b4f7bb5a4 100644
--- a/pkg/log/glog.go
+++ b/pkg/log/glog.go
@@ -46,7 +46,7 @@ var pid = os.Getpid()
// line The line number
// msg The user-supplied message
//
-func (g *GoogleEmitter) Emit(level Level, timestamp time.Time, format string, args ...interface{}) {
+func (g *GoogleEmitter) Emit(depth int, level Level, timestamp time.Time, format string, args ...interface{}) {
// Log level.
prefix := byte('?')
switch level {
@@ -64,9 +64,7 @@ func (g *GoogleEmitter) Emit(level Level, timestamp time.Time, format string, ar
microsecond := int(timestamp.Nanosecond() / 1000)
// 0 = this frame.
- // 1 = Debugf, etc.
- // 2 = Caller.
- _, file, line, ok := runtime.Caller(2)
+ _, file, line, ok := runtime.Caller(depth + 1)
if ok {
// Trim any directory path from the file.
slash := strings.LastIndexByte(file, byte('/'))
diff --git a/pkg/log/json.go b/pkg/log/json.go
index a278c8fc8..0943db1cc 100644
--- a/pkg/log/json.go
+++ b/pkg/log/json.go
@@ -62,7 +62,7 @@ type JSONEmitter struct {
}
// Emit implements Emitter.Emit.
-func (e JSONEmitter) Emit(level Level, timestamp time.Time, format string, v ...interface{}) {
+func (e JSONEmitter) Emit(_ int, level Level, timestamp time.Time, format string, v ...interface{}) {
j := jsonLog{
Msg: fmt.Sprintf(format, v...),
Level: level,
diff --git a/pkg/log/json_k8s.go b/pkg/log/json_k8s.go
index cee6eb514..6c6fc8b6f 100644
--- a/pkg/log/json_k8s.go
+++ b/pkg/log/json_k8s.go
@@ -33,7 +33,7 @@ type K8sJSONEmitter struct {
}
// Emit implements Emitter.Emit.
-func (e *K8sJSONEmitter) Emit(level Level, timestamp time.Time, format string, v ...interface{}) {
+func (e *K8sJSONEmitter) Emit(_ int, level Level, timestamp time.Time, format string, v ...interface{}) {
j := k8sJSONLog{
Log: fmt.Sprintf(format, v...),
Level: level,
diff --git a/pkg/log/log.go b/pkg/log/log.go
index 5056f17e6..a794da1aa 100644
--- a/pkg/log/log.go
+++ b/pkg/log/log.go
@@ -79,7 +79,7 @@ func (l Level) String() string {
type Emitter interface {
// Emit emits the given log statement. This allows for control over the
// timestamp used for logging.
- Emit(level Level, timestamp time.Time, format string, v ...interface{})
+ Emit(depth int, level Level, timestamp time.Time, format string, v ...interface{})
}
// Writer writes the output to the given writer.
@@ -142,7 +142,7 @@ func (l *Writer) Write(data []byte) (int, error) {
}
// Emit emits the message.
-func (l *Writer) Emit(level Level, timestamp time.Time, format string, args ...interface{}) {
+func (l *Writer) Emit(_ int, _ Level, _ time.Time, format string, args ...interface{}) {
fmt.Fprintf(l, format, args...)
}
@@ -150,9 +150,9 @@ func (l *Writer) Emit(level Level, timestamp time.Time, format string, args ...i
type MultiEmitter []Emitter
// Emit emits to all emitters.
-func (m *MultiEmitter) Emit(level Level, timestamp time.Time, format string, v ...interface{}) {
+func (m *MultiEmitter) Emit(depth int, level Level, timestamp time.Time, format string, v ...interface{}) {
for _, e := range *m {
- e.Emit(level, timestamp, format, v...)
+ e.Emit(1+depth, level, timestamp, format, v...)
}
}
@@ -167,7 +167,7 @@ type TestEmitter struct {
}
// Emit emits to the TestLogger.
-func (t *TestEmitter) Emit(level Level, timestamp time.Time, format string, v ...interface{}) {
+func (t *TestEmitter) Emit(_ int, level Level, timestamp time.Time, format string, v ...interface{}) {
t.Logf(format, v...)
}
@@ -198,22 +198,37 @@ type BasicLogger struct {
// Debugf implements logger.Debugf.
func (l *BasicLogger) Debugf(format string, v ...interface{}) {
- if l.IsLogging(Debug) {
- l.Emit(Debug, time.Now(), format, v...)
- }
+ l.DebugfAtDepth(1, format, v...)
}
// Infof implements logger.Infof.
func (l *BasicLogger) Infof(format string, v ...interface{}) {
- if l.IsLogging(Info) {
- l.Emit(Info, time.Now(), format, v...)
- }
+ l.InfofAtDepth(1, format, v...)
}
// Warningf implements logger.Warningf.
func (l *BasicLogger) Warningf(format string, v ...interface{}) {
+ l.WarningfAtDepth(1, format, v...)
+}
+
+// DebugfAtDepth logs at a specific depth.
+func (l *BasicLogger) DebugfAtDepth(depth int, format string, v ...interface{}) {
+ if l.IsLogging(Debug) {
+ l.Emit(1+depth, Debug, time.Now(), format, v...)
+ }
+}
+
+// InfofAtDepth logs at a specific depth.
+func (l *BasicLogger) InfofAtDepth(depth int, format string, v ...interface{}) {
+ if l.IsLogging(Info) {
+ l.Emit(1+depth, Info, time.Now(), format, v...)
+ }
+}
+
+// WarningfAtDepth logs at a specific depth.
+func (l *BasicLogger) WarningfAtDepth(depth int, format string, v ...interface{}) {
if l.IsLogging(Warning) {
- l.Emit(Warning, time.Now(), format, v...)
+ l.Emit(1+depth, Warning, time.Now(), format, v...)
}
}
@@ -257,17 +272,32 @@ func SetLevel(newLevel Level) {
// Debugf logs to the global logger.
func Debugf(format string, v ...interface{}) {
- Log().Debugf(format, v...)
+ Log().DebugfAtDepth(1, format, v...)
}
// Infof logs to the global logger.
func Infof(format string, v ...interface{}) {
- Log().Infof(format, v...)
+ Log().InfofAtDepth(1, format, v...)
}
// Warningf logs to the global logger.
func Warningf(format string, v ...interface{}) {
- Log().Warningf(format, v...)
+ Log().WarningfAtDepth(1, format, v...)
+}
+
+// DebugfAtDepth logs to the global logger.
+func DebugfAtDepth(depth int, format string, v ...interface{}) {
+ Log().DebugfAtDepth(1+depth, format, v...)
+}
+
+// InfofAtDepth logs to the global logger.
+func InfofAtDepth(depth int, format string, v ...interface{}) {
+ Log().InfofAtDepth(1+depth, format, v...)
+}
+
+// WarningfAtDepth logs to the global logger.
+func WarningfAtDepth(depth int, format string, v ...interface{}) {
+ Log().WarningfAtDepth(1+depth, format, v...)
}
// defaultStackSize is the default buffer size to allocate for stack traces.
diff --git a/pkg/sentry/arch/arch_aarch64.go b/pkg/sentry/arch/arch_aarch64.go
index 3b6987665..d7794db63 100644
--- a/pkg/sentry/arch/arch_aarch64.go
+++ b/pkg/sentry/arch/arch_aarch64.go
@@ -34,6 +34,9 @@ const (
SyscallWidth = 4
)
+// ARMTrapFlag is the mask for the trap flag.
+const ARMTrapFlag = uint64(1) << 21
+
// aarch64FPState is aarch64 floating point state.
type aarch64FPState []byte
diff --git a/pkg/sentry/fs/dev/BUILD b/pkg/sentry/fs/dev/BUILD
index 4c4b7d5cc..9b6bb26d0 100644
--- a/pkg/sentry/fs/dev/BUILD
+++ b/pkg/sentry/fs/dev/BUILD
@@ -9,6 +9,7 @@ go_library(
"device.go",
"fs.go",
"full.go",
+ "net_tun.go",
"null.go",
"random.go",
"tty.go",
@@ -19,15 +20,19 @@ go_library(
"//pkg/context",
"//pkg/rand",
"//pkg/safemem",
+ "//pkg/sentry/arch",
"//pkg/sentry/device",
"//pkg/sentry/fs",
"//pkg/sentry/fs/fsutil",
"//pkg/sentry/fs/ramfs",
"//pkg/sentry/fs/tmpfs",
+ "//pkg/sentry/kernel",
"//pkg/sentry/memmap",
"//pkg/sentry/mm",
"//pkg/sentry/pgalloc",
+ "//pkg/sentry/socket/netstack",
"//pkg/syserror",
+ "//pkg/tcpip/link/tun",
"//pkg/usermem",
"//pkg/waiter",
],
diff --git a/pkg/sentry/fs/dev/dev.go b/pkg/sentry/fs/dev/dev.go
index 35bd23991..7e66c29b0 100644
--- a/pkg/sentry/fs/dev/dev.go
+++ b/pkg/sentry/fs/dev/dev.go
@@ -66,8 +66,8 @@ func newMemDevice(ctx context.Context, iops fs.InodeOperations, msrc *fs.MountSo
})
}
-func newDirectory(ctx context.Context, msrc *fs.MountSource) *fs.Inode {
- iops := ramfs.NewDir(ctx, nil, fs.RootOwner, fs.FilePermsFromMode(0555))
+func newDirectory(ctx context.Context, contents map[string]*fs.Inode, msrc *fs.MountSource) *fs.Inode {
+ iops := ramfs.NewDir(ctx, contents, fs.RootOwner, fs.FilePermsFromMode(0555))
return fs.NewInode(ctx, iops, msrc, fs.StableAttr{
DeviceID: devDevice.DeviceID(),
InodeID: devDevice.NextIno(),
@@ -111,7 +111,7 @@ func New(ctx context.Context, msrc *fs.MountSource) *fs.Inode {
// A devpts is typically mounted at /dev/pts to provide
// pseudoterminal support. Place an empty directory there for
// the devpts to be mounted over.
- "pts": newDirectory(ctx, msrc),
+ "pts": newDirectory(ctx, nil, msrc),
// Similarly, applications expect a ptmx device at /dev/ptmx
// connected to the terminals provided by /dev/pts/. Rather
// than creating a device directly (which requires a hairy
@@ -124,6 +124,10 @@ func New(ctx context.Context, msrc *fs.MountSource) *fs.Inode {
"ptmx": newSymlink(ctx, "pts/ptmx", msrc),
"tty": newCharacterDevice(ctx, newTTYDevice(ctx, fs.RootOwner, 0666), msrc, ttyDevMajor, ttyDevMinor),
+
+ "net": newDirectory(ctx, map[string]*fs.Inode{
+ "tun": newCharacterDevice(ctx, newNetTunDevice(ctx, fs.RootOwner, 0666), msrc, netTunDevMajor, netTunDevMinor),
+ }, msrc),
}
iops := ramfs.NewDir(ctx, contents, fs.RootOwner, fs.FilePermsFromMode(0555))
diff --git a/pkg/sentry/fs/dev/net_tun.go b/pkg/sentry/fs/dev/net_tun.go
new file mode 100644
index 000000000..755644488
--- /dev/null
+++ b/pkg/sentry/fs/dev/net_tun.go
@@ -0,0 +1,170 @@
+// Copyright 2020 The gVisor Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+package dev
+
+import (
+ "gvisor.dev/gvisor/pkg/abi/linux"
+ "gvisor.dev/gvisor/pkg/context"
+ "gvisor.dev/gvisor/pkg/sentry/arch"
+ "gvisor.dev/gvisor/pkg/sentry/fs"
+ "gvisor.dev/gvisor/pkg/sentry/fs/fsutil"
+ "gvisor.dev/gvisor/pkg/sentry/kernel"
+ "gvisor.dev/gvisor/pkg/sentry/socket/netstack"
+ "gvisor.dev/gvisor/pkg/syserror"
+ "gvisor.dev/gvisor/pkg/tcpip/link/tun"
+ "gvisor.dev/gvisor/pkg/usermem"
+ "gvisor.dev/gvisor/pkg/waiter"
+)
+
+const (
+ netTunDevMajor = 10
+ netTunDevMinor = 200
+)
+
+// +stateify savable
+type netTunInodeOperations struct {
+ fsutil.InodeGenericChecker `state:"nosave"`
+ fsutil.InodeNoExtendedAttributes `state:"nosave"`
+ fsutil.InodeNoopAllocate `state:"nosave"`
+ fsutil.InodeNoopRelease `state:"nosave"`
+ fsutil.InodeNoopTruncate `state:"nosave"`
+ fsutil.InodeNoopWriteOut `state:"nosave"`
+ fsutil.InodeNotDirectory `state:"nosave"`
+ fsutil.InodeNotMappable `state:"nosave"`
+ fsutil.InodeNotSocket `state:"nosave"`
+ fsutil.InodeNotSymlink `state:"nosave"`
+ fsutil.InodeVirtual `state:"nosave"`
+
+ fsutil.InodeSimpleAttributes
+}
+
+var _ fs.InodeOperations = (*netTunInodeOperations)(nil)
+
+func newNetTunDevice(ctx context.Context, owner fs.FileOwner, mode linux.FileMode) *netTunInodeOperations {
+ return &netTunInodeOperations{
+ InodeSimpleAttributes: fsutil.NewInodeSimpleAttributes(ctx, owner, fs.FilePermsFromMode(mode), linux.TMPFS_MAGIC),
+ }
+}
+
+// GetFile implements fs.InodeOperations.GetFile.
+func (iops *netTunInodeOperations) GetFile(ctx context.Context, d *fs.Dirent, flags fs.FileFlags) (*fs.File, error) {
+ return fs.NewFile(ctx, d, flags, &netTunFileOperations{}), nil
+}
+
+// +stateify savable
+type netTunFileOperations struct {
+ fsutil.FileNoSeek `state:"nosave"`
+ fsutil.FileNoMMap `state:"nosave"`
+ fsutil.FileNoSplice `state:"nosave"`
+ fsutil.FileNoopFlush `state:"nosave"`
+ fsutil.FileNoopFsync `state:"nosave"`
+ fsutil.FileNotDirReaddir `state:"nosave"`
+ fsutil.FileUseInodeUnstableAttr `state:"nosave"`
+
+ device tun.Device
+}
+
+var _ fs.FileOperations = (*netTunFileOperations)(nil)
+
+// Release implements fs.FileOperations.Release.
+func (fops *netTunFileOperations) Release() {
+ fops.device.Release()
+}
+
+// Ioctl implements fs.FileOperations.Ioctl.
+func (fops *netTunFileOperations) Ioctl(ctx context.Context, file *fs.File, io usermem.IO, args arch.SyscallArguments) (uintptr, error) {
+ request := args[1].Uint()
+ data := args[2].Pointer()
+
+ switch request {
+ case linux.TUNSETIFF:
+ t := kernel.TaskFromContext(ctx)
+ if t == nil {
+ panic("Ioctl should be called from a task context")
+ }
+ if !t.HasCapability(linux.CAP_NET_ADMIN) {
+ return 0, syserror.EPERM
+ }
+ stack, ok := t.NetworkContext().(*netstack.Stack)
+ if !ok {
+ return 0, syserror.EINVAL
+ }
+
+ var req linux.IFReq
+ if _, err := usermem.CopyObjectIn(ctx, io, data, &req, usermem.IOOpts{
+ AddressSpaceActive: true,
+ }); err != nil {
+ return 0, err
+ }
+ flags := usermem.ByteOrder.Uint16(req.Data[:])
+ return 0, fops.device.SetIff(stack.Stack, req.Name(), flags)
+
+ case linux.TUNGETIFF:
+ var req linux.IFReq
+
+ copy(req.IFName[:], fops.device.Name())
+
+ // Linux adds IFF_NOFILTER (the same value as IFF_NO_PI unfortunately) when
+ // there is no sk_filter. See __tun_chr_ioctl() in net/drivers/tun.c.
+ flags := fops.device.Flags() | linux.IFF_NOFILTER
+ usermem.ByteOrder.PutUint16(req.Data[:], flags)
+
+ _, err := usermem.CopyObjectOut(ctx, io, data, &req, usermem.IOOpts{
+ AddressSpaceActive: true,
+ })
+ return 0, err
+
+ default:
+ return 0, syserror.ENOTTY
+ }
+}
+
+// Write implements fs.FileOperations.Write.
+func (fops *netTunFileOperations) Write(ctx context.Context, file *fs.File, src usermem.IOSequence, offset int64) (int64, error) {
+ data := make([]byte, src.NumBytes())
+ if _, err := src.CopyIn(ctx, data); err != nil {
+ return 0, err
+ }
+ return fops.device.Write(data)
+}
+
+// Read implements fs.FileOperations.Read.
+func (fops *netTunFileOperations) Read(ctx context.Context, file *fs.File, dst usermem.IOSequence, offset int64) (int64, error) {
+ data, err := fops.device.Read()
+ if err != nil {
+ return 0, err
+ }
+ n, err := dst.CopyOut(ctx, data)
+ if n > 0 && n < len(data) {
+ // Not an error for partial copying. Packet truncated.
+ err = nil
+ }
+ return int64(n), err
+}
+
+// Readiness implements watier.Waitable.Readiness.
+func (fops *netTunFileOperations) Readiness(mask waiter.EventMask) waiter.EventMask {
+ return fops.device.Readiness(mask)
+}
+
+// EventRegister implements watier.Waitable.EventRegister.
+func (fops *netTunFileOperations) EventRegister(e *waiter.Entry, mask waiter.EventMask) {
+ fops.device.EventRegister(e, mask)
+}
+
+// EventUnregister implements watier.Waitable.EventUnregister.
+func (fops *netTunFileOperations) EventUnregister(e *waiter.Entry) {
+ fops.device.EventUnregister(e)
+}
diff --git a/pkg/sentry/fs/dirent.go b/pkg/sentry/fs/dirent.go
index acab0411a..e0b32e1c1 100644
--- a/pkg/sentry/fs/dirent.go
+++ b/pkg/sentry/fs/dirent.go
@@ -1438,8 +1438,8 @@ func lockForRename(oldParent *Dirent, oldName string, newParent *Dirent, newName
}, nil
}
-func checkSticky(ctx context.Context, dir *Dirent, victim *Dirent) error {
- uattr, err := dir.Inode.UnstableAttr(ctx)
+func (d *Dirent) checkSticky(ctx context.Context, victim *Dirent) error {
+ uattr, err := d.Inode.UnstableAttr(ctx)
if err != nil {
return syserror.EPERM
}
@@ -1465,30 +1465,33 @@ func checkSticky(ctx context.Context, dir *Dirent, victim *Dirent) error {
return syserror.EPERM
}
-// MayDelete determines whether `name`, a child of `dir`, can be deleted or
+// MayDelete determines whether `name`, a child of `d`, can be deleted or
// renamed by `ctx`.
//
// Compare Linux kernel fs/namei.c:may_delete.
-func MayDelete(ctx context.Context, root, dir *Dirent, name string) error {
- if err := dir.Inode.CheckPermission(ctx, PermMask{Write: true, Execute: true}); err != nil {
+func (d *Dirent) MayDelete(ctx context.Context, root *Dirent, name string) error {
+ if err := d.Inode.CheckPermission(ctx, PermMask{Write: true, Execute: true}); err != nil {
return err
}
- victim, err := dir.Walk(ctx, root, name)
+ unlock := d.lockDirectory()
+ defer unlock()
+
+ victim, err := d.walk(ctx, root, name, true /* may unlock */)
if err != nil {
return err
}
defer victim.DecRef()
- return mayDelete(ctx, dir, victim)
+ return d.mayDelete(ctx, victim)
}
// mayDelete determines whether `victim`, a child of `dir`, can be deleted or
// renamed by `ctx`.
//
// Preconditions: `dir` is writable and executable by `ctx`.
-func mayDelete(ctx context.Context, dir, victim *Dirent) error {
- if err := checkSticky(ctx, dir, victim); err != nil {
+func (d *Dirent) mayDelete(ctx context.Context, victim *Dirent) error {
+ if err := d.checkSticky(ctx, victim); err != nil {
return err
}
@@ -1542,7 +1545,7 @@ func Rename(ctx context.Context, root *Dirent, oldParent *Dirent, oldName string
defer renamed.DecRef()
// Check that the renamed dirent is deletable.
- if err := mayDelete(ctx, oldParent, renamed); err != nil {
+ if err := oldParent.mayDelete(ctx, renamed); err != nil {
return err
}
@@ -1580,7 +1583,7 @@ func Rename(ctx context.Context, root *Dirent, oldParent *Dirent, oldName string
// across the Rename, so must call DecRef manually (no defer).
// Check that we can delete replaced.
- if err := mayDelete(ctx, newParent, replaced); err != nil {
+ if err := newParent.mayDelete(ctx, replaced); err != nil {
replaced.DecRef()
return err
}
diff --git a/pkg/sentry/fs/mount_test.go b/pkg/sentry/fs/mount_test.go
index e672a438c..a3d10770b 100644
--- a/pkg/sentry/fs/mount_test.go
+++ b/pkg/sentry/fs/mount_test.go
@@ -36,11 +36,12 @@ func mountPathsAre(root *Dirent, got []*Mount, want ...string) error {
gotPaths := make(map[string]struct{}, len(got))
gotStr := make([]string, len(got))
for i, g := range got {
- groot := g.Root()
- name, _ := groot.FullName(root)
- groot.DecRef()
- gotStr[i] = name
- gotPaths[name] = struct{}{}
+ if groot := g.Root(); groot != nil {
+ name, _ := groot.FullName(root)
+ groot.DecRef()
+ gotStr[i] = name
+ gotPaths[name] = struct{}{}
+ }
}
if len(got) != len(want) {
return fmt.Errorf("mount paths are different, got: %q, want: %q", gotStr, want)
diff --git a/pkg/sentry/fs/mounts.go b/pkg/sentry/fs/mounts.go
index 574a2cc91..c7981f66e 100644
--- a/pkg/sentry/fs/mounts.go
+++ b/pkg/sentry/fs/mounts.go
@@ -100,10 +100,14 @@ func newUndoMount(d *Dirent) *Mount {
}
}
-// Root returns the root dirent of this mount. Callers must call DecRef on the
-// returned dirent.
+// Root returns the root dirent of this mount.
+//
+// This may return nil if the mount has already been free. Callers must handle this
+// case appropriately. If non-nil, callers must call DecRef on the returned *Dirent.
func (m *Mount) Root() *Dirent {
- m.root.IncRef()
+ if !m.root.TryIncRef() {
+ return nil
+ }
return m.root
}
diff --git a/pkg/sentry/fs/proc/mounts.go b/pkg/sentry/fs/proc/mounts.go
index c10888100..94deb553b 100644
--- a/pkg/sentry/fs/proc/mounts.go
+++ b/pkg/sentry/fs/proc/mounts.go
@@ -60,13 +60,15 @@ func forEachMount(t *kernel.Task, fn func(string, *fs.Mount)) {
})
for _, m := range ms {
mroot := m.Root()
+ if mroot == nil {
+ continue // No longer valid.
+ }
mountPath, desc := mroot.FullName(rootDir)
mroot.DecRef()
if !desc {
// MountSources that are not descendants of the chroot jail are ignored.
continue
}
-
fn(mountPath, m)
}
}
@@ -91,6 +93,12 @@ func (mif *mountInfoFile) ReadSeqFileData(ctx context.Context, handle seqfile.Se
var buf bytes.Buffer
forEachMount(mif.t, func(mountPath string, m *fs.Mount) {
+ mroot := m.Root()
+ if mroot == nil {
+ return // No longer valid.
+ }
+ defer mroot.DecRef()
+
// Format:
// 36 35 98:0 /mnt1 /mnt2 rw,noatime master:1 - ext3 /dev/root rw,errors=continue
// (1)(2)(3) (4) (5) (6) (7) (8) (9) (10) (11)
@@ -107,9 +115,6 @@ func (mif *mountInfoFile) ReadSeqFileData(ctx context.Context, handle seqfile.Se
// (3) Major:Minor device ID. We don't have a superblock, so we
// just use the root inode device number.
- mroot := m.Root()
- defer mroot.DecRef()
-
sa := mroot.Inode.StableAttr
fmt.Fprintf(&buf, "%d:%d ", sa.DeviceFileMajor, sa.DeviceFileMinor)
@@ -207,6 +212,9 @@ func (mf *mountsFile) ReadSeqFileData(ctx context.Context, handle seqfile.SeqHan
//
// The "needs dump"and fsck flags are always 0, which is allowed.
root := m.Root()
+ if root == nil {
+ return // No longer valid.
+ }
defer root.DecRef()
flags := root.Inode.MountSource.Flags
diff --git a/pkg/sentry/fs/proc/net.go b/pkg/sentry/fs/proc/net.go
index 6f2775344..95d5817ff 100644
--- a/pkg/sentry/fs/proc/net.go
+++ b/pkg/sentry/fs/proc/net.go
@@ -43,7 +43,10 @@ import (
// newNet creates a new proc net entry.
func (p *proc) newNetDir(ctx context.Context, k *kernel.Kernel, msrc *fs.MountSource) *fs.Inode {
var contents map[string]*fs.Inode
- if s := p.k.NetworkStack(); s != nil {
+ // TODO(gvisor.dev/issue/1833): Support for using the network stack in the
+ // network namespace of the calling process. We should make this per-process,
+ // a.k.a. /proc/PID/net, and make /proc/net a symlink to /proc/self/net.
+ if s := p.k.RootNetworkNamespace().Stack(); s != nil {
contents = map[string]*fs.Inode{
"dev": seqfile.NewSeqFileInode(ctx, &netDev{s: s}, msrc),
"snmp": seqfile.NewSeqFileInode(ctx, &netSnmp{s: s}, msrc),
diff --git a/pkg/sentry/fs/proc/sys_net.go b/pkg/sentry/fs/proc/sys_net.go
index 0772d4ae4..d4c4b533d 100644
--- a/pkg/sentry/fs/proc/sys_net.go
+++ b/pkg/sentry/fs/proc/sys_net.go
@@ -357,7 +357,9 @@ func (p *proc) newSysNetIPv4Dir(ctx context.Context, msrc *fs.MountSource, s ine
func (p *proc) newSysNetDir(ctx context.Context, msrc *fs.MountSource) *fs.Inode {
var contents map[string]*fs.Inode
- if s := p.k.NetworkStack(); s != nil {
+ // TODO(gvisor.dev/issue/1833): Support for using the network stack in the
+ // network namespace of the calling process.
+ if s := p.k.RootNetworkNamespace().Stack(); s != nil {
contents = map[string]*fs.Inode{
"ipv4": p.newSysNetIPv4Dir(ctx, msrc, s),
"core": p.newSysNetCore(ctx, msrc, s),
diff --git a/pkg/sentry/fsbridge/vfs.go b/pkg/sentry/fsbridge/vfs.go
index e657c39bc..6aa17bfc1 100644
--- a/pkg/sentry/fsbridge/vfs.go
+++ b/pkg/sentry/fsbridge/vfs.go
@@ -117,15 +117,19 @@ func NewVFSLookup(mntns *vfs.MountNamespace, root, workingDir vfs.VirtualDentry)
// default anyways.
//
// TODO(gvisor.dev/issue/1623): Check mount has read and exec permission.
-func (l *vfsLookup) OpenPath(ctx context.Context, path string, opts vfs.OpenOptions, _ *uint, resolveFinal bool) (File, error) {
+func (l *vfsLookup) OpenPath(ctx context.Context, pathname string, opts vfs.OpenOptions, _ *uint, resolveFinal bool) (File, error) {
vfsObj := l.mntns.Root().Mount().Filesystem().VirtualFilesystem()
creds := auth.CredentialsFromContext(ctx)
+ path := fspath.Parse(pathname)
pop := &vfs.PathOperation{
Root: l.root,
- Start: l.root,
- Path: fspath.Parse(path),
+ Start: l.workingDir,
+ Path: path,
FollowFinalSymlink: resolveFinal,
}
+ if path.Absolute {
+ pop.Start = l.root
+ }
fd, err := vfsObj.OpenAt(ctx, creds, pop, &opts)
if err != nil {
return nil, err
diff --git a/pkg/sentry/fsimpl/proc/tasks.go b/pkg/sentry/fsimpl/proc/tasks.go
index ce08a7d53..10c08fa90 100644
--- a/pkg/sentry/fsimpl/proc/tasks.go
+++ b/pkg/sentry/fsimpl/proc/tasks.go
@@ -73,9 +73,9 @@ func newTasksInode(inoGen InoGenerator, k *kernel.Kernel, pidns *kernel.PIDNames
"meminfo": newDentry(root, inoGen.NextIno(), 0444, &meminfoData{}),
"mounts": kernfs.NewStaticSymlink(root, inoGen.NextIno(), "self/mounts"),
"net": newNetDir(root, inoGen, k),
- "stat": newDentry(root, inoGen.NextIno(), 0444, &statData{}),
+ "stat": newDentry(root, inoGen.NextIno(), 0444, &statData{k: k}),
"uptime": newDentry(root, inoGen.NextIno(), 0444, &uptimeData{}),
- "version": newDentry(root, inoGen.NextIno(), 0444, &versionData{}),
+ "version": newDentry(root, inoGen.NextIno(), 0444, &versionData{k: k}),
}
inode := &tasksInode{
diff --git a/pkg/sentry/fsimpl/proc/tasks_net.go b/pkg/sentry/fsimpl/proc/tasks_net.go
index 608fec017..d4e1812d8 100644
--- a/pkg/sentry/fsimpl/proc/tasks_net.go
+++ b/pkg/sentry/fsimpl/proc/tasks_net.go
@@ -39,7 +39,10 @@ import (
func newNetDir(root *auth.Credentials, inoGen InoGenerator, k *kernel.Kernel) *kernfs.Dentry {
var contents map[string]*kernfs.Dentry
- if stack := k.NetworkStack(); stack != nil {
+ // TODO(gvisor.dev/issue/1833): Support for using the network stack in the
+ // network namespace of the calling process. We should make this per-process,
+ // a.k.a. /proc/PID/net, and make /proc/net a symlink to /proc/self/net.
+ if stack := k.RootNetworkNamespace().Stack(); stack != nil {
const (
arp = "IP address HW type Flags HW address Mask Device\n"
netlink = "sk Eth Pid Groups Rmem Wmem Dump Locks Drops Inode\n"
diff --git a/pkg/sentry/fsimpl/proc/tasks_sys.go b/pkg/sentry/fsimpl/proc/tasks_sys.go
index c7ce74883..3d5dc463c 100644
--- a/pkg/sentry/fsimpl/proc/tasks_sys.go
+++ b/pkg/sentry/fsimpl/proc/tasks_sys.go
@@ -50,7 +50,9 @@ func newSysDir(root *auth.Credentials, inoGen InoGenerator, k *kernel.Kernel) *k
func newSysNetDir(root *auth.Credentials, inoGen InoGenerator, k *kernel.Kernel) *kernfs.Dentry {
var contents map[string]*kernfs.Dentry
- if stack := k.NetworkStack(); stack != nil {
+ // TODO(gvisor.dev/issue/1833): Support for using the network stack in the
+ // network namespace of the calling process.
+ if stack := k.RootNetworkNamespace().Stack(); stack != nil {
contents = map[string]*kernfs.Dentry{
"ipv4": kernfs.NewStaticDir(root, inoGen.NextIno(), 0555, map[string]*kernfs.Dentry{
"tcp_sack": newDentry(root, inoGen.NextIno(), 0644, &tcpSackData{stack: stack}),
diff --git a/pkg/sentry/fsimpl/testutil/kernel.go b/pkg/sentry/fsimpl/testutil/kernel.go
index d0be32e72..488478e29 100644
--- a/pkg/sentry/fsimpl/testutil/kernel.go
+++ b/pkg/sentry/fsimpl/testutil/kernel.go
@@ -128,6 +128,7 @@ func CreateTask(ctx context.Context, name string, tc *kernel.ThreadGroup, mntns
ThreadGroup: tc,
TaskContext: &kernel.TaskContext{Name: name},
Credentials: auth.CredentialsFromContext(ctx),
+ NetworkNamespace: k.RootNetworkNamespace(),
AllowedCPUMask: sched.NewFullCPUSet(k.ApplicationCores()),
UTSNamespace: kernel.UTSNamespaceFromContext(ctx),
IPCNamespace: kernel.IPCNamespaceFromContext(ctx),
diff --git a/pkg/sentry/fsimpl/tmpfs/filesystem.go b/pkg/sentry/fsimpl/tmpfs/filesystem.go
index 7f7b791c4..e1b551422 100644
--- a/pkg/sentry/fsimpl/tmpfs/filesystem.go
+++ b/pkg/sentry/fsimpl/tmpfs/filesystem.go
@@ -16,7 +16,6 @@ package tmpfs
import (
"fmt"
- "sync/atomic"
"gvisor.dev/gvisor/pkg/abi/linux"
"gvisor.dev/gvisor/pkg/context"
@@ -347,10 +346,9 @@ func (d *dentry) open(ctx context.Context, rp *vfs.ResolvingPath, opts *vfs.Open
return nil, err
}
if opts.Flags&linux.O_TRUNC != 0 {
- impl.mu.Lock()
- impl.data.Truncate(0, impl.memFile)
- atomic.StoreUint64(&impl.size, 0)
- impl.mu.Unlock()
+ if _, err := impl.truncate(0); err != nil {
+ return nil, err
+ }
}
return &fd.vfsfd, nil
case *directory:
diff --git a/pkg/sentry/fsimpl/tmpfs/regular_file.go b/pkg/sentry/fsimpl/tmpfs/regular_file.go
index dab346a41..711442424 100644
--- a/pkg/sentry/fsimpl/tmpfs/regular_file.go
+++ b/pkg/sentry/fsimpl/tmpfs/regular_file.go
@@ -15,6 +15,7 @@
package tmpfs
import (
+ "fmt"
"io"
"math"
"sync/atomic"
@@ -22,6 +23,7 @@ import (
"gvisor.dev/gvisor/pkg/abi/linux"
"gvisor.dev/gvisor/pkg/context"
"gvisor.dev/gvisor/pkg/safemem"
+ "gvisor.dev/gvisor/pkg/sentry/fs"
"gvisor.dev/gvisor/pkg/sentry/fs/fsutil"
"gvisor.dev/gvisor/pkg/sentry/fs/lock"
"gvisor.dev/gvisor/pkg/sentry/kernel/auth"
@@ -34,25 +36,53 @@ import (
"gvisor.dev/gvisor/pkg/usermem"
)
+// regularFile is a regular (=S_IFREG) tmpfs file.
type regularFile struct {
inode inode
// memFile is a platform.File used to allocate pages to this regularFile.
memFile *pgalloc.MemoryFile
- // mu protects the fields below.
- mu sync.RWMutex
+ // mapsMu protects mappings.
+ mapsMu sync.Mutex `state:"nosave"`
+
+ // mappings tracks mappings of the file into memmap.MappingSpaces.
+ //
+ // Protected by mapsMu.
+ mappings memmap.MappingSet
+
+ // writableMappingPages tracks how many pages of virtual memory are mapped
+ // as potentially writable from this file. If a page has multiple mappings,
+ // each mapping is counted separately.
+ //
+ // This counter is susceptible to overflow as we can potentially count
+ // mappings from many VMAs. We count pages rather than bytes to slightly
+ // mitigate this.
+ //
+ // Protected by mapsMu.
+ writableMappingPages uint64
+
+ // dataMu protects the fields below.
+ dataMu sync.RWMutex
// data maps offsets into the file to offsets into memFile that store
// the file's data.
+ //
+ // Protected by dataMu.
data fsutil.FileRangeSet
- // size is the size of data, but accessed using atomic memory
- // operations to avoid locking in inode.stat().
- size uint64
-
// seals represents file seals on this inode.
+ //
+ // Protected by dataMu.
seals uint32
+
+ // size is the size of data.
+ //
+ // Protected by both dataMu and inode.mu; reading it requires holding
+ // either mutex, while writing requires holding both AND using atomics.
+ // Readers that do not require consistency (like Stat) may read the
+ // value atomically without holding either lock.
+ size uint64
}
func (fs *filesystem) newRegularFile(creds *auth.Credentials, mode linux.FileMode) *inode {
@@ -66,39 +96,170 @@ func (fs *filesystem) newRegularFile(creds *auth.Credentials, mode linux.FileMod
// truncate grows or shrinks the file to the given size. It returns true if the
// file size was updated.
-func (rf *regularFile) truncate(size uint64) (bool, error) {
- rf.mu.Lock()
- defer rf.mu.Unlock()
+func (rf *regularFile) truncate(newSize uint64) (bool, error) {
+ rf.inode.mu.Lock()
+ defer rf.inode.mu.Unlock()
+ return rf.truncateLocked(newSize)
+}
- if size == rf.size {
+// Preconditions: rf.inode.mu must be held.
+func (rf *regularFile) truncateLocked(newSize uint64) (bool, error) {
+ oldSize := rf.size
+ if newSize == oldSize {
// Nothing to do.
return false, nil
}
- if size > rf.size {
- // Growing the file.
+ // Need to hold inode.mu and dataMu while modifying size.
+ rf.dataMu.Lock()
+ if newSize > oldSize {
+ // Can we grow the file?
if rf.seals&linux.F_SEAL_GROW != 0 {
- // Seal does not allow growth.
+ rf.dataMu.Unlock()
return false, syserror.EPERM
}
- rf.size = size
+ // We only need to update the file size.
+ atomic.StoreUint64(&rf.size, newSize)
+ rf.dataMu.Unlock()
return true, nil
}
- // Shrinking the file
+ // We are shrinking the file. First check if this is allowed.
if rf.seals&linux.F_SEAL_SHRINK != 0 {
- // Seal does not allow shrink.
+ rf.dataMu.Unlock()
return false, syserror.EPERM
}
- // TODO(gvisor.dev/issues/1197): Invalidate mappings once we have
- // mappings.
+ // Update the file size.
+ atomic.StoreUint64(&rf.size, newSize)
+ rf.dataMu.Unlock()
+
+ // Invalidate past translations of truncated pages.
+ oldpgend := fs.OffsetPageEnd(int64(oldSize))
+ newpgend := fs.OffsetPageEnd(int64(newSize))
+ if newpgend < oldpgend {
+ rf.mapsMu.Lock()
+ rf.mappings.Invalidate(memmap.MappableRange{newpgend, oldpgend}, memmap.InvalidateOpts{
+ // Compare Linux's mm/shmem.c:shmem_setattr() =>
+ // mm/memory.c:unmap_mapping_range(evencows=1).
+ InvalidatePrivate: true,
+ })
+ rf.mapsMu.Unlock()
+ }
- rf.data.Truncate(size, rf.memFile)
- rf.size = size
+ // We are now guaranteed that there are no translations of truncated pages,
+ // and can remove them.
+ rf.dataMu.Lock()
+ rf.data.Truncate(newSize, rf.memFile)
+ rf.dataMu.Unlock()
return true, nil
}
+// AddMapping implements memmap.Mappable.AddMapping.
+func (rf *regularFile) AddMapping(ctx context.Context, ms memmap.MappingSpace, ar usermem.AddrRange, offset uint64, writable bool) error {
+ rf.mapsMu.Lock()
+ defer rf.mapsMu.Unlock()
+ rf.dataMu.RLock()
+ defer rf.dataMu.RUnlock()
+
+ // Reject writable mapping if F_SEAL_WRITE is set.
+ if rf.seals&linux.F_SEAL_WRITE != 0 && writable {
+ return syserror.EPERM
+ }
+
+ rf.mappings.AddMapping(ms, ar, offset, writable)
+ if writable {
+ pagesBefore := rf.writableMappingPages
+
+ // ar is guaranteed to be page aligned per memmap.Mappable.
+ rf.writableMappingPages += uint64(ar.Length() / usermem.PageSize)
+
+ if rf.writableMappingPages < pagesBefore {
+ panic(fmt.Sprintf("Overflow while mapping potentially writable pages pointing to a tmpfs file. Before %v, after %v", pagesBefore, rf.writableMappingPages))
+ }
+ }
+
+ return nil
+}
+
+// RemoveMapping implements memmap.Mappable.RemoveMapping.
+func (rf *regularFile) RemoveMapping(ctx context.Context, ms memmap.MappingSpace, ar usermem.AddrRange, offset uint64, writable bool) {
+ rf.mapsMu.Lock()
+ defer rf.mapsMu.Unlock()
+
+ rf.mappings.RemoveMapping(ms, ar, offset, writable)
+
+ if writable {
+ pagesBefore := rf.writableMappingPages
+
+ // ar is guaranteed to be page aligned per memmap.Mappable.
+ rf.writableMappingPages -= uint64(ar.Length() / usermem.PageSize)
+
+ if rf.writableMappingPages > pagesBefore {
+ panic(fmt.Sprintf("Underflow while unmapping potentially writable pages pointing to a tmpfs file. Before %v, after %v", pagesBefore, rf.writableMappingPages))
+ }
+ }
+}
+
+// CopyMapping implements memmap.Mappable.CopyMapping.
+func (rf *regularFile) CopyMapping(ctx context.Context, ms memmap.MappingSpace, srcAR, dstAR usermem.AddrRange, offset uint64, writable bool) error {
+ return rf.AddMapping(ctx, ms, dstAR, offset, writable)
+}
+
+// Translate implements memmap.Mappable.Translate.
+func (rf *regularFile) Translate(ctx context.Context, required, optional memmap.MappableRange, at usermem.AccessType) ([]memmap.Translation, error) {
+ rf.dataMu.Lock()
+ defer rf.dataMu.Unlock()
+
+ // Constrain translations to f.attr.Size (rounded up) to prevent
+ // translation to pages that may be concurrently truncated.
+ pgend := fs.OffsetPageEnd(int64(rf.size))
+ var beyondEOF bool
+ if required.End > pgend {
+ if required.Start >= pgend {
+ return nil, &memmap.BusError{io.EOF}
+ }
+ beyondEOF = true
+ required.End = pgend
+ }
+ if optional.End > pgend {
+ optional.End = pgend
+ }
+
+ cerr := rf.data.Fill(ctx, required, optional, rf.memFile, usage.Tmpfs, func(_ context.Context, dsts safemem.BlockSeq, _ uint64) (uint64, error) {
+ // Newly-allocated pages are zeroed, so we don't need to do anything.
+ return dsts.NumBytes(), nil
+ })
+
+ var ts []memmap.Translation
+ var translatedEnd uint64
+ for seg := rf.data.FindSegment(required.Start); seg.Ok() && seg.Start() < required.End; seg, _ = seg.NextNonEmpty() {
+ segMR := seg.Range().Intersect(optional)
+ ts = append(ts, memmap.Translation{
+ Source: segMR,
+ File: rf.memFile,
+ Offset: seg.FileRangeOf(segMR).Start,
+ Perms: usermem.AnyAccess,
+ })
+ translatedEnd = segMR.End
+ }
+
+ // Don't return the error returned by f.data.Fill if it occurred outside of
+ // required.
+ if translatedEnd < required.End && cerr != nil {
+ return ts, &memmap.BusError{cerr}
+ }
+ if beyondEOF {
+ return ts, &memmap.BusError{io.EOF}
+ }
+ return ts, nil
+}
+
+// InvalidateUnsavable implements memmap.Mappable.InvalidateUnsavable.
+func (*regularFile) InvalidateUnsavable(context.Context) error {
+ return nil
+}
+
type regularFileFD struct {
fileDescription
@@ -152,8 +313,10 @@ func (fd *regularFileFD) PWrite(ctx context.Context, src usermem.IOSequence, off
// Overflow.
return 0, syserror.EFBIG
}
+ f.inode.mu.Lock()
rw := getRegularFileReadWriter(f, offset)
n, err := src.CopyInTo(ctx, rw)
+ f.inode.mu.Unlock()
putRegularFileReadWriter(rw)
return n, err
}
@@ -215,6 +378,12 @@ func (fd *regularFileFD) UnlockPOSIX(ctx context.Context, uid lock.UniqueID, rng
return nil
}
+// ConfigureMMap implements vfs.FileDescriptionImpl.ConfigureMMap.
+func (fd *regularFileFD) ConfigureMMap(ctx context.Context, opts *memmap.MMapOpts) error {
+ file := fd.inode().impl.(*regularFile)
+ return vfs.GenericConfigureMMap(&fd.vfsfd, file, opts)
+}
+
// regularFileReadWriter implements safemem.Reader and Safemem.Writer.
type regularFileReadWriter struct {
file *regularFile
@@ -244,14 +413,15 @@ func putRegularFileReadWriter(rw *regularFileReadWriter) {
// ReadToBlocks implements safemem.Reader.ReadToBlocks.
func (rw *regularFileReadWriter) ReadToBlocks(dsts safemem.BlockSeq) (uint64, error) {
- rw.file.mu.RLock()
+ rw.file.dataMu.RLock()
+ defer rw.file.dataMu.RUnlock()
+ size := rw.file.size
// Compute the range to read (limited by file size and overflow-checked).
- if rw.off >= rw.file.size {
- rw.file.mu.RUnlock()
+ if rw.off >= size {
return 0, io.EOF
}
- end := rw.file.size
+ end := size
if rend := rw.off + dsts.NumBytes(); rend > rw.off && rend < end {
end = rend
}
@@ -265,7 +435,6 @@ func (rw *regularFileReadWriter) ReadToBlocks(dsts safemem.BlockSeq) (uint64, er
// Get internal mappings.
ims, err := rw.file.memFile.MapInternal(seg.FileRangeOf(seg.Range().Intersect(mr)), usermem.Read)
if err != nil {
- rw.file.mu.RUnlock()
return done, err
}
@@ -275,7 +444,6 @@ func (rw *regularFileReadWriter) ReadToBlocks(dsts safemem.BlockSeq) (uint64, er
rw.off += uint64(n)
dsts = dsts.DropFirst64(n)
if err != nil {
- rw.file.mu.RUnlock()
return done, err
}
@@ -291,7 +459,6 @@ func (rw *regularFileReadWriter) ReadToBlocks(dsts safemem.BlockSeq) (uint64, er
rw.off += uint64(n)
dsts = dsts.DropFirst64(n)
if err != nil {
- rw.file.mu.RUnlock()
return done, err
}
@@ -299,13 +466,16 @@ func (rw *regularFileReadWriter) ReadToBlocks(dsts safemem.BlockSeq) (uint64, er
seg, gap = gap.NextSegment(), fsutil.FileRangeGapIterator{}
}
}
- rw.file.mu.RUnlock()
return done, nil
}
// WriteFromBlocks implements safemem.Writer.WriteFromBlocks.
+//
+// Preconditions: inode.mu must be held.
func (rw *regularFileReadWriter) WriteFromBlocks(srcs safemem.BlockSeq) (uint64, error) {
- rw.file.mu.Lock()
+ // Hold dataMu so we can modify size.
+ rw.file.dataMu.Lock()
+ defer rw.file.dataMu.Unlock()
// Compute the range to write (overflow-checked).
end := rw.off + srcs.NumBytes()
@@ -316,7 +486,6 @@ func (rw *regularFileReadWriter) WriteFromBlocks(srcs safemem.BlockSeq) (uint64,
// Check if seals prevent either file growth or all writes.
switch {
case rw.file.seals&linux.F_SEAL_WRITE != 0: // Write sealed
- rw.file.mu.Unlock()
return 0, syserror.EPERM
case end > rw.file.size && rw.file.seals&linux.F_SEAL_GROW != 0: // Grow sealed
// When growth is sealed, Linux effectively allows writes which would
@@ -338,7 +507,6 @@ func (rw *regularFileReadWriter) WriteFromBlocks(srcs safemem.BlockSeq) (uint64,
}
if end <= rw.off {
// Truncation would result in no data being written.
- rw.file.mu.Unlock()
return 0, syserror.EPERM
}
}
@@ -395,9 +563,8 @@ exitLoop:
// If the write ends beyond the file's previous size, it causes the
// file to grow.
if rw.off > rw.file.size {
- atomic.StoreUint64(&rw.file.size, rw.off)
+ rw.file.size = rw.off
}
- rw.file.mu.Unlock()
return done, retErr
}
diff --git a/pkg/sentry/fsimpl/tmpfs/tmpfs.go b/pkg/sentry/fsimpl/tmpfs/tmpfs.go
index c5bb17562..521206305 100644
--- a/pkg/sentry/fsimpl/tmpfs/tmpfs.go
+++ b/pkg/sentry/fsimpl/tmpfs/tmpfs.go
@@ -18,9 +18,10 @@
// Lock order:
//
// filesystem.mu
-// regularFileFD.offMu
-// regularFile.mu
// inode.mu
+// regularFileFD.offMu
+// regularFile.mapsMu
+// regularFile.dataMu
package tmpfs
import (
@@ -226,12 +227,15 @@ func (i *inode) tryIncRef() bool {
func (i *inode) decRef() {
if refs := atomic.AddInt64(&i.refs, -1); refs == 0 {
- // This is unnecessary; it's mostly to simulate what tmpfs would do.
if regFile, ok := i.impl.(*regularFile); ok {
- regFile.mu.Lock()
+ // Hold inode.mu and regFile.dataMu while mutating
+ // size.
+ i.mu.Lock()
+ regFile.dataMu.Lock()
regFile.data.DropAll(regFile.memFile)
atomic.StoreUint64(&regFile.size, 0)
- regFile.mu.Unlock()
+ regFile.dataMu.Unlock()
+ i.mu.Unlock()
}
} else if refs < 0 {
panic("tmpfs.inode.decRef() called without holding a reference")
@@ -320,7 +324,7 @@ func (i *inode) setStat(stat linux.Statx) error {
if mask&linux.STATX_SIZE != 0 {
switch impl := i.impl.(type) {
case *regularFile:
- updated, err := impl.truncate(stat.Size)
+ updated, err := impl.truncateLocked(stat.Size)
if err != nil {
return err
}
diff --git a/pkg/sentry/inet/BUILD b/pkg/sentry/inet/BUILD
index 334432abf..07bf39fed 100644
--- a/pkg/sentry/inet/BUILD
+++ b/pkg/sentry/inet/BUILD
@@ -10,6 +10,7 @@ go_library(
srcs = [
"context.go",
"inet.go",
+ "namespace.go",
"test_stack.go",
],
deps = [
diff --git a/pkg/sentry/inet/namespace.go b/pkg/sentry/inet/namespace.go
new file mode 100644
index 000000000..c16667e7f
--- /dev/null
+++ b/pkg/sentry/inet/namespace.go
@@ -0,0 +1,99 @@
+// Copyright 2020 The gVisor Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+package inet
+
+// Namespace represents a network namespace. See network_namespaces(7).
+//
+// +stateify savable
+type Namespace struct {
+ // stack is the network stack implementation of this network namespace.
+ stack Stack `state:"nosave"`
+
+ // creator allows kernel to create new network stack for network namespaces.
+ // If nil, no networking will function if network is namespaced.
+ creator NetworkStackCreator
+
+ // isRoot indicates whether this is the root network namespace.
+ isRoot bool
+}
+
+// NewRootNamespace creates the root network namespace, with creator
+// allowing new network namespaces to be created. If creator is nil, no
+// networking will function if the network is namespaced.
+func NewRootNamespace(stack Stack, creator NetworkStackCreator) *Namespace {
+ return &Namespace{
+ stack: stack,
+ creator: creator,
+ isRoot: true,
+ }
+}
+
+// NewNamespace creates a new network namespace from the root.
+func NewNamespace(root *Namespace) *Namespace {
+ n := &Namespace{
+ creator: root.creator,
+ }
+ n.init()
+ return n
+}
+
+// Stack returns the network stack of n. Stack may return nil if no network
+// stack is configured.
+func (n *Namespace) Stack() Stack {
+ return n.stack
+}
+
+// IsRoot returns whether n is the root network namespace.
+func (n *Namespace) IsRoot() bool {
+ return n.isRoot
+}
+
+// RestoreRootStack restores the root network namespace with stack. This should
+// only be called when restoring kernel.
+func (n *Namespace) RestoreRootStack(stack Stack) {
+ if !n.isRoot {
+ panic("RestoreRootStack can only be called on root network namespace")
+ }
+ if n.stack != nil {
+ panic("RestoreRootStack called after a stack has already been set")
+ }
+ n.stack = stack
+}
+
+func (n *Namespace) init() {
+ // Root network namespace will have stack assigned later.
+ if n.isRoot {
+ return
+ }
+ if n.creator != nil {
+ var err error
+ n.stack, err = n.creator.CreateStack()
+ if err != nil {
+ panic(err)
+ }
+ }
+}
+
+// afterLoad is invoked by stateify.
+func (n *Namespace) afterLoad() {
+ n.init()
+}
+
+// NetworkStackCreator allows new instances of a network stack to be created. It
+// is used by the kernel to create new network namespaces when requested.
+type NetworkStackCreator interface {
+ // CreateStack creates a new network stack for a network namespace.
+ CreateStack() (Stack, error)
+}
diff --git a/pkg/sentry/kernel/fd_table.go b/pkg/sentry/kernel/fd_table.go
index 23b88f7a6..58001d56c 100644
--- a/pkg/sentry/kernel/fd_table.go
+++ b/pkg/sentry/kernel/fd_table.go
@@ -296,6 +296,50 @@ func (f *FDTable) NewFDs(ctx context.Context, fd int32, files []*fs.File, flags
return fds, nil
}
+// NewFDVFS2 allocates a file descriptor greater than or equal to minfd for
+// the given file description. If it succeeds, it takes a reference on file.
+func (f *FDTable) NewFDVFS2(ctx context.Context, minfd int32, file *vfs.FileDescription, flags FDFlags) (int32, error) {
+ if minfd < 0 {
+ // Don't accept negative FDs.
+ return -1, syscall.EINVAL
+ }
+
+ // Default limit.
+ end := int32(math.MaxInt32)
+
+ // Ensure we don't get past the provided limit.
+ if limitSet := limits.FromContext(ctx); limitSet != nil {
+ lim := limitSet.Get(limits.NumberOfFiles)
+ if lim.Cur != limits.Infinity {
+ end = int32(lim.Cur)
+ }
+ if minfd >= end {
+ return -1, syscall.EMFILE
+ }
+ }
+
+ f.mu.Lock()
+ defer f.mu.Unlock()
+
+ // From f.next to find available fd.
+ fd := minfd
+ if fd < f.next {
+ fd = f.next
+ }
+ for fd < end {
+ if d, _, _ := f.get(fd); d == nil {
+ f.setVFS2(fd, file, flags)
+ if fd == f.next {
+ // Update next search start position.
+ f.next = fd + 1
+ }
+ return fd, nil
+ }
+ fd++
+ }
+ return -1, syscall.EMFILE
+}
+
// NewFDAt sets the file reference for the given FD. If there is an active
// reference for that FD, the ref count for that existing reference is
// decremented.
@@ -316,9 +360,6 @@ func (f *FDTable) newFDAt(ctx context.Context, fd int32, file *fs.File, fileVFS2
return syscall.EBADF
}
- f.mu.Lock()
- defer f.mu.Unlock()
-
// Check the limit for the provided file.
if limitSet := limits.FromContext(ctx); limitSet != nil {
if lim := limitSet.Get(limits.NumberOfFiles); lim.Cur != limits.Infinity && uint64(fd) >= lim.Cur {
@@ -327,6 +368,8 @@ func (f *FDTable) newFDAt(ctx context.Context, fd int32, file *fs.File, fileVFS2
}
// Install the entry.
+ f.mu.Lock()
+ defer f.mu.Unlock()
f.setAll(fd, file, fileVFS2, flags)
return nil
}
diff --git a/pkg/sentry/kernel/fs_context.go b/pkg/sentry/kernel/fs_context.go
index 7218aa24e..47f78df9a 100644
--- a/pkg/sentry/kernel/fs_context.go
+++ b/pkg/sentry/kernel/fs_context.go
@@ -244,6 +244,28 @@ func (f *FSContext) SetRootDirectory(d *fs.Dirent) {
old.DecRef()
}
+// SetRootDirectoryVFS2 sets the root directory. It takes a reference on vd.
+//
+// This is not a valid call after free.
+func (f *FSContext) SetRootDirectoryVFS2(vd vfs.VirtualDentry) {
+ if !vd.Ok() {
+ panic("FSContext.SetRootDirectoryVFS2 called with zero-value VirtualDentry")
+ }
+
+ f.mu.Lock()
+
+ if !f.rootVFS2.Ok() {
+ f.mu.Unlock()
+ panic(fmt.Sprintf("FSContext.SetRootDirectoryVFS2(%v)) called after destroy", vd))
+ }
+
+ old := f.rootVFS2
+ vd.IncRef()
+ f.rootVFS2 = vd
+ f.mu.Unlock()
+ old.DecRef()
+}
+
// Umask returns the current umask.
func (f *FSContext) Umask() uint {
f.mu.Lock()
diff --git a/pkg/sentry/kernel/kernel.go b/pkg/sentry/kernel/kernel.go
index 7da0368f1..8b76750e9 100644
--- a/pkg/sentry/kernel/kernel.go
+++ b/pkg/sentry/kernel/kernel.go
@@ -111,7 +111,7 @@ type Kernel struct {
timekeeper *Timekeeper
tasks *TaskSet
rootUserNamespace *auth.UserNamespace
- networkStack inet.Stack `state:"nosave"`
+ rootNetworkNamespace *inet.Namespace
applicationCores uint
useHostCores bool
extraAuxv []arch.AuxEntry
@@ -247,6 +247,10 @@ type Kernel struct {
// VFS keeps the filesystem state used across the kernel.
vfs vfs.VirtualFilesystem
+
+ // If set to true, report address space activation waits as if the task is in
+ // external wait so that the watchdog doesn't report the task stuck.
+ SleepForAddressSpaceActivation bool
}
// InitKernelArgs holds arguments to Init.
@@ -260,8 +264,9 @@ type InitKernelArgs struct {
// RootUserNamespace is the root user namespace.
RootUserNamespace *auth.UserNamespace
- // NetworkStack is the TCP/IP network stack. NetworkStack may be nil.
- NetworkStack inet.Stack
+ // RootNetworkNamespace is the root network namespace. If nil, no networking
+ // will be available.
+ RootNetworkNamespace *inet.Namespace
// ApplicationCores is the number of logical CPUs visible to sandboxed
// applications. The set of logical CPU IDs is [0, ApplicationCores); thus
@@ -320,7 +325,10 @@ func (k *Kernel) Init(args InitKernelArgs) error {
k.rootUTSNamespace = args.RootUTSNamespace
k.rootIPCNamespace = args.RootIPCNamespace
k.rootAbstractSocketNamespace = args.RootAbstractSocketNamespace
- k.networkStack = args.NetworkStack
+ k.rootNetworkNamespace = args.RootNetworkNamespace
+ if k.rootNetworkNamespace == nil {
+ k.rootNetworkNamespace = inet.NewRootNamespace(nil, nil)
+ }
k.applicationCores = args.ApplicationCores
if args.UseHostCores {
k.useHostCores = true
@@ -543,8 +551,6 @@ func (ts *TaskSet) unregisterEpollWaiters() {
func (k *Kernel) LoadFrom(r io.Reader, net inet.Stack, clocks sentrytime.Clocks) error {
loadStart := time.Now()
- k.networkStack = net
-
initAppCores := k.applicationCores
// Load the pre-saved CPUID FeatureSet.
@@ -575,6 +581,10 @@ func (k *Kernel) LoadFrom(r io.Reader, net inet.Stack, clocks sentrytime.Clocks)
log.Infof("Kernel load stats: %s", &stats)
log.Infof("Kernel load took [%s].", time.Since(kernelStart))
+ // rootNetworkNamespace should be populated after loading the state file.
+ // Restore the root network stack.
+ k.rootNetworkNamespace.RestoreRootStack(net)
+
// Load the memory file's state.
memoryStart := time.Now()
if err := k.mf.LoadFrom(k.SupervisorContext(), r); err != nil {
@@ -905,6 +915,7 @@ func (k *Kernel) CreateProcess(args CreateProcessArgs) (*ThreadGroup, ThreadID,
FSContext: fsContext,
FDTable: args.FDTable,
Credentials: args.Credentials,
+ NetworkNamespace: k.RootNetworkNamespace(),
AllowedCPUMask: sched.NewFullCPUSet(k.applicationCores),
UTSNamespace: args.UTSNamespace,
IPCNamespace: args.IPCNamespace,
@@ -1255,10 +1266,9 @@ func (k *Kernel) RootAbstractSocketNamespace() *AbstractSocketNamespace {
return k.rootAbstractSocketNamespace
}
-// NetworkStack returns the network stack. NetworkStack may return nil if no
-// network stack is available.
-func (k *Kernel) NetworkStack() inet.Stack {
- return k.networkStack
+// RootNetworkNamespace returns the root network namespace, always non-nil.
+func (k *Kernel) RootNetworkNamespace() *inet.Namespace {
+ return k.rootNetworkNamespace
}
// GlobalInit returns the thread group with ID 1 in the root PID namespace, or
diff --git a/pkg/sentry/kernel/rseq.go b/pkg/sentry/kernel/rseq.go
index 18416643b..ded95f532 100644
--- a/pkg/sentry/kernel/rseq.go
+++ b/pkg/sentry/kernel/rseq.go
@@ -304,7 +304,7 @@ func (t *Task) rseqAddrInterrupt() {
}
var cs linux.RSeqCriticalSection
- if _, err := cs.CopyIn(t, critAddr); err != nil {
+ if err := cs.CopyIn(t, critAddr); err != nil {
t.Debugf("Failed to copy critical section from %#x for rseq: %v", critAddr, err)
t.forceSignal(linux.SIGSEGV, false /* unconditional */)
t.SendSignal(SignalInfoPriv(linux.SIGSEGV))
diff --git a/pkg/sentry/kernel/task.go b/pkg/sentry/kernel/task.go
index a3443ff21..2cee2e6ed 100644
--- a/pkg/sentry/kernel/task.go
+++ b/pkg/sentry/kernel/task.go
@@ -486,13 +486,10 @@ type Task struct {
numaPolicy int32
numaNodeMask uint64
- // If netns is true, the task is in a non-root network namespace. Network
- // namespaces aren't currently implemented in full; being in a network
- // namespace simply prevents the task from observing any network devices
- // (including loopback) or using abstract socket addresses (see unix(7)).
+ // netns is the task's network namespace. netns is never nil.
//
- // netns is protected by mu. netns is owned by the task goroutine.
- netns bool
+ // netns is protected by mu.
+ netns *inet.Namespace
// If rseqPreempted is true, before the next call to p.Switch(),
// interrupt rseq critical regions as defined by rseqAddr and
@@ -792,6 +789,15 @@ func (t *Task) NewFDFrom(fd int32, file *fs.File, flags FDFlags) (int32, error)
return fds[0], nil
}
+// NewFDFromVFS2 is a convenience wrapper for t.FDTable().NewFDVFS2.
+//
+// This automatically passes the task as the context.
+//
+// Precondition: same as FDTable.Get.
+func (t *Task) NewFDFromVFS2(fd int32, file *vfs.FileDescription, flags FDFlags) (int32, error) {
+ return t.fdTable.NewFDVFS2(t, fd, file, flags)
+}
+
// NewFDAt is a convenience wrapper for t.FDTable().NewFDAt.
//
// This automatically passes the task as the context.
@@ -801,6 +807,15 @@ func (t *Task) NewFDAt(fd int32, file *fs.File, flags FDFlags) error {
return t.fdTable.NewFDAt(t, fd, file, flags)
}
+// NewFDAtVFS2 is a convenience wrapper for t.FDTable().NewFDAtVFS2.
+//
+// This automatically passes the task as the context.
+//
+// Precondition: same as FDTable.
+func (t *Task) NewFDAtVFS2(fd int32, file *vfs.FileDescription, flags FDFlags) error {
+ return t.fdTable.NewFDAtVFS2(t, fd, file, flags)
+}
+
// WithMuLocked executes f with t.mu locked.
func (t *Task) WithMuLocked(f func(*Task)) {
t.mu.Lock()
diff --git a/pkg/sentry/kernel/task_clone.go b/pkg/sentry/kernel/task_clone.go
index ba74b4c1c..78866f280 100644
--- a/pkg/sentry/kernel/task_clone.go
+++ b/pkg/sentry/kernel/task_clone.go
@@ -17,6 +17,7 @@ package kernel
import (
"gvisor.dev/gvisor/pkg/abi/linux"
"gvisor.dev/gvisor/pkg/bpf"
+ "gvisor.dev/gvisor/pkg/sentry/inet"
"gvisor.dev/gvisor/pkg/syserror"
"gvisor.dev/gvisor/pkg/usermem"
)
@@ -54,8 +55,7 @@ type SharingOptions struct {
NewUserNamespace bool
// If NewNetworkNamespace is true, the task should have an independent
- // network namespace. (Note that network namespaces are not really
- // implemented; see comment on Task.netns for details.)
+ // network namespace.
NewNetworkNamespace bool
// If NewFiles is true, the task should use an independent file descriptor
@@ -199,6 +199,11 @@ func (t *Task) Clone(opts *CloneOptions) (ThreadID, *SyscallControl, error) {
ipcns = NewIPCNamespace(userns)
}
+ netns := t.NetworkNamespace()
+ if opts.NewNetworkNamespace {
+ netns = inet.NewNamespace(netns)
+ }
+
// TODO(b/63601033): Implement CLONE_NEWNS.
mntnsVFS2 := t.mountNamespaceVFS2
if mntnsVFS2 != nil {
@@ -268,7 +273,7 @@ func (t *Task) Clone(opts *CloneOptions) (ThreadID, *SyscallControl, error) {
FDTable: fdTable,
Credentials: creds,
Niceness: t.Niceness(),
- NetworkNamespaced: t.netns,
+ NetworkNamespace: netns,
AllowedCPUMask: t.CPUMask(),
UTSNamespace: utsns,
IPCNamespace: ipcns,
@@ -283,9 +288,6 @@ func (t *Task) Clone(opts *CloneOptions) (ThreadID, *SyscallControl, error) {
} else {
cfg.InheritParent = t
}
- if opts.NewNetworkNamespace {
- cfg.NetworkNamespaced = true
- }
nt, err := t.tg.pidns.owner.NewTask(cfg)
if err != nil {
if opts.NewThreadGroup {
@@ -482,7 +484,7 @@ func (t *Task) Unshare(opts *SharingOptions) error {
t.mu.Unlock()
return syserror.EPERM
}
- t.netns = true
+ t.netns = inet.NewNamespace(t.netns)
}
if opts.NewUTSNamespace {
if !haveCapSysAdmin {
diff --git a/pkg/sentry/kernel/task_context.go b/pkg/sentry/kernel/task_context.go
index 2be982684..0158b1788 100644
--- a/pkg/sentry/kernel/task_context.go
+++ b/pkg/sentry/kernel/task_context.go
@@ -140,7 +140,7 @@ func (k *Kernel) LoadTaskImage(ctx context.Context, args loader.LoadArgs) (*Task
}
// Prepare a new user address space to load into.
- m := mm.NewMemoryManager(k, k)
+ m := mm.NewMemoryManager(k, k, k.SleepForAddressSpaceActivation)
defer m.DecUsers(ctx)
args.MemoryManager = m
diff --git a/pkg/sentry/kernel/task_exec.go b/pkg/sentry/kernel/task_exec.go
index 8f57a34a6..00c425cca 100644
--- a/pkg/sentry/kernel/task_exec.go
+++ b/pkg/sentry/kernel/task_exec.go
@@ -220,7 +220,7 @@ func (r *runSyscallAfterExecStop) execute(t *Task) taskRunState {
t.mu.Unlock()
t.unstopVforkParent()
// NOTE(b/30316266): All locks must be dropped prior to calling Activate.
- t.MemoryManager().Activate()
+ t.MemoryManager().Activate(t)
t.ptraceExec(oldTID)
return (*runSyscallExit)(nil)
diff --git a/pkg/sentry/kernel/task_log.go b/pkg/sentry/kernel/task_log.go
index 6d737d3e5..eeccaa197 100644
--- a/pkg/sentry/kernel/task_log.go
+++ b/pkg/sentry/kernel/task_log.go
@@ -32,21 +32,21 @@ const (
// Infof logs an formatted info message by calling log.Infof.
func (t *Task) Infof(fmt string, v ...interface{}) {
if log.IsLogging(log.Info) {
- log.Infof(t.logPrefix.Load().(string)+fmt, v...)
+ log.InfofAtDepth(1, t.logPrefix.Load().(string)+fmt, v...)
}
}
// Warningf logs a warning string by calling log.Warningf.
func (t *Task) Warningf(fmt string, v ...interface{}) {
if log.IsLogging(log.Warning) {
- log.Warningf(t.logPrefix.Load().(string)+fmt, v...)
+ log.WarningfAtDepth(1, t.logPrefix.Load().(string)+fmt, v...)
}
}
// Debugf creates a debug string that includes the task ID.
func (t *Task) Debugf(fmt string, v ...interface{}) {
if log.IsLogging(log.Debug) {
- log.Debugf(t.logPrefix.Load().(string)+fmt, v...)
+ log.DebugfAtDepth(1, t.logPrefix.Load().(string)+fmt, v...)
}
}
diff --git a/pkg/sentry/kernel/task_net.go b/pkg/sentry/kernel/task_net.go
index 172a31e1d..f7711232c 100644
--- a/pkg/sentry/kernel/task_net.go
+++ b/pkg/sentry/kernel/task_net.go
@@ -22,14 +22,23 @@ import (
func (t *Task) IsNetworkNamespaced() bool {
t.mu.Lock()
defer t.mu.Unlock()
- return t.netns
+ return !t.netns.IsRoot()
}
// NetworkContext returns the network stack used by the task. NetworkContext
// may return nil if no network stack is available.
+//
+// TODO(gvisor.dev/issue/1833): Migrate callers of this method to
+// NetworkNamespace().
func (t *Task) NetworkContext() inet.Stack {
- if t.IsNetworkNamespaced() {
- return nil
- }
- return t.k.networkStack
+ t.mu.Lock()
+ defer t.mu.Unlock()
+ return t.netns.Stack()
+}
+
+// NetworkNamespace returns the network namespace observed by the task.
+func (t *Task) NetworkNamespace() *inet.Namespace {
+ t.mu.Lock()
+ defer t.mu.Unlock()
+ return t.netns
}
diff --git a/pkg/sentry/kernel/task_start.go b/pkg/sentry/kernel/task_start.go
index f9236a842..a5035bb7f 100644
--- a/pkg/sentry/kernel/task_start.go
+++ b/pkg/sentry/kernel/task_start.go
@@ -17,6 +17,7 @@ package kernel
import (
"gvisor.dev/gvisor/pkg/abi/linux"
"gvisor.dev/gvisor/pkg/sentry/arch"
+ "gvisor.dev/gvisor/pkg/sentry/inet"
"gvisor.dev/gvisor/pkg/sentry/kernel/auth"
"gvisor.dev/gvisor/pkg/sentry/kernel/futex"
"gvisor.dev/gvisor/pkg/sentry/kernel/sched"
@@ -65,9 +66,8 @@ type TaskConfig struct {
// Niceness is the niceness of the new task.
Niceness int
- // If NetworkNamespaced is true, the new task should observe a non-root
- // network namespace.
- NetworkNamespaced bool
+ // NetworkNamespace is the network namespace to be used for the new task.
+ NetworkNamespace *inet.Namespace
// AllowedCPUMask contains the cpus that this task can run on.
AllowedCPUMask sched.CPUSet
@@ -133,7 +133,7 @@ func (ts *TaskSet) newTask(cfg *TaskConfig) (*Task, error) {
allowedCPUMask: cfg.AllowedCPUMask.Copy(),
ioUsage: &usage.IO{},
niceness: cfg.Niceness,
- netns: cfg.NetworkNamespaced,
+ netns: cfg.NetworkNamespace,
utsns: cfg.UTSNamespace,
ipcns: cfg.IPCNamespace,
abstractSockets: cfg.AbstractSocketNamespace,
diff --git a/pkg/sentry/kernel/task_usermem.go b/pkg/sentry/kernel/task_usermem.go
index 2bf3ce8a8..b02044ad2 100644
--- a/pkg/sentry/kernel/task_usermem.go
+++ b/pkg/sentry/kernel/task_usermem.go
@@ -30,7 +30,7 @@ var MAX_RW_COUNT = int(usermem.Addr(math.MaxInt32).RoundDown())
// Activate ensures that the task has an active address space.
func (t *Task) Activate() {
if mm := t.MemoryManager(); mm != nil {
- if err := mm.Activate(); err != nil {
+ if err := mm.Activate(t); err != nil {
panic("unable to activate mm: " + err.Error())
}
}
diff --git a/pkg/sentry/mm/address_space.go b/pkg/sentry/mm/address_space.go
index 94d39af60..0332fc71c 100644
--- a/pkg/sentry/mm/address_space.go
+++ b/pkg/sentry/mm/address_space.go
@@ -18,6 +18,7 @@ import (
"fmt"
"sync/atomic"
+ "gvisor.dev/gvisor/pkg/context"
"gvisor.dev/gvisor/pkg/sentry/platform"
"gvisor.dev/gvisor/pkg/usermem"
)
@@ -38,7 +39,7 @@ func (mm *MemoryManager) AddressSpace() platform.AddressSpace {
//
// When this MemoryManager is no longer needed by a task, it should call
// Deactivate to release the reference.
-func (mm *MemoryManager) Activate() error {
+func (mm *MemoryManager) Activate(ctx context.Context) error {
// Fast path: the MemoryManager already has an active
// platform.AddressSpace, and we just need to indicate that we need it too.
for {
@@ -91,16 +92,20 @@ func (mm *MemoryManager) Activate() error {
if as == nil {
// AddressSpace is unavailable, we must wait.
//
- // activeMu must not be held while waiting, as the user
- // of the address space we are waiting on may attempt
- // to take activeMu.
- //
- // Don't call UninterruptibleSleepStart to register the
- // wait to allow the watchdog stuck task to trigger in
- // case a process is starved waiting for the address
- // space.
+ // activeMu must not be held while waiting, as the user of the address
+ // space we are waiting on may attempt to take activeMu.
mm.activeMu.Unlock()
+
+ sleep := mm.p.CooperativelySchedulesAddressSpace() && mm.sleepForActivation
+ if sleep {
+ // Mark this task sleeping while waiting for the address space to
+ // prevent the watchdog from reporting it as a stuck task.
+ ctx.UninterruptibleSleepStart(false)
+ }
<-c
+ if sleep {
+ ctx.UninterruptibleSleepFinish(false)
+ }
continue
}
diff --git a/pkg/sentry/mm/lifecycle.go b/pkg/sentry/mm/lifecycle.go
index 3c263ebaa..d8a5b9d29 100644
--- a/pkg/sentry/mm/lifecycle.go
+++ b/pkg/sentry/mm/lifecycle.go
@@ -28,16 +28,17 @@ import (
)
// NewMemoryManager returns a new MemoryManager with no mappings and 1 user.
-func NewMemoryManager(p platform.Platform, mfp pgalloc.MemoryFileProvider) *MemoryManager {
+func NewMemoryManager(p platform.Platform, mfp pgalloc.MemoryFileProvider, sleepForActivation bool) *MemoryManager {
return &MemoryManager{
- p: p,
- mfp: mfp,
- haveASIO: p.SupportsAddressSpaceIO(),
- privateRefs: &privateRefs{},
- users: 1,
- auxv: arch.Auxv{},
- dumpability: UserDumpable,
- aioManager: aioManager{contexts: make(map[uint64]*AIOContext)},
+ p: p,
+ mfp: mfp,
+ haveASIO: p.SupportsAddressSpaceIO(),
+ privateRefs: &privateRefs{},
+ users: 1,
+ auxv: arch.Auxv{},
+ dumpability: UserDumpable,
+ aioManager: aioManager{contexts: make(map[uint64]*AIOContext)},
+ sleepForActivation: sleepForActivation,
}
}
@@ -79,9 +80,10 @@ func (mm *MemoryManager) Fork(ctx context.Context) (*MemoryManager, error) {
envv: mm.envv,
auxv: append(arch.Auxv(nil), mm.auxv...),
// IncRef'd below, once we know that there isn't an error.
- executable: mm.executable,
- dumpability: mm.dumpability,
- aioManager: aioManager{contexts: make(map[uint64]*AIOContext)},
+ executable: mm.executable,
+ dumpability: mm.dumpability,
+ aioManager: aioManager{contexts: make(map[uint64]*AIOContext)},
+ sleepForActivation: mm.sleepForActivation,
}
// Copy vmas.
diff --git a/pkg/sentry/mm/mm.go b/pkg/sentry/mm/mm.go
index 637383c7a..c2195ae11 100644
--- a/pkg/sentry/mm/mm.go
+++ b/pkg/sentry/mm/mm.go
@@ -226,6 +226,11 @@ type MemoryManager struct {
// aioManager keeps track of AIOContexts used for async IOs. AIOManager
// must be cloned when CLONE_VM is used.
aioManager aioManager
+
+ // sleepForActivation indicates whether the task should report to be sleeping
+ // before trying to activate the address space. When set to true, delays in
+ // activation are not reported as stuck tasks by the watchdog.
+ sleepForActivation bool
}
// vma represents a virtual memory area.
diff --git a/pkg/sentry/mm/mm_test.go b/pkg/sentry/mm/mm_test.go
index edacca741..fdc308542 100644
--- a/pkg/sentry/mm/mm_test.go
+++ b/pkg/sentry/mm/mm_test.go
@@ -31,7 +31,7 @@ import (
func testMemoryManager(ctx context.Context) *MemoryManager {
p := platform.FromContext(ctx)
mfp := pgalloc.MemoryFileProviderFromContext(ctx)
- mm := NewMemoryManager(p, mfp)
+ mm := NewMemoryManager(p, mfp, false)
mm.layout = arch.MmapLayout{
MinAddr: p.MinUserAddress(),
MaxAddr: p.MaxUserAddress(),
diff --git a/pkg/sentry/platform/kvm/machine.go b/pkg/sentry/platform/kvm/machine.go
index 8076c7529..f1afc74dc 100644
--- a/pkg/sentry/platform/kvm/machine.go
+++ b/pkg/sentry/platform/kvm/machine.go
@@ -329,10 +329,12 @@ func (m *machine) Destroy() {
}
// Get gets an available vCPU.
+//
+// This will return with the OS thread locked.
func (m *machine) Get() *vCPU {
+ m.mu.RLock()
runtime.LockOSThread()
tid := procid.Current()
- m.mu.RLock()
// Check for an exact match.
if c := m.vCPUs[tid]; c != nil {
@@ -343,8 +345,22 @@ func (m *machine) Get() *vCPU {
// The happy path failed. We now proceed to acquire an exclusive lock
// (because the vCPU map may change), and scan all available vCPUs.
+ // In this case, we first unlock the OS thread. Otherwise, if mu is
+ // not available, the current system thread will be parked and a new
+ // system thread spawned. We avoid this situation by simply refreshing
+ // tid after relocking the system thread.
m.mu.RUnlock()
+ runtime.UnlockOSThread()
m.mu.Lock()
+ runtime.LockOSThread()
+ tid = procid.Current()
+
+ // Recheck for an exact match.
+ if c := m.vCPUs[tid]; c != nil {
+ c.lock()
+ m.mu.Unlock()
+ return c
+ }
for {
// Scan for an available vCPU.
diff --git a/pkg/sentry/platform/ring0/entry_arm64.s b/pkg/sentry/platform/ring0/entry_arm64.s
index baa6c4910..d42eda37b 100644
--- a/pkg/sentry/platform/ring0/entry_arm64.s
+++ b/pkg/sentry/platform/ring0/entry_arm64.s
@@ -25,10 +25,14 @@
// not available for calls.
//
+// ERET returns using the ELR and SPSR for the current exception level.
#define ERET() \
WORD $0xd69f03e0
+// RSV_REG is a register that holds el1 information temporarily.
#define RSV_REG R18_PLATFORM
+
+// RSV_REG_APP is a register that holds el0 information temporarily.
#define RSV_REG_APP R9
#define FPEN_NOTRAP 0x3
@@ -36,6 +40,12 @@
#define FPEN_ENABLE (FPEN_NOTRAP << FPEN_SHIFT)
+// Saves a register set.
+//
+// This is a macro because it may need to executed in contents where a stack is
+// not available for calls.
+//
+// The following registers are not saved: R9, R18.
#define REGISTERS_SAVE(reg, offset) \
MOVD R0, offset+PTRACE_R0(reg); \
MOVD R1, offset+PTRACE_R1(reg); \
@@ -67,6 +77,12 @@
MOVD R29, offset+PTRACE_R29(reg); \
MOVD R30, offset+PTRACE_R30(reg);
+// Loads a register set.
+//
+// This is a macro because it may need to executed in contents where a stack is
+// not available for calls.
+//
+// The following registers are not loaded: R9, R18.
#define REGISTERS_LOAD(reg, offset) \
MOVD offset+PTRACE_R0(reg), R0; \
MOVD offset+PTRACE_R1(reg), R1; \
@@ -98,7 +114,7 @@
MOVD offset+PTRACE_R29(reg), R29; \
MOVD offset+PTRACE_R30(reg), R30;
-//NOP
+// NOP-s
#define nop31Instructions() \
WORD $0xd503201f; \
WORD $0xd503201f; \
@@ -254,6 +270,7 @@
#define ESR_ELx_WFx_ISS_WFE (UL(1) << 0)
#define ESR_ELx_xVC_IMM_MASK ((1UL << 16) - 1)
+// LOAD_KERNEL_ADDRESS loads a kernel address.
#define LOAD_KERNEL_ADDRESS(from, to) \
MOVD from, to; \
ORR $0xffff000000000000, to, to;
@@ -263,15 +280,18 @@
LOAD_KERNEL_ADDRESS(CPU_SELF(from), RSV_REG); \
MOVD $CPU_STACK_TOP(RSV_REG), RSV_REG; \
MOVD RSV_REG, RSP; \
+ WORD $0xd538d092; \ //MRS TPIDR_EL1, R18
ISB $15; \
DSB $15;
+// SWITCH_TO_APP_PAGETABLE sets a new pagetable for a container application.
#define SWITCH_TO_APP_PAGETABLE(from) \
MOVD CPU_TTBR0_APP(from), RSV_REG; \
WORD $0xd5182012; \ // MSR R18, TTBR0_EL1
ISB $15; \
DSB $15;
+// SWITCH_TO_KVM_PAGETABLE sets the kvm pagetable.
#define SWITCH_TO_KVM_PAGETABLE(from) \
MOVD CPU_TTBR0_KVM(from), RSV_REG; \
WORD $0xd5182012; \ // MSR R18, TTBR0_EL1
@@ -294,6 +314,7 @@
WORD $0xd5181040; \ //MSR R0, CPACR_EL1
ISB $15;
+// KERNEL_ENTRY_FROM_EL0 is the entry code of the vcpu from el0 to el1.
#define KERNEL_ENTRY_FROM_EL0 \
SUB $16, RSP, RSP; \ // step1, save r18, r9 into kernel temporary stack.
STP (RSV_REG, RSV_REG_APP), 16*0(RSP); \
@@ -315,19 +336,22 @@
WORD $0xd5384103; \ // MRS SP_EL0, R3
MOVD R3, PTRACE_SP(RSV_REG_APP);
+// KERNEL_ENTRY_FROM_EL1 is the entry code of the vcpu from el1 to el1.
#define KERNEL_ENTRY_FROM_EL1 \
WORD $0xd538d092; \ //MRS TPIDR_EL1, R18
- REGISTERS_SAVE(RSV_REG, CPU_REGISTERS); \ // save sentry context
+ REGISTERS_SAVE(RSV_REG, CPU_REGISTERS); \ // Save sentry context.
MOVD RSV_REG_APP, CPU_REGISTERS+PTRACE_R9(RSV_REG); \
WORD $0xd5384004; \ // MRS SPSR_EL1, R4
MOVD R4, CPU_REGISTERS+PTRACE_PSTATE(RSV_REG); \
MRS ELR_EL1, R4; \
MOVD R4, CPU_REGISTERS+PTRACE_PC(RSV_REG); \
MOVD RSP, R4; \
- MOVD R4, CPU_REGISTERS+PTRACE_SP(RSV_REG);
+ MOVD R4, CPU_REGISTERS+PTRACE_SP(RSV_REG); \
+ LOAD_KERNEL_STACK(RSV_REG); // Load the temporary stack.
+// Halt halts execution.
TEXT ·Halt(SB),NOSPLIT,$0
- // clear bluepill.
+ // Clear bluepill.
WORD $0xd538d092 //MRS TPIDR_EL1, R18
CMP RSV_REG, R9
BNE mmio_exit
@@ -341,8 +365,22 @@ mmio_exit:
// MMIO_EXIT.
MOVD $0, R9
MOVD R0, 0xffff000000001000(R9)
- B ·kernelExitToEl1(SB)
+ RET
+
+// HaltAndResume halts execution and point the pointer to the resume function.
+TEXT ·HaltAndResume(SB),NOSPLIT,$0
+ BL ·Halt(SB)
+ B ·kernelExitToEl1(SB) // Resume.
+// HaltEl1SvcAndResume calls Hooks.KernelSyscall and resume.
+TEXT ·HaltEl1SvcAndResume(SB),NOSPLIT,$0
+ WORD $0xd538d092 // MRS TPIDR_EL1, R18
+ MOVD CPU_SELF(RSV_REG), R3 // Load vCPU.
+ MOVD R3, 8(RSP) // First argument (vCPU).
+ CALL ·kernelSyscall(SB) // Call the trampoline.
+ B ·kernelExitToEl1(SB) // Resume.
+
+// Shutdown stops the guest.
TEXT ·Shutdown(SB),NOSPLIT,$0
// PSCI EVENT.
MOVD $0x84000009, R0
@@ -429,6 +467,7 @@ TEXT ·kernelExitToEl0(SB),NOSPLIT,$0
TEXT ·kernelExitToEl1(SB),NOSPLIT,$0
ERET()
+// Start is the CPU entrypoint.
TEXT ·Start(SB),NOSPLIT,$0
IRQ_DISABLE
MOVD R8, RSV_REG
@@ -437,18 +476,23 @@ TEXT ·Start(SB),NOSPLIT,$0
B ·kernelExitToEl1(SB)
+// El1_sync_invalid is the handler for an invalid EL1_sync.
TEXT ·El1_sync_invalid(SB),NOSPLIT,$0
B ·Shutdown(SB)
+// El1_irq_invalid is the handler for an invalid El1_irq.
TEXT ·El1_irq_invalid(SB),NOSPLIT,$0
B ·Shutdown(SB)
+// El1_fiq_invalid is the handler for an invalid El1_fiq.
TEXT ·El1_fiq_invalid(SB),NOSPLIT,$0
B ·Shutdown(SB)
+// El1_error_invalid is the handler for an invalid El1_error.
TEXT ·El1_error_invalid(SB),NOSPLIT,$0
B ·Shutdown(SB)
+// El1_sync is the handler for El1_sync.
TEXT ·El1_sync(SB),NOSPLIT,$0
KERNEL_ENTRY_FROM_EL1
WORD $0xd5385219 // MRS ESR_EL1, R25
@@ -484,10 +528,10 @@ el1_da:
MOVD $PageFault, R3
MOVD R3, CPU_VECTOR_CODE(RSV_REG)
- B ·Halt(SB)
+ B ·HaltAndResume(SB)
el1_ia:
- B ·Halt(SB)
+ B ·HaltAndResume(SB)
el1_sp_pc:
B ·Shutdown(SB)
@@ -496,7 +540,9 @@ el1_undef:
B ·Shutdown(SB)
el1_svc:
- B ·Halt(SB)
+ MOVD $0, CPU_ERROR_CODE(RSV_REG)
+ MOVD $0, CPU_ERROR_TYPE(RSV_REG)
+ B ·HaltEl1SvcAndResume(SB)
el1_dbg:
B ·Shutdown(SB)
@@ -508,15 +554,19 @@ el1_fpsimd_acc:
el1_invalid:
B ·Shutdown(SB)
+// El1_irq is the handler for El1_irq.
TEXT ·El1_irq(SB),NOSPLIT,$0
B ·Shutdown(SB)
+// El1_fiq is the handler for El1_fiq.
TEXT ·El1_fiq(SB),NOSPLIT,$0
B ·Shutdown(SB)
+// El1_error is the handler for El1_error.
TEXT ·El1_error(SB),NOSPLIT,$0
B ·Shutdown(SB)
+// El0_sync is the handler for El0_sync.
TEXT ·El0_sync(SB),NOSPLIT,$0
KERNEL_ENTRY_FROM_EL0
WORD $0xd5385219 // MRS ESR_EL1, R25
@@ -554,7 +604,7 @@ el0_svc:
MOVD $Syscall, R3
MOVD R3, CPU_VECTOR_CODE(RSV_REG)
- B ·Halt(SB)
+ B ·HaltAndResume(SB)
el0_da:
WORD $0xd538d092 //MRS TPIDR_EL1, R18
@@ -568,7 +618,7 @@ el0_da:
MOVD $PageFault, R3
MOVD R3, CPU_VECTOR_CODE(RSV_REG)
- B ·Halt(SB)
+ B ·HaltAndResume(SB)
el0_ia:
B ·Shutdown(SB)
@@ -613,7 +663,7 @@ TEXT ·El0_error(SB),NOSPLIT,$0
MOVD $VirtualizationException, R3
MOVD R3, CPU_VECTOR_CODE(RSV_REG)
- B ·Halt(SB)
+ B ·HaltAndResume(SB)
TEXT ·El0_sync_invalid(SB),NOSPLIT,$0
B ·Shutdown(SB)
@@ -627,6 +677,7 @@ TEXT ·El0_fiq_invalid(SB),NOSPLIT,$0
TEXT ·El0_error_invalid(SB),NOSPLIT,$0
B ·Shutdown(SB)
+// Vectors implements exception vector table.
TEXT ·Vectors(SB),NOSPLIT,$0
B ·El1_sync_invalid(SB)
nop31Instructions()
diff --git a/pkg/sentry/platform/ring0/kernel_arm64.go b/pkg/sentry/platform/ring0/kernel_arm64.go
index c3d341998..ccacaea6b 100644
--- a/pkg/sentry/platform/ring0/kernel_arm64.go
+++ b/pkg/sentry/platform/ring0/kernel_arm64.go
@@ -16,6 +16,14 @@
package ring0
+// HaltAndResume halts execution and point the pointer to the resume function.
+//go:nosplit
+func HaltAndResume()
+
+// HaltEl1SvcAndResume calls Hooks.KernelSyscall and resume.
+//go:nosplit
+func HaltEl1SvcAndResume()
+
// init initializes architecture-specific state.
func (k *Kernel) init(opts KernelOpts) {
// Save the root page tables.
diff --git a/pkg/sentry/strace/BUILD b/pkg/sentry/strace/BUILD
index 2f39a6f2b..88d5db9fc 100644
--- a/pkg/sentry/strace/BUILD
+++ b/pkg/sentry/strace/BUILD
@@ -7,6 +7,7 @@ go_library(
srcs = [
"capability.go",
"clone.go",
+ "epoll.go",
"futex.go",
"linux64_amd64.go",
"linux64_arm64.go",
diff --git a/pkg/sentry/strace/epoll.go b/pkg/sentry/strace/epoll.go
new file mode 100644
index 000000000..a6e48b836
--- /dev/null
+++ b/pkg/sentry/strace/epoll.go
@@ -0,0 +1,89 @@
+// Copyright 2020 The gVisor Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+package strace
+
+import (
+ "fmt"
+ "strings"
+
+ "gvisor.dev/gvisor/pkg/abi"
+ "gvisor.dev/gvisor/pkg/abi/linux"
+ "gvisor.dev/gvisor/pkg/sentry/kernel"
+ "gvisor.dev/gvisor/pkg/usermem"
+)
+
+func epollEvent(t *kernel.Task, eventAddr usermem.Addr) string {
+ var e linux.EpollEvent
+ if _, err := t.CopyIn(eventAddr, &e); err != nil {
+ return fmt.Sprintf("%#x {error reading event: %v}", eventAddr, err)
+ }
+ var sb strings.Builder
+ fmt.Fprintf(&sb, "%#x ", eventAddr)
+ writeEpollEvent(&sb, e)
+ return sb.String()
+}
+
+func epollEvents(t *kernel.Task, eventsAddr usermem.Addr, numEvents, maxBytes uint64) string {
+ var sb strings.Builder
+ fmt.Fprintf(&sb, "%#x {", eventsAddr)
+ addr := eventsAddr
+ for i := uint64(0); i < numEvents; i++ {
+ var e linux.EpollEvent
+ if _, err := t.CopyIn(addr, &e); err != nil {
+ fmt.Fprintf(&sb, "{error reading event at %#x: %v}", addr, err)
+ continue
+ }
+ writeEpollEvent(&sb, e)
+ if uint64(sb.Len()) >= maxBytes {
+ sb.WriteString("...")
+ break
+ }
+ if _, ok := addr.AddLength(uint64(linux.SizeOfEpollEvent)); !ok {
+ fmt.Fprintf(&sb, "{error reading event at %#x: EFAULT}", addr)
+ continue
+ }
+ }
+ sb.WriteString("}")
+ return sb.String()
+}
+
+func writeEpollEvent(sb *strings.Builder, e linux.EpollEvent) {
+ events := epollEventEvents.Parse(uint64(e.Events))
+ fmt.Fprintf(sb, "{events=%s data=[%#x, %#x]}", events, e.Data[0], e.Data[1])
+}
+
+var epollCtlOps = abi.ValueSet{
+ linux.EPOLL_CTL_ADD: "EPOLL_CTL_ADD",
+ linux.EPOLL_CTL_DEL: "EPOLL_CTL_DEL",
+ linux.EPOLL_CTL_MOD: "EPOLL_CTL_MOD",
+}
+
+var epollEventEvents = abi.FlagSet{
+ {Flag: linux.EPOLLIN, Name: "EPOLLIN"},
+ {Flag: linux.EPOLLPRI, Name: "EPOLLPRI"},
+ {Flag: linux.EPOLLOUT, Name: "EPOLLOUT"},
+ {Flag: linux.EPOLLERR, Name: "EPOLLERR"},
+ {Flag: linux.EPOLLHUP, Name: "EPULLHUP"},
+ {Flag: linux.EPOLLRDNORM, Name: "EPOLLRDNORM"},
+ {Flag: linux.EPOLLRDBAND, Name: "EPOLLRDBAND"},
+ {Flag: linux.EPOLLWRNORM, Name: "EPOLLWRNORM"},
+ {Flag: linux.EPOLLWRBAND, Name: "EPOLLWRBAND"},
+ {Flag: linux.EPOLLMSG, Name: "EPOLLMSG"},
+ {Flag: linux.EPOLLRDHUP, Name: "EPOLLRDHUP"},
+ {Flag: linux.EPOLLEXCLUSIVE, Name: "EPOLLEXCLUSIVE"},
+ {Flag: linux.EPOLLWAKEUP, Name: "EPOLLWAKEUP"},
+ {Flag: linux.EPOLLONESHOT, Name: "EPOLLONESHOT"},
+ {Flag: linux.EPOLLET, Name: "EPOLLET"},
+}
diff --git a/pkg/sentry/strace/linux64_amd64.go b/pkg/sentry/strace/linux64_amd64.go
index a4de545e9..71b92eaee 100644
--- a/pkg/sentry/strace/linux64_amd64.go
+++ b/pkg/sentry/strace/linux64_amd64.go
@@ -256,8 +256,8 @@ var linuxAMD64 = SyscallMap{
229: makeSyscallInfo("clock_getres", Hex, PostTimespec),
230: makeSyscallInfo("clock_nanosleep", Hex, Hex, Timespec, PostTimespec),
231: makeSyscallInfo("exit_group", Hex),
- 232: makeSyscallInfo("epoll_wait", Hex, Hex, Hex, Hex),
- 233: makeSyscallInfo("epoll_ctl", Hex, Hex, FD, Hex),
+ 232: makeSyscallInfo("epoll_wait", FD, EpollEvents, Hex, Hex),
+ 233: makeSyscallInfo("epoll_ctl", FD, EpollCtlOp, FD, EpollEvent),
234: makeSyscallInfo("tgkill", Hex, Hex, Signal),
235: makeSyscallInfo("utimes", Path, Timeval),
// 236: vserver (not implemented in the Linux kernel)
@@ -305,7 +305,7 @@ var linuxAMD64 = SyscallMap{
278: makeSyscallInfo("vmsplice", FD, Hex, Hex, Hex),
279: makeSyscallInfo("move_pages", Hex, Hex, Hex, Hex, Hex, Hex),
280: makeSyscallInfo("utimensat", FD, Path, UTimeTimespec, Hex),
- 281: makeSyscallInfo("epoll_pwait", Hex, Hex, Hex, Hex, SigSet, Hex),
+ 281: makeSyscallInfo("epoll_pwait", FD, EpollEvents, Hex, Hex, SigSet, Hex),
282: makeSyscallInfo("signalfd", Hex, Hex, Hex),
283: makeSyscallInfo("timerfd_create", Hex, Hex),
284: makeSyscallInfo("eventfd", Hex),
diff --git a/pkg/sentry/strace/linux64_arm64.go b/pkg/sentry/strace/linux64_arm64.go
index 8bc38545f..bd7361a52 100644
--- a/pkg/sentry/strace/linux64_arm64.go
+++ b/pkg/sentry/strace/linux64_arm64.go
@@ -45,8 +45,8 @@ var linuxARM64 = SyscallMap{
18: makeSyscallInfo("lookup_dcookie", Hex, Hex, Hex),
19: makeSyscallInfo("eventfd2", Hex, Hex),
20: makeSyscallInfo("epoll_create1", Hex),
- 21: makeSyscallInfo("epoll_ctl", Hex, Hex, FD, Hex),
- 22: makeSyscallInfo("epoll_pwait", Hex, Hex, Hex, Hex, SigSet, Hex),
+ 21: makeSyscallInfo("epoll_ctl", FD, EpollCtlOp, FD, EpollEvent),
+ 22: makeSyscallInfo("epoll_pwait", FD, EpollEvents, Hex, Hex, SigSet, Hex),
23: makeSyscallInfo("dup", FD),
24: makeSyscallInfo("dup3", FD, FD, Hex),
25: makeSyscallInfo("fcntl", FD, Hex, Hex),
diff --git a/pkg/sentry/strace/socket.go b/pkg/sentry/strace/socket.go
index 51e6d81b2..c0512de89 100644
--- a/pkg/sentry/strace/socket.go
+++ b/pkg/sentry/strace/socket.go
@@ -632,4 +632,13 @@ var sockOptNames = map[uint64]abi.ValueSet{
linux.MCAST_MSFILTER: "MCAST_MSFILTER",
linux.IPV6_ADDRFORM: "IPV6_ADDRFORM",
},
+ linux.SOL_NETLINK: {
+ linux.NETLINK_BROADCAST_ERROR: "NETLINK_BROADCAST_ERROR",
+ linux.NETLINK_CAP_ACK: "NETLINK_CAP_ACK",
+ linux.NETLINK_DUMP_STRICT_CHK: "NETLINK_DUMP_STRICT_CHK",
+ linux.NETLINK_EXT_ACK: "NETLINK_EXT_ACK",
+ linux.NETLINK_LIST_MEMBERSHIPS: "NETLINK_LIST_MEMBERSHIPS",
+ linux.NETLINK_NO_ENOBUFS: "NETLINK_NO_ENOBUFS",
+ linux.NETLINK_PKTINFO: "NETLINK_PKTINFO",
+ },
}
diff --git a/pkg/sentry/strace/strace.go b/pkg/sentry/strace/strace.go
index 46cb2a1cc..77655558e 100644
--- a/pkg/sentry/strace/strace.go
+++ b/pkg/sentry/strace/strace.go
@@ -481,6 +481,12 @@ func (i *SyscallInfo) pre(t *kernel.Task, args arch.SyscallArguments, maximumBlo
output = append(output, capData(t, args[arg-1].Pointer(), args[arg].Pointer()))
case PollFDs:
output = append(output, pollFDs(t, args[arg].Pointer(), uint(args[arg+1].Uint()), false))
+ case EpollCtlOp:
+ output = append(output, epollCtlOps.Parse(uint64(args[arg].Int())))
+ case EpollEvent:
+ output = append(output, epollEvent(t, args[arg].Pointer()))
+ case EpollEvents:
+ output = append(output, epollEvents(t, args[arg].Pointer(), 0 /* numEvents */, uint64(maximumBlobSize)))
case SelectFDSet:
output = append(output, fdSet(t, int(args[0].Int()), args[arg].Pointer()))
case Oct:
@@ -549,6 +555,8 @@ func (i *SyscallInfo) post(t *kernel.Task, args arch.SyscallArguments, rval uint
output[arg] = capData(t, args[arg-1].Pointer(), args[arg].Pointer())
case PollFDs:
output[arg] = pollFDs(t, args[arg].Pointer(), uint(args[arg+1].Uint()), true)
+ case EpollEvents:
+ output[arg] = epollEvents(t, args[arg].Pointer(), uint64(rval), uint64(maximumBlobSize))
case GetSockOptVal:
output[arg] = getSockOptVal(t, args[arg-2].Uint64() /* level */, args[arg-1].Uint64() /* optName */, args[arg].Pointer() /* optVal */, args[arg+1].Pointer() /* optLen */, maximumBlobSize, rval)
case SetSockOptVal:
diff --git a/pkg/sentry/strace/syscalls.go b/pkg/sentry/strace/syscalls.go
index 446d1e0f6..7e69b9279 100644
--- a/pkg/sentry/strace/syscalls.go
+++ b/pkg/sentry/strace/syscalls.go
@@ -228,6 +228,16 @@ const (
// SockOptLevel is the optname argument in getsockopt(2) and
// setsockopt(2).
SockOptName
+
+ // EpollCtlOp is the op argument to epoll_ctl(2).
+ EpollCtlOp
+
+ // EpollEvent is the event argument in epoll_ctl(2).
+ EpollEvent
+
+ // EpollEvents is an array of struct epoll_event. It is the events
+ // argument in epoll_wait(2)/epoll_pwait(2).
+ EpollEvents
)
// defaultFormat is the syscall argument format to use if the actual format is
diff --git a/pkg/sentry/syscalls/linux/sys_epoll.go b/pkg/sentry/syscalls/linux/sys_epoll.go
index fbef5b376..3ab93fbde 100644
--- a/pkg/sentry/syscalls/linux/sys_epoll.go
+++ b/pkg/sentry/syscalls/linux/sys_epoll.go
@@ -25,6 +25,8 @@ import (
"gvisor.dev/gvisor/pkg/waiter"
)
+// LINT.IfChange
+
// EpollCreate1 implements the epoll_create1(2) linux syscall.
func EpollCreate1(t *kernel.Task, args arch.SyscallArguments) (uintptr, *kernel.SyscallControl, error) {
flags := args[0].Int()
@@ -164,3 +166,5 @@ func EpollPwait(t *kernel.Task, args arch.SyscallArguments) (uintptr, *kernel.Sy
return EpollWait(t, args)
}
+
+// LINT.ThenChange(vfs2/epoll.go)
diff --git a/pkg/sentry/syscalls/linux/sys_file.go b/pkg/sentry/syscalls/linux/sys_file.go
index 421845ebb..d10a9bed8 100644
--- a/pkg/sentry/syscalls/linux/sys_file.go
+++ b/pkg/sentry/syscalls/linux/sys_file.go
@@ -130,6 +130,8 @@ func copyInPath(t *kernel.Task, addr usermem.Addr, allowEmpty bool) (path string
return path, dirPath, nil
}
+// LINT.IfChange
+
func openAt(t *kernel.Task, dirFD int32, addr usermem.Addr, flags uint) (fd uintptr, err error) {
path, dirPath, err := copyInPath(t, addr, false /* allowEmpty */)
if err != nil {
@@ -575,6 +577,10 @@ func Faccessat(t *kernel.Task, args arch.SyscallArguments) (uintptr, *kernel.Sys
return 0, nil, accessAt(t, dirFD, addr, flags&linux.AT_SYMLINK_NOFOLLOW == 0, mode)
}
+// LINT.ThenChange(vfs2/filesystem.go)
+
+// LINT.IfChange
+
// Ioctl implements linux syscall ioctl(2).
func Ioctl(t *kernel.Task, args arch.SyscallArguments) (uintptr, *kernel.SyscallControl, error) {
fd := args[0].Int()
@@ -650,6 +656,10 @@ func Ioctl(t *kernel.Task, args arch.SyscallArguments) (uintptr, *kernel.Syscall
}
}
+// LINT.ThenChange(vfs2/ioctl.go)
+
+// LINT.IfChange
+
// Getcwd implements the linux syscall getcwd(2).
func Getcwd(t *kernel.Task, args arch.SyscallArguments) (uintptr, *kernel.SyscallControl, error) {
addr := args[0].Pointer()
@@ -760,6 +770,10 @@ func Fchdir(t *kernel.Task, args arch.SyscallArguments) (uintptr, *kernel.Syscal
return 0, nil, nil
}
+// LINT.ThenChange(vfs2/fscontext.go)
+
+// LINT.IfChange
+
// Close implements linux syscall close(2).
func Close(t *kernel.Task, args arch.SyscallArguments) (uintptr, *kernel.SyscallControl, error) {
fd := args[0].Int()
@@ -1094,6 +1108,8 @@ func Fcntl(t *kernel.Task, args arch.SyscallArguments) (uintptr, *kernel.Syscall
}
}
+// LINT.ThenChange(vfs2/fd.go)
+
const (
_FADV_NORMAL = 0
_FADV_RANDOM = 1
@@ -1141,6 +1157,8 @@ func Fadvise64(t *kernel.Task, args arch.SyscallArguments) (uintptr, *kernel.Sys
return 0, nil, nil
}
+// LINT.IfChange
+
func mkdirAt(t *kernel.Task, dirFD int32, addr usermem.Addr, mode linux.FileMode) error {
path, _, err := copyInPath(t, addr, false /* allowEmpty */)
if err != nil {
@@ -1218,7 +1236,7 @@ func rmdirAt(t *kernel.Task, dirFD int32, addr usermem.Addr) error {
return syserror.ENOTEMPTY
}
- if err := fs.MayDelete(t, root, d, name); err != nil {
+ if err := d.MayDelete(t, root, name); err != nil {
return err
}
@@ -1421,6 +1439,10 @@ func Linkat(t *kernel.Task, args arch.SyscallArguments) (uintptr, *kernel.Syscal
return 0, nil, linkAt(t, oldDirFD, oldAddr, newDirFD, newAddr, resolve, allowEmpty)
}
+// LINT.ThenChange(vfs2/filesystem.go)
+
+// LINT.IfChange
+
func readlinkAt(t *kernel.Task, dirFD int32, addr usermem.Addr, bufAddr usermem.Addr, size uint) (copied uintptr, err error) {
path, dirPath, err := copyInPath(t, addr, false /* allowEmpty */)
if err != nil {
@@ -1480,6 +1502,10 @@ func Readlinkat(t *kernel.Task, args arch.SyscallArguments) (uintptr, *kernel.Sy
return n, nil, err
}
+// LINT.ThenChange(vfs2/stat.go)
+
+// LINT.IfChange
+
func unlinkAt(t *kernel.Task, dirFD int32, addr usermem.Addr) error {
path, dirPath, err := copyInPath(t, addr, false /* allowEmpty */)
if err != nil {
@@ -1491,7 +1517,7 @@ func unlinkAt(t *kernel.Task, dirFD int32, addr usermem.Addr) error {
return syserror.ENOTDIR
}
- if err := fs.MayDelete(t, root, d, name); err != nil {
+ if err := d.MayDelete(t, root, name); err != nil {
return err
}
@@ -1516,6 +1542,10 @@ func Unlinkat(t *kernel.Task, args arch.SyscallArguments) (uintptr, *kernel.Sysc
return 0, nil, unlinkAt(t, dirFD, addr)
}
+// LINT.ThenChange(vfs2/filesystem.go)
+
+// LINT.IfChange
+
// Truncate implements linux syscall truncate(2).
func Truncate(t *kernel.Task, args arch.SyscallArguments) (uintptr, *kernel.SyscallControl, error) {
addr := args[0].Pointer()
@@ -1614,6 +1644,8 @@ func Ftruncate(t *kernel.Task, args arch.SyscallArguments) (uintptr, *kernel.Sys
return 0, nil, nil
}
+// LINT.ThenChange(vfs2/setstat.go)
+
// Umask implements linux syscall umask(2).
func Umask(t *kernel.Task, args arch.SyscallArguments) (uintptr, *kernel.SyscallControl, error) {
mask := args[0].ModeT()
@@ -1621,6 +1653,8 @@ func Umask(t *kernel.Task, args arch.SyscallArguments) (uintptr, *kernel.Syscall
return uintptr(mask), nil, nil
}
+// LINT.IfChange
+
// Change ownership of a file.
//
// uid and gid may be -1, in which case they will not be changed.
@@ -1987,6 +2021,10 @@ func Futimesat(t *kernel.Task, args arch.SyscallArguments) (uintptr, *kernel.Sys
return 0, nil, utimes(t, dirFD, pathnameAddr, ts, true)
}
+// LINT.ThenChange(vfs2/setstat.go)
+
+// LINT.IfChange
+
func renameAt(t *kernel.Task, oldDirFD int32, oldAddr usermem.Addr, newDirFD int32, newAddr usermem.Addr) error {
newPath, _, err := copyInPath(t, newAddr, false /* allowEmpty */)
if err != nil {
@@ -2042,6 +2080,8 @@ func Renameat(t *kernel.Task, args arch.SyscallArguments) (uintptr, *kernel.Sysc
return 0, nil, renameAt(t, oldDirFD, oldPathAddr, newDirFD, newPathAddr)
}
+// LINT.ThenChange(vfs2/filesystem.go)
+
// Fallocate implements linux system call fallocate(2).
func Fallocate(t *kernel.Task, args arch.SyscallArguments) (uintptr, *kernel.SyscallControl, error) {
fd := args[0].Int()
diff --git a/pkg/sentry/syscalls/linux/sys_getdents.go b/pkg/sentry/syscalls/linux/sys_getdents.go
index f66f4ffde..b126fecc0 100644
--- a/pkg/sentry/syscalls/linux/sys_getdents.go
+++ b/pkg/sentry/syscalls/linux/sys_getdents.go
@@ -27,6 +27,8 @@ import (
"gvisor.dev/gvisor/pkg/usermem"
)
+// LINT.IfChange
+
// Getdents implements linux syscall getdents(2) for 64bit systems.
func Getdents(t *kernel.Task, args arch.SyscallArguments) (uintptr, *kernel.SyscallControl, error) {
fd := args[0].Int()
@@ -244,3 +246,5 @@ func (ds *direntSerializer) CopyOut(name string, attr fs.DentAttr) error {
func (ds *direntSerializer) Written() int {
return ds.written
}
+
+// LINT.ThenChange(vfs2/getdents.go)
diff --git a/pkg/sentry/syscalls/linux/sys_lseek.go b/pkg/sentry/syscalls/linux/sys_lseek.go
index 297e920c4..3f7691eae 100644
--- a/pkg/sentry/syscalls/linux/sys_lseek.go
+++ b/pkg/sentry/syscalls/linux/sys_lseek.go
@@ -21,6 +21,8 @@ import (
"gvisor.dev/gvisor/pkg/syserror"
)
+// LINT.IfChange
+
// Lseek implements linux syscall lseek(2).
func Lseek(t *kernel.Task, args arch.SyscallArguments) (uintptr, *kernel.SyscallControl, error) {
fd := args[0].Int()
@@ -52,3 +54,5 @@ func Lseek(t *kernel.Task, args arch.SyscallArguments) (uintptr, *kernel.Syscall
}
return uintptr(offset), nil, err
}
+
+// LINT.ThenChange(vfs2/read_write.go)
diff --git a/pkg/sentry/syscalls/linux/sys_mmap.go b/pkg/sentry/syscalls/linux/sys_mmap.go
index 9959f6e61..91694d374 100644
--- a/pkg/sentry/syscalls/linux/sys_mmap.go
+++ b/pkg/sentry/syscalls/linux/sys_mmap.go
@@ -35,6 +35,8 @@ func Brk(t *kernel.Task, args arch.SyscallArguments) (uintptr, *kernel.SyscallCo
return uintptr(addr), nil, nil
}
+// LINT.IfChange
+
// Mmap implements linux syscall mmap(2).
func Mmap(t *kernel.Task, args arch.SyscallArguments) (uintptr, *kernel.SyscallControl, error) {
prot := args[2].Int()
@@ -104,6 +106,8 @@ func Mmap(t *kernel.Task, args arch.SyscallArguments) (uintptr, *kernel.SyscallC
return uintptr(rv), nil, err
}
+// LINT.ThenChange(vfs2/mmap.go)
+
// Munmap implements linux syscall munmap(2).
func Munmap(t *kernel.Task, args arch.SyscallArguments) (uintptr, *kernel.SyscallControl, error) {
return 0, nil, t.MemoryManager().MUnmap(t, args[0].Pointer(), args[1].Uint64())
diff --git a/pkg/sentry/syscalls/linux/sys_read.go b/pkg/sentry/syscalls/linux/sys_read.go
index 227692f06..78a2cb750 100644
--- a/pkg/sentry/syscalls/linux/sys_read.go
+++ b/pkg/sentry/syscalls/linux/sys_read.go
@@ -28,6 +28,8 @@ import (
"gvisor.dev/gvisor/pkg/waiter"
)
+// LINT.IfChange
+
const (
// EventMaskRead contains events that can be triggered on reads.
EventMaskRead = waiter.EventIn | waiter.EventHUp | waiter.EventErr
@@ -388,3 +390,5 @@ func preadv(t *kernel.Task, f *fs.File, dst usermem.IOSequence, offset int64) (i
return total, err
}
+
+// LINT.ThenChange(vfs2/read_write.go)
diff --git a/pkg/sentry/syscalls/linux/sys_stat.go b/pkg/sentry/syscalls/linux/sys_stat.go
index 8b66a9006..701b27b4a 100644
--- a/pkg/sentry/syscalls/linux/sys_stat.go
+++ b/pkg/sentry/syscalls/linux/sys_stat.go
@@ -23,6 +23,8 @@ import (
"gvisor.dev/gvisor/pkg/usermem"
)
+// LINT.IfChange
+
func statFromAttrs(t *kernel.Task, sattr fs.StableAttr, uattr fs.UnstableAttr) linux.Stat {
return linux.Stat{
Dev: sattr.DeviceID,
@@ -131,8 +133,7 @@ func stat(t *kernel.Task, d *fs.Dirent, dirPath bool, statAddr usermem.Addr) err
return err
}
s := statFromAttrs(t, d.Inode.StableAttr, uattr)
- _, err = s.CopyOut(t, statAddr)
- return err
+ return s.CopyOut(t, statAddr)
}
// fstat implements fstat for the given *fs.File.
@@ -142,8 +143,7 @@ func fstat(t *kernel.Task, f *fs.File, statAddr usermem.Addr) error {
return err
}
s := statFromAttrs(t, f.Dirent.Inode.StableAttr, uattr)
- _, err = s.CopyOut(t, statAddr)
- return err
+ return s.CopyOut(t, statAddr)
}
// Statx implements linux syscall statx(2).
@@ -299,3 +299,5 @@ func statfsImpl(t *kernel.Task, d *fs.Dirent, addr usermem.Addr) error {
_, err = t.CopyOut(addr, &statfs)
return err
}
+
+// LINT.ThenChange(vfs2/stat.go)
diff --git a/pkg/sentry/syscalls/linux/sys_sync.go b/pkg/sentry/syscalls/linux/sys_sync.go
index 3e55235bd..5ad465ae3 100644
--- a/pkg/sentry/syscalls/linux/sys_sync.go
+++ b/pkg/sentry/syscalls/linux/sys_sync.go
@@ -22,6 +22,8 @@ import (
"gvisor.dev/gvisor/pkg/syserror"
)
+// LINT.IfChange
+
// Sync implements linux system call sync(2).
func Sync(t *kernel.Task, args arch.SyscallArguments) (uintptr, *kernel.SyscallControl, error) {
t.MountNamespace().SyncAll(t)
@@ -135,3 +137,5 @@ func SyncFileRange(t *kernel.Task, args arch.SyscallArguments) (uintptr, *kernel
return 0, nil, syserror.ConvertIntr(err, kernel.ERESTARTSYS)
}
+
+// LINT.ThenChange(vfs2/sync.go)
diff --git a/pkg/sentry/syscalls/linux/sys_write.go b/pkg/sentry/syscalls/linux/sys_write.go
index aba892939..506ee54ce 100644
--- a/pkg/sentry/syscalls/linux/sys_write.go
+++ b/pkg/sentry/syscalls/linux/sys_write.go
@@ -28,6 +28,8 @@ import (
"gvisor.dev/gvisor/pkg/waiter"
)
+// LINT.IfChange
+
const (
// EventMaskWrite contains events that can be triggered on writes.
//
@@ -358,3 +360,5 @@ func pwritev(t *kernel.Task, f *fs.File, src usermem.IOSequence, offset int64) (
return total, err
}
+
+// LINT.ThenChange(vfs2/read_write.go)
diff --git a/pkg/sentry/syscalls/linux/sys_xattr.go b/pkg/sentry/syscalls/linux/sys_xattr.go
index 9d8140b8a..2de5e3422 100644
--- a/pkg/sentry/syscalls/linux/sys_xattr.go
+++ b/pkg/sentry/syscalls/linux/sys_xattr.go
@@ -25,6 +25,8 @@ import (
"gvisor.dev/gvisor/pkg/usermem"
)
+// LINT.IfChange
+
// GetXattr implements linux syscall getxattr(2).
func GetXattr(t *kernel.Task, args arch.SyscallArguments) (uintptr, *kernel.SyscallControl, error) {
return getXattrFromPath(t, args, true)
@@ -418,3 +420,5 @@ func removeXattr(t *kernel.Task, d *fs.Dirent, nameAddr usermem.Addr) error {
return d.Inode.RemoveXattr(t, d, name)
}
+
+// LINT.ThenChange(vfs2/xattr.go)
diff --git a/pkg/sentry/syscalls/linux/vfs2/BUILD b/pkg/sentry/syscalls/linux/vfs2/BUILD
index 6b8a00b6e..f51761e81 100644
--- a/pkg/sentry/syscalls/linux/vfs2/BUILD
+++ b/pkg/sentry/syscalls/linux/vfs2/BUILD
@@ -5,18 +5,44 @@ package(licenses = ["notice"])
go_library(
name = "vfs2",
srcs = [
+ "epoll.go",
+ "epoll_unsafe.go",
+ "execve.go",
+ "fd.go",
+ "filesystem.go",
+ "fscontext.go",
+ "getdents.go",
+ "ioctl.go",
"linux64.go",
"linux64_override_amd64.go",
"linux64_override_arm64.go",
- "sys_read.go",
+ "mmap.go",
+ "path.go",
+ "poll.go",
+ "read_write.go",
+ "setstat.go",
+ "stat.go",
+ "sync.go",
+ "xattr.go",
],
+ marshal = True,
visibility = ["//:sandbox"],
deps = [
+ "//pkg/abi/linux",
+ "//pkg/fspath",
+ "//pkg/gohacks",
"//pkg/sentry/arch",
+ "//pkg/sentry/fsbridge",
"//pkg/sentry/kernel",
+ "//pkg/sentry/kernel/auth",
+ "//pkg/sentry/kernel/time",
+ "//pkg/sentry/limits",
+ "//pkg/sentry/loader",
+ "//pkg/sentry/memmap",
"//pkg/sentry/syscalls",
"//pkg/sentry/syscalls/linux",
"//pkg/sentry/vfs",
+ "//pkg/sync",
"//pkg/syserror",
"//pkg/usermem",
"//pkg/waiter",
diff --git a/pkg/sentry/syscalls/linux/vfs2/epoll.go b/pkg/sentry/syscalls/linux/vfs2/epoll.go
new file mode 100644
index 000000000..d6cb0e79a
--- /dev/null
+++ b/pkg/sentry/syscalls/linux/vfs2/epoll.go
@@ -0,0 +1,225 @@
+// Copyright 2020 The gVisor Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+package vfs2
+
+import (
+ "math"
+ "time"
+
+ "gvisor.dev/gvisor/pkg/abi/linux"
+ "gvisor.dev/gvisor/pkg/sentry/arch"
+ "gvisor.dev/gvisor/pkg/sentry/kernel"
+ ktime "gvisor.dev/gvisor/pkg/sentry/kernel/time"
+ "gvisor.dev/gvisor/pkg/sentry/vfs"
+ "gvisor.dev/gvisor/pkg/syserror"
+ "gvisor.dev/gvisor/pkg/usermem"
+ "gvisor.dev/gvisor/pkg/waiter"
+)
+
+// EpollCreate1 implements Linux syscall epoll_create1(2).
+func EpollCreate1(t *kernel.Task, args arch.SyscallArguments) (uintptr, *kernel.SyscallControl, error) {
+ flags := args[0].Int()
+ if flags&^linux.EPOLL_CLOEXEC != 0 {
+ return 0, nil, syserror.EINVAL
+ }
+
+ file, err := t.Kernel().VFS().NewEpollInstanceFD()
+ if err != nil {
+ return 0, nil, err
+ }
+ defer file.DecRef()
+
+ fd, err := t.NewFDFromVFS2(0, file, kernel.FDFlags{
+ CloseOnExec: flags&linux.EPOLL_CLOEXEC != 0,
+ })
+ if err != nil {
+ return 0, nil, err
+ }
+ return uintptr(fd), nil, nil
+}
+
+// EpollCreate implements Linux syscall epoll_create(2).
+func EpollCreate(t *kernel.Task, args arch.SyscallArguments) (uintptr, *kernel.SyscallControl, error) {
+ size := args[0].Int()
+
+ // "Since Linux 2.6.8, the size argument is ignored, but must be greater
+ // than zero" - epoll_create(2)
+ if size <= 0 {
+ return 0, nil, syserror.EINVAL
+ }
+
+ file, err := t.Kernel().VFS().NewEpollInstanceFD()
+ if err != nil {
+ return 0, nil, err
+ }
+ defer file.DecRef()
+
+ fd, err := t.NewFDFromVFS2(0, file, kernel.FDFlags{})
+ if err != nil {
+ return 0, nil, err
+ }
+ return uintptr(fd), nil, nil
+}
+
+// EpollCtl implements Linux syscall epoll_ctl(2).
+func EpollCtl(t *kernel.Task, args arch.SyscallArguments) (uintptr, *kernel.SyscallControl, error) {
+ epfd := args[0].Int()
+ op := args[1].Int()
+ fd := args[2].Int()
+ eventAddr := args[3].Pointer()
+
+ epfile := t.GetFileVFS2(epfd)
+ if epfile == nil {
+ return 0, nil, syserror.EBADF
+ }
+ defer epfile.DecRef()
+ ep, ok := epfile.Impl().(*vfs.EpollInstance)
+ if !ok {
+ return 0, nil, syserror.EINVAL
+ }
+ file := t.GetFileVFS2(fd)
+ if file == nil {
+ return 0, nil, syserror.EBADF
+ }
+ defer file.DecRef()
+ if epfile == file {
+ return 0, nil, syserror.EINVAL
+ }
+
+ var event linux.EpollEvent
+ switch op {
+ case linux.EPOLL_CTL_ADD:
+ if err := event.CopyIn(t, eventAddr); err != nil {
+ return 0, nil, err
+ }
+ return 0, nil, ep.AddInterest(file, fd, event)
+ case linux.EPOLL_CTL_DEL:
+ return 0, nil, ep.DeleteInterest(file, fd)
+ case linux.EPOLL_CTL_MOD:
+ if err := event.CopyIn(t, eventAddr); err != nil {
+ return 0, nil, err
+ }
+ return 0, nil, ep.ModifyInterest(file, fd, event)
+ default:
+ return 0, nil, syserror.EINVAL
+ }
+}
+
+// EpollWait implements Linux syscall epoll_wait(2).
+func EpollWait(t *kernel.Task, args arch.SyscallArguments) (uintptr, *kernel.SyscallControl, error) {
+ epfd := args[0].Int()
+ eventsAddr := args[1].Pointer()
+ maxEvents := int(args[2].Int())
+ timeout := int(args[3].Int())
+
+ const _EP_MAX_EVENTS = math.MaxInt32 / sizeofEpollEvent // Linux: fs/eventpoll.c:EP_MAX_EVENTS
+ if maxEvents <= 0 || maxEvents > _EP_MAX_EVENTS {
+ return 0, nil, syserror.EINVAL
+ }
+
+ epfile := t.GetFileVFS2(epfd)
+ if epfile == nil {
+ return 0, nil, syserror.EBADF
+ }
+ defer epfile.DecRef()
+ ep, ok := epfile.Impl().(*vfs.EpollInstance)
+ if !ok {
+ return 0, nil, syserror.EINVAL
+ }
+
+ // Use a fixed-size buffer in a loop, instead of make([]linux.EpollEvent,
+ // maxEvents), so that the buffer can be allocated on the stack.
+ var (
+ events [16]linux.EpollEvent
+ total int
+ ch chan struct{}
+ haveDeadline bool
+ deadline ktime.Time
+ )
+ for {
+ batchEvents := len(events)
+ if batchEvents > maxEvents {
+ batchEvents = maxEvents
+ }
+ n := ep.ReadEvents(events[:batchEvents])
+ maxEvents -= n
+ if n != 0 {
+ // Copy what we read out.
+ copiedEvents, err := copyOutEvents(t, eventsAddr, events[:n])
+ eventsAddr += usermem.Addr(copiedEvents * sizeofEpollEvent)
+ total += copiedEvents
+ if err != nil {
+ if total != 0 {
+ return uintptr(total), nil, nil
+ }
+ return 0, nil, err
+ }
+ // If we've filled the application's event buffer, we're done.
+ if maxEvents == 0 {
+ return uintptr(total), nil, nil
+ }
+ // Loop if we read a full batch, under the expectation that there
+ // may be more events to read.
+ if n == batchEvents {
+ continue
+ }
+ }
+ // We get here if n != batchEvents. If we read any number of events
+ // (just now, or in a previous iteration of this loop), or if timeout
+ // is 0 (such that epoll_wait should be non-blocking), return the
+ // events we've read so far to the application.
+ if total != 0 || timeout == 0 {
+ return uintptr(total), nil, nil
+ }
+ // In the first iteration of this loop, register with the epoll
+ // instance for readability events, but then immediately continue the
+ // loop since we need to retry ReadEvents() before blocking. In all
+ // subsequent iterations, block until events are available, the timeout
+ // expires, or an interrupt arrives.
+ if ch == nil {
+ var w waiter.Entry
+ w, ch = waiter.NewChannelEntry(nil)
+ epfile.EventRegister(&w, waiter.EventIn)
+ defer epfile.EventUnregister(&w)
+ } else {
+ // Set up the timer if a timeout was specified.
+ if timeout > 0 && !haveDeadline {
+ timeoutDur := time.Duration(timeout) * time.Millisecond
+ deadline = t.Kernel().MonotonicClock().Now().Add(timeoutDur)
+ haveDeadline = true
+ }
+ if err := t.BlockWithDeadline(ch, haveDeadline, deadline); err != nil {
+ if err == syserror.ETIMEDOUT {
+ err = nil
+ }
+ // total must be 0 since otherwise we would have returned
+ // above.
+ return 0, nil, err
+ }
+ }
+ }
+}
+
+// EpollPwait implements Linux syscall epoll_pwait(2).
+func EpollPwait(t *kernel.Task, args arch.SyscallArguments) (uintptr, *kernel.SyscallControl, error) {
+ maskAddr := args[4].Pointer()
+ maskSize := uint(args[5].Uint())
+
+ if err := setTempSignalSet(t, maskAddr, maskSize); err != nil {
+ return 0, nil, err
+ }
+
+ return EpollWait(t, args)
+}
diff --git a/pkg/sentry/syscalls/linux/vfs2/epoll_unsafe.go b/pkg/sentry/syscalls/linux/vfs2/epoll_unsafe.go
new file mode 100644
index 000000000..825f325bf
--- /dev/null
+++ b/pkg/sentry/syscalls/linux/vfs2/epoll_unsafe.go
@@ -0,0 +1,44 @@
+// Copyright 2020 The gVisor Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+package vfs2
+
+import (
+ "reflect"
+ "runtime"
+ "unsafe"
+
+ "gvisor.dev/gvisor/pkg/abi/linux"
+ "gvisor.dev/gvisor/pkg/gohacks"
+ "gvisor.dev/gvisor/pkg/sentry/kernel"
+ "gvisor.dev/gvisor/pkg/usermem"
+)
+
+const sizeofEpollEvent = int(unsafe.Sizeof(linux.EpollEvent{}))
+
+func copyOutEvents(t *kernel.Task, addr usermem.Addr, events []linux.EpollEvent) (int, error) {
+ if len(events) == 0 {
+ return 0, nil
+ }
+ // Cast events to a byte slice for copying.
+ var eventBytes []byte
+ eventBytesHdr := (*reflect.SliceHeader)(unsafe.Pointer(&eventBytes))
+ eventBytesHdr.Data = uintptr(gohacks.Noescape(unsafe.Pointer(&events[0])))
+ eventBytesHdr.Len = len(events) * sizeofEpollEvent
+ eventBytesHdr.Cap = len(events) * sizeofEpollEvent
+ copiedBytes, err := t.CopyOutBytes(addr, eventBytes)
+ runtime.KeepAlive(events)
+ copiedEvents := copiedBytes / sizeofEpollEvent // rounded down
+ return copiedEvents, err
+}
diff --git a/pkg/sentry/syscalls/linux/vfs2/execve.go b/pkg/sentry/syscalls/linux/vfs2/execve.go
new file mode 100644
index 000000000..aef0078a8
--- /dev/null
+++ b/pkg/sentry/syscalls/linux/vfs2/execve.go
@@ -0,0 +1,137 @@
+// Copyright 2020 The gVisor Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+package vfs2
+
+import (
+ "gvisor.dev/gvisor/pkg/abi/linux"
+ "gvisor.dev/gvisor/pkg/fspath"
+ "gvisor.dev/gvisor/pkg/sentry/arch"
+ "gvisor.dev/gvisor/pkg/sentry/fsbridge"
+ "gvisor.dev/gvisor/pkg/sentry/kernel"
+ "gvisor.dev/gvisor/pkg/sentry/loader"
+ slinux "gvisor.dev/gvisor/pkg/sentry/syscalls/linux"
+ "gvisor.dev/gvisor/pkg/sentry/vfs"
+ "gvisor.dev/gvisor/pkg/syserror"
+ "gvisor.dev/gvisor/pkg/usermem"
+)
+
+// Execve implements linux syscall execve(2).
+func Execve(t *kernel.Task, args arch.SyscallArguments) (uintptr, *kernel.SyscallControl, error) {
+ pathnameAddr := args[0].Pointer()
+ argvAddr := args[1].Pointer()
+ envvAddr := args[2].Pointer()
+ return execveat(t, linux.AT_FDCWD, pathnameAddr, argvAddr, envvAddr, 0 /* flags */)
+}
+
+// Execveat implements linux syscall execveat(2).
+func Execveat(t *kernel.Task, args arch.SyscallArguments) (uintptr, *kernel.SyscallControl, error) {
+ dirfd := args[0].Int()
+ pathnameAddr := args[1].Pointer()
+ argvAddr := args[2].Pointer()
+ envvAddr := args[3].Pointer()
+ flags := args[4].Int()
+ return execveat(t, dirfd, pathnameAddr, argvAddr, envvAddr, flags)
+}
+
+func execveat(t *kernel.Task, dirfd int32, pathnameAddr, argvAddr, envvAddr usermem.Addr, flags int32) (uintptr, *kernel.SyscallControl, error) {
+ if flags&^(linux.AT_EMPTY_PATH|linux.AT_SYMLINK_NOFOLLOW) != 0 {
+ return 0, nil, syserror.EINVAL
+ }
+
+ pathname, err := t.CopyInString(pathnameAddr, linux.PATH_MAX)
+ if err != nil {
+ return 0, nil, err
+ }
+ var argv, envv []string
+ if argvAddr != 0 {
+ var err error
+ argv, err = t.CopyInVector(argvAddr, slinux.ExecMaxElemSize, slinux.ExecMaxTotalSize)
+ if err != nil {
+ return 0, nil, err
+ }
+ }
+ if envvAddr != 0 {
+ var err error
+ envv, err = t.CopyInVector(envvAddr, slinux.ExecMaxElemSize, slinux.ExecMaxTotalSize)
+ if err != nil {
+ return 0, nil, err
+ }
+ }
+
+ root := t.FSContext().RootDirectoryVFS2()
+ defer root.DecRef()
+ var executable fsbridge.File
+ closeOnExec := false
+ if path := fspath.Parse(pathname); dirfd != linux.AT_FDCWD && !path.Absolute {
+ // We must open the executable ourselves since dirfd is used as the
+ // starting point while resolving path, but the task working directory
+ // is used as the starting point while resolving interpreters (Linux:
+ // fs/binfmt_script.c:load_script() => fs/exec.c:open_exec() =>
+ // do_open_execat(fd=AT_FDCWD)), and the loader package is currently
+ // incapable of handling this correctly.
+ if !path.HasComponents() && flags&linux.AT_EMPTY_PATH == 0 {
+ return 0, nil, syserror.ENOENT
+ }
+ dirfile, dirfileFlags := t.FDTable().GetVFS2(dirfd)
+ if dirfile == nil {
+ return 0, nil, syserror.EBADF
+ }
+ start := dirfile.VirtualDentry()
+ start.IncRef()
+ dirfile.DecRef()
+ closeOnExec = dirfileFlags.CloseOnExec
+ file, err := t.Kernel().VFS().OpenAt(t, t.Credentials(), &vfs.PathOperation{
+ Root: root,
+ Start: start,
+ Path: path,
+ FollowFinalSymlink: flags&linux.AT_SYMLINK_NOFOLLOW == 0,
+ }, &vfs.OpenOptions{
+ Flags: linux.O_RDONLY,
+ FileExec: true,
+ })
+ start.DecRef()
+ if err != nil {
+ return 0, nil, err
+ }
+ defer file.DecRef()
+ executable = fsbridge.NewVFSFile(file)
+ }
+
+ // Load the new TaskContext.
+ mntns := t.MountNamespaceVFS2() // FIXME(jamieliu): useless refcount change
+ defer mntns.DecRef()
+ wd := t.FSContext().WorkingDirectoryVFS2()
+ defer wd.DecRef()
+ remainingTraversals := uint(linux.MaxSymlinkTraversals)
+ loadArgs := loader.LoadArgs{
+ Opener: fsbridge.NewVFSLookup(mntns, root, wd),
+ RemainingTraversals: &remainingTraversals,
+ ResolveFinal: flags&linux.AT_SYMLINK_NOFOLLOW == 0,
+ Filename: pathname,
+ File: executable,
+ CloseOnExec: closeOnExec,
+ Argv: argv,
+ Envv: envv,
+ Features: t.Arch().FeatureSet(),
+ }
+
+ tc, se := t.Kernel().LoadTaskImage(t, loadArgs)
+ if se != nil {
+ return 0, nil, se.ToError()
+ }
+
+ ctrl, err := t.Execve(tc)
+ return 0, ctrl, err
+}
diff --git a/pkg/sentry/syscalls/linux/vfs2/fd.go b/pkg/sentry/syscalls/linux/vfs2/fd.go
new file mode 100644
index 000000000..3afcea665
--- /dev/null
+++ b/pkg/sentry/syscalls/linux/vfs2/fd.go
@@ -0,0 +1,147 @@
+// Copyright 2020 The gVisor Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+package vfs2
+
+import (
+ "gvisor.dev/gvisor/pkg/abi/linux"
+ "gvisor.dev/gvisor/pkg/sentry/arch"
+ "gvisor.dev/gvisor/pkg/sentry/kernel"
+ slinux "gvisor.dev/gvisor/pkg/sentry/syscalls/linux"
+ "gvisor.dev/gvisor/pkg/syserror"
+)
+
+// Close implements Linux syscall close(2).
+func Close(t *kernel.Task, args arch.SyscallArguments) (uintptr, *kernel.SyscallControl, error) {
+ fd := args[0].Int()
+
+ // Note that Remove provides a reference on the file that we may use to
+ // flush. It is still active until we drop the final reference below
+ // (and other reference-holding operations complete).
+ _, file := t.FDTable().Remove(fd)
+ if file == nil {
+ return 0, nil, syserror.EBADF
+ }
+ defer file.DecRef()
+
+ err := file.OnClose(t)
+ return 0, nil, slinux.HandleIOErrorVFS2(t, false /* partial */, err, syserror.EINTR, "close", file)
+}
+
+// Dup implements Linux syscall dup(2).
+func Dup(t *kernel.Task, args arch.SyscallArguments) (uintptr, *kernel.SyscallControl, error) {
+ fd := args[0].Int()
+
+ file := t.GetFileVFS2(fd)
+ if file == nil {
+ return 0, nil, syserror.EBADF
+ }
+ defer file.DecRef()
+
+ newFD, err := t.NewFDFromVFS2(0, file, kernel.FDFlags{})
+ if err != nil {
+ return 0, nil, syserror.EMFILE
+ }
+ return uintptr(newFD), nil, nil
+}
+
+// Dup2 implements Linux syscall dup2(2).
+func Dup2(t *kernel.Task, args arch.SyscallArguments) (uintptr, *kernel.SyscallControl, error) {
+ oldfd := args[0].Int()
+ newfd := args[1].Int()
+
+ if oldfd == newfd {
+ // As long as oldfd is valid, dup2() does nothing and returns newfd.
+ file := t.GetFileVFS2(oldfd)
+ if file == nil {
+ return 0, nil, syserror.EBADF
+ }
+ file.DecRef()
+ return uintptr(newfd), nil, nil
+ }
+
+ return dup3(t, oldfd, newfd, 0)
+}
+
+// Dup3 implements Linux syscall dup3(2).
+func Dup3(t *kernel.Task, args arch.SyscallArguments) (uintptr, *kernel.SyscallControl, error) {
+ oldfd := args[0].Int()
+ newfd := args[1].Int()
+ flags := args[2].Uint()
+
+ if oldfd == newfd {
+ return 0, nil, syserror.EINVAL
+ }
+
+ return dup3(t, oldfd, newfd, flags)
+}
+
+func dup3(t *kernel.Task, oldfd, newfd int32, flags uint32) (uintptr, *kernel.SyscallControl, error) {
+ if flags&^linux.O_CLOEXEC != 0 {
+ return 0, nil, syserror.EINVAL
+ }
+
+ file := t.GetFileVFS2(oldfd)
+ if file == nil {
+ return 0, nil, syserror.EBADF
+ }
+ defer file.DecRef()
+
+ err := t.NewFDAtVFS2(newfd, file, kernel.FDFlags{
+ CloseOnExec: flags&linux.O_CLOEXEC != 0,
+ })
+ if err != nil {
+ return 0, nil, err
+ }
+ return uintptr(newfd), nil, nil
+}
+
+// Fcntl implements linux syscall fcntl(2).
+func Fcntl(t *kernel.Task, args arch.SyscallArguments) (uintptr, *kernel.SyscallControl, error) {
+ fd := args[0].Int()
+ cmd := args[1].Int()
+
+ file, flags := t.FDTable().GetVFS2(fd)
+ if file == nil {
+ return 0, nil, syserror.EBADF
+ }
+ defer file.DecRef()
+
+ switch cmd {
+ case linux.F_DUPFD, linux.F_DUPFD_CLOEXEC:
+ minfd := args[2].Int()
+ fd, err := t.NewFDFromVFS2(minfd, file, kernel.FDFlags{
+ CloseOnExec: cmd == linux.F_DUPFD_CLOEXEC,
+ })
+ if err != nil {
+ return 0, nil, err
+ }
+ return uintptr(fd), nil, nil
+ case linux.F_GETFD:
+ return uintptr(flags.ToLinuxFDFlags()), nil, nil
+ case linux.F_SETFD:
+ flags := args[2].Uint()
+ t.FDTable().SetFlags(fd, kernel.FDFlags{
+ CloseOnExec: flags&linux.FD_CLOEXEC != 0,
+ })
+ return 0, nil, nil
+ case linux.F_GETFL:
+ return uintptr(file.StatusFlags()), nil, nil
+ case linux.F_SETFL:
+ return 0, nil, file.SetStatusFlags(t, t.Credentials(), args[2].Uint())
+ default:
+ // TODO(gvisor.dev/issue/1623): Everything else is not yet supported.
+ return 0, nil, syserror.EINVAL
+ }
+}
diff --git a/pkg/sentry/syscalls/linux/vfs2/filesystem.go b/pkg/sentry/syscalls/linux/vfs2/filesystem.go
new file mode 100644
index 000000000..fc5ceea4c
--- /dev/null
+++ b/pkg/sentry/syscalls/linux/vfs2/filesystem.go
@@ -0,0 +1,326 @@
+// Copyright 2020 The gVisor Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+package vfs2
+
+import (
+ "gvisor.dev/gvisor/pkg/abi/linux"
+ "gvisor.dev/gvisor/pkg/sentry/arch"
+ "gvisor.dev/gvisor/pkg/sentry/kernel"
+ "gvisor.dev/gvisor/pkg/sentry/vfs"
+ "gvisor.dev/gvisor/pkg/syserror"
+ "gvisor.dev/gvisor/pkg/usermem"
+)
+
+// Link implements Linux syscall link(2).
+func Link(t *kernel.Task, args arch.SyscallArguments) (uintptr, *kernel.SyscallControl, error) {
+ oldpathAddr := args[0].Pointer()
+ newpathAddr := args[1].Pointer()
+ return 0, nil, linkat(t, linux.AT_FDCWD, oldpathAddr, linux.AT_FDCWD, newpathAddr, 0 /* flags */)
+}
+
+// Linkat implements Linux syscall linkat(2).
+func Linkat(t *kernel.Task, args arch.SyscallArguments) (uintptr, *kernel.SyscallControl, error) {
+ olddirfd := args[0].Int()
+ oldpathAddr := args[1].Pointer()
+ newdirfd := args[2].Int()
+ newpathAddr := args[3].Pointer()
+ flags := args[4].Int()
+ return 0, nil, linkat(t, olddirfd, oldpathAddr, newdirfd, newpathAddr, flags)
+}
+
+func linkat(t *kernel.Task, olddirfd int32, oldpathAddr usermem.Addr, newdirfd int32, newpathAddr usermem.Addr, flags int32) error {
+ if flags&^(linux.AT_EMPTY_PATH|linux.AT_SYMLINK_FOLLOW) != 0 {
+ return syserror.EINVAL
+ }
+ if flags&linux.AT_EMPTY_PATH != 0 && !t.HasCapability(linux.CAP_DAC_READ_SEARCH) {
+ return syserror.ENOENT
+ }
+
+ oldpath, err := copyInPath(t, oldpathAddr)
+ if err != nil {
+ return err
+ }
+ oldtpop, err := getTaskPathOperation(t, olddirfd, oldpath, shouldAllowEmptyPath(flags&linux.AT_EMPTY_PATH != 0), shouldFollowFinalSymlink(flags&linux.AT_SYMLINK_FOLLOW != 0))
+ if err != nil {
+ return err
+ }
+ defer oldtpop.Release()
+
+ newpath, err := copyInPath(t, newpathAddr)
+ if err != nil {
+ return err
+ }
+ newtpop, err := getTaskPathOperation(t, newdirfd, newpath, disallowEmptyPath, nofollowFinalSymlink)
+ if err != nil {
+ return err
+ }
+ defer newtpop.Release()
+
+ return t.Kernel().VFS().LinkAt(t, t.Credentials(), &oldtpop.pop, &newtpop.pop)
+}
+
+// Mkdir implements Linux syscall mkdir(2).
+func Mkdir(t *kernel.Task, args arch.SyscallArguments) (uintptr, *kernel.SyscallControl, error) {
+ addr := args[0].Pointer()
+ mode := args[1].ModeT()
+ return 0, nil, mkdirat(t, linux.AT_FDCWD, addr, mode)
+}
+
+// Mkdirat implements Linux syscall mkdirat(2).
+func Mkdirat(t *kernel.Task, args arch.SyscallArguments) (uintptr, *kernel.SyscallControl, error) {
+ dirfd := args[0].Int()
+ addr := args[1].Pointer()
+ mode := args[2].ModeT()
+ return 0, nil, mkdirat(t, dirfd, addr, mode)
+}
+
+func mkdirat(t *kernel.Task, dirfd int32, addr usermem.Addr, mode uint) error {
+ path, err := copyInPath(t, addr)
+ if err != nil {
+ return err
+ }
+ tpop, err := getTaskPathOperation(t, dirfd, path, disallowEmptyPath, nofollowFinalSymlink)
+ if err != nil {
+ return err
+ }
+ defer tpop.Release()
+ return t.Kernel().VFS().MkdirAt(t, t.Credentials(), &tpop.pop, &vfs.MkdirOptions{
+ Mode: linux.FileMode(mode & (0777 | linux.S_ISVTX) &^ t.FSContext().Umask()),
+ })
+}
+
+// Mknod implements Linux syscall mknod(2).
+func Mknod(t *kernel.Task, args arch.SyscallArguments) (uintptr, *kernel.SyscallControl, error) {
+ addr := args[0].Pointer()
+ mode := args[1].ModeT()
+ dev := args[2].Uint()
+ return 0, nil, mknodat(t, linux.AT_FDCWD, addr, mode, dev)
+}
+
+// Mknodat implements Linux syscall mknodat(2).
+func Mknodat(t *kernel.Task, args arch.SyscallArguments) (uintptr, *kernel.SyscallControl, error) {
+ dirfd := args[0].Int()
+ addr := args[1].Pointer()
+ mode := args[2].ModeT()
+ dev := args[3].Uint()
+ return 0, nil, mknodat(t, dirfd, addr, mode, dev)
+}
+
+func mknodat(t *kernel.Task, dirfd int32, addr usermem.Addr, mode uint, dev uint32) error {
+ path, err := copyInPath(t, addr)
+ if err != nil {
+ return err
+ }
+ tpop, err := getTaskPathOperation(t, dirfd, path, disallowEmptyPath, nofollowFinalSymlink)
+ if err != nil {
+ return err
+ }
+ defer tpop.Release()
+ major, minor := linux.DecodeDeviceID(dev)
+ return t.Kernel().VFS().MknodAt(t, t.Credentials(), &tpop.pop, &vfs.MknodOptions{
+ Mode: linux.FileMode(mode &^ t.FSContext().Umask()),
+ DevMajor: uint32(major),
+ DevMinor: minor,
+ })
+}
+
+// Open implements Linux syscall open(2).
+func Open(t *kernel.Task, args arch.SyscallArguments) (uintptr, *kernel.SyscallControl, error) {
+ addr := args[0].Pointer()
+ flags := args[1].Uint()
+ mode := args[2].ModeT()
+ return openat(t, linux.AT_FDCWD, addr, flags, mode)
+}
+
+// Openat implements Linux syscall openat(2).
+func Openat(t *kernel.Task, args arch.SyscallArguments) (uintptr, *kernel.SyscallControl, error) {
+ dirfd := args[0].Int()
+ addr := args[1].Pointer()
+ flags := args[2].Uint()
+ mode := args[3].ModeT()
+ return openat(t, dirfd, addr, flags, mode)
+}
+
+// Creat implements Linux syscall creat(2).
+func Creat(t *kernel.Task, args arch.SyscallArguments) (uintptr, *kernel.SyscallControl, error) {
+ addr := args[0].Pointer()
+ mode := args[1].ModeT()
+ return openat(t, linux.AT_FDCWD, addr, linux.O_WRONLY|linux.O_CREAT|linux.O_TRUNC, mode)
+}
+
+func openat(t *kernel.Task, dirfd int32, pathAddr usermem.Addr, flags uint32, mode uint) (uintptr, *kernel.SyscallControl, error) {
+ path, err := copyInPath(t, pathAddr)
+ if err != nil {
+ return 0, nil, err
+ }
+ tpop, err := getTaskPathOperation(t, dirfd, path, disallowEmptyPath, shouldFollowFinalSymlink(flags&linux.O_NOFOLLOW == 0))
+ if err != nil {
+ return 0, nil, err
+ }
+ defer tpop.Release()
+
+ file, err := t.Kernel().VFS().OpenAt(t, t.Credentials(), &tpop.pop, &vfs.OpenOptions{
+ Flags: flags,
+ Mode: linux.FileMode(mode & (0777 | linux.S_ISUID | linux.S_ISGID | linux.S_ISVTX) &^ t.FSContext().Umask()),
+ })
+ if err != nil {
+ return 0, nil, err
+ }
+ defer file.DecRef()
+
+ fd, err := t.NewFDFromVFS2(0, file, kernel.FDFlags{
+ CloseOnExec: flags&linux.O_CLOEXEC != 0,
+ })
+ return uintptr(fd), nil, err
+}
+
+// Rename implements Linux syscall rename(2).
+func Rename(t *kernel.Task, args arch.SyscallArguments) (uintptr, *kernel.SyscallControl, error) {
+ oldpathAddr := args[0].Pointer()
+ newpathAddr := args[1].Pointer()
+ return 0, nil, renameat(t, linux.AT_FDCWD, oldpathAddr, linux.AT_FDCWD, newpathAddr, 0 /* flags */)
+}
+
+// Renameat implements Linux syscall renameat(2).
+func Renameat(t *kernel.Task, args arch.SyscallArguments) (uintptr, *kernel.SyscallControl, error) {
+ olddirfd := args[0].Int()
+ oldpathAddr := args[1].Pointer()
+ newdirfd := args[2].Int()
+ newpathAddr := args[3].Pointer()
+ return 0, nil, renameat(t, olddirfd, oldpathAddr, newdirfd, newpathAddr, 0 /* flags */)
+}
+
+// Renameat2 implements Linux syscall renameat2(2).
+func Renameat2(t *kernel.Task, args arch.SyscallArguments) (uintptr, *kernel.SyscallControl, error) {
+ olddirfd := args[0].Int()
+ oldpathAddr := args[1].Pointer()
+ newdirfd := args[2].Int()
+ newpathAddr := args[3].Pointer()
+ flags := args[4].Uint()
+ return 0, nil, renameat(t, olddirfd, oldpathAddr, newdirfd, newpathAddr, flags)
+}
+
+func renameat(t *kernel.Task, olddirfd int32, oldpathAddr usermem.Addr, newdirfd int32, newpathAddr usermem.Addr, flags uint32) error {
+ oldpath, err := copyInPath(t, oldpathAddr)
+ if err != nil {
+ return err
+ }
+ // "If oldpath refers to a symbolic link, the link is renamed" - rename(2)
+ oldtpop, err := getTaskPathOperation(t, olddirfd, oldpath, disallowEmptyPath, nofollowFinalSymlink)
+ if err != nil {
+ return err
+ }
+ defer oldtpop.Release()
+
+ newpath, err := copyInPath(t, newpathAddr)
+ if err != nil {
+ return err
+ }
+ newtpop, err := getTaskPathOperation(t, newdirfd, newpath, disallowEmptyPath, nofollowFinalSymlink)
+ if err != nil {
+ return err
+ }
+ defer newtpop.Release()
+
+ return t.Kernel().VFS().RenameAt(t, t.Credentials(), &oldtpop.pop, &newtpop.pop, &vfs.RenameOptions{
+ Flags: flags,
+ })
+}
+
+// Rmdir implements Linux syscall rmdir(2).
+func Rmdir(t *kernel.Task, args arch.SyscallArguments) (uintptr, *kernel.SyscallControl, error) {
+ pathAddr := args[0].Pointer()
+ return 0, nil, rmdirat(t, linux.AT_FDCWD, pathAddr)
+}
+
+func rmdirat(t *kernel.Task, dirfd int32, pathAddr usermem.Addr) error {
+ path, err := copyInPath(t, pathAddr)
+ if err != nil {
+ return err
+ }
+ tpop, err := getTaskPathOperation(t, dirfd, path, disallowEmptyPath, followFinalSymlink)
+ if err != nil {
+ return err
+ }
+ defer tpop.Release()
+ return t.Kernel().VFS().RmdirAt(t, t.Credentials(), &tpop.pop)
+}
+
+// Unlink implements Linux syscall unlink(2).
+func Unlink(t *kernel.Task, args arch.SyscallArguments) (uintptr, *kernel.SyscallControl, error) {
+ pathAddr := args[0].Pointer()
+ return 0, nil, unlinkat(t, linux.AT_FDCWD, pathAddr)
+}
+
+func unlinkat(t *kernel.Task, dirfd int32, pathAddr usermem.Addr) error {
+ path, err := copyInPath(t, pathAddr)
+ if err != nil {
+ return err
+ }
+ tpop, err := getTaskPathOperation(t, dirfd, path, disallowEmptyPath, nofollowFinalSymlink)
+ if err != nil {
+ return err
+ }
+ defer tpop.Release()
+ return t.Kernel().VFS().UnlinkAt(t, t.Credentials(), &tpop.pop)
+}
+
+// Unlinkat implements Linux syscall unlinkat(2).
+func Unlinkat(t *kernel.Task, args arch.SyscallArguments) (uintptr, *kernel.SyscallControl, error) {
+ dirfd := args[0].Int()
+ pathAddr := args[1].Pointer()
+ flags := args[2].Int()
+
+ if flags&^linux.AT_REMOVEDIR != 0 {
+ return 0, nil, syserror.EINVAL
+ }
+
+ if flags&linux.AT_REMOVEDIR != 0 {
+ return 0, nil, rmdirat(t, dirfd, pathAddr)
+ }
+ return 0, nil, unlinkat(t, dirfd, pathAddr)
+}
+
+// Symlink implements Linux syscall symlink(2).
+func Symlink(t *kernel.Task, args arch.SyscallArguments) (uintptr, *kernel.SyscallControl, error) {
+ targetAddr := args[0].Pointer()
+ linkpathAddr := args[1].Pointer()
+ return 0, nil, symlinkat(t, targetAddr, linux.AT_FDCWD, linkpathAddr)
+}
+
+// Symlinkat implements Linux syscall symlinkat(2).
+func Symlinkat(t *kernel.Task, args arch.SyscallArguments) (uintptr, *kernel.SyscallControl, error) {
+ targetAddr := args[0].Pointer()
+ newdirfd := args[1].Int()
+ linkpathAddr := args[2].Pointer()
+ return 0, nil, symlinkat(t, targetAddr, newdirfd, linkpathAddr)
+}
+
+func symlinkat(t *kernel.Task, targetAddr usermem.Addr, newdirfd int32, linkpathAddr usermem.Addr) error {
+ target, err := t.CopyInString(targetAddr, linux.PATH_MAX)
+ if err != nil {
+ return err
+ }
+ linkpath, err := copyInPath(t, linkpathAddr)
+ if err != nil {
+ return err
+ }
+ tpop, err := getTaskPathOperation(t, newdirfd, linkpath, disallowEmptyPath, nofollowFinalSymlink)
+ if err != nil {
+ return err
+ }
+ defer tpop.Release()
+ return t.Kernel().VFS().SymlinkAt(t, t.Credentials(), &tpop.pop, target)
+}
diff --git a/pkg/sentry/syscalls/linux/vfs2/fscontext.go b/pkg/sentry/syscalls/linux/vfs2/fscontext.go
new file mode 100644
index 000000000..317409a18
--- /dev/null
+++ b/pkg/sentry/syscalls/linux/vfs2/fscontext.go
@@ -0,0 +1,131 @@
+// Copyright 2020 The gVisor Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+package vfs2
+
+import (
+ "gvisor.dev/gvisor/pkg/abi/linux"
+ "gvisor.dev/gvisor/pkg/fspath"
+ "gvisor.dev/gvisor/pkg/sentry/arch"
+ "gvisor.dev/gvisor/pkg/sentry/kernel"
+ "gvisor.dev/gvisor/pkg/sentry/vfs"
+ "gvisor.dev/gvisor/pkg/syserror"
+)
+
+// Getcwd implements Linux syscall getcwd(2).
+func Getcwd(t *kernel.Task, args arch.SyscallArguments) (uintptr, *kernel.SyscallControl, error) {
+ addr := args[0].Pointer()
+ size := args[1].SizeT()
+
+ root := t.FSContext().RootDirectoryVFS2()
+ wd := t.FSContext().WorkingDirectoryVFS2()
+ s, err := t.Kernel().VFS().PathnameForGetcwd(t, root, wd)
+ root.DecRef()
+ wd.DecRef()
+ if err != nil {
+ return 0, nil, err
+ }
+
+ // Note this is >= because we need a terminator.
+ if uint(len(s)) >= size {
+ return 0, nil, syserror.ERANGE
+ }
+
+ // Construct a byte slice containing a NUL terminator.
+ buf := t.CopyScratchBuffer(len(s) + 1)
+ copy(buf, s)
+ buf[len(buf)-1] = 0
+
+ // Write the pathname slice.
+ n, err := t.CopyOutBytes(addr, buf)
+ if err != nil {
+ return 0, nil, err
+ }
+ return uintptr(n), nil, nil
+}
+
+// Chdir implements Linux syscall chdir(2).
+func Chdir(t *kernel.Task, args arch.SyscallArguments) (uintptr, *kernel.SyscallControl, error) {
+ addr := args[0].Pointer()
+
+ path, err := copyInPath(t, addr)
+ if err != nil {
+ return 0, nil, err
+ }
+ tpop, err := getTaskPathOperation(t, linux.AT_FDCWD, path, disallowEmptyPath, followFinalSymlink)
+ if err != nil {
+ return 0, nil, err
+ }
+ defer tpop.Release()
+
+ vd, err := t.Kernel().VFS().GetDentryAt(t, t.Credentials(), &tpop.pop, &vfs.GetDentryOptions{
+ CheckSearchable: true,
+ })
+ if err != nil {
+ return 0, nil, err
+ }
+ t.FSContext().SetWorkingDirectoryVFS2(vd)
+ vd.DecRef()
+ return 0, nil, nil
+}
+
+// Fchdir implements Linux syscall fchdir(2).
+func Fchdir(t *kernel.Task, args arch.SyscallArguments) (uintptr, *kernel.SyscallControl, error) {
+ fd := args[0].Int()
+
+ tpop, err := getTaskPathOperation(t, fd, fspath.Path{}, allowEmptyPath, nofollowFinalSymlink)
+ if err != nil {
+ return 0, nil, err
+ }
+ defer tpop.Release()
+
+ vd, err := t.Kernel().VFS().GetDentryAt(t, t.Credentials(), &tpop.pop, &vfs.GetDentryOptions{
+ CheckSearchable: true,
+ })
+ if err != nil {
+ return 0, nil, err
+ }
+ t.FSContext().SetWorkingDirectoryVFS2(vd)
+ vd.DecRef()
+ return 0, nil, nil
+}
+
+// Chroot implements Linux syscall chroot(2).
+func Chroot(t *kernel.Task, args arch.SyscallArguments) (uintptr, *kernel.SyscallControl, error) {
+ addr := args[0].Pointer()
+
+ if !t.HasCapability(linux.CAP_SYS_CHROOT) {
+ return 0, nil, syserror.EPERM
+ }
+
+ path, err := copyInPath(t, addr)
+ if err != nil {
+ return 0, nil, err
+ }
+ tpop, err := getTaskPathOperation(t, linux.AT_FDCWD, path, disallowEmptyPath, followFinalSymlink)
+ if err != nil {
+ return 0, nil, err
+ }
+ defer tpop.Release()
+
+ vd, err := t.Kernel().VFS().GetDentryAt(t, t.Credentials(), &tpop.pop, &vfs.GetDentryOptions{
+ CheckSearchable: true,
+ })
+ if err != nil {
+ return 0, nil, err
+ }
+ t.FSContext().SetRootDirectoryVFS2(vd)
+ vd.DecRef()
+ return 0, nil, nil
+}
diff --git a/pkg/sentry/syscalls/linux/vfs2/getdents.go b/pkg/sentry/syscalls/linux/vfs2/getdents.go
new file mode 100644
index 000000000..ddc140b65
--- /dev/null
+++ b/pkg/sentry/syscalls/linux/vfs2/getdents.go
@@ -0,0 +1,149 @@
+// Copyright 2020 The gVisor Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+package vfs2
+
+import (
+ "fmt"
+
+ "gvisor.dev/gvisor/pkg/sentry/arch"
+ "gvisor.dev/gvisor/pkg/sentry/kernel"
+ "gvisor.dev/gvisor/pkg/sentry/vfs"
+ "gvisor.dev/gvisor/pkg/sync"
+ "gvisor.dev/gvisor/pkg/syserror"
+ "gvisor.dev/gvisor/pkg/usermem"
+)
+
+// Getdents implements Linux syscall getdents(2).
+func Getdents(t *kernel.Task, args arch.SyscallArguments) (uintptr, *kernel.SyscallControl, error) {
+ return getdents(t, args, false /* isGetdents64 */)
+}
+
+// Getdents64 implements Linux syscall getdents64(2).
+func Getdents64(t *kernel.Task, args arch.SyscallArguments) (uintptr, *kernel.SyscallControl, error) {
+ return getdents(t, args, true /* isGetdents64 */)
+}
+
+func getdents(t *kernel.Task, args arch.SyscallArguments, isGetdents64 bool) (uintptr, *kernel.SyscallControl, error) {
+ fd := args[0].Int()
+ addr := args[1].Pointer()
+ size := int(args[2].Uint())
+
+ file := t.GetFileVFS2(fd)
+ if file == nil {
+ return 0, nil, syserror.EBADF
+ }
+ defer file.DecRef()
+
+ cb := getGetdentsCallback(t, addr, size, isGetdents64)
+ err := file.IterDirents(t, cb)
+ n := size - cb.remaining
+ putGetdentsCallback(cb)
+ if n == 0 {
+ return 0, nil, err
+ }
+ return uintptr(n), nil, nil
+}
+
+type getdentsCallback struct {
+ t *kernel.Task
+ addr usermem.Addr
+ remaining int
+ isGetdents64 bool
+}
+
+var getdentsCallbackPool = sync.Pool{
+ New: func() interface{} {
+ return &getdentsCallback{}
+ },
+}
+
+func getGetdentsCallback(t *kernel.Task, addr usermem.Addr, size int, isGetdents64 bool) *getdentsCallback {
+ cb := getdentsCallbackPool.Get().(*getdentsCallback)
+ *cb = getdentsCallback{
+ t: t,
+ addr: addr,
+ remaining: size,
+ isGetdents64: isGetdents64,
+ }
+ return cb
+}
+
+func putGetdentsCallback(cb *getdentsCallback) {
+ cb.t = nil
+ getdentsCallbackPool.Put(cb)
+}
+
+// Handle implements vfs.IterDirentsCallback.Handle.
+func (cb *getdentsCallback) Handle(dirent vfs.Dirent) error {
+ var buf []byte
+ if cb.isGetdents64 {
+ // struct linux_dirent64 {
+ // ino64_t d_ino; /* 64-bit inode number */
+ // off64_t d_off; /* 64-bit offset to next structure */
+ // unsigned short d_reclen; /* Size of this dirent */
+ // unsigned char d_type; /* File type */
+ // char d_name[]; /* Filename (null-terminated) */
+ // };
+ size := 8 + 8 + 2 + 1 + 1 + len(dirent.Name)
+ if size < cb.remaining {
+ return syserror.EINVAL
+ }
+ buf = cb.t.CopyScratchBuffer(size)
+ usermem.ByteOrder.PutUint64(buf[0:8], dirent.Ino)
+ usermem.ByteOrder.PutUint64(buf[8:16], uint64(dirent.NextOff))
+ usermem.ByteOrder.PutUint16(buf[16:18], uint16(size))
+ buf[18] = dirent.Type
+ copy(buf[19:], dirent.Name)
+ buf[size-1] = 0 // NUL terminator
+ } else {
+ // struct linux_dirent {
+ // unsigned long d_ino; /* Inode number */
+ // unsigned long d_off; /* Offset to next linux_dirent */
+ // unsigned short d_reclen; /* Length of this linux_dirent */
+ // char d_name[]; /* Filename (null-terminated) */
+ // /* length is actually (d_reclen - 2 -
+ // offsetof(struct linux_dirent, d_name)) */
+ // /*
+ // char pad; // Zero padding byte
+ // char d_type; // File type (only since Linux
+ // // 2.6.4); offset is (d_reclen - 1)
+ // */
+ // };
+ if cb.t.Arch().Width() != 8 {
+ panic(fmt.Sprintf("unsupported sizeof(unsigned long): %d", cb.t.Arch().Width()))
+ }
+ size := 8 + 8 + 2 + 1 + 1 + 1 + len(dirent.Name)
+ if size < cb.remaining {
+ return syserror.EINVAL
+ }
+ buf = cb.t.CopyScratchBuffer(size)
+ usermem.ByteOrder.PutUint64(buf[0:8], dirent.Ino)
+ usermem.ByteOrder.PutUint64(buf[8:16], uint64(dirent.NextOff))
+ usermem.ByteOrder.PutUint16(buf[16:18], uint16(size))
+ copy(buf[18:], dirent.Name)
+ buf[size-3] = 0 // NUL terminator
+ buf[size-2] = 0 // zero padding byte
+ buf[size-1] = dirent.Type
+ }
+ n, err := cb.t.CopyOutBytes(cb.addr, buf)
+ if err != nil {
+ // Don't report partially-written dirents by advancing cb.addr or
+ // cb.remaining.
+ return err
+ }
+ cb.addr += usermem.Addr(n)
+ cb.remaining -= n
+ return nil
+}
diff --git a/pkg/usermem/usermem_unsafe.go b/pkg/sentry/syscalls/linux/vfs2/ioctl.go
index 876783e78..5a2418da9 100644
--- a/pkg/usermem/usermem_unsafe.go
+++ b/pkg/sentry/syscalls/linux/vfs2/ioctl.go
@@ -1,4 +1,4 @@
-// Copyright 2019 The gVisor Authors.
+// Copyright 2020 The gVisor Authors.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
@@ -12,16 +12,24 @@
// See the License for the specific language governing permissions and
// limitations under the License.
-package usermem
+package vfs2
import (
- "unsafe"
+ "gvisor.dev/gvisor/pkg/sentry/arch"
+ "gvisor.dev/gvisor/pkg/sentry/kernel"
+ "gvisor.dev/gvisor/pkg/syserror"
)
-// stringFromImmutableBytes is equivalent to string(bs), except that it never
-// copies even if escape analysis can't prove that bs does not escape. This is
-// only valid if bs is never mutated after stringFromImmutableBytes returns.
-func stringFromImmutableBytes(bs []byte) string {
- // Compare strings.Builder.String().
- return *(*string)(unsafe.Pointer(&bs))
+// Ioctl implements Linux syscall ioctl(2).
+func Ioctl(t *kernel.Task, args arch.SyscallArguments) (uintptr, *kernel.SyscallControl, error) {
+ fd := args[0].Int()
+
+ file := t.GetFileVFS2(fd)
+ if file == nil {
+ return 0, nil, syserror.EBADF
+ }
+ defer file.DecRef()
+
+ ret, err := file.Ioctl(t, t.MemoryManager(), args)
+ return ret, nil, err
}
diff --git a/pkg/sentry/syscalls/linux/vfs2/linux64_override_amd64.go b/pkg/sentry/syscalls/linux/vfs2/linux64_override_amd64.go
index e0ac32b33..7d220bc20 100644
--- a/pkg/sentry/syscalls/linux/vfs2/linux64_override_amd64.go
+++ b/pkg/sentry/syscalls/linux/vfs2/linux64_override_amd64.go
@@ -12,6 +12,8 @@
// See the License for the specific language governing permissions and
// limitations under the License.
+// +build amd64
+
package vfs2
import (
@@ -22,110 +24,142 @@ import (
// Override syscall table to add syscalls implementations from this package.
func Override(table map[uintptr]kernel.Syscall) {
table[0] = syscalls.Supported("read", Read)
-
- // Remove syscalls that haven't been converted yet. It's better to get ENOSYS
- // rather than a SIGSEGV deep in the stack.
- delete(table, 1) // write
- delete(table, 2) // open
- delete(table, 3) // close
- delete(table, 4) // stat
- delete(table, 5) // fstat
- delete(table, 6) // lstat
- delete(table, 7) // poll
- delete(table, 8) // lseek
- delete(table, 9) // mmap
- delete(table, 16) // ioctl
- delete(table, 17) // pread64
- delete(table, 18) // pwrite64
- delete(table, 19) // readv
- delete(table, 20) // writev
- delete(table, 21) // access
- delete(table, 22) // pipe
- delete(table, 32) // dup
- delete(table, 33) // dup2
- delete(table, 40) // sendfile
- delete(table, 59) // execve
- delete(table, 72) // fcntl
- delete(table, 73) // flock
- delete(table, 74) // fsync
- delete(table, 75) // fdatasync
- delete(table, 76) // truncate
- delete(table, 77) // ftruncate
- delete(table, 78) // getdents
- delete(table, 79) // getcwd
- delete(table, 80) // chdir
- delete(table, 81) // fchdir
- delete(table, 82) // rename
- delete(table, 83) // mkdir
- delete(table, 84) // rmdir
- delete(table, 85) // creat
- delete(table, 86) // link
- delete(table, 87) // unlink
- delete(table, 88) // symlink
- delete(table, 89) // readlink
- delete(table, 90) // chmod
- delete(table, 91) // fchmod
- delete(table, 92) // chown
- delete(table, 93) // fchown
- delete(table, 94) // lchown
- delete(table, 133) // mknod
- delete(table, 137) // statfs
- delete(table, 138) // fstatfs
- delete(table, 161) // chroot
- delete(table, 162) // sync
+ table[1] = syscalls.Supported("write", Write)
+ table[2] = syscalls.Supported("open", Open)
+ table[3] = syscalls.Supported("close", Close)
+ table[4] = syscalls.Supported("stat", Stat)
+ table[5] = syscalls.Supported("fstat", Fstat)
+ table[6] = syscalls.Supported("lstat", Lstat)
+ table[7] = syscalls.Supported("poll", Poll)
+ table[8] = syscalls.Supported("lseek", Lseek)
+ table[9] = syscalls.Supported("mmap", Mmap)
+ table[16] = syscalls.Supported("ioctl", Ioctl)
+ table[17] = syscalls.Supported("pread64", Pread64)
+ table[18] = syscalls.Supported("pwrite64", Pwrite64)
+ table[19] = syscalls.Supported("readv", Readv)
+ table[20] = syscalls.Supported("writev", Writev)
+ table[21] = syscalls.Supported("access", Access)
+ delete(table, 22) // pipe
+ table[23] = syscalls.Supported("select", Select)
+ table[32] = syscalls.Supported("dup", Dup)
+ table[33] = syscalls.Supported("dup2", Dup2)
+ delete(table, 40) // sendfile
+ delete(table, 41) // socket
+ delete(table, 42) // connect
+ delete(table, 43) // accept
+ delete(table, 44) // sendto
+ delete(table, 45) // recvfrom
+ delete(table, 46) // sendmsg
+ delete(table, 47) // recvmsg
+ delete(table, 48) // shutdown
+ delete(table, 49) // bind
+ delete(table, 50) // listen
+ delete(table, 51) // getsockname
+ delete(table, 52) // getpeername
+ delete(table, 53) // socketpair
+ delete(table, 54) // setsockopt
+ delete(table, 55) // getsockopt
+ table[59] = syscalls.Supported("execve", Execve)
+ table[72] = syscalls.Supported("fcntl", Fcntl)
+ delete(table, 73) // flock
+ table[74] = syscalls.Supported("fsync", Fsync)
+ table[75] = syscalls.Supported("fdatasync", Fdatasync)
+ table[76] = syscalls.Supported("truncate", Truncate)
+ table[77] = syscalls.Supported("ftruncate", Ftruncate)
+ table[78] = syscalls.Supported("getdents", Getdents)
+ table[79] = syscalls.Supported("getcwd", Getcwd)
+ table[80] = syscalls.Supported("chdir", Chdir)
+ table[81] = syscalls.Supported("fchdir", Fchdir)
+ table[82] = syscalls.Supported("rename", Rename)
+ table[83] = syscalls.Supported("mkdir", Mkdir)
+ table[84] = syscalls.Supported("rmdir", Rmdir)
+ table[85] = syscalls.Supported("creat", Creat)
+ table[86] = syscalls.Supported("link", Link)
+ table[87] = syscalls.Supported("unlink", Unlink)
+ table[88] = syscalls.Supported("symlink", Symlink)
+ table[89] = syscalls.Supported("readlink", Readlink)
+ table[90] = syscalls.Supported("chmod", Chmod)
+ table[91] = syscalls.Supported("fchmod", Fchmod)
+ table[92] = syscalls.Supported("chown", Chown)
+ table[93] = syscalls.Supported("fchown", Fchown)
+ table[94] = syscalls.Supported("lchown", Lchown)
+ table[132] = syscalls.Supported("utime", Utime)
+ table[133] = syscalls.Supported("mknod", Mknod)
+ table[137] = syscalls.Supported("statfs", Statfs)
+ table[138] = syscalls.Supported("fstatfs", Fstatfs)
+ table[161] = syscalls.Supported("chroot", Chroot)
+ table[162] = syscalls.Supported("sync", Sync)
delete(table, 165) // mount
delete(table, 166) // umount2
- delete(table, 172) // iopl
- delete(table, 173) // ioperm
delete(table, 187) // readahead
- delete(table, 188) // setxattr
- delete(table, 189) // lsetxattr
- delete(table, 190) // fsetxattr
- delete(table, 191) // getxattr
- delete(table, 192) // lgetxattr
- delete(table, 193) // fgetxattr
+ table[188] = syscalls.Supported("setxattr", Setxattr)
+ table[189] = syscalls.Supported("lsetxattr", Lsetxattr)
+ table[190] = syscalls.Supported("fsetxattr", Fsetxattr)
+ table[191] = syscalls.Supported("getxattr", Getxattr)
+ table[192] = syscalls.Supported("lgetxattr", Lgetxattr)
+ table[193] = syscalls.Supported("fgetxattr", Fgetxattr)
+ table[194] = syscalls.Supported("listxattr", Listxattr)
+ table[195] = syscalls.Supported("llistxattr", Llistxattr)
+ table[196] = syscalls.Supported("flistxattr", Flistxattr)
+ table[197] = syscalls.Supported("removexattr", Removexattr)
+ table[198] = syscalls.Supported("lremovexattr", Lremovexattr)
+ table[199] = syscalls.Supported("fremovexattr", Fremovexattr)
delete(table, 206) // io_setup
delete(table, 207) // io_destroy
delete(table, 208) // io_getevents
delete(table, 209) // io_submit
delete(table, 210) // io_cancel
- delete(table, 213) // epoll_create
- delete(table, 214) // epoll_ctl_old
- delete(table, 215) // epoll_wait_old
- delete(table, 216) // remap_file_pages
- delete(table, 217) // getdents64
- delete(table, 232) // epoll_wait
- delete(table, 233) // epoll_ctl
+ table[213] = syscalls.Supported("epoll_create", EpollCreate)
+ table[217] = syscalls.Supported("getdents64", Getdents64)
+ delete(table, 221) // fdavise64
+ table[232] = syscalls.Supported("epoll_wait", EpollWait)
+ table[233] = syscalls.Supported("epoll_ctl", EpollCtl)
+ table[235] = syscalls.Supported("utimes", Utimes)
delete(table, 253) // inotify_init
delete(table, 254) // inotify_add_watch
delete(table, 255) // inotify_rm_watch
- delete(table, 257) // openat
- delete(table, 258) // mkdirat
- delete(table, 259) // mknodat
- delete(table, 260) // fchownat
- delete(table, 261) // futimesat
- delete(table, 262) // fstatat
- delete(table, 263) // unlinkat
- delete(table, 264) // renameat
- delete(table, 265) // linkat
- delete(table, 266) // symlinkat
- delete(table, 267) // readlinkat
- delete(table, 268) // fchmodat
- delete(table, 269) // faccessat
- delete(table, 270) // pselect
- delete(table, 271) // ppoll
+ table[257] = syscalls.Supported("openat", Openat)
+ table[258] = syscalls.Supported("mkdirat", Mkdirat)
+ table[259] = syscalls.Supported("mknodat", Mknodat)
+ table[260] = syscalls.Supported("fchownat", Fchownat)
+ table[261] = syscalls.Supported("futimens", Futimens)
+ table[262] = syscalls.Supported("newfstatat", Newfstatat)
+ table[263] = syscalls.Supported("unlinkat", Unlinkat)
+ table[264] = syscalls.Supported("renameat", Renameat)
+ table[265] = syscalls.Supported("linkat", Linkat)
+ table[266] = syscalls.Supported("symlinkat", Symlinkat)
+ table[267] = syscalls.Supported("readlinkat", Readlinkat)
+ table[268] = syscalls.Supported("fchmodat", Fchmodat)
+ table[269] = syscalls.Supported("faccessat", Faccessat)
+ table[270] = syscalls.Supported("pselect", Pselect)
+ table[271] = syscalls.Supported("ppoll", Ppoll)
+ delete(table, 275) // splice
+ delete(table, 276) // tee
+ table[277] = syscalls.Supported("sync_file_range", SyncFileRange)
+ table[280] = syscalls.Supported("utimensat", Utimensat)
+ table[281] = syscalls.Supported("epoll_pwait", EpollPwait)
+ delete(table, 282) // signalfd
+ delete(table, 283) // timerfd_create
+ delete(table, 284) // eventfd
delete(table, 285) // fallocate
- delete(table, 291) // epoll_create1
- delete(table, 292) // dup3
+ delete(table, 286) // timerfd_settime
+ delete(table, 287) // timerfd_gettime
+ delete(table, 288) // accept4
+ delete(table, 289) // signalfd4
+ delete(table, 290) // eventfd2
+ table[291] = syscalls.Supported("epoll_create1", EpollCreate1)
+ table[292] = syscalls.Supported("dup3", Dup3)
delete(table, 293) // pipe2
delete(table, 294) // inotify_init1
- delete(table, 295) // preadv
- delete(table, 296) // pwritev
- delete(table, 306) // syncfs
- delete(table, 316) // renameat2
+ table[295] = syscalls.Supported("preadv", Preadv)
+ table[296] = syscalls.Supported("pwritev", Pwritev)
+ delete(table, 299) // recvmmsg
+ table[306] = syscalls.Supported("syncfs", Syncfs)
+ delete(table, 307) // sendmmsg
+ table[316] = syscalls.Supported("renameat2", Renameat2)
delete(table, 319) // memfd_create
- delete(table, 322) // execveat
- delete(table, 327) // preadv2
- delete(table, 328) // pwritev2
- delete(table, 332) // statx
+ table[322] = syscalls.Supported("execveat", Execveat)
+ table[327] = syscalls.Supported("preadv2", Preadv2)
+ table[328] = syscalls.Supported("pwritev2", Pwritev2)
+ table[332] = syscalls.Supported("statx", Statx)
}
diff --git a/pkg/sentry/syscalls/linux/vfs2/linux64_override_arm64.go b/pkg/sentry/syscalls/linux/vfs2/linux64_override_arm64.go
index 6af5c400f..a6b367468 100644
--- a/pkg/sentry/syscalls/linux/vfs2/linux64_override_arm64.go
+++ b/pkg/sentry/syscalls/linux/vfs2/linux64_override_arm64.go
@@ -12,6 +12,8 @@
// See the License for the specific language governing permissions and
// limitations under the License.
+// +build arm64
+
package vfs2
import (
diff --git a/pkg/sentry/syscalls/linux/vfs2/mmap.go b/pkg/sentry/syscalls/linux/vfs2/mmap.go
new file mode 100644
index 000000000..60a43f0a0
--- /dev/null
+++ b/pkg/sentry/syscalls/linux/vfs2/mmap.go
@@ -0,0 +1,92 @@
+// Copyright 2020 The gVisor Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+package vfs2
+
+import (
+ "gvisor.dev/gvisor/pkg/abi/linux"
+ "gvisor.dev/gvisor/pkg/sentry/arch"
+ "gvisor.dev/gvisor/pkg/sentry/kernel"
+ "gvisor.dev/gvisor/pkg/sentry/memmap"
+ "gvisor.dev/gvisor/pkg/syserror"
+ "gvisor.dev/gvisor/pkg/usermem"
+)
+
+// Mmap implements Linux syscall mmap(2).
+func Mmap(t *kernel.Task, args arch.SyscallArguments) (uintptr, *kernel.SyscallControl, error) {
+ prot := args[2].Int()
+ flags := args[3].Int()
+ fd := args[4].Int()
+ fixed := flags&linux.MAP_FIXED != 0
+ private := flags&linux.MAP_PRIVATE != 0
+ shared := flags&linux.MAP_SHARED != 0
+ anon := flags&linux.MAP_ANONYMOUS != 0
+ map32bit := flags&linux.MAP_32BIT != 0
+
+ // Require exactly one of MAP_PRIVATE and MAP_SHARED.
+ if private == shared {
+ return 0, nil, syserror.EINVAL
+ }
+
+ opts := memmap.MMapOpts{
+ Length: args[1].Uint64(),
+ Offset: args[5].Uint64(),
+ Addr: args[0].Pointer(),
+ Fixed: fixed,
+ Unmap: fixed,
+ Map32Bit: map32bit,
+ Private: private,
+ Perms: usermem.AccessType{
+ Read: linux.PROT_READ&prot != 0,
+ Write: linux.PROT_WRITE&prot != 0,
+ Execute: linux.PROT_EXEC&prot != 0,
+ },
+ MaxPerms: usermem.AnyAccess,
+ GrowsDown: linux.MAP_GROWSDOWN&flags != 0,
+ Precommit: linux.MAP_POPULATE&flags != 0,
+ }
+ if linux.MAP_LOCKED&flags != 0 {
+ opts.MLockMode = memmap.MLockEager
+ }
+ defer func() {
+ if opts.MappingIdentity != nil {
+ opts.MappingIdentity.DecRef()
+ }
+ }()
+
+ if !anon {
+ // Convert the passed FD to a file reference.
+ file := t.GetFileVFS2(fd)
+ if file == nil {
+ return 0, nil, syserror.EBADF
+ }
+ defer file.DecRef()
+
+ // mmap unconditionally requires that the FD is readable.
+ if !file.IsReadable() {
+ return 0, nil, syserror.EACCES
+ }
+ // MAP_SHARED requires that the FD be writable for PROT_WRITE.
+ if shared && !file.IsWritable() {
+ opts.MaxPerms.Write = false
+ }
+
+ if err := file.ConfigureMMap(t, &opts); err != nil {
+ return 0, nil, err
+ }
+ }
+
+ rv, err := t.MemoryManager().MMap(t, opts)
+ return uintptr(rv), nil, err
+}
diff --git a/pkg/sentry/syscalls/linux/vfs2/path.go b/pkg/sentry/syscalls/linux/vfs2/path.go
new file mode 100644
index 000000000..97da6c647
--- /dev/null
+++ b/pkg/sentry/syscalls/linux/vfs2/path.go
@@ -0,0 +1,94 @@
+// Copyright 2020 The gVisor Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+package vfs2
+
+import (
+ "gvisor.dev/gvisor/pkg/abi/linux"
+ "gvisor.dev/gvisor/pkg/fspath"
+ "gvisor.dev/gvisor/pkg/sentry/kernel"
+ "gvisor.dev/gvisor/pkg/sentry/vfs"
+ "gvisor.dev/gvisor/pkg/syserror"
+ "gvisor.dev/gvisor/pkg/usermem"
+)
+
+func copyInPath(t *kernel.Task, addr usermem.Addr) (fspath.Path, error) {
+ pathname, err := t.CopyInString(addr, linux.PATH_MAX)
+ if err != nil {
+ return fspath.Path{}, err
+ }
+ return fspath.Parse(pathname), nil
+}
+
+type taskPathOperation struct {
+ pop vfs.PathOperation
+ haveStartRef bool
+}
+
+func getTaskPathOperation(t *kernel.Task, dirfd int32, path fspath.Path, shouldAllowEmptyPath shouldAllowEmptyPath, shouldFollowFinalSymlink shouldFollowFinalSymlink) (taskPathOperation, error) {
+ root := t.FSContext().RootDirectoryVFS2()
+ start := root
+ haveStartRef := false
+ if !path.Absolute {
+ if !path.HasComponents() && !bool(shouldAllowEmptyPath) {
+ root.DecRef()
+ return taskPathOperation{}, syserror.ENOENT
+ }
+ if dirfd == linux.AT_FDCWD {
+ start = t.FSContext().WorkingDirectoryVFS2()
+ haveStartRef = true
+ } else {
+ dirfile := t.GetFileVFS2(dirfd)
+ if dirfile == nil {
+ root.DecRef()
+ return taskPathOperation{}, syserror.EBADF
+ }
+ start = dirfile.VirtualDentry()
+ start.IncRef()
+ haveStartRef = true
+ dirfile.DecRef()
+ }
+ }
+ return taskPathOperation{
+ pop: vfs.PathOperation{
+ Root: root,
+ Start: start,
+ Path: path,
+ FollowFinalSymlink: bool(shouldFollowFinalSymlink),
+ },
+ haveStartRef: haveStartRef,
+ }, nil
+}
+
+func (tpop *taskPathOperation) Release() {
+ tpop.pop.Root.DecRef()
+ if tpop.haveStartRef {
+ tpop.pop.Start.DecRef()
+ tpop.haveStartRef = false
+ }
+}
+
+type shouldAllowEmptyPath bool
+
+const (
+ disallowEmptyPath shouldAllowEmptyPath = false
+ allowEmptyPath shouldAllowEmptyPath = true
+)
+
+type shouldFollowFinalSymlink bool
+
+const (
+ nofollowFinalSymlink shouldFollowFinalSymlink = false
+ followFinalSymlink shouldFollowFinalSymlink = true
+)
diff --git a/pkg/sentry/syscalls/linux/vfs2/poll.go b/pkg/sentry/syscalls/linux/vfs2/poll.go
new file mode 100644
index 000000000..dbf4882da
--- /dev/null
+++ b/pkg/sentry/syscalls/linux/vfs2/poll.go
@@ -0,0 +1,584 @@
+// Copyright 2020 The gVisor Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+package vfs2
+
+import (
+ "fmt"
+ "time"
+
+ "gvisor.dev/gvisor/pkg/abi/linux"
+ "gvisor.dev/gvisor/pkg/sentry/arch"
+ "gvisor.dev/gvisor/pkg/sentry/kernel"
+ ktime "gvisor.dev/gvisor/pkg/sentry/kernel/time"
+ "gvisor.dev/gvisor/pkg/sentry/limits"
+ "gvisor.dev/gvisor/pkg/sentry/vfs"
+ "gvisor.dev/gvisor/pkg/syserror"
+ "gvisor.dev/gvisor/pkg/usermem"
+ "gvisor.dev/gvisor/pkg/waiter"
+)
+
+// fileCap is the maximum allowable files for poll & select. This has no
+// equivalent in Linux; it exists in gVisor since allocation failure in Go is
+// unrecoverable.
+const fileCap = 1024 * 1024
+
+// Masks for "readable", "writable", and "exceptional" events as defined by
+// select(2).
+const (
+ // selectReadEvents is analogous to the Linux kernel's
+ // fs/select.c:POLLIN_SET.
+ selectReadEvents = linux.POLLIN | linux.POLLHUP | linux.POLLERR
+
+ // selectWriteEvents is analogous to the Linux kernel's
+ // fs/select.c:POLLOUT_SET.
+ selectWriteEvents = linux.POLLOUT | linux.POLLERR
+
+ // selectExceptEvents is analogous to the Linux kernel's
+ // fs/select.c:POLLEX_SET.
+ selectExceptEvents = linux.POLLPRI
+)
+
+// pollState tracks the associated file description and waiter of a PollFD.
+type pollState struct {
+ file *vfs.FileDescription
+ waiter waiter.Entry
+}
+
+// initReadiness gets the current ready mask for the file represented by the FD
+// stored in pfd.FD. If a channel is passed in, the waiter entry in "state" is
+// used to register with the file for event notifications, and a reference to
+// the file is stored in "state".
+func initReadiness(t *kernel.Task, pfd *linux.PollFD, state *pollState, ch chan struct{}) {
+ if pfd.FD < 0 {
+ pfd.REvents = 0
+ return
+ }
+
+ file := t.GetFileVFS2(pfd.FD)
+ if file == nil {
+ pfd.REvents = linux.POLLNVAL
+ return
+ }
+
+ if ch == nil {
+ defer file.DecRef()
+ } else {
+ state.file = file
+ state.waiter, _ = waiter.NewChannelEntry(ch)
+ file.EventRegister(&state.waiter, waiter.EventMaskFromLinux(uint32(pfd.Events)))
+ }
+
+ r := file.Readiness(waiter.EventMaskFromLinux(uint32(pfd.Events)))
+ pfd.REvents = int16(r.ToLinux()) & pfd.Events
+}
+
+// releaseState releases all the pollState in "state".
+func releaseState(state []pollState) {
+ for i := range state {
+ if state[i].file != nil {
+ state[i].file.EventUnregister(&state[i].waiter)
+ state[i].file.DecRef()
+ }
+ }
+}
+
+// pollBlock polls the PollFDs in "pfd" with a bounded time specified in "timeout"
+// when "timeout" is greater than zero.
+//
+// pollBlock returns the remaining timeout, which is always 0 on a timeout; and 0 or
+// positive if interrupted by a signal.
+func pollBlock(t *kernel.Task, pfd []linux.PollFD, timeout time.Duration) (time.Duration, uintptr, error) {
+ var ch chan struct{}
+ if timeout != 0 {
+ ch = make(chan struct{}, 1)
+ }
+
+ // Register for event notification in the files involved if we may
+ // block (timeout not zero). Once we find a file that has a non-zero
+ // result, we stop registering for events but still go through all files
+ // to get their ready masks.
+ state := make([]pollState, len(pfd))
+ defer releaseState(state)
+ n := uintptr(0)
+ for i := range pfd {
+ initReadiness(t, &pfd[i], &state[i], ch)
+ if pfd[i].REvents != 0 {
+ n++
+ ch = nil
+ }
+ }
+
+ if timeout == 0 {
+ return timeout, n, nil
+ }
+
+ haveTimeout := timeout >= 0
+
+ for n == 0 {
+ var err error
+ // Wait for a notification.
+ timeout, err = t.BlockWithTimeout(ch, haveTimeout, timeout)
+ if err != nil {
+ if err == syserror.ETIMEDOUT {
+ err = nil
+ }
+ return timeout, 0, err
+ }
+
+ // We got notified, count how many files are ready. If none,
+ // then this was a spurious notification, and we just go back
+ // to sleep with the remaining timeout.
+ for i := range state {
+ if state[i].file == nil {
+ continue
+ }
+
+ r := state[i].file.Readiness(waiter.EventMaskFromLinux(uint32(pfd[i].Events)))
+ rl := int16(r.ToLinux()) & pfd[i].Events
+ if rl != 0 {
+ pfd[i].REvents = rl
+ n++
+ }
+ }
+ }
+
+ return timeout, n, nil
+}
+
+// copyInPollFDs copies an array of struct pollfd unless nfds exceeds the max.
+func copyInPollFDs(t *kernel.Task, addr usermem.Addr, nfds uint) ([]linux.PollFD, error) {
+ if uint64(nfds) > t.ThreadGroup().Limits().GetCapped(limits.NumberOfFiles, fileCap) {
+ return nil, syserror.EINVAL
+ }
+
+ pfd := make([]linux.PollFD, nfds)
+ if nfds > 0 {
+ if _, err := t.CopyIn(addr, &pfd); err != nil {
+ return nil, err
+ }
+ }
+
+ return pfd, nil
+}
+
+func doPoll(t *kernel.Task, addr usermem.Addr, nfds uint, timeout time.Duration) (time.Duration, uintptr, error) {
+ pfd, err := copyInPollFDs(t, addr, nfds)
+ if err != nil {
+ return timeout, 0, err
+ }
+
+ // Compatibility warning: Linux adds POLLHUP and POLLERR just before
+ // polling, in fs/select.c:do_pollfd(). Since pfd is copied out after
+ // polling, changing event masks here is an application-visible difference.
+ // (Linux also doesn't copy out event masks at all, only revents.)
+ for i := range pfd {
+ pfd[i].Events |= linux.POLLHUP | linux.POLLERR
+ }
+ remainingTimeout, n, err := pollBlock(t, pfd, timeout)
+ err = syserror.ConvertIntr(err, syserror.EINTR)
+
+ // The poll entries are copied out regardless of whether
+ // any are set or not. This aligns with the Linux behavior.
+ if nfds > 0 && err == nil {
+ if _, err := t.CopyOut(addr, pfd); err != nil {
+ return remainingTimeout, 0, err
+ }
+ }
+
+ return remainingTimeout, n, err
+}
+
+// CopyInFDSet copies an fd set from select(2)/pselect(2).
+func CopyInFDSet(t *kernel.Task, addr usermem.Addr, nBytes, nBitsInLastPartialByte int) ([]byte, error) {
+ set := make([]byte, nBytes)
+
+ if addr != 0 {
+ if _, err := t.CopyIn(addr, &set); err != nil {
+ return nil, err
+ }
+ // If we only use part of the last byte, mask out the extraneous bits.
+ //
+ // N.B. This only works on little-endian architectures.
+ if nBitsInLastPartialByte != 0 {
+ set[nBytes-1] &^= byte(0xff) << nBitsInLastPartialByte
+ }
+ }
+ return set, nil
+}
+
+func doSelect(t *kernel.Task, nfds int, readFDs, writeFDs, exceptFDs usermem.Addr, timeout time.Duration) (uintptr, error) {
+ if nfds < 0 || nfds > fileCap {
+ return 0, syserror.EINVAL
+ }
+
+ // Calculate the size of the fd sets (one bit per fd).
+ nBytes := (nfds + 7) / 8
+ nBitsInLastPartialByte := nfds % 8
+
+ // Capture all the provided input vectors.
+ r, err := CopyInFDSet(t, readFDs, nBytes, nBitsInLastPartialByte)
+ if err != nil {
+ return 0, err
+ }
+ w, err := CopyInFDSet(t, writeFDs, nBytes, nBitsInLastPartialByte)
+ if err != nil {
+ return 0, err
+ }
+ e, err := CopyInFDSet(t, exceptFDs, nBytes, nBitsInLastPartialByte)
+ if err != nil {
+ return 0, err
+ }
+
+ // Count how many FDs are actually being requested so that we can build
+ // a PollFD array.
+ fdCount := 0
+ for i := 0; i < nBytes; i++ {
+ v := r[i] | w[i] | e[i]
+ for v != 0 {
+ v &= (v - 1)
+ fdCount++
+ }
+ }
+
+ // Build the PollFD array.
+ pfd := make([]linux.PollFD, 0, fdCount)
+ var fd int32
+ for i := 0; i < nBytes; i++ {
+ rV, wV, eV := r[i], w[i], e[i]
+ v := rV | wV | eV
+ m := byte(1)
+ for j := 0; j < 8; j++ {
+ if (v & m) != 0 {
+ // Make sure the fd is valid and decrement the reference
+ // immediately to ensure we don't leak. Note, another thread
+ // might be about to close fd. This is racy, but that's
+ // OK. Linux is racy in the same way.
+ file := t.GetFileVFS2(fd)
+ if file == nil {
+ return 0, syserror.EBADF
+ }
+ file.DecRef()
+
+ var mask int16
+ if (rV & m) != 0 {
+ mask |= selectReadEvents
+ }
+
+ if (wV & m) != 0 {
+ mask |= selectWriteEvents
+ }
+
+ if (eV & m) != 0 {
+ mask |= selectExceptEvents
+ }
+
+ pfd = append(pfd, linux.PollFD{
+ FD: fd,
+ Events: mask,
+ })
+ }
+
+ fd++
+ m <<= 1
+ }
+ }
+
+ // Do the syscall, then count the number of bits set.
+ if _, _, err = pollBlock(t, pfd, timeout); err != nil {
+ return 0, syserror.ConvertIntr(err, syserror.EINTR)
+ }
+
+ // r, w, and e are currently event mask bitsets; unset bits corresponding
+ // to events that *didn't* occur.
+ bitSetCount := uintptr(0)
+ for idx := range pfd {
+ events := pfd[idx].REvents
+ i, j := pfd[idx].FD/8, uint(pfd[idx].FD%8)
+ m := byte(1) << j
+ if r[i]&m != 0 {
+ if (events & selectReadEvents) != 0 {
+ bitSetCount++
+ } else {
+ r[i] &^= m
+ }
+ }
+ if w[i]&m != 0 {
+ if (events & selectWriteEvents) != 0 {
+ bitSetCount++
+ } else {
+ w[i] &^= m
+ }
+ }
+ if e[i]&m != 0 {
+ if (events & selectExceptEvents) != 0 {
+ bitSetCount++
+ } else {
+ e[i] &^= m
+ }
+ }
+ }
+
+ // Copy updated vectors back.
+ if readFDs != 0 {
+ if _, err := t.CopyOut(readFDs, r); err != nil {
+ return 0, err
+ }
+ }
+
+ if writeFDs != 0 {
+ if _, err := t.CopyOut(writeFDs, w); err != nil {
+ return 0, err
+ }
+ }
+
+ if exceptFDs != 0 {
+ if _, err := t.CopyOut(exceptFDs, e); err != nil {
+ return 0, err
+ }
+ }
+
+ return bitSetCount, nil
+}
+
+// timeoutRemaining returns the amount of time remaining for the specified
+// timeout or 0 if it has elapsed.
+//
+// startNs must be from CLOCK_MONOTONIC.
+func timeoutRemaining(t *kernel.Task, startNs ktime.Time, timeout time.Duration) time.Duration {
+ now := t.Kernel().MonotonicClock().Now()
+ remaining := timeout - now.Sub(startNs)
+ if remaining < 0 {
+ remaining = 0
+ }
+ return remaining
+}
+
+// copyOutTimespecRemaining copies the time remaining in timeout to timespecAddr.
+//
+// startNs must be from CLOCK_MONOTONIC.
+func copyOutTimespecRemaining(t *kernel.Task, startNs ktime.Time, timeout time.Duration, timespecAddr usermem.Addr) error {
+ if timeout <= 0 {
+ return nil
+ }
+ remaining := timeoutRemaining(t, startNs, timeout)
+ tsRemaining := linux.NsecToTimespec(remaining.Nanoseconds())
+ return tsRemaining.CopyOut(t, timespecAddr)
+}
+
+// copyOutTimevalRemaining copies the time remaining in timeout to timevalAddr.
+//
+// startNs must be from CLOCK_MONOTONIC.
+func copyOutTimevalRemaining(t *kernel.Task, startNs ktime.Time, timeout time.Duration, timevalAddr usermem.Addr) error {
+ if timeout <= 0 {
+ return nil
+ }
+ remaining := timeoutRemaining(t, startNs, timeout)
+ tvRemaining := linux.NsecToTimeval(remaining.Nanoseconds())
+ return tvRemaining.CopyOut(t, timevalAddr)
+}
+
+// pollRestartBlock encapsulates the state required to restart poll(2) via
+// restart_syscall(2).
+//
+// +stateify savable
+type pollRestartBlock struct {
+ pfdAddr usermem.Addr
+ nfds uint
+ timeout time.Duration
+}
+
+// Restart implements kernel.SyscallRestartBlock.Restart.
+func (p *pollRestartBlock) Restart(t *kernel.Task) (uintptr, error) {
+ return poll(t, p.pfdAddr, p.nfds, p.timeout)
+}
+
+func poll(t *kernel.Task, pfdAddr usermem.Addr, nfds uint, timeout time.Duration) (uintptr, error) {
+ remainingTimeout, n, err := doPoll(t, pfdAddr, nfds, timeout)
+ // On an interrupt poll(2) is restarted with the remaining timeout.
+ if err == syserror.EINTR {
+ t.SetSyscallRestartBlock(&pollRestartBlock{
+ pfdAddr: pfdAddr,
+ nfds: nfds,
+ timeout: remainingTimeout,
+ })
+ return 0, kernel.ERESTART_RESTARTBLOCK
+ }
+ return n, err
+}
+
+// Poll implements linux syscall poll(2).
+func Poll(t *kernel.Task, args arch.SyscallArguments) (uintptr, *kernel.SyscallControl, error) {
+ pfdAddr := args[0].Pointer()
+ nfds := uint(args[1].Uint()) // poll(2) uses unsigned long.
+ timeout := time.Duration(args[2].Int()) * time.Millisecond
+ n, err := poll(t, pfdAddr, nfds, timeout)
+ return n, nil, err
+}
+
+// Ppoll implements linux syscall ppoll(2).
+func Ppoll(t *kernel.Task, args arch.SyscallArguments) (uintptr, *kernel.SyscallControl, error) {
+ pfdAddr := args[0].Pointer()
+ nfds := uint(args[1].Uint()) // poll(2) uses unsigned long.
+ timespecAddr := args[2].Pointer()
+ maskAddr := args[3].Pointer()
+ maskSize := uint(args[4].Uint())
+
+ timeout, err := copyTimespecInToDuration(t, timespecAddr)
+ if err != nil {
+ return 0, nil, err
+ }
+
+ var startNs ktime.Time
+ if timeout > 0 {
+ startNs = t.Kernel().MonotonicClock().Now()
+ }
+
+ if err := setTempSignalSet(t, maskAddr, maskSize); err != nil {
+ return 0, nil, err
+ }
+
+ _, n, err := doPoll(t, pfdAddr, nfds, timeout)
+ copyErr := copyOutTimespecRemaining(t, startNs, timeout, timespecAddr)
+ // doPoll returns EINTR if interrupted, but ppoll is normally restartable
+ // if interrupted by something other than a signal handled by the
+ // application (i.e. returns ERESTARTNOHAND). However, if
+ // copyOutTimespecRemaining failed, then the restarted ppoll would use the
+ // wrong timeout, so the error should be left as EINTR.
+ //
+ // Note that this means that if err is nil but copyErr is not, copyErr is
+ // ignored. This is consistent with Linux.
+ if err == syserror.EINTR && copyErr == nil {
+ err = kernel.ERESTARTNOHAND
+ }
+ return n, nil, err
+}
+
+// Select implements linux syscall select(2).
+func Select(t *kernel.Task, args arch.SyscallArguments) (uintptr, *kernel.SyscallControl, error) {
+ nfds := int(args[0].Int()) // select(2) uses an int.
+ readFDs := args[1].Pointer()
+ writeFDs := args[2].Pointer()
+ exceptFDs := args[3].Pointer()
+ timevalAddr := args[4].Pointer()
+
+ // Use a negative Duration to indicate "no timeout".
+ timeout := time.Duration(-1)
+ if timevalAddr != 0 {
+ var timeval linux.Timeval
+ if err := timeval.CopyIn(t, timevalAddr); err != nil {
+ return 0, nil, err
+ }
+ if timeval.Sec < 0 || timeval.Usec < 0 {
+ return 0, nil, syserror.EINVAL
+ }
+ timeout = time.Duration(timeval.ToNsecCapped())
+ }
+ startNs := t.Kernel().MonotonicClock().Now()
+ n, err := doSelect(t, nfds, readFDs, writeFDs, exceptFDs, timeout)
+ copyErr := copyOutTimevalRemaining(t, startNs, timeout, timevalAddr)
+ // See comment in Ppoll.
+ if err == syserror.EINTR && copyErr == nil {
+ err = kernel.ERESTARTNOHAND
+ }
+ return n, nil, err
+}
+
+// Pselect implements linux syscall pselect(2).
+func Pselect(t *kernel.Task, args arch.SyscallArguments) (uintptr, *kernel.SyscallControl, error) {
+ nfds := int(args[0].Int()) // select(2) uses an int.
+ readFDs := args[1].Pointer()
+ writeFDs := args[2].Pointer()
+ exceptFDs := args[3].Pointer()
+ timespecAddr := args[4].Pointer()
+ maskWithSizeAddr := args[5].Pointer()
+
+ timeout, err := copyTimespecInToDuration(t, timespecAddr)
+ if err != nil {
+ return 0, nil, err
+ }
+
+ var startNs ktime.Time
+ if timeout > 0 {
+ startNs = t.Kernel().MonotonicClock().Now()
+ }
+
+ if maskWithSizeAddr != 0 {
+ if t.Arch().Width() != 8 {
+ panic(fmt.Sprintf("unsupported sizeof(void*): %d", t.Arch().Width()))
+ }
+ var maskStruct sigSetWithSize
+ if err := maskStruct.CopyIn(t, maskWithSizeAddr); err != nil {
+ return 0, nil, err
+ }
+ if err := setTempSignalSet(t, usermem.Addr(maskStruct.sigsetAddr), uint(maskStruct.sizeofSigset)); err != nil {
+ return 0, nil, err
+ }
+ }
+
+ n, err := doSelect(t, nfds, readFDs, writeFDs, exceptFDs, timeout)
+ copyErr := copyOutTimespecRemaining(t, startNs, timeout, timespecAddr)
+ // See comment in Ppoll.
+ if err == syserror.EINTR && copyErr == nil {
+ err = kernel.ERESTARTNOHAND
+ }
+ return n, nil, err
+}
+
+// +marshal
+type sigSetWithSize struct {
+ sigsetAddr uint64
+ sizeofSigset uint64
+}
+
+// copyTimespecInToDuration copies a Timespec from the untrusted app range,
+// validates it and converts it to a Duration.
+//
+// If the Timespec is larger than what can be represented in a Duration, the
+// returned value is the maximum that Duration will allow.
+//
+// If timespecAddr is NULL, the returned value is negative.
+func copyTimespecInToDuration(t *kernel.Task, timespecAddr usermem.Addr) (time.Duration, error) {
+ // Use a negative Duration to indicate "no timeout".
+ timeout := time.Duration(-1)
+ if timespecAddr != 0 {
+ var timespec linux.Timespec
+ if err := timespec.CopyIn(t, timespecAddr); err != nil {
+ return 0, err
+ }
+ if !timespec.Valid() {
+ return 0, syserror.EINVAL
+ }
+ timeout = time.Duration(timespec.ToNsecCapped())
+ }
+ return timeout, nil
+}
+
+func setTempSignalSet(t *kernel.Task, maskAddr usermem.Addr, maskSize uint) error {
+ if maskAddr == 0 {
+ return nil
+ }
+ if maskSize != linux.SignalSetSize {
+ return syserror.EINVAL
+ }
+ var mask linux.SignalSet
+ if err := mask.CopyIn(t, maskAddr); err != nil {
+ return err
+ }
+ mask &^= kernel.UnblockableSignals
+ oldmask := t.SignalMask()
+ t.SetSignalMask(mask)
+ t.SetSavedSignalMask(oldmask)
+ return nil
+}
diff --git a/pkg/sentry/syscalls/linux/vfs2/read_write.go b/pkg/sentry/syscalls/linux/vfs2/read_write.go
new file mode 100644
index 000000000..35f6308d6
--- /dev/null
+++ b/pkg/sentry/syscalls/linux/vfs2/read_write.go
@@ -0,0 +1,511 @@
+// Copyright 2020 The gVisor Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+package vfs2
+
+import (
+ "gvisor.dev/gvisor/pkg/abi/linux"
+ "gvisor.dev/gvisor/pkg/sentry/arch"
+ "gvisor.dev/gvisor/pkg/sentry/kernel"
+ slinux "gvisor.dev/gvisor/pkg/sentry/syscalls/linux"
+ "gvisor.dev/gvisor/pkg/sentry/vfs"
+ "gvisor.dev/gvisor/pkg/syserror"
+ "gvisor.dev/gvisor/pkg/usermem"
+ "gvisor.dev/gvisor/pkg/waiter"
+)
+
+const (
+ eventMaskRead = waiter.EventIn | waiter.EventHUp | waiter.EventErr
+ eventMaskWrite = waiter.EventOut | waiter.EventHUp | waiter.EventErr
+)
+
+// Read implements Linux syscall read(2).
+func Read(t *kernel.Task, args arch.SyscallArguments) (uintptr, *kernel.SyscallControl, error) {
+ fd := args[0].Int()
+ addr := args[1].Pointer()
+ size := args[2].SizeT()
+
+ file := t.GetFileVFS2(fd)
+ if file == nil {
+ return 0, nil, syserror.EBADF
+ }
+ defer file.DecRef()
+
+ // Check that the size is legitimate.
+ si := int(size)
+ if si < 0 {
+ return 0, nil, syserror.EINVAL
+ }
+
+ // Get the destination of the read.
+ dst, err := t.SingleIOSequence(addr, si, usermem.IOOpts{
+ AddressSpaceActive: true,
+ })
+ if err != nil {
+ return 0, nil, err
+ }
+
+ n, err := read(t, file, dst, vfs.ReadOptions{})
+ t.IOUsage().AccountReadSyscall(n)
+ return uintptr(n), nil, slinux.HandleIOErrorVFS2(t, n != 0, err, kernel.ERESTARTSYS, "read", file)
+}
+
+// Readv implements Linux syscall readv(2).
+func Readv(t *kernel.Task, args arch.SyscallArguments) (uintptr, *kernel.SyscallControl, error) {
+ fd := args[0].Int()
+ addr := args[1].Pointer()
+ iovcnt := int(args[2].Int())
+
+ file := t.GetFileVFS2(fd)
+ if file == nil {
+ return 0, nil, syserror.EBADF
+ }
+ defer file.DecRef()
+
+ // Get the destination of the read.
+ dst, err := t.IovecsIOSequence(addr, iovcnt, usermem.IOOpts{
+ AddressSpaceActive: true,
+ })
+ if err != nil {
+ return 0, nil, err
+ }
+
+ n, err := read(t, file, dst, vfs.ReadOptions{})
+ t.IOUsage().AccountReadSyscall(n)
+ return uintptr(n), nil, slinux.HandleIOErrorVFS2(t, n != 0, err, kernel.ERESTARTSYS, "readv", file)
+}
+
+func read(t *kernel.Task, file *vfs.FileDescription, dst usermem.IOSequence, opts vfs.ReadOptions) (int64, error) {
+ n, err := file.Read(t, dst, opts)
+ if err != syserror.ErrWouldBlock || file.StatusFlags()&linux.O_NONBLOCK != 0 {
+ return n, err
+ }
+
+ // Register for notifications.
+ w, ch := waiter.NewChannelEntry(nil)
+ file.EventRegister(&w, eventMaskRead)
+
+ total := n
+ for {
+ // Shorten dst to reflect bytes previously read.
+ dst = dst.DropFirst(int(n))
+
+ // Issue the request and break out if it completes with anything other than
+ // "would block".
+ n, err := file.Read(t, dst, opts)
+ total += n
+ if err != syserror.ErrWouldBlock {
+ break
+ }
+ if err := t.Block(ch); err != nil {
+ break
+ }
+ }
+ file.EventUnregister(&w)
+
+ return total, err
+}
+
+// Pread64 implements Linux syscall pread64(2).
+func Pread64(t *kernel.Task, args arch.SyscallArguments) (uintptr, *kernel.SyscallControl, error) {
+ fd := args[0].Int()
+ addr := args[1].Pointer()
+ size := args[2].SizeT()
+ offset := args[3].Int64()
+
+ file := t.GetFileVFS2(fd)
+ if file == nil {
+ return 0, nil, syserror.EBADF
+ }
+ defer file.DecRef()
+
+ // Check that the offset is legitimate.
+ if offset < 0 {
+ return 0, nil, syserror.EINVAL
+ }
+
+ // Check that the size is legitimate.
+ si := int(size)
+ if si < 0 {
+ return 0, nil, syserror.EINVAL
+ }
+
+ // Get the destination of the read.
+ dst, err := t.SingleIOSequence(addr, si, usermem.IOOpts{
+ AddressSpaceActive: true,
+ })
+ if err != nil {
+ return 0, nil, err
+ }
+
+ n, err := pread(t, file, dst, offset, vfs.ReadOptions{})
+ t.IOUsage().AccountReadSyscall(n)
+ return uintptr(n), nil, slinux.HandleIOErrorVFS2(t, n != 0, err, kernel.ERESTARTSYS, "pread64", file)
+}
+
+// Preadv implements Linux syscall preadv(2).
+func Preadv(t *kernel.Task, args arch.SyscallArguments) (uintptr, *kernel.SyscallControl, error) {
+ fd := args[0].Int()
+ addr := args[1].Pointer()
+ iovcnt := int(args[2].Int())
+ offset := args[3].Int64()
+
+ file := t.GetFileVFS2(fd)
+ if file == nil {
+ return 0, nil, syserror.EBADF
+ }
+ defer file.DecRef()
+
+ // Check that the offset is legitimate.
+ if offset < 0 {
+ return 0, nil, syserror.EINVAL
+ }
+
+ // Get the destination of the read.
+ dst, err := t.IovecsIOSequence(addr, iovcnt, usermem.IOOpts{
+ AddressSpaceActive: true,
+ })
+ if err != nil {
+ return 0, nil, err
+ }
+
+ n, err := pread(t, file, dst, offset, vfs.ReadOptions{})
+ t.IOUsage().AccountReadSyscall(n)
+ return uintptr(n), nil, slinux.HandleIOErrorVFS2(t, n != 0, err, kernel.ERESTARTSYS, "preadv", file)
+}
+
+// Preadv2 implements Linux syscall preadv2(2).
+func Preadv2(t *kernel.Task, args arch.SyscallArguments) (uintptr, *kernel.SyscallControl, error) {
+ // While the glibc signature is
+ // preadv2(int fd, struct iovec* iov, int iov_cnt, off_t offset, int flags)
+ // the actual syscall
+ // (https://elixir.bootlin.com/linux/v5.5/source/fs/read_write.c#L1142)
+ // splits the offset argument into a high/low value for compatibility with
+ // 32-bit architectures. The flags argument is the 6th argument (index 5).
+ fd := args[0].Int()
+ addr := args[1].Pointer()
+ iovcnt := int(args[2].Int())
+ offset := args[3].Int64()
+ flags := args[5].Int()
+
+ file := t.GetFileVFS2(fd)
+ if file == nil {
+ return 0, nil, syserror.EBADF
+ }
+ defer file.DecRef()
+
+ // Check that the offset is legitimate.
+ if offset < -1 {
+ return 0, nil, syserror.EINVAL
+ }
+
+ // Get the destination of the read.
+ dst, err := t.IovecsIOSequence(addr, iovcnt, usermem.IOOpts{
+ AddressSpaceActive: true,
+ })
+ if err != nil {
+ return 0, nil, err
+ }
+
+ opts := vfs.ReadOptions{
+ Flags: uint32(flags),
+ }
+ var n int64
+ if offset == -1 {
+ n, err = read(t, file, dst, opts)
+ } else {
+ n, err = pread(t, file, dst, offset, opts)
+ }
+ t.IOUsage().AccountReadSyscall(n)
+ return uintptr(n), nil, slinux.HandleIOErrorVFS2(t, n != 0, err, kernel.ERESTARTSYS, "preadv2", file)
+}
+
+func pread(t *kernel.Task, file *vfs.FileDescription, dst usermem.IOSequence, offset int64, opts vfs.ReadOptions) (int64, error) {
+ n, err := file.PRead(t, dst, offset, opts)
+ if err != syserror.ErrWouldBlock || file.StatusFlags()&linux.O_NONBLOCK != 0 {
+ return n, err
+ }
+
+ // Register for notifications.
+ w, ch := waiter.NewChannelEntry(nil)
+ file.EventRegister(&w, eventMaskRead)
+
+ total := n
+ for {
+ // Shorten dst to reflect bytes previously read.
+ dst = dst.DropFirst(int(n))
+
+ // Issue the request and break out if it completes with anything other than
+ // "would block".
+ n, err := file.PRead(t, dst, offset+total, opts)
+ total += n
+ if err != syserror.ErrWouldBlock {
+ break
+ }
+ if err := t.Block(ch); err != nil {
+ break
+ }
+ }
+ file.EventUnregister(&w)
+
+ return total, err
+}
+
+// Write implements Linux syscall write(2).
+func Write(t *kernel.Task, args arch.SyscallArguments) (uintptr, *kernel.SyscallControl, error) {
+ fd := args[0].Int()
+ addr := args[1].Pointer()
+ size := args[2].SizeT()
+
+ file := t.GetFileVFS2(fd)
+ if file == nil {
+ return 0, nil, syserror.EBADF
+ }
+ defer file.DecRef()
+
+ // Check that the size is legitimate.
+ si := int(size)
+ if si < 0 {
+ return 0, nil, syserror.EINVAL
+ }
+
+ // Get the source of the write.
+ src, err := t.SingleIOSequence(addr, si, usermem.IOOpts{
+ AddressSpaceActive: true,
+ })
+ if err != nil {
+ return 0, nil, err
+ }
+
+ n, err := write(t, file, src, vfs.WriteOptions{})
+ t.IOUsage().AccountWriteSyscall(n)
+ return uintptr(n), nil, slinux.HandleIOErrorVFS2(t, n != 0, err, kernel.ERESTARTSYS, "write", file)
+}
+
+// Writev implements Linux syscall writev(2).
+func Writev(t *kernel.Task, args arch.SyscallArguments) (uintptr, *kernel.SyscallControl, error) {
+ fd := args[0].Int()
+ addr := args[1].Pointer()
+ iovcnt := int(args[2].Int())
+
+ file := t.GetFileVFS2(fd)
+ if file == nil {
+ return 0, nil, syserror.EBADF
+ }
+ defer file.DecRef()
+
+ // Get the source of the write.
+ src, err := t.IovecsIOSequence(addr, iovcnt, usermem.IOOpts{
+ AddressSpaceActive: true,
+ })
+ if err != nil {
+ return 0, nil, err
+ }
+
+ n, err := write(t, file, src, vfs.WriteOptions{})
+ t.IOUsage().AccountWriteSyscall(n)
+ return uintptr(n), nil, slinux.HandleIOErrorVFS2(t, n != 0, err, kernel.ERESTARTSYS, "writev", file)
+}
+
+func write(t *kernel.Task, file *vfs.FileDescription, src usermem.IOSequence, opts vfs.WriteOptions) (int64, error) {
+ n, err := file.Write(t, src, opts)
+ if err != syserror.ErrWouldBlock || file.StatusFlags()&linux.O_NONBLOCK != 0 {
+ return n, err
+ }
+
+ // Register for notifications.
+ w, ch := waiter.NewChannelEntry(nil)
+ file.EventRegister(&w, eventMaskWrite)
+
+ total := n
+ for {
+ // Shorten src to reflect bytes previously written.
+ src = src.DropFirst(int(n))
+
+ // Issue the request and break out if it completes with anything other than
+ // "would block".
+ n, err := file.Write(t, src, opts)
+ total += n
+ if err != syserror.ErrWouldBlock {
+ break
+ }
+ if err := t.Block(ch); err != nil {
+ break
+ }
+ }
+ file.EventUnregister(&w)
+
+ return total, err
+}
+
+// Pwrite64 implements Linux syscall pwrite64(2).
+func Pwrite64(t *kernel.Task, args arch.SyscallArguments) (uintptr, *kernel.SyscallControl, error) {
+ fd := args[0].Int()
+ addr := args[1].Pointer()
+ size := args[2].SizeT()
+ offset := args[3].Int64()
+
+ file := t.GetFileVFS2(fd)
+ if file == nil {
+ return 0, nil, syserror.EBADF
+ }
+ defer file.DecRef()
+
+ // Check that the offset is legitimate.
+ if offset < 0 {
+ return 0, nil, syserror.EINVAL
+ }
+
+ // Check that the size is legitimate.
+ si := int(size)
+ if si < 0 {
+ return 0, nil, syserror.EINVAL
+ }
+
+ // Get the source of the write.
+ src, err := t.SingleIOSequence(addr, si, usermem.IOOpts{
+ AddressSpaceActive: true,
+ })
+ if err != nil {
+ return 0, nil, err
+ }
+
+ n, err := pwrite(t, file, src, offset, vfs.WriteOptions{})
+ t.IOUsage().AccountWriteSyscall(n)
+ return uintptr(n), nil, slinux.HandleIOErrorVFS2(t, n != 0, err, kernel.ERESTARTSYS, "pwrite64", file)
+}
+
+// Pwritev implements Linux syscall pwritev(2).
+func Pwritev(t *kernel.Task, args arch.SyscallArguments) (uintptr, *kernel.SyscallControl, error) {
+ fd := args[0].Int()
+ addr := args[1].Pointer()
+ iovcnt := int(args[2].Int())
+ offset := args[3].Int64()
+
+ file := t.GetFileVFS2(fd)
+ if file == nil {
+ return 0, nil, syserror.EBADF
+ }
+ defer file.DecRef()
+
+ // Check that the offset is legitimate.
+ if offset < 0 {
+ return 0, nil, syserror.EINVAL
+ }
+
+ // Get the source of the write.
+ src, err := t.IovecsIOSequence(addr, iovcnt, usermem.IOOpts{
+ AddressSpaceActive: true,
+ })
+ if err != nil {
+ return 0, nil, err
+ }
+
+ n, err := pwrite(t, file, src, offset, vfs.WriteOptions{})
+ t.IOUsage().AccountReadSyscall(n)
+ return uintptr(n), nil, slinux.HandleIOErrorVFS2(t, n != 0, err, kernel.ERESTARTSYS, "pwritev", file)
+}
+
+// Pwritev2 implements Linux syscall pwritev2(2).
+func Pwritev2(t *kernel.Task, args arch.SyscallArguments) (uintptr, *kernel.SyscallControl, error) {
+ // While the glibc signature is
+ // pwritev2(int fd, struct iovec* iov, int iov_cnt, off_t offset, int flags)
+ // the actual syscall
+ // (https://elixir.bootlin.com/linux/v5.5/source/fs/read_write.c#L1162)
+ // splits the offset argument into a high/low value for compatibility with
+ // 32-bit architectures. The flags argument is the 6th argument (index 5).
+ fd := args[0].Int()
+ addr := args[1].Pointer()
+ iovcnt := int(args[2].Int())
+ offset := args[3].Int64()
+ flags := args[5].Int()
+
+ file := t.GetFileVFS2(fd)
+ if file == nil {
+ return 0, nil, syserror.EBADF
+ }
+ defer file.DecRef()
+
+ // Check that the offset is legitimate.
+ if offset < -1 {
+ return 0, nil, syserror.EINVAL
+ }
+
+ // Get the source of the write.
+ src, err := t.IovecsIOSequence(addr, iovcnt, usermem.IOOpts{
+ AddressSpaceActive: true,
+ })
+ if err != nil {
+ return 0, nil, err
+ }
+
+ opts := vfs.WriteOptions{
+ Flags: uint32(flags),
+ }
+ var n int64
+ if offset == -1 {
+ n, err = write(t, file, src, opts)
+ } else {
+ n, err = pwrite(t, file, src, offset, opts)
+ }
+ t.IOUsage().AccountWriteSyscall(n)
+ return uintptr(n), nil, slinux.HandleIOErrorVFS2(t, n != 0, err, kernel.ERESTARTSYS, "pwritev2", file)
+}
+
+func pwrite(t *kernel.Task, file *vfs.FileDescription, src usermem.IOSequence, offset int64, opts vfs.WriteOptions) (int64, error) {
+ n, err := file.PWrite(t, src, offset, opts)
+ if err != syserror.ErrWouldBlock || file.StatusFlags()&linux.O_NONBLOCK != 0 {
+ return n, err
+ }
+
+ // Register for notifications.
+ w, ch := waiter.NewChannelEntry(nil)
+ file.EventRegister(&w, eventMaskWrite)
+
+ total := n
+ for {
+ // Shorten src to reflect bytes previously written.
+ src = src.DropFirst(int(n))
+
+ // Issue the request and break out if it completes with anything other than
+ // "would block".
+ n, err := file.PWrite(t, src, offset+total, opts)
+ total += n
+ if err != syserror.ErrWouldBlock {
+ break
+ }
+ if err := t.Block(ch); err != nil {
+ break
+ }
+ }
+ file.EventUnregister(&w)
+
+ return total, err
+}
+
+// Lseek implements Linux syscall lseek(2).
+func Lseek(t *kernel.Task, args arch.SyscallArguments) (uintptr, *kernel.SyscallControl, error) {
+ fd := args[0].Int()
+ offset := args[1].Int64()
+ whence := args[2].Int()
+
+ file := t.GetFileVFS2(fd)
+ if file == nil {
+ return 0, nil, syserror.EBADF
+ }
+ defer file.DecRef()
+
+ newoff, err := file.Seek(t, offset, whence)
+ return uintptr(newoff), nil, err
+}
diff --git a/pkg/sentry/syscalls/linux/vfs2/setstat.go b/pkg/sentry/syscalls/linux/vfs2/setstat.go
new file mode 100644
index 000000000..9250659ff
--- /dev/null
+++ b/pkg/sentry/syscalls/linux/vfs2/setstat.go
@@ -0,0 +1,380 @@
+// Copyright 2020 The gVisor Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+package vfs2
+
+import (
+ "gvisor.dev/gvisor/pkg/abi/linux"
+ "gvisor.dev/gvisor/pkg/fspath"
+ "gvisor.dev/gvisor/pkg/sentry/arch"
+ "gvisor.dev/gvisor/pkg/sentry/kernel"
+ "gvisor.dev/gvisor/pkg/sentry/kernel/auth"
+ "gvisor.dev/gvisor/pkg/sentry/vfs"
+ "gvisor.dev/gvisor/pkg/syserror"
+ "gvisor.dev/gvisor/pkg/usermem"
+)
+
+const chmodMask = 0777 | linux.S_ISUID | linux.S_ISGID | linux.S_ISVTX
+
+// Chmod implements Linux syscall chmod(2).
+func Chmod(t *kernel.Task, args arch.SyscallArguments) (uintptr, *kernel.SyscallControl, error) {
+ pathAddr := args[0].Pointer()
+ mode := args[1].ModeT()
+ return 0, nil, fchmodat(t, linux.AT_FDCWD, pathAddr, mode)
+}
+
+// Fchmodat implements Linux syscall fchmodat(2).
+func Fchmodat(t *kernel.Task, args arch.SyscallArguments) (uintptr, *kernel.SyscallControl, error) {
+ dirfd := args[0].Int()
+ pathAddr := args[1].Pointer()
+ mode := args[2].ModeT()
+ return 0, nil, fchmodat(t, dirfd, pathAddr, mode)
+}
+
+func fchmodat(t *kernel.Task, dirfd int32, pathAddr usermem.Addr, mode uint) error {
+ path, err := copyInPath(t, pathAddr)
+ if err != nil {
+ return err
+ }
+
+ return setstatat(t, dirfd, path, disallowEmptyPath, followFinalSymlink, &vfs.SetStatOptions{
+ Stat: linux.Statx{
+ Mask: linux.STATX_MODE,
+ Mode: uint16(mode & chmodMask),
+ },
+ })
+}
+
+// Fchmod implements Linux syscall fchmod(2).
+func Fchmod(t *kernel.Task, args arch.SyscallArguments) (uintptr, *kernel.SyscallControl, error) {
+ fd := args[0].Int()
+ mode := args[1].ModeT()
+
+ file := t.GetFileVFS2(fd)
+ if file == nil {
+ return 0, nil, syserror.EBADF
+ }
+ defer file.DecRef()
+
+ return 0, nil, file.SetStat(t, vfs.SetStatOptions{
+ Stat: linux.Statx{
+ Mask: linux.STATX_MODE,
+ Mode: uint16(mode & chmodMask),
+ },
+ })
+}
+
+// Chown implements Linux syscall chown(2).
+func Chown(t *kernel.Task, args arch.SyscallArguments) (uintptr, *kernel.SyscallControl, error) {
+ pathAddr := args[0].Pointer()
+ owner := args[1].Int()
+ group := args[2].Int()
+ return 0, nil, fchownat(t, linux.AT_FDCWD, pathAddr, owner, group, 0 /* flags */)
+}
+
+// Lchown implements Linux syscall lchown(2).
+func Lchown(t *kernel.Task, args arch.SyscallArguments) (uintptr, *kernel.SyscallControl, error) {
+ pathAddr := args[0].Pointer()
+ owner := args[1].Int()
+ group := args[2].Int()
+ return 0, nil, fchownat(t, linux.AT_FDCWD, pathAddr, owner, group, linux.AT_SYMLINK_NOFOLLOW)
+}
+
+// Fchownat implements Linux syscall fchownat(2).
+func Fchownat(t *kernel.Task, args arch.SyscallArguments) (uintptr, *kernel.SyscallControl, error) {
+ dirfd := args[0].Int()
+ pathAddr := args[1].Pointer()
+ owner := args[2].Int()
+ group := args[3].Int()
+ flags := args[4].Int()
+ return 0, nil, fchownat(t, dirfd, pathAddr, owner, group, flags)
+}
+
+func fchownat(t *kernel.Task, dirfd int32, pathAddr usermem.Addr, owner, group, flags int32) error {
+ if flags&^(linux.AT_EMPTY_PATH|linux.AT_SYMLINK_NOFOLLOW) != 0 {
+ return syserror.EINVAL
+ }
+
+ path, err := copyInPath(t, pathAddr)
+ if err != nil {
+ return err
+ }
+
+ var opts vfs.SetStatOptions
+ if err := populateSetStatOptionsForChown(t, owner, group, &opts); err != nil {
+ return err
+ }
+
+ return setstatat(t, dirfd, path, shouldAllowEmptyPath(flags&linux.AT_EMPTY_PATH != 0), shouldFollowFinalSymlink(flags&linux.AT_SYMLINK_NOFOLLOW == 0), &opts)
+}
+
+func populateSetStatOptionsForChown(t *kernel.Task, owner, group int32, opts *vfs.SetStatOptions) error {
+ userns := t.UserNamespace()
+ if owner != -1 {
+ kuid := userns.MapToKUID(auth.UID(owner))
+ if !kuid.Ok() {
+ return syserror.EINVAL
+ }
+ opts.Stat.Mask |= linux.STATX_UID
+ opts.Stat.UID = uint32(kuid)
+ }
+ if group != -1 {
+ kgid := userns.MapToKGID(auth.GID(group))
+ if !kgid.Ok() {
+ return syserror.EINVAL
+ }
+ opts.Stat.Mask |= linux.STATX_GID
+ opts.Stat.GID = uint32(kgid)
+ }
+ return nil
+}
+
+// Fchown implements Linux syscall fchown(2).
+func Fchown(t *kernel.Task, args arch.SyscallArguments) (uintptr, *kernel.SyscallControl, error) {
+ fd := args[0].Int()
+ owner := args[1].Int()
+ group := args[2].Int()
+
+ file := t.GetFileVFS2(fd)
+ if file == nil {
+ return 0, nil, syserror.EBADF
+ }
+ defer file.DecRef()
+
+ var opts vfs.SetStatOptions
+ if err := populateSetStatOptionsForChown(t, owner, group, &opts); err != nil {
+ return 0, nil, err
+ }
+ return 0, nil, file.SetStat(t, opts)
+}
+
+// Truncate implements Linux syscall truncate(2).
+func Truncate(t *kernel.Task, args arch.SyscallArguments) (uintptr, *kernel.SyscallControl, error) {
+ addr := args[0].Pointer()
+ length := args[1].Int64()
+
+ if length < 0 {
+ return 0, nil, syserror.EINVAL
+ }
+
+ path, err := copyInPath(t, addr)
+ if err != nil {
+ return 0, nil, err
+ }
+
+ return 0, nil, setstatat(t, linux.AT_FDCWD, path, disallowEmptyPath, followFinalSymlink, &vfs.SetStatOptions{
+ Stat: linux.Statx{
+ Mask: linux.STATX_SIZE,
+ Size: uint64(length),
+ },
+ })
+}
+
+// Ftruncate implements Linux syscall ftruncate(2).
+func Ftruncate(t *kernel.Task, args arch.SyscallArguments) (uintptr, *kernel.SyscallControl, error) {
+ fd := args[0].Int()
+ length := args[1].Int64()
+
+ if length < 0 {
+ return 0, nil, syserror.EINVAL
+ }
+
+ file := t.GetFileVFS2(fd)
+ if file == nil {
+ return 0, nil, syserror.EBADF
+ }
+ defer file.DecRef()
+
+ return 0, nil, file.SetStat(t, vfs.SetStatOptions{
+ Stat: linux.Statx{
+ Mask: linux.STATX_SIZE,
+ Size: uint64(length),
+ },
+ })
+}
+
+// Utime implements Linux syscall utime(2).
+func Utime(t *kernel.Task, args arch.SyscallArguments) (uintptr, *kernel.SyscallControl, error) {
+ pathAddr := args[0].Pointer()
+ timesAddr := args[1].Pointer()
+
+ path, err := copyInPath(t, pathAddr)
+ if err != nil {
+ return 0, nil, err
+ }
+
+ opts := vfs.SetStatOptions{
+ Stat: linux.Statx{
+ Mask: linux.STATX_ATIME | linux.STATX_MTIME,
+ },
+ }
+ if timesAddr == 0 {
+ opts.Stat.Atime.Nsec = linux.UTIME_NOW
+ opts.Stat.Mtime.Nsec = linux.UTIME_NOW
+ } else {
+ var times linux.Utime
+ if err := times.CopyIn(t, timesAddr); err != nil {
+ return 0, nil, err
+ }
+ opts.Stat.Atime.Sec = times.Actime
+ opts.Stat.Mtime.Sec = times.Modtime
+ }
+
+ return 0, nil, setstatat(t, linux.AT_FDCWD, path, disallowEmptyPath, followFinalSymlink, &opts)
+}
+
+// Utimes implements Linux syscall utimes(2).
+func Utimes(t *kernel.Task, args arch.SyscallArguments) (uintptr, *kernel.SyscallControl, error) {
+ pathAddr := args[0].Pointer()
+ timesAddr := args[1].Pointer()
+
+ path, err := copyInPath(t, pathAddr)
+ if err != nil {
+ return 0, nil, err
+ }
+
+ opts := vfs.SetStatOptions{
+ Stat: linux.Statx{
+ Mask: linux.STATX_ATIME | linux.STATX_MTIME,
+ },
+ }
+ if timesAddr == 0 {
+ opts.Stat.Atime.Nsec = linux.UTIME_NOW
+ opts.Stat.Mtime.Nsec = linux.UTIME_NOW
+ } else {
+ var times [2]linux.Timeval
+ if _, err := t.CopyIn(timesAddr, &times); err != nil {
+ return 0, nil, err
+ }
+ opts.Stat.Atime = linux.StatxTimestamp{
+ Sec: times[0].Sec,
+ Nsec: uint32(times[0].Usec * 1000),
+ }
+ opts.Stat.Mtime = linux.StatxTimestamp{
+ Sec: times[1].Sec,
+ Nsec: uint32(times[1].Usec * 1000),
+ }
+ }
+
+ return 0, nil, setstatat(t, linux.AT_FDCWD, path, disallowEmptyPath, followFinalSymlink, &opts)
+}
+
+// Utimensat implements Linux syscall utimensat(2).
+func Utimensat(t *kernel.Task, args arch.SyscallArguments) (uintptr, *kernel.SyscallControl, error) {
+ dirfd := args[0].Int()
+ pathAddr := args[1].Pointer()
+ timesAddr := args[2].Pointer()
+ flags := args[3].Int()
+
+ if flags&^linux.AT_SYMLINK_NOFOLLOW != 0 {
+ return 0, nil, syserror.EINVAL
+ }
+
+ path, err := copyInPath(t, pathAddr)
+ if err != nil {
+ return 0, nil, err
+ }
+
+ var opts vfs.SetStatOptions
+ if err := populateSetStatOptionsForUtimens(t, timesAddr, &opts); err != nil {
+ return 0, nil, err
+ }
+
+ return 0, nil, setstatat(t, dirfd, path, disallowEmptyPath, followFinalSymlink, &opts)
+}
+
+// Futimens implements Linux syscall futimens(2).
+func Futimens(t *kernel.Task, args arch.SyscallArguments) (uintptr, *kernel.SyscallControl, error) {
+ fd := args[0].Int()
+ timesAddr := args[1].Pointer()
+
+ file := t.GetFileVFS2(fd)
+ if file == nil {
+ return 0, nil, syserror.EBADF
+ }
+ defer file.DecRef()
+
+ var opts vfs.SetStatOptions
+ if err := populateSetStatOptionsForUtimens(t, timesAddr, &opts); err != nil {
+ return 0, nil, err
+ }
+
+ return 0, nil, file.SetStat(t, opts)
+}
+
+func populateSetStatOptionsForUtimens(t *kernel.Task, timesAddr usermem.Addr, opts *vfs.SetStatOptions) error {
+ if timesAddr == 0 {
+ opts.Stat.Mask = linux.STATX_ATIME | linux.STATX_MTIME
+ opts.Stat.Atime.Nsec = linux.UTIME_NOW
+ opts.Stat.Mtime.Nsec = linux.UTIME_NOW
+ return nil
+ }
+ var times [2]linux.Timespec
+ if _, err := t.CopyIn(timesAddr, &times); err != nil {
+ return err
+ }
+ if times[0].Nsec != linux.UTIME_OMIT {
+ opts.Stat.Mask |= linux.STATX_ATIME
+ opts.Stat.Atime = linux.StatxTimestamp{
+ Sec: times[0].Sec,
+ Nsec: uint32(times[0].Nsec),
+ }
+ }
+ if times[1].Nsec != linux.UTIME_OMIT {
+ opts.Stat.Mask |= linux.STATX_MTIME
+ opts.Stat.Mtime = linux.StatxTimestamp{
+ Sec: times[1].Sec,
+ Nsec: uint32(times[1].Nsec),
+ }
+ }
+ return nil
+}
+
+func setstatat(t *kernel.Task, dirfd int32, path fspath.Path, shouldAllowEmptyPath shouldAllowEmptyPath, shouldFollowFinalSymlink shouldFollowFinalSymlink, opts *vfs.SetStatOptions) error {
+ root := t.FSContext().RootDirectoryVFS2()
+ defer root.DecRef()
+ start := root
+ if !path.Absolute {
+ if !path.HasComponents() && !bool(shouldAllowEmptyPath) {
+ return syserror.ENOENT
+ }
+ if dirfd == linux.AT_FDCWD {
+ start = t.FSContext().WorkingDirectoryVFS2()
+ defer start.DecRef()
+ } else {
+ dirfile := t.GetFileVFS2(dirfd)
+ if dirfile == nil {
+ return syserror.EBADF
+ }
+ if !path.HasComponents() {
+ // Use FileDescription.SetStat() instead of
+ // VirtualFilesystem.SetStatAt(), since the former may be able
+ // to use opened file state to expedite the SetStat.
+ err := dirfile.SetStat(t, *opts)
+ dirfile.DecRef()
+ return err
+ }
+ start = dirfile.VirtualDentry()
+ start.IncRef()
+ defer start.DecRef()
+ dirfile.DecRef()
+ }
+ }
+ return t.Kernel().VFS().SetStatAt(t, t.Credentials(), &vfs.PathOperation{
+ Root: root,
+ Start: start,
+ Path: path,
+ FollowFinalSymlink: bool(shouldFollowFinalSymlink),
+ }, opts)
+}
diff --git a/pkg/sentry/syscalls/linux/vfs2/stat.go b/pkg/sentry/syscalls/linux/vfs2/stat.go
new file mode 100644
index 000000000..dca8d7011
--- /dev/null
+++ b/pkg/sentry/syscalls/linux/vfs2/stat.go
@@ -0,0 +1,346 @@
+// Copyright 2020 The gVisor Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+package vfs2
+
+import (
+ "gvisor.dev/gvisor/pkg/abi/linux"
+ "gvisor.dev/gvisor/pkg/fspath"
+ "gvisor.dev/gvisor/pkg/gohacks"
+ "gvisor.dev/gvisor/pkg/sentry/arch"
+ "gvisor.dev/gvisor/pkg/sentry/kernel"
+ "gvisor.dev/gvisor/pkg/sentry/kernel/auth"
+ "gvisor.dev/gvisor/pkg/sentry/vfs"
+ "gvisor.dev/gvisor/pkg/syserror"
+ "gvisor.dev/gvisor/pkg/usermem"
+)
+
+// Stat implements Linux syscall stat(2).
+func Stat(t *kernel.Task, args arch.SyscallArguments) (uintptr, *kernel.SyscallControl, error) {
+ pathAddr := args[0].Pointer()
+ statAddr := args[1].Pointer()
+ return 0, nil, fstatat(t, linux.AT_FDCWD, pathAddr, statAddr, 0 /* flags */)
+}
+
+// Lstat implements Linux syscall lstat(2).
+func Lstat(t *kernel.Task, args arch.SyscallArguments) (uintptr, *kernel.SyscallControl, error) {
+ pathAddr := args[0].Pointer()
+ statAddr := args[1].Pointer()
+ return 0, nil, fstatat(t, linux.AT_FDCWD, pathAddr, statAddr, linux.AT_SYMLINK_NOFOLLOW)
+}
+
+// Newfstatat implements Linux syscall newfstatat, which backs fstatat(2).
+func Newfstatat(t *kernel.Task, args arch.SyscallArguments) (uintptr, *kernel.SyscallControl, error) {
+ dirfd := args[0].Int()
+ pathAddr := args[1].Pointer()
+ statAddr := args[2].Pointer()
+ flags := args[3].Int()
+ return 0, nil, fstatat(t, dirfd, pathAddr, statAddr, flags)
+}
+
+func fstatat(t *kernel.Task, dirfd int32, pathAddr, statAddr usermem.Addr, flags int32) error {
+ if flags&^(linux.AT_EMPTY_PATH|linux.AT_SYMLINK_NOFOLLOW) != 0 {
+ return syserror.EINVAL
+ }
+
+ opts := vfs.StatOptions{
+ Mask: linux.STATX_BASIC_STATS,
+ }
+
+ path, err := copyInPath(t, pathAddr)
+ if err != nil {
+ return err
+ }
+
+ root := t.FSContext().RootDirectoryVFS2()
+ defer root.DecRef()
+ start := root
+ if !path.Absolute {
+ if !path.HasComponents() && flags&linux.AT_EMPTY_PATH == 0 {
+ return syserror.ENOENT
+ }
+ if dirfd == linux.AT_FDCWD {
+ start = t.FSContext().WorkingDirectoryVFS2()
+ defer start.DecRef()
+ } else {
+ dirfile := t.GetFileVFS2(dirfd)
+ if dirfile == nil {
+ return syserror.EBADF
+ }
+ if !path.HasComponents() {
+ // Use FileDescription.Stat() instead of
+ // VirtualFilesystem.StatAt() for fstatat(fd, ""), since the
+ // former may be able to use opened file state to expedite the
+ // Stat.
+ statx, err := dirfile.Stat(t, opts)
+ dirfile.DecRef()
+ if err != nil {
+ return err
+ }
+ var stat linux.Stat
+ convertStatxToUserStat(t, &statx, &stat)
+ return stat.CopyOut(t, statAddr)
+ }
+ start = dirfile.VirtualDentry()
+ start.IncRef()
+ defer start.DecRef()
+ dirfile.DecRef()
+ }
+ }
+
+ statx, err := t.Kernel().VFS().StatAt(t, t.Credentials(), &vfs.PathOperation{
+ Root: root,
+ Start: start,
+ Path: path,
+ FollowFinalSymlink: flags&linux.AT_SYMLINK_NOFOLLOW == 0,
+ }, &opts)
+ if err != nil {
+ return err
+ }
+ var stat linux.Stat
+ convertStatxToUserStat(t, &statx, &stat)
+ return stat.CopyOut(t, statAddr)
+}
+
+// This takes both input and output as pointer arguments to avoid copying large
+// structs.
+func convertStatxToUserStat(t *kernel.Task, statx *linux.Statx, stat *linux.Stat) {
+ // Linux just copies fields from struct kstat without regard to struct
+ // kstat::result_mask (fs/stat.c:cp_new_stat()), so we do too.
+ userns := t.UserNamespace()
+ *stat = linux.Stat{
+ Dev: uint64(linux.MakeDeviceID(uint16(statx.DevMajor), statx.DevMinor)),
+ Ino: statx.Ino,
+ Nlink: uint64(statx.Nlink),
+ Mode: uint32(statx.Mode),
+ UID: uint32(auth.KUID(statx.UID).In(userns).OrOverflow()),
+ GID: uint32(auth.KGID(statx.GID).In(userns).OrOverflow()),
+ Rdev: uint64(linux.MakeDeviceID(uint16(statx.RdevMajor), statx.RdevMinor)),
+ Size: int64(statx.Size),
+ Blksize: int64(statx.Blksize),
+ Blocks: int64(statx.Blocks),
+ ATime: timespecFromStatxTimestamp(statx.Atime),
+ MTime: timespecFromStatxTimestamp(statx.Mtime),
+ CTime: timespecFromStatxTimestamp(statx.Ctime),
+ }
+}
+
+func timespecFromStatxTimestamp(sxts linux.StatxTimestamp) linux.Timespec {
+ return linux.Timespec{
+ Sec: sxts.Sec,
+ Nsec: int64(sxts.Nsec),
+ }
+}
+
+// Fstat implements Linux syscall fstat(2).
+func Fstat(t *kernel.Task, args arch.SyscallArguments) (uintptr, *kernel.SyscallControl, error) {
+ fd := args[0].Int()
+ statAddr := args[1].Pointer()
+
+ file := t.GetFileVFS2(fd)
+ if file == nil {
+ return 0, nil, syserror.EBADF
+ }
+ defer file.DecRef()
+
+ statx, err := file.Stat(t, vfs.StatOptions{
+ Mask: linux.STATX_BASIC_STATS,
+ })
+ if err != nil {
+ return 0, nil, err
+ }
+ var stat linux.Stat
+ convertStatxToUserStat(t, &statx, &stat)
+ return 0, nil, stat.CopyOut(t, statAddr)
+}
+
+// Statx implements Linux syscall statx(2).
+func Statx(t *kernel.Task, args arch.SyscallArguments) (uintptr, *kernel.SyscallControl, error) {
+ dirfd := args[0].Int()
+ pathAddr := args[1].Pointer()
+ flags := args[2].Int()
+ mask := args[3].Uint()
+ statxAddr := args[4].Pointer()
+
+ if flags&^(linux.AT_EMPTY_PATH|linux.AT_SYMLINK_NOFOLLOW) != 0 {
+ return 0, nil, syserror.EINVAL
+ }
+
+ opts := vfs.StatOptions{
+ Mask: mask,
+ Sync: uint32(flags & linux.AT_STATX_SYNC_TYPE),
+ }
+
+ path, err := copyInPath(t, pathAddr)
+ if err != nil {
+ return 0, nil, err
+ }
+
+ root := t.FSContext().RootDirectoryVFS2()
+ defer root.DecRef()
+ start := root
+ if !path.Absolute {
+ if !path.HasComponents() && flags&linux.AT_EMPTY_PATH == 0 {
+ return 0, nil, syserror.ENOENT
+ }
+ if dirfd == linux.AT_FDCWD {
+ start = t.FSContext().WorkingDirectoryVFS2()
+ defer start.DecRef()
+ } else {
+ dirfile := t.GetFileVFS2(dirfd)
+ if dirfile == nil {
+ return 0, nil, syserror.EBADF
+ }
+ if !path.HasComponents() {
+ // Use FileDescription.Stat() instead of
+ // VirtualFilesystem.StatAt() for statx(fd, ""), since the
+ // former may be able to use opened file state to expedite the
+ // Stat.
+ statx, err := dirfile.Stat(t, opts)
+ dirfile.DecRef()
+ if err != nil {
+ return 0, nil, err
+ }
+ userifyStatx(t, &statx)
+ return 0, nil, statx.CopyOut(t, statxAddr)
+ }
+ start = dirfile.VirtualDentry()
+ start.IncRef()
+ defer start.DecRef()
+ dirfile.DecRef()
+ }
+ }
+
+ statx, err := t.Kernel().VFS().StatAt(t, t.Credentials(), &vfs.PathOperation{
+ Root: root,
+ Start: start,
+ Path: path,
+ FollowFinalSymlink: flags&linux.AT_SYMLINK_NOFOLLOW == 0,
+ }, &opts)
+ if err != nil {
+ return 0, nil, err
+ }
+ userifyStatx(t, &statx)
+ return 0, nil, statx.CopyOut(t, statxAddr)
+}
+
+func userifyStatx(t *kernel.Task, statx *linux.Statx) {
+ userns := t.UserNamespace()
+ statx.UID = uint32(auth.KUID(statx.UID).In(userns).OrOverflow())
+ statx.GID = uint32(auth.KGID(statx.GID).In(userns).OrOverflow())
+}
+
+// Readlink implements Linux syscall readlink(2).
+func Readlink(t *kernel.Task, args arch.SyscallArguments) (uintptr, *kernel.SyscallControl, error) {
+ pathAddr := args[0].Pointer()
+ bufAddr := args[1].Pointer()
+ size := args[2].SizeT()
+ return readlinkat(t, linux.AT_FDCWD, pathAddr, bufAddr, size)
+}
+
+// Access implements Linux syscall access(2).
+func Access(t *kernel.Task, args arch.SyscallArguments) (uintptr, *kernel.SyscallControl, error) {
+ // FIXME(jamieliu): actually implement
+ return 0, nil, nil
+}
+
+// Faccessat implements Linux syscall access(2).
+func Faccessat(t *kernel.Task, args arch.SyscallArguments) (uintptr, *kernel.SyscallControl, error) {
+ // FIXME(jamieliu): actually implement
+ return 0, nil, nil
+}
+
+// Readlinkat implements Linux syscall mknodat(2).
+func Readlinkat(t *kernel.Task, args arch.SyscallArguments) (uintptr, *kernel.SyscallControl, error) {
+ dirfd := args[0].Int()
+ pathAddr := args[1].Pointer()
+ bufAddr := args[2].Pointer()
+ size := args[3].SizeT()
+ return readlinkat(t, dirfd, pathAddr, bufAddr, size)
+}
+
+func readlinkat(t *kernel.Task, dirfd int32, pathAddr, bufAddr usermem.Addr, size uint) (uintptr, *kernel.SyscallControl, error) {
+ if int(size) <= 0 {
+ return 0, nil, syserror.EINVAL
+ }
+
+ path, err := copyInPath(t, pathAddr)
+ if err != nil {
+ return 0, nil, err
+ }
+ // "Since Linux 2.6.39, pathname can be an empty string, in which case the
+ // call operates on the symbolic link referred to by dirfd ..." -
+ // readlinkat(2)
+ tpop, err := getTaskPathOperation(t, dirfd, path, allowEmptyPath, nofollowFinalSymlink)
+ if err != nil {
+ return 0, nil, err
+ }
+ defer tpop.Release()
+
+ target, err := t.Kernel().VFS().ReadlinkAt(t, t.Credentials(), &tpop.pop)
+ if err != nil {
+ return 0, nil, err
+ }
+
+ if len(target) > int(size) {
+ target = target[:size]
+ }
+ n, err := t.CopyOutBytes(bufAddr, gohacks.ImmutableBytesFromString(target))
+ if n == 0 {
+ return 0, nil, err
+ }
+ return uintptr(n), nil, nil
+}
+
+// Statfs implements Linux syscall statfs(2).
+func Statfs(t *kernel.Task, args arch.SyscallArguments) (uintptr, *kernel.SyscallControl, error) {
+ pathAddr := args[0].Pointer()
+ bufAddr := args[1].Pointer()
+
+ path, err := copyInPath(t, pathAddr)
+ if err != nil {
+ return 0, nil, err
+ }
+ tpop, err := getTaskPathOperation(t, linux.AT_FDCWD, path, disallowEmptyPath, followFinalSymlink)
+ if err != nil {
+ return 0, nil, err
+ }
+ defer tpop.Release()
+
+ statfs, err := t.Kernel().VFS().StatFSAt(t, t.Credentials(), &tpop.pop)
+ if err != nil {
+ return 0, nil, err
+ }
+
+ return 0, nil, statfs.CopyOut(t, bufAddr)
+}
+
+// Fstatfs implements Linux syscall fstatfs(2).
+func Fstatfs(t *kernel.Task, args arch.SyscallArguments) (uintptr, *kernel.SyscallControl, error) {
+ fd := args[0].Int()
+ bufAddr := args[1].Pointer()
+
+ tpop, err := getTaskPathOperation(t, fd, fspath.Path{}, allowEmptyPath, nofollowFinalSymlink)
+ if err != nil {
+ return 0, nil, err
+ }
+ defer tpop.Release()
+
+ statfs, err := t.Kernel().VFS().StatFSAt(t, t.Credentials(), &tpop.pop)
+ if err != nil {
+ return 0, nil, err
+ }
+
+ return 0, nil, statfs.CopyOut(t, bufAddr)
+}
diff --git a/pkg/sentry/syscalls/linux/vfs2/sync.go b/pkg/sentry/syscalls/linux/vfs2/sync.go
new file mode 100644
index 000000000..365250b0b
--- /dev/null
+++ b/pkg/sentry/syscalls/linux/vfs2/sync.go
@@ -0,0 +1,87 @@
+// Copyright 2020 The gVisor Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+package vfs2
+
+import (
+ "gvisor.dev/gvisor/pkg/abi/linux"
+ "gvisor.dev/gvisor/pkg/sentry/arch"
+ "gvisor.dev/gvisor/pkg/sentry/kernel"
+ "gvisor.dev/gvisor/pkg/syserror"
+)
+
+// Sync implements Linux syscall sync(2).
+func Sync(t *kernel.Task, args arch.SyscallArguments) (uintptr, *kernel.SyscallControl, error) {
+ return 0, nil, t.Kernel().VFS().SyncAllFilesystems(t)
+}
+
+// Syncfs implements Linux syscall syncfs(2).
+func Syncfs(t *kernel.Task, args arch.SyscallArguments) (uintptr, *kernel.SyscallControl, error) {
+ fd := args[0].Int()
+
+ file := t.GetFileVFS2(fd)
+ if file == nil {
+ return 0, nil, syserror.EBADF
+ }
+ defer file.DecRef()
+
+ return 0, nil, file.SyncFS(t)
+}
+
+// Fsync implements Linux syscall fsync(2).
+func Fsync(t *kernel.Task, args arch.SyscallArguments) (uintptr, *kernel.SyscallControl, error) {
+ fd := args[0].Int()
+
+ file := t.GetFileVFS2(fd)
+ if file == nil {
+ return 0, nil, syserror.EBADF
+ }
+ defer file.DecRef()
+
+ return 0, nil, file.Sync(t)
+}
+
+// Fdatasync implements Linux syscall fdatasync(2).
+func Fdatasync(t *kernel.Task, args arch.SyscallArguments) (uintptr, *kernel.SyscallControl, error) {
+ // TODO(gvisor.dev/issue/1897): Avoid writeback of unnecessary metadata.
+ return Fsync(t, args)
+}
+
+// SyncFileRange implements Linux syscall sync_file_range(2).
+func SyncFileRange(t *kernel.Task, args arch.SyscallArguments) (uintptr, *kernel.SyscallControl, error) {
+ fd := args[0].Int()
+ offset := args[1].Int64()
+ nbytes := args[2].Int64()
+ flags := args[3].Uint()
+
+ if offset < 0 {
+ return 0, nil, syserror.EINVAL
+ }
+ if nbytes < 0 {
+ return 0, nil, syserror.EINVAL
+ }
+ if flags&^(linux.SYNC_FILE_RANGE_WAIT_BEFORE|linux.SYNC_FILE_RANGE_WRITE|linux.SYNC_FILE_RANGE_WAIT_AFTER) != 0 {
+ return 0, nil, syserror.EINVAL
+ }
+
+ file := t.GetFileVFS2(fd)
+ if file == nil {
+ return 0, nil, syserror.EBADF
+ }
+ defer file.DecRef()
+
+ // TODO(gvisor.dev/issue/1897): Avoid writeback of data ranges outside of
+ // [offset, offset+nbytes).
+ return 0, nil, file.Sync(t)
+}
diff --git a/pkg/sentry/syscalls/linux/vfs2/sys_read.go b/pkg/sentry/syscalls/linux/vfs2/sys_read.go
deleted file mode 100644
index 7667524c7..000000000
--- a/pkg/sentry/syscalls/linux/vfs2/sys_read.go
+++ /dev/null
@@ -1,95 +0,0 @@
-// Copyright 2020 The gVisor Authors.
-//
-// Licensed under the Apache License, Version 2.0 (the "License");
-// you may not use this file except in compliance with the License.
-// You may obtain a copy of the License at
-//
-// http://www.apache.org/licenses/LICENSE-2.0
-//
-// Unless required by applicable law or agreed to in writing, software
-// distributed under the License is distributed on an "AS IS" BASIS,
-// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-// See the License for the specific language governing permissions and
-// limitations under the License.
-
-package vfs2
-
-import (
- "gvisor.dev/gvisor/pkg/sentry/arch"
- "gvisor.dev/gvisor/pkg/sentry/kernel"
- "gvisor.dev/gvisor/pkg/sentry/syscalls/linux"
- "gvisor.dev/gvisor/pkg/sentry/vfs"
- "gvisor.dev/gvisor/pkg/syserror"
- "gvisor.dev/gvisor/pkg/usermem"
- "gvisor.dev/gvisor/pkg/waiter"
-)
-
-const (
- // EventMaskRead contains events that can be triggered on reads.
- EventMaskRead = waiter.EventIn | waiter.EventHUp | waiter.EventErr
-)
-
-// Read implements linux syscall read(2). Note that we try to get a buffer that
-// is exactly the size requested because some applications like qemu expect
-// they can do large reads all at once. Bug for bug. Same for other read
-// calls below.
-func Read(t *kernel.Task, args arch.SyscallArguments) (uintptr, *kernel.SyscallControl, error) {
- fd := args[0].Int()
- addr := args[1].Pointer()
- size := args[2].SizeT()
-
- file := t.GetFileVFS2(fd)
- if file == nil {
- return 0, nil, syserror.EBADF
- }
- defer file.DecRef()
-
- // Check that the size is legitimate.
- si := int(size)
- if si < 0 {
- return 0, nil, syserror.EINVAL
- }
-
- // Get the destination of the read.
- dst, err := t.SingleIOSequence(addr, si, usermem.IOOpts{
- AddressSpaceActive: true,
- })
- if err != nil {
- return 0, nil, err
- }
-
- n, err := read(t, file, dst, vfs.ReadOptions{})
- t.IOUsage().AccountReadSyscall(n)
- return uintptr(n), nil, linux.HandleIOErrorVFS2(t, n != 0, err, kernel.ERESTARTSYS, "read", file)
-}
-
-func read(t *kernel.Task, file *vfs.FileDescription, dst usermem.IOSequence, opts vfs.ReadOptions) (int64, error) {
- n, err := file.Read(t, dst, opts)
- if err != syserror.ErrWouldBlock {
- return n, err
- }
-
- // Register for notifications.
- w, ch := waiter.NewChannelEntry(nil)
- file.EventRegister(&w, EventMaskRead)
-
- total := n
- for {
- // Shorten dst to reflect bytes previously read.
- dst = dst.DropFirst(int(n))
-
- // Issue the request and break out if it completes with anything other than
- // "would block".
- n, err := file.Read(t, dst, opts)
- total += n
- if err != syserror.ErrWouldBlock {
- break
- }
- if err := t.Block(ch); err != nil {
- break
- }
- }
- file.EventUnregister(&w)
-
- return total, err
-}
diff --git a/pkg/sentry/syscalls/linux/vfs2/xattr.go b/pkg/sentry/syscalls/linux/vfs2/xattr.go
new file mode 100644
index 000000000..89e9ff4d7
--- /dev/null
+++ b/pkg/sentry/syscalls/linux/vfs2/xattr.go
@@ -0,0 +1,353 @@
+// Copyright 2020 The gVisor Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+package vfs2
+
+import (
+ "bytes"
+
+ "gvisor.dev/gvisor/pkg/abi/linux"
+ "gvisor.dev/gvisor/pkg/gohacks"
+ "gvisor.dev/gvisor/pkg/sentry/arch"
+ "gvisor.dev/gvisor/pkg/sentry/kernel"
+ "gvisor.dev/gvisor/pkg/sentry/vfs"
+ "gvisor.dev/gvisor/pkg/syserror"
+ "gvisor.dev/gvisor/pkg/usermem"
+)
+
+// Listxattr implements Linux syscall listxattr(2).
+func Listxattr(t *kernel.Task, args arch.SyscallArguments) (uintptr, *kernel.SyscallControl, error) {
+ return listxattr(t, args, followFinalSymlink)
+}
+
+// Llistxattr implements Linux syscall llistxattr(2).
+func Llistxattr(t *kernel.Task, args arch.SyscallArguments) (uintptr, *kernel.SyscallControl, error) {
+ return listxattr(t, args, nofollowFinalSymlink)
+}
+
+func listxattr(t *kernel.Task, args arch.SyscallArguments, shouldFollowFinalSymlink shouldFollowFinalSymlink) (uintptr, *kernel.SyscallControl, error) {
+ pathAddr := args[0].Pointer()
+ listAddr := args[1].Pointer()
+ size := args[2].SizeT()
+
+ path, err := copyInPath(t, pathAddr)
+ if err != nil {
+ return 0, nil, err
+ }
+ tpop, err := getTaskPathOperation(t, linux.AT_FDCWD, path, disallowEmptyPath, shouldFollowFinalSymlink)
+ if err != nil {
+ return 0, nil, err
+ }
+ defer tpop.Release()
+
+ names, err := t.Kernel().VFS().ListxattrAt(t, t.Credentials(), &tpop.pop)
+ if err != nil {
+ return 0, nil, err
+ }
+ n, err := copyOutXattrNameList(t, listAddr, size, names)
+ if err != nil {
+ return 0, nil, err
+ }
+ return uintptr(n), nil, nil
+}
+
+// Flistxattr implements Linux syscall flistxattr(2).
+func Flistxattr(t *kernel.Task, args arch.SyscallArguments) (uintptr, *kernel.SyscallControl, error) {
+ fd := args[0].Int()
+ listAddr := args[1].Pointer()
+ size := args[2].SizeT()
+
+ file := t.GetFileVFS2(fd)
+ if file == nil {
+ return 0, nil, syserror.EBADF
+ }
+ defer file.DecRef()
+
+ names, err := file.Listxattr(t)
+ if err != nil {
+ return 0, nil, err
+ }
+ n, err := copyOutXattrNameList(t, listAddr, size, names)
+ if err != nil {
+ return 0, nil, err
+ }
+ return uintptr(n), nil, nil
+}
+
+// Getxattr implements Linux syscall getxattr(2).
+func Getxattr(t *kernel.Task, args arch.SyscallArguments) (uintptr, *kernel.SyscallControl, error) {
+ return getxattr(t, args, followFinalSymlink)
+}
+
+// Lgetxattr implements Linux syscall lgetxattr(2).
+func Lgetxattr(t *kernel.Task, args arch.SyscallArguments) (uintptr, *kernel.SyscallControl, error) {
+ return getxattr(t, args, nofollowFinalSymlink)
+}
+
+func getxattr(t *kernel.Task, args arch.SyscallArguments, shouldFollowFinalSymlink shouldFollowFinalSymlink) (uintptr, *kernel.SyscallControl, error) {
+ pathAddr := args[0].Pointer()
+ nameAddr := args[1].Pointer()
+ valueAddr := args[2].Pointer()
+ size := args[3].SizeT()
+
+ path, err := copyInPath(t, pathAddr)
+ if err != nil {
+ return 0, nil, err
+ }
+ tpop, err := getTaskPathOperation(t, linux.AT_FDCWD, path, disallowEmptyPath, shouldFollowFinalSymlink)
+ if err != nil {
+ return 0, nil, err
+ }
+ defer tpop.Release()
+
+ name, err := copyInXattrName(t, nameAddr)
+ if err != nil {
+ return 0, nil, err
+ }
+
+ value, err := t.Kernel().VFS().GetxattrAt(t, t.Credentials(), &tpop.pop, name)
+ if err != nil {
+ return 0, nil, err
+ }
+ n, err := copyOutXattrValue(t, valueAddr, size, value)
+ if err != nil {
+ return 0, nil, err
+ }
+ return uintptr(n), nil, nil
+}
+
+// Fgetxattr implements Linux syscall fgetxattr(2).
+func Fgetxattr(t *kernel.Task, args arch.SyscallArguments) (uintptr, *kernel.SyscallControl, error) {
+ fd := args[0].Int()
+ nameAddr := args[1].Pointer()
+ valueAddr := args[2].Pointer()
+ size := args[3].SizeT()
+
+ file := t.GetFileVFS2(fd)
+ if file == nil {
+ return 0, nil, syserror.EBADF
+ }
+ defer file.DecRef()
+
+ name, err := copyInXattrName(t, nameAddr)
+ if err != nil {
+ return 0, nil, err
+ }
+
+ value, err := file.Getxattr(t, name)
+ if err != nil {
+ return 0, nil, err
+ }
+ n, err := copyOutXattrValue(t, valueAddr, size, value)
+ if err != nil {
+ return 0, nil, err
+ }
+ return uintptr(n), nil, nil
+}
+
+// Setxattr implements Linux syscall setxattr(2).
+func Setxattr(t *kernel.Task, args arch.SyscallArguments) (uintptr, *kernel.SyscallControl, error) {
+ return 0, nil, setxattr(t, args, followFinalSymlink)
+}
+
+// Lsetxattr implements Linux syscall lsetxattr(2).
+func Lsetxattr(t *kernel.Task, args arch.SyscallArguments) (uintptr, *kernel.SyscallControl, error) {
+ return 0, nil, setxattr(t, args, nofollowFinalSymlink)
+}
+
+func setxattr(t *kernel.Task, args arch.SyscallArguments, shouldFollowFinalSymlink shouldFollowFinalSymlink) error {
+ pathAddr := args[0].Pointer()
+ nameAddr := args[1].Pointer()
+ valueAddr := args[2].Pointer()
+ size := args[3].SizeT()
+ flags := args[4].Int()
+
+ if flags&^(linux.XATTR_CREATE|linux.XATTR_REPLACE) != 0 {
+ return syserror.EINVAL
+ }
+
+ path, err := copyInPath(t, pathAddr)
+ if err != nil {
+ return err
+ }
+ tpop, err := getTaskPathOperation(t, linux.AT_FDCWD, path, disallowEmptyPath, shouldFollowFinalSymlink)
+ if err != nil {
+ return err
+ }
+ defer tpop.Release()
+
+ name, err := copyInXattrName(t, nameAddr)
+ if err != nil {
+ return err
+ }
+ value, err := copyInXattrValue(t, valueAddr, size)
+ if err != nil {
+ return err
+ }
+
+ return t.Kernel().VFS().SetxattrAt(t, t.Credentials(), &tpop.pop, &vfs.SetxattrOptions{
+ Name: name,
+ Value: value,
+ Flags: uint32(flags),
+ })
+}
+
+// Fsetxattr implements Linux syscall fsetxattr(2).
+func Fsetxattr(t *kernel.Task, args arch.SyscallArguments) (uintptr, *kernel.SyscallControl, error) {
+ fd := args[0].Int()
+ nameAddr := args[1].Pointer()
+ valueAddr := args[2].Pointer()
+ size := args[3].SizeT()
+ flags := args[4].Int()
+
+ if flags&^(linux.XATTR_CREATE|linux.XATTR_REPLACE) != 0 {
+ return 0, nil, syserror.EINVAL
+ }
+
+ file := t.GetFileVFS2(fd)
+ if file == nil {
+ return 0, nil, syserror.EBADF
+ }
+ defer file.DecRef()
+
+ name, err := copyInXattrName(t, nameAddr)
+ if err != nil {
+ return 0, nil, err
+ }
+ value, err := copyInXattrValue(t, valueAddr, size)
+ if err != nil {
+ return 0, nil, err
+ }
+
+ return 0, nil, file.Setxattr(t, vfs.SetxattrOptions{
+ Name: name,
+ Value: value,
+ Flags: uint32(flags),
+ })
+}
+
+// Removexattr implements Linux syscall removexattr(2).
+func Removexattr(t *kernel.Task, args arch.SyscallArguments) (uintptr, *kernel.SyscallControl, error) {
+ return 0, nil, removexattr(t, args, followFinalSymlink)
+}
+
+// Lremovexattr implements Linux syscall lremovexattr(2).
+func Lremovexattr(t *kernel.Task, args arch.SyscallArguments) (uintptr, *kernel.SyscallControl, error) {
+ return 0, nil, removexattr(t, args, nofollowFinalSymlink)
+}
+
+func removexattr(t *kernel.Task, args arch.SyscallArguments, shouldFollowFinalSymlink shouldFollowFinalSymlink) error {
+ pathAddr := args[0].Pointer()
+ nameAddr := args[1].Pointer()
+
+ path, err := copyInPath(t, pathAddr)
+ if err != nil {
+ return err
+ }
+ tpop, err := getTaskPathOperation(t, linux.AT_FDCWD, path, disallowEmptyPath, shouldFollowFinalSymlink)
+ if err != nil {
+ return err
+ }
+ defer tpop.Release()
+
+ name, err := copyInXattrName(t, nameAddr)
+ if err != nil {
+ return err
+ }
+
+ return t.Kernel().VFS().RemovexattrAt(t, t.Credentials(), &tpop.pop, name)
+}
+
+// Fremovexattr implements Linux syscall fremovexattr(2).
+func Fremovexattr(t *kernel.Task, args arch.SyscallArguments) (uintptr, *kernel.SyscallControl, error) {
+ fd := args[0].Int()
+ nameAddr := args[1].Pointer()
+
+ file := t.GetFileVFS2(fd)
+ if file == nil {
+ return 0, nil, syserror.EBADF
+ }
+ defer file.DecRef()
+
+ name, err := copyInXattrName(t, nameAddr)
+ if err != nil {
+ return 0, nil, err
+ }
+
+ return 0, nil, file.Removexattr(t, name)
+}
+
+func copyInXattrName(t *kernel.Task, nameAddr usermem.Addr) (string, error) {
+ name, err := t.CopyInString(nameAddr, linux.XATTR_NAME_MAX+1)
+ if err != nil {
+ if err == syserror.ENAMETOOLONG {
+ return "", syserror.ERANGE
+ }
+ return "", err
+ }
+ if len(name) == 0 {
+ return "", syserror.ERANGE
+ }
+ return name, nil
+}
+
+func copyOutXattrNameList(t *kernel.Task, listAddr usermem.Addr, size uint, names []string) (int, error) {
+ if size > linux.XATTR_LIST_MAX {
+ size = linux.XATTR_LIST_MAX
+ }
+ var buf bytes.Buffer
+ for _, name := range names {
+ buf.WriteString(name)
+ buf.WriteByte(0)
+ }
+ if size == 0 {
+ // Return the size that would be required to accomodate the list.
+ return buf.Len(), nil
+ }
+ if buf.Len() > int(size) {
+ if size >= linux.XATTR_LIST_MAX {
+ return 0, syserror.E2BIG
+ }
+ return 0, syserror.ERANGE
+ }
+ return t.CopyOutBytes(listAddr, buf.Bytes())
+}
+
+func copyInXattrValue(t *kernel.Task, valueAddr usermem.Addr, size uint) (string, error) {
+ if size > linux.XATTR_SIZE_MAX {
+ return "", syserror.E2BIG
+ }
+ buf := make([]byte, size)
+ if _, err := t.CopyInBytes(valueAddr, buf); err != nil {
+ return "", err
+ }
+ return gohacks.StringFromImmutableBytes(buf), nil
+}
+
+func copyOutXattrValue(t *kernel.Task, valueAddr usermem.Addr, size uint, value string) (int, error) {
+ if size > linux.XATTR_SIZE_MAX {
+ size = linux.XATTR_SIZE_MAX
+ }
+ if size == 0 {
+ // Return the size that would be required to accomodate the value.
+ return len(value), nil
+ }
+ if len(value) > int(size) {
+ if size >= linux.XATTR_SIZE_MAX {
+ return 0, syserror.E2BIG
+ }
+ return 0, syserror.ERANGE
+ }
+ return t.CopyOutBytes(valueAddr, gohacks.ImmutableBytesFromString(value))
+}
diff --git a/pkg/sentry/vfs/BUILD b/pkg/sentry/vfs/BUILD
index 0b4f18ab5..07c8383e6 100644
--- a/pkg/sentry/vfs/BUILD
+++ b/pkg/sentry/vfs/BUILD
@@ -43,6 +43,7 @@ go_library(
"//pkg/abi/linux",
"//pkg/context",
"//pkg/fspath",
+ "//pkg/gohacks",
"//pkg/log",
"//pkg/sentry/arch",
"//pkg/sentry/fs/lock",
diff --git a/pkg/sentry/vfs/epoll.go b/pkg/sentry/vfs/epoll.go
index eed41139b..3da45d744 100644
--- a/pkg/sentry/vfs/epoll.go
+++ b/pkg/sentry/vfs/epoll.go
@@ -202,6 +202,9 @@ func (ep *EpollInstance) AddInterest(file *FileDescription, num int32, event lin
// Add epi to file.epolls so that it is removed when the last
// FileDescription reference is dropped.
file.epollMu.Lock()
+ if file.epolls == nil {
+ file.epolls = make(map[*epollInterest]struct{})
+ }
file.epolls[epi] = struct{}{}
file.epollMu.Unlock()
diff --git a/pkg/sentry/vfs/mount.go b/pkg/sentry/vfs/mount.go
index 9912df799..31a4e5480 100644
--- a/pkg/sentry/vfs/mount.go
+++ b/pkg/sentry/vfs/mount.go
@@ -139,6 +139,23 @@ func (vfs *VirtualFilesystem) NewMountNamespace(ctx context.Context, creds *auth
return mntns, nil
}
+// NewDisconnectedMount returns a Mount representing fs with the given root
+// (which may be nil). The new Mount is not associated with any MountNamespace
+// and is not connected to any other Mounts. References are taken on fs and
+// root.
+func (vfs *VirtualFilesystem) NewDisconnectedMount(fs *Filesystem, root *Dentry, opts *MountOptions) (*Mount, error) {
+ fs.IncRef()
+ if root != nil {
+ root.IncRef()
+ }
+ return &Mount{
+ vfs: vfs,
+ fs: fs,
+ root: root,
+ refs: 1,
+ }, nil
+}
+
// MountAt creates and mounts a Filesystem configured by the given arguments.
func (vfs *VirtualFilesystem) MountAt(ctx context.Context, creds *auth.Credentials, source string, target *PathOperation, fsTypeName string, opts *MountOptions) error {
rft := vfs.getFilesystemType(fsTypeName)
diff --git a/pkg/sentry/vfs/mount_unsafe.go b/pkg/sentry/vfs/mount_unsafe.go
index 1fe766a44..bc7581698 100644
--- a/pkg/sentry/vfs/mount_unsafe.go
+++ b/pkg/sentry/vfs/mount_unsafe.go
@@ -26,6 +26,7 @@ import (
"sync/atomic"
"unsafe"
+ "gvisor.dev/gvisor/pkg/gohacks"
"gvisor.dev/gvisor/pkg/sync"
)
@@ -160,7 +161,7 @@ func newMountTableSlots(cap uintptr) unsafe.Pointer {
// Lookup may be called even if there are concurrent mutators of mt.
func (mt *mountTable) Lookup(parent *Mount, point *Dentry) *Mount {
key := mountKey{parent: unsafe.Pointer(parent), point: unsafe.Pointer(point)}
- hash := memhash(noescape(unsafe.Pointer(&key)), uintptr(mt.seed), mountKeyBytes)
+ hash := memhash(gohacks.Noescape(unsafe.Pointer(&key)), uintptr(mt.seed), mountKeyBytes)
loop:
for {
@@ -361,12 +362,3 @@ func memhash(p unsafe.Pointer, seed, s uintptr) uintptr
//go:linkname rand32 runtime.fastrand
func rand32() uint32
-
-// This is copy/pasted from runtime.noescape(), and is needed because arguments
-// apparently escape from all functions defined by linkname.
-//
-//go:nosplit
-func noescape(p unsafe.Pointer) unsafe.Pointer {
- x := uintptr(p)
- return unsafe.Pointer(x ^ 0)
-}
diff --git a/pkg/sentry/vfs/resolving_path.go b/pkg/sentry/vfs/resolving_path.go
index 8a0b382f6..eb4ebb511 100644
--- a/pkg/sentry/vfs/resolving_path.go
+++ b/pkg/sentry/vfs/resolving_path.go
@@ -228,7 +228,7 @@ func (rp *ResolvingPath) Advance() {
rp.pit = next
} else { // at end of path segment, continue with next one
rp.curPart--
- rp.pit = rp.parts[rp.curPart-1]
+ rp.pit = rp.parts[rp.curPart]
}
}
diff --git a/pkg/sentry/vfs/vfs.go b/pkg/sentry/vfs/vfs.go
index 8f29031b2..bde81e1ef 100644
--- a/pkg/sentry/vfs/vfs.go
+++ b/pkg/sentry/vfs/vfs.go
@@ -126,17 +126,23 @@ func (vfs *VirtualFilesystem) Init() error {
// Construct vfs.anonMount.
anonfsDevMinor, err := vfs.GetAnonBlockDevMinor()
if err != nil {
- return err
+ // This shouldn't be possible since anonBlockDevMinorNext was
+ // initialized to 1 above (no device numbers have been allocated yet).
+ panic(fmt.Sprintf("VirtualFilesystem.Init: device number allocation for anonfs failed: %v", err))
}
anonfs := anonFilesystem{
devMinor: anonfsDevMinor,
}
anonfs.vfsfs.Init(vfs, &anonfs)
- vfs.anonMount = &Mount{
- vfs: vfs,
- fs: &anonfs.vfsfs,
- refs: 1,
+ defer anonfs.vfsfs.DecRef()
+ anonMount, err := vfs.NewDisconnectedMount(&anonfs.vfsfs, nil, &MountOptions{})
+ if err != nil {
+ // We should not be passing any MountOptions that would cause
+ // construction of this mount to fail.
+ panic(fmt.Sprintf("VirtualFilesystem.Init: anonfs mount failed: %v", err))
}
+ vfs.anonMount = anonMount
+
return nil
}
@@ -385,15 +391,11 @@ func (vfs *VirtualFilesystem) OpenAt(ctx context.Context, creds *auth.Credential
// Only a regular file can be executed.
stat, err := fd.Stat(ctx, StatOptions{Mask: linux.STATX_TYPE})
if err != nil {
+ fd.DecRef()
return nil, err
}
- if stat.Mask&linux.STATX_TYPE != 0 {
- // This shouldn't happen, but if type can't be retrieved, file can't
- // be executed.
- return nil, syserror.EACCES
- }
- if t := linux.FileMode(stat.Mode).FileType(); t != linux.ModeRegular {
- ctx.Infof("%q is not a regular file: %v", pop.Path, t)
+ if stat.Mask&linux.STATX_TYPE == 0 || stat.Mode&linux.S_IFMT != linux.S_IFREG {
+ fd.DecRef()
return nil, syserror.EACCES
}
}
diff --git a/pkg/syserror/syserror.go b/pkg/syserror/syserror.go
index 2269f6237..4b5a0fca6 100644
--- a/pkg/syserror/syserror.go
+++ b/pkg/syserror/syserror.go
@@ -29,6 +29,7 @@ var (
EACCES = error(syscall.EACCES)
EAGAIN = error(syscall.EAGAIN)
EBADF = error(syscall.EBADF)
+ EBADFD = error(syscall.EBADFD)
EBUSY = error(syscall.EBUSY)
ECHILD = error(syscall.ECHILD)
ECONNREFUSED = error(syscall.ECONNREFUSED)
diff --git a/pkg/tcpip/adapters/gonet/gonet_test.go b/pkg/tcpip/adapters/gonet/gonet_test.go
index ea0a0409a..3c552988a 100644
--- a/pkg/tcpip/adapters/gonet/gonet_test.go
+++ b/pkg/tcpip/adapters/gonet/gonet_test.go
@@ -127,6 +127,10 @@ func TestCloseReader(t *testing.T) {
if err != nil {
t.Fatalf("newLoopbackStack() = %v", err)
}
+ defer func() {
+ s.Close()
+ s.Wait()
+ }()
addr := tcpip.FullAddress{NICID, tcpip.Address(net.IPv4(169, 254, 10, 1).To4()), 11211}
@@ -175,6 +179,10 @@ func TestCloseReaderWithForwarder(t *testing.T) {
if err != nil {
t.Fatalf("newLoopbackStack() = %v", err)
}
+ defer func() {
+ s.Close()
+ s.Wait()
+ }()
addr := tcpip.FullAddress{NICID, tcpip.Address(net.IPv4(169, 254, 10, 1).To4()), 11211}
s.AddAddress(NICID, ipv4.ProtocolNumber, addr.Addr)
@@ -225,30 +233,21 @@ func TestCloseRead(t *testing.T) {
if terr != nil {
t.Fatalf("newLoopbackStack() = %v", terr)
}
+ defer func() {
+ s.Close()
+ s.Wait()
+ }()
addr := tcpip.FullAddress{NICID, tcpip.Address(net.IPv4(169, 254, 10, 1).To4()), 11211}
s.AddAddress(NICID, ipv4.ProtocolNumber, addr.Addr)
fwd := tcp.NewForwarder(s, 30000, 10, func(r *tcp.ForwarderRequest) {
var wq waiter.Queue
- ep, err := r.CreateEndpoint(&wq)
+ _, err := r.CreateEndpoint(&wq)
if err != nil {
t.Fatalf("r.CreateEndpoint() = %v", err)
}
- defer ep.Close()
- r.Complete(false)
-
- c := NewTCPConn(&wq, ep)
-
- buf := make([]byte, 256)
- n, e := c.Read(buf)
- if e != nil || string(buf[:n]) != "abc123" {
- t.Fatalf("c.Read() = (%d, %v), want (6, nil)", n, e)
- }
-
- if n, e = c.Write([]byte("abc123")); e != nil {
- t.Errorf("c.Write() = (%d, %v), want (6, nil)", n, e)
- }
+ // Endpoint will be closed in deferred s.Close (above).
})
s.SetTransportProtocolHandler(tcp.ProtocolNumber, fwd.HandlePacket)
@@ -278,6 +277,10 @@ func TestCloseWrite(t *testing.T) {
if terr != nil {
t.Fatalf("newLoopbackStack() = %v", terr)
}
+ defer func() {
+ s.Close()
+ s.Wait()
+ }()
addr := tcpip.FullAddress{NICID, tcpip.Address(net.IPv4(169, 254, 10, 1).To4()), 11211}
s.AddAddress(NICID, ipv4.ProtocolNumber, addr.Addr)
@@ -334,6 +337,10 @@ func TestUDPForwarder(t *testing.T) {
if terr != nil {
t.Fatalf("newLoopbackStack() = %v", terr)
}
+ defer func() {
+ s.Close()
+ s.Wait()
+ }()
ip1 := tcpip.Address(net.IPv4(169, 254, 10, 1).To4())
addr1 := tcpip.FullAddress{NICID, ip1, 11211}
@@ -391,6 +398,10 @@ func TestDeadlineChange(t *testing.T) {
if err != nil {
t.Fatalf("newLoopbackStack() = %v", err)
}
+ defer func() {
+ s.Close()
+ s.Wait()
+ }()
addr := tcpip.FullAddress{NICID, tcpip.Address(net.IPv4(169, 254, 10, 1).To4()), 11211}
@@ -440,6 +451,10 @@ func TestPacketConnTransfer(t *testing.T) {
if e != nil {
t.Fatalf("newLoopbackStack() = %v", e)
}
+ defer func() {
+ s.Close()
+ s.Wait()
+ }()
ip1 := tcpip.Address(net.IPv4(169, 254, 10, 1).To4())
addr1 := tcpip.FullAddress{NICID, ip1, 11211}
@@ -492,6 +507,10 @@ func TestConnectedPacketConnTransfer(t *testing.T) {
if e != nil {
t.Fatalf("newLoopbackStack() = %v", e)
}
+ defer func() {
+ s.Close()
+ s.Wait()
+ }()
ip := tcpip.Address(net.IPv4(169, 254, 10, 1).To4())
addr := tcpip.FullAddress{NICID, ip, 11211}
@@ -562,6 +581,8 @@ func makePipe() (c1, c2 net.Conn, stop func(), err error) {
stop = func() {
c1.Close()
c2.Close()
+ s.Close()
+ s.Wait()
}
if err := l.Close(); err != nil {
@@ -624,6 +645,10 @@ func TestTCPDialError(t *testing.T) {
if e != nil {
t.Fatalf("newLoopbackStack() = %v", e)
}
+ defer func() {
+ s.Close()
+ s.Wait()
+ }()
ip := tcpip.Address(net.IPv4(169, 254, 10, 1).To4())
addr := tcpip.FullAddress{NICID, ip, 11211}
@@ -641,6 +666,10 @@ func TestDialContextTCPCanceled(t *testing.T) {
if err != nil {
t.Fatalf("newLoopbackStack() = %v", err)
}
+ defer func() {
+ s.Close()
+ s.Wait()
+ }()
addr := tcpip.FullAddress{NICID, tcpip.Address(net.IPv4(169, 254, 10, 1).To4()), 11211}
s.AddAddress(NICID, ipv4.ProtocolNumber, addr.Addr)
@@ -659,6 +688,10 @@ func TestDialContextTCPTimeout(t *testing.T) {
if err != nil {
t.Fatalf("newLoopbackStack() = %v", err)
}
+ defer func() {
+ s.Close()
+ s.Wait()
+ }()
addr := tcpip.FullAddress{NICID, tcpip.Address(net.IPv4(169, 254, 10, 1).To4()), 11211}
s.AddAddress(NICID, ipv4.ProtocolNumber, addr.Addr)
diff --git a/pkg/tcpip/buffer/view.go b/pkg/tcpip/buffer/view.go
index 150310c11..17e94c562 100644
--- a/pkg/tcpip/buffer/view.go
+++ b/pkg/tcpip/buffer/view.go
@@ -156,3 +156,9 @@ func (vv *VectorisedView) Append(vv2 VectorisedView) {
vv.views = append(vv.views, vv2.views...)
vv.size += vv2.size
}
+
+// AppendView appends the given view into this vectorised view.
+func (vv *VectorisedView) AppendView(v View) {
+ vv.views = append(vv.views, v)
+ vv.size += len(v)
+}
diff --git a/pkg/tcpip/header/ipv6.go b/pkg/tcpip/header/ipv6.go
index 70e6ce095..76e88e9b3 100644
--- a/pkg/tcpip/header/ipv6.go
+++ b/pkg/tcpip/header/ipv6.go
@@ -115,6 +115,19 @@ const (
// for the secret key used to generate an opaque interface identifier as
// outlined by RFC 7217.
OpaqueIIDSecretKeyMinBytes = 16
+
+ // ipv6MulticastAddressScopeByteIdx is the byte where the scope (scop) field
+ // is located within a multicast IPv6 address, as per RFC 4291 section 2.7.
+ ipv6MulticastAddressScopeByteIdx = 1
+
+ // ipv6MulticastAddressScopeMask is the mask for the scope (scop) field,
+ // within the byte holding the field, as per RFC 4291 section 2.7.
+ ipv6MulticastAddressScopeMask = 0xF
+
+ // ipv6LinkLocalMulticastScope is the value of the scope (scop) field within
+ // a multicast IPv6 address that indicates the address has link-local scope,
+ // as per RFC 4291 section 2.7.
+ ipv6LinkLocalMulticastScope = 2
)
// IPv6EmptySubnet is the empty IPv6 subnet. It may also be known as the
@@ -340,6 +353,12 @@ func IsV6LinkLocalAddress(addr tcpip.Address) bool {
return addr[0] == 0xfe && (addr[1]&0xc0) == 0x80
}
+// IsV6LinkLocalMulticastAddress determines if the provided address is an IPv6
+// link-local multicast address.
+func IsV6LinkLocalMulticastAddress(addr tcpip.Address) bool {
+ return IsV6MulticastAddress(addr) && addr[ipv6MulticastAddressScopeByteIdx]&ipv6MulticastAddressScopeMask == ipv6LinkLocalMulticastScope
+}
+
// IsV6UniqueLocalAddress determines if the provided address is an IPv6
// unique-local address (within the prefix FC00::/7).
func IsV6UniqueLocalAddress(addr tcpip.Address) bool {
@@ -411,6 +430,9 @@ func ScopeForIPv6Address(addr tcpip.Address) (IPv6AddressScope, *tcpip.Error) {
}
switch {
+ case IsV6LinkLocalMulticastAddress(addr):
+ return LinkLocalScope, nil
+
case IsV6LinkLocalAddress(addr):
return LinkLocalScope, nil
diff --git a/pkg/tcpip/header/ipv6_test.go b/pkg/tcpip/header/ipv6_test.go
index c3ad503aa..426a873b1 100644
--- a/pkg/tcpip/header/ipv6_test.go
+++ b/pkg/tcpip/header/ipv6_test.go
@@ -27,11 +27,12 @@ import (
)
const (
- linkAddr = tcpip.LinkAddress("\x02\x02\x03\x04\x05\x06")
- linkLocalAddr = tcpip.Address("\xfe\x80\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x01")
- uniqueLocalAddr1 = tcpip.Address("\xfc\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x01")
- uniqueLocalAddr2 = tcpip.Address("\xfd\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x02")
- globalAddr = tcpip.Address("\xa0\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x01")
+ linkAddr = tcpip.LinkAddress("\x02\x02\x03\x04\x05\x06")
+ linkLocalAddr = tcpip.Address("\xfe\x80\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x01")
+ linkLocalMulticastAddr = tcpip.Address("\xff\x02\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x01")
+ uniqueLocalAddr1 = tcpip.Address("\xfc\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x01")
+ uniqueLocalAddr2 = tcpip.Address("\xfd\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x02")
+ globalAddr = tcpip.Address("\xa0\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x01")
)
func TestEthernetAdddressToModifiedEUI64(t *testing.T) {
@@ -256,6 +257,85 @@ func TestIsV6UniqueLocalAddress(t *testing.T) {
}
}
+func TestIsV6LinkLocalMulticastAddress(t *testing.T) {
+ tests := []struct {
+ name string
+ addr tcpip.Address
+ expected bool
+ }{
+ {
+ name: "Valid Link Local Multicast",
+ addr: linkLocalMulticastAddr,
+ expected: true,
+ },
+ {
+ name: "Valid Link Local Multicast with flags",
+ addr: "\xff\xf2\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x01",
+ expected: true,
+ },
+ {
+ name: "Link Local Unicast",
+ addr: linkLocalAddr,
+ expected: false,
+ },
+ {
+ name: "IPv4 Multicast",
+ addr: "\xe0\x00\x00\x01",
+ expected: false,
+ },
+ }
+
+ for _, test := range tests {
+ t.Run(test.name, func(t *testing.T) {
+ if got := header.IsV6LinkLocalMulticastAddress(test.addr); got != test.expected {
+ t.Errorf("got header.IsV6LinkLocalMulticastAddress(%s) = %t, want = %t", test.addr, got, test.expected)
+ }
+ })
+ }
+}
+
+func TestIsV6LinkLocalAddress(t *testing.T) {
+ tests := []struct {
+ name string
+ addr tcpip.Address
+ expected bool
+ }{
+ {
+ name: "Valid Link Local Unicast",
+ addr: linkLocalAddr,
+ expected: true,
+ },
+ {
+ name: "Link Local Multicast",
+ addr: linkLocalMulticastAddr,
+ expected: false,
+ },
+ {
+ name: "Unique Local",
+ addr: uniqueLocalAddr1,
+ expected: false,
+ },
+ {
+ name: "Global",
+ addr: globalAddr,
+ expected: false,
+ },
+ {
+ name: "IPv4 Link Local",
+ addr: "\xa9\xfe\x00\x01",
+ expected: false,
+ },
+ }
+
+ for _, test := range tests {
+ t.Run(test.name, func(t *testing.T) {
+ if got := header.IsV6LinkLocalAddress(test.addr); got != test.expected {
+ t.Errorf("got header.IsV6LinkLocalAddress(%s) = %t, want = %t", test.addr, got, test.expected)
+ }
+ })
+ }
+}
+
func TestScopeForIPv6Address(t *testing.T) {
tests := []struct {
name string
@@ -270,12 +350,18 @@ func TestScopeForIPv6Address(t *testing.T) {
err: nil,
},
{
- name: "Link Local",
+ name: "Link Local Unicast",
addr: linkLocalAddr,
scope: header.LinkLocalScope,
err: nil,
},
{
+ name: "Link Local Multicast",
+ addr: linkLocalMulticastAddr,
+ scope: header.LinkLocalScope,
+ err: nil,
+ },
+ {
name: "Global",
addr: globalAddr,
scope: header.GlobalScope,
diff --git a/pkg/tcpip/link/channel/BUILD b/pkg/tcpip/link/channel/BUILD
index 3974c464e..b8b93e78e 100644
--- a/pkg/tcpip/link/channel/BUILD
+++ b/pkg/tcpip/link/channel/BUILD
@@ -7,6 +7,7 @@ go_library(
srcs = ["channel.go"],
visibility = ["//visibility:public"],
deps = [
+ "//pkg/sync",
"//pkg/tcpip",
"//pkg/tcpip/buffer",
"//pkg/tcpip/stack",
diff --git a/pkg/tcpip/link/channel/channel.go b/pkg/tcpip/link/channel/channel.go
index 78d447acd..5944ba190 100644
--- a/pkg/tcpip/link/channel/channel.go
+++ b/pkg/tcpip/link/channel/channel.go
@@ -20,6 +20,7 @@ package channel
import (
"context"
+ "gvisor.dev/gvisor/pkg/sync"
"gvisor.dev/gvisor/pkg/tcpip"
"gvisor.dev/gvisor/pkg/tcpip/buffer"
"gvisor.dev/gvisor/pkg/tcpip/stack"
@@ -33,6 +34,118 @@ type PacketInfo struct {
Route stack.Route
}
+// Notification is the interface for receiving notification from the packet
+// queue.
+type Notification interface {
+ // WriteNotify will be called when a write happens to the queue.
+ WriteNotify()
+}
+
+// NotificationHandle is an opaque handle to the registered notification target.
+// It can be used to unregister the notification when no longer interested.
+//
+// +stateify savable
+type NotificationHandle struct {
+ n Notification
+}
+
+type queue struct {
+ // mu protects fields below.
+ mu sync.RWMutex
+ // c is the outbound packet channel. Sending to c should hold mu.
+ c chan PacketInfo
+ numWrite int
+ numRead int
+ notify []*NotificationHandle
+}
+
+func (q *queue) Close() {
+ close(q.c)
+}
+
+func (q *queue) Read() (PacketInfo, bool) {
+ q.mu.Lock()
+ defer q.mu.Unlock()
+ select {
+ case p := <-q.c:
+ q.numRead++
+ return p, true
+ default:
+ return PacketInfo{}, false
+ }
+}
+
+func (q *queue) ReadContext(ctx context.Context) (PacketInfo, bool) {
+ // We have to receive from channel without holding the lock, since it can
+ // block indefinitely. This will cause a window that numWrite - numRead
+ // produces a larger number, but won't go to negative. numWrite >= numRead
+ // still holds.
+ select {
+ case pkt := <-q.c:
+ q.mu.Lock()
+ defer q.mu.Unlock()
+ q.numRead++
+ return pkt, true
+ case <-ctx.Done():
+ return PacketInfo{}, false
+ }
+}
+
+func (q *queue) Write(p PacketInfo) bool {
+ wrote := false
+
+ // It's important to make sure nobody can see numWrite until we increment it,
+ // so numWrite >= numRead holds.
+ q.mu.Lock()
+ select {
+ case q.c <- p:
+ wrote = true
+ q.numWrite++
+ default:
+ }
+ notify := q.notify
+ q.mu.Unlock()
+
+ if wrote {
+ // Send notification outside of lock.
+ for _, h := range notify {
+ h.n.WriteNotify()
+ }
+ }
+ return wrote
+}
+
+func (q *queue) Num() int {
+ q.mu.RLock()
+ defer q.mu.RUnlock()
+ n := q.numWrite - q.numRead
+ if n < 0 {
+ panic("numWrite < numRead")
+ }
+ return n
+}
+
+func (q *queue) AddNotify(notify Notification) *NotificationHandle {
+ q.mu.Lock()
+ defer q.mu.Unlock()
+ h := &NotificationHandle{n: notify}
+ q.notify = append(q.notify, h)
+ return h
+}
+
+func (q *queue) RemoveNotify(handle *NotificationHandle) {
+ q.mu.Lock()
+ defer q.mu.Unlock()
+ // Make a copy, since we reads the array outside of lock when notifying.
+ notify := make([]*NotificationHandle, 0, len(q.notify))
+ for _, h := range q.notify {
+ if h != handle {
+ notify = append(notify, h)
+ }
+ }
+ q.notify = notify
+}
+
// Endpoint is link layer endpoint that stores outbound packets in a channel
// and allows injection of inbound packets.
type Endpoint struct {
@@ -41,14 +154,16 @@ type Endpoint struct {
linkAddr tcpip.LinkAddress
LinkEPCapabilities stack.LinkEndpointCapabilities
- // c is where outbound packets are queued.
- c chan PacketInfo
+ // Outbound packet queue.
+ q *queue
}
// New creates a new channel endpoint.
func New(size int, mtu uint32, linkAddr tcpip.LinkAddress) *Endpoint {
return &Endpoint{
- c: make(chan PacketInfo, size),
+ q: &queue{
+ c: make(chan PacketInfo, size),
+ },
mtu: mtu,
linkAddr: linkAddr,
}
@@ -57,43 +172,36 @@ func New(size int, mtu uint32, linkAddr tcpip.LinkAddress) *Endpoint {
// Close closes e. Further packet injections will panic. Reads continue to
// succeed until all packets are read.
func (e *Endpoint) Close() {
- close(e.c)
+ e.q.Close()
}
-// Read does non-blocking read for one packet from the outbound packet queue.
+// Read does non-blocking read one packet from the outbound packet queue.
func (e *Endpoint) Read() (PacketInfo, bool) {
- select {
- case pkt := <-e.c:
- return pkt, true
- default:
- return PacketInfo{}, false
- }
+ return e.q.Read()
}
// ReadContext does blocking read for one packet from the outbound packet queue.
// It can be cancelled by ctx, and in this case, it returns false.
func (e *Endpoint) ReadContext(ctx context.Context) (PacketInfo, bool) {
- select {
- case pkt := <-e.c:
- return pkt, true
- case <-ctx.Done():
- return PacketInfo{}, false
- }
+ return e.q.ReadContext(ctx)
}
// Drain removes all outbound packets from the channel and counts them.
func (e *Endpoint) Drain() int {
c := 0
for {
- select {
- case <-e.c:
- c++
- default:
+ if _, ok := e.Read(); !ok {
return c
}
+ c++
}
}
+// NumQueued returns the number of packet queued for outbound.
+func (e *Endpoint) NumQueued() int {
+ return e.q.Num()
+}
+
// InjectInbound injects an inbound packet.
func (e *Endpoint) InjectInbound(protocol tcpip.NetworkProtocolNumber, pkt tcpip.PacketBuffer) {
e.InjectLinkAddr(protocol, "", pkt)
@@ -155,10 +263,7 @@ func (e *Endpoint) WritePacket(r *stack.Route, gso *stack.GSO, protocol tcpip.Ne
Route: route,
}
- select {
- case e.c <- p:
- default:
- }
+ e.q.Write(p)
return nil
}
@@ -171,7 +276,6 @@ func (e *Endpoint) WritePackets(r *stack.Route, gso *stack.GSO, pkts []tcpip.Pac
route.Release()
payloadView := pkts[0].Data.ToView()
n := 0
-packetLoop:
for _, pkt := range pkts {
off := pkt.DataOffset
size := pkt.DataSize
@@ -185,12 +289,10 @@ packetLoop:
Route: route,
}
- select {
- case e.c <- p:
- n++
- default:
- break packetLoop
+ if !e.q.Write(p) {
+ break
}
+ n++
}
return n, nil
@@ -204,13 +306,21 @@ func (e *Endpoint) WriteRawPacket(vv buffer.VectorisedView) *tcpip.Error {
GSO: nil,
}
- select {
- case e.c <- p:
- default:
- }
+ e.q.Write(p)
return nil
}
// Wait implements stack.LinkEndpoint.Wait.
func (*Endpoint) Wait() {}
+
+// AddNotify adds a notification target for receiving event about outgoing
+// packets.
+func (e *Endpoint) AddNotify(notify Notification) *NotificationHandle {
+ return e.q.AddNotify(notify)
+}
+
+// RemoveNotify removes handle from the list of notification targets.
+func (e *Endpoint) RemoveNotify(handle *NotificationHandle) {
+ e.q.RemoveNotify(handle)
+}
diff --git a/pkg/tcpip/link/tun/BUILD b/pkg/tcpip/link/tun/BUILD
index e5096ea38..e0db6cf54 100644
--- a/pkg/tcpip/link/tun/BUILD
+++ b/pkg/tcpip/link/tun/BUILD
@@ -4,6 +4,22 @@ package(licenses = ["notice"])
go_library(
name = "tun",
- srcs = ["tun_unsafe.go"],
+ srcs = [
+ "device.go",
+ "protocol.go",
+ "tun_unsafe.go",
+ ],
visibility = ["//visibility:public"],
+ deps = [
+ "//pkg/abi/linux",
+ "//pkg/refs",
+ "//pkg/sync",
+ "//pkg/syserror",
+ "//pkg/tcpip",
+ "//pkg/tcpip/buffer",
+ "//pkg/tcpip/header",
+ "//pkg/tcpip/link/channel",
+ "//pkg/tcpip/stack",
+ "//pkg/waiter",
+ ],
)
diff --git a/pkg/tcpip/link/tun/device.go b/pkg/tcpip/link/tun/device.go
new file mode 100644
index 000000000..6ff47a742
--- /dev/null
+++ b/pkg/tcpip/link/tun/device.go
@@ -0,0 +1,352 @@
+// Copyright 2020 The gVisor Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+package tun
+
+import (
+ "fmt"
+
+ "gvisor.dev/gvisor/pkg/abi/linux"
+ "gvisor.dev/gvisor/pkg/refs"
+ "gvisor.dev/gvisor/pkg/sync"
+ "gvisor.dev/gvisor/pkg/syserror"
+ "gvisor.dev/gvisor/pkg/tcpip"
+ "gvisor.dev/gvisor/pkg/tcpip/buffer"
+ "gvisor.dev/gvisor/pkg/tcpip/header"
+ "gvisor.dev/gvisor/pkg/tcpip/link/channel"
+ "gvisor.dev/gvisor/pkg/tcpip/stack"
+ "gvisor.dev/gvisor/pkg/waiter"
+)
+
+const (
+ // drivers/net/tun.c:tun_net_init()
+ defaultDevMtu = 1500
+
+ // Queue length for outbound packet, arriving at fd side for read. Overflow
+ // causes packet drops. gVisor implementation-specific.
+ defaultDevOutQueueLen = 1024
+)
+
+var zeroMAC [6]byte
+
+// Device is an opened /dev/net/tun device.
+//
+// +stateify savable
+type Device struct {
+ waiter.Queue
+
+ mu sync.RWMutex `state:"nosave"`
+ endpoint *tunEndpoint
+ notifyHandle *channel.NotificationHandle
+ flags uint16
+}
+
+// beforeSave is invoked by stateify.
+func (d *Device) beforeSave() {
+ d.mu.Lock()
+ defer d.mu.Unlock()
+ // TODO(b/110961832): Restore the device to stack. At this moment, the stack
+ // is not savable.
+ if d.endpoint != nil {
+ panic("/dev/net/tun does not support save/restore when a device is associated with it.")
+ }
+}
+
+// Release implements fs.FileOperations.Release.
+func (d *Device) Release() {
+ d.mu.Lock()
+ defer d.mu.Unlock()
+
+ // Decrease refcount if there is an endpoint associated with this file.
+ if d.endpoint != nil {
+ d.endpoint.RemoveNotify(d.notifyHandle)
+ d.endpoint.DecRef()
+ d.endpoint = nil
+ }
+}
+
+// SetIff services TUNSETIFF ioctl(2) request.
+func (d *Device) SetIff(s *stack.Stack, name string, flags uint16) error {
+ d.mu.Lock()
+ defer d.mu.Unlock()
+
+ if d.endpoint != nil {
+ return syserror.EINVAL
+ }
+
+ // Input validations.
+ isTun := flags&linux.IFF_TUN != 0
+ isTap := flags&linux.IFF_TAP != 0
+ supportedFlags := uint16(linux.IFF_TUN | linux.IFF_TAP | linux.IFF_NO_PI)
+ if isTap && isTun || !isTap && !isTun || flags&^supportedFlags != 0 {
+ return syserror.EINVAL
+ }
+
+ prefix := "tun"
+ if isTap {
+ prefix = "tap"
+ }
+
+ endpoint, err := attachOrCreateNIC(s, name, prefix)
+ if err != nil {
+ return syserror.EINVAL
+ }
+
+ d.endpoint = endpoint
+ d.notifyHandle = d.endpoint.AddNotify(d)
+ d.flags = flags
+ return nil
+}
+
+func attachOrCreateNIC(s *stack.Stack, name, prefix string) (*tunEndpoint, error) {
+ for {
+ // 1. Try to attach to an existing NIC.
+ if name != "" {
+ if nic, found := s.GetNICByName(name); found {
+ endpoint, ok := nic.LinkEndpoint().(*tunEndpoint)
+ if !ok {
+ // Not a NIC created by tun device.
+ return nil, syserror.EOPNOTSUPP
+ }
+ if !endpoint.TryIncRef() {
+ // Race detected: NIC got deleted in between.
+ continue
+ }
+ return endpoint, nil
+ }
+ }
+
+ // 2. Creating a new NIC.
+ id := tcpip.NICID(s.UniqueID())
+ endpoint := &tunEndpoint{
+ Endpoint: channel.New(defaultDevOutQueueLen, defaultDevMtu, ""),
+ stack: s,
+ nicID: id,
+ name: name,
+ }
+ if endpoint.name == "" {
+ endpoint.name = fmt.Sprintf("%s%d", prefix, id)
+ }
+ err := s.CreateNICWithOptions(endpoint.nicID, endpoint, stack.NICOptions{
+ Name: endpoint.name,
+ })
+ switch err {
+ case nil:
+ return endpoint, nil
+ case tcpip.ErrDuplicateNICID:
+ // Race detected: A NIC has been created in between.
+ continue
+ default:
+ return nil, syserror.EINVAL
+ }
+ }
+}
+
+// Write inject one inbound packet to the network interface.
+func (d *Device) Write(data []byte) (int64, error) {
+ d.mu.RLock()
+ endpoint := d.endpoint
+ d.mu.RUnlock()
+ if endpoint == nil {
+ return 0, syserror.EBADFD
+ }
+ if !endpoint.IsAttached() {
+ return 0, syserror.EIO
+ }
+
+ dataLen := int64(len(data))
+
+ // Packet information.
+ var pktInfoHdr PacketInfoHeader
+ if !d.hasFlags(linux.IFF_NO_PI) {
+ if len(data) < PacketInfoHeaderSize {
+ // Ignore bad packet.
+ return dataLen, nil
+ }
+ pktInfoHdr = PacketInfoHeader(data[:PacketInfoHeaderSize])
+ data = data[PacketInfoHeaderSize:]
+ }
+
+ // Ethernet header (TAP only).
+ var ethHdr header.Ethernet
+ if d.hasFlags(linux.IFF_TAP) {
+ if len(data) < header.EthernetMinimumSize {
+ // Ignore bad packet.
+ return dataLen, nil
+ }
+ ethHdr = header.Ethernet(data[:header.EthernetMinimumSize])
+ data = data[header.EthernetMinimumSize:]
+ }
+
+ // Try to determine network protocol number, default zero.
+ var protocol tcpip.NetworkProtocolNumber
+ switch {
+ case pktInfoHdr != nil:
+ protocol = pktInfoHdr.Protocol()
+ case ethHdr != nil:
+ protocol = ethHdr.Type()
+ }
+
+ // Try to determine remote link address, default zero.
+ var remote tcpip.LinkAddress
+ switch {
+ case ethHdr != nil:
+ remote = ethHdr.SourceAddress()
+ default:
+ remote = tcpip.LinkAddress(zeroMAC[:])
+ }
+
+ pkt := tcpip.PacketBuffer{
+ Data: buffer.View(data).ToVectorisedView(),
+ }
+ if ethHdr != nil {
+ pkt.LinkHeader = buffer.View(ethHdr)
+ }
+ endpoint.InjectLinkAddr(protocol, remote, pkt)
+ return dataLen, nil
+}
+
+// Read reads one outgoing packet from the network interface.
+func (d *Device) Read() ([]byte, error) {
+ d.mu.RLock()
+ endpoint := d.endpoint
+ d.mu.RUnlock()
+ if endpoint == nil {
+ return nil, syserror.EBADFD
+ }
+
+ for {
+ info, ok := endpoint.Read()
+ if !ok {
+ return nil, syserror.ErrWouldBlock
+ }
+
+ v, ok := d.encodePkt(&info)
+ if !ok {
+ // Ignore unsupported packet.
+ continue
+ }
+ return v, nil
+ }
+}
+
+// encodePkt encodes packet for fd side.
+func (d *Device) encodePkt(info *channel.PacketInfo) (buffer.View, bool) {
+ var vv buffer.VectorisedView
+
+ // Packet information.
+ if !d.hasFlags(linux.IFF_NO_PI) {
+ hdr := make(PacketInfoHeader, PacketInfoHeaderSize)
+ hdr.Encode(&PacketInfoFields{
+ Protocol: info.Proto,
+ })
+ vv.AppendView(buffer.View(hdr))
+ }
+
+ // If the packet does not already have link layer header, and the route
+ // does not exist, we can't compute it. This is possibly a raw packet, tun
+ // device doesn't support this at the moment.
+ if info.Pkt.LinkHeader == nil && info.Route.RemoteLinkAddress == "" {
+ return nil, false
+ }
+
+ // Ethernet header (TAP only).
+ if d.hasFlags(linux.IFF_TAP) {
+ // Add ethernet header if not provided.
+ if info.Pkt.LinkHeader == nil {
+ hdr := &header.EthernetFields{
+ SrcAddr: info.Route.LocalLinkAddress,
+ DstAddr: info.Route.RemoteLinkAddress,
+ Type: info.Proto,
+ }
+ if hdr.SrcAddr == "" {
+ hdr.SrcAddr = d.endpoint.LinkAddress()
+ }
+
+ eth := make(header.Ethernet, header.EthernetMinimumSize)
+ eth.Encode(hdr)
+ vv.AppendView(buffer.View(eth))
+ } else {
+ vv.AppendView(info.Pkt.LinkHeader)
+ }
+ }
+
+ // Append upper headers.
+ vv.AppendView(buffer.View(info.Pkt.Header.View()[len(info.Pkt.LinkHeader):]))
+ // Append data payload.
+ vv.Append(info.Pkt.Data)
+
+ return vv.ToView(), true
+}
+
+// Name returns the name of the attached network interface. Empty string if
+// unattached.
+func (d *Device) Name() string {
+ d.mu.RLock()
+ defer d.mu.RUnlock()
+ if d.endpoint != nil {
+ return d.endpoint.name
+ }
+ return ""
+}
+
+// Flags returns the flags set for d. Zero value if unset.
+func (d *Device) Flags() uint16 {
+ d.mu.RLock()
+ defer d.mu.RUnlock()
+ return d.flags
+}
+
+func (d *Device) hasFlags(flags uint16) bool {
+ return d.flags&flags == flags
+}
+
+// Readiness implements watier.Waitable.Readiness.
+func (d *Device) Readiness(mask waiter.EventMask) waiter.EventMask {
+ if mask&waiter.EventIn != 0 {
+ d.mu.RLock()
+ endpoint := d.endpoint
+ d.mu.RUnlock()
+ if endpoint != nil && endpoint.NumQueued() == 0 {
+ mask &= ^waiter.EventIn
+ }
+ }
+ return mask & (waiter.EventIn | waiter.EventOut)
+}
+
+// WriteNotify implements channel.Notification.WriteNotify.
+func (d *Device) WriteNotify() {
+ d.Notify(waiter.EventIn)
+}
+
+// tunEndpoint is the link endpoint for the NIC created by the tun device.
+//
+// It is ref-counted as multiple opening files can attach to the same NIC.
+// The last owner is responsible for deleting the NIC.
+type tunEndpoint struct {
+ *channel.Endpoint
+
+ refs.AtomicRefCount
+
+ stack *stack.Stack
+ nicID tcpip.NICID
+ name string
+}
+
+// DecRef decrements refcount of e, removes NIC if refcount goes to 0.
+func (e *tunEndpoint) DecRef() {
+ e.DecRefWithDestructor(func() {
+ e.stack.RemoveNIC(e.nicID)
+ })
+}
diff --git a/pkg/tcpip/link/tun/protocol.go b/pkg/tcpip/link/tun/protocol.go
new file mode 100644
index 000000000..89d9d91a9
--- /dev/null
+++ b/pkg/tcpip/link/tun/protocol.go
@@ -0,0 +1,56 @@
+// Copyright 2020 The gVisor Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+package tun
+
+import (
+ "encoding/binary"
+
+ "gvisor.dev/gvisor/pkg/tcpip"
+)
+
+const (
+ // PacketInfoHeaderSize is the size of the packet information header.
+ PacketInfoHeaderSize = 4
+
+ offsetFlags = 0
+ offsetProtocol = 2
+)
+
+// PacketInfoFields contains fields sent through the wire if IFF_NO_PI flag is
+// not set.
+type PacketInfoFields struct {
+ Flags uint16
+ Protocol tcpip.NetworkProtocolNumber
+}
+
+// PacketInfoHeader is the wire representation of the packet information sent if
+// IFF_NO_PI flag is not set.
+type PacketInfoHeader []byte
+
+// Encode encodes f into h.
+func (h PacketInfoHeader) Encode(f *PacketInfoFields) {
+ binary.BigEndian.PutUint16(h[offsetFlags:][:2], f.Flags)
+ binary.BigEndian.PutUint16(h[offsetProtocol:][:2], uint16(f.Protocol))
+}
+
+// Flags returns the flag field in h.
+func (h PacketInfoHeader) Flags() uint16 {
+ return binary.BigEndian.Uint16(h[offsetFlags:])
+}
+
+// Protocol returns the protocol field in h.
+func (h PacketInfoHeader) Protocol() tcpip.NetworkProtocolNumber {
+ return tcpip.NetworkProtocolNumber(binary.BigEndian.Uint16(h[offsetProtocol:]))
+}
diff --git a/pkg/tcpip/network/arp/arp.go b/pkg/tcpip/network/arp/arp.go
index 4da13c5df..e9fcc89a8 100644
--- a/pkg/tcpip/network/arp/arp.go
+++ b/pkg/tcpip/network/arp/arp.go
@@ -148,12 +148,12 @@ func (p *protocol) NewEndpoint(nicID tcpip.NICID, addrWithPrefix tcpip.AddressWi
}, nil
}
-// LinkAddressProtocol implements stack.LinkAddressResolver.
+// LinkAddressProtocol implements stack.LinkAddressResolver.LinkAddressProtocol.
func (*protocol) LinkAddressProtocol() tcpip.NetworkProtocolNumber {
return header.IPv4ProtocolNumber
}
-// LinkAddressRequest implements stack.LinkAddressResolver.
+// LinkAddressRequest implements stack.LinkAddressResolver.LinkAddressRequest.
func (*protocol) LinkAddressRequest(addr, localAddr tcpip.Address, linkEP stack.LinkEndpoint) *tcpip.Error {
r := &stack.Route{
RemoteLinkAddress: broadcastMAC,
@@ -172,7 +172,7 @@ func (*protocol) LinkAddressRequest(addr, localAddr tcpip.Address, linkEP stack.
})
}
-// ResolveStaticAddress implements stack.LinkAddressResolver.
+// ResolveStaticAddress implements stack.LinkAddressResolver.ResolveStaticAddress.
func (*protocol) ResolveStaticAddress(addr tcpip.Address) (tcpip.LinkAddress, bool) {
if addr == header.IPv4Broadcast {
return broadcastMAC, true
@@ -183,16 +183,22 @@ func (*protocol) ResolveStaticAddress(addr tcpip.Address) (tcpip.LinkAddress, bo
return tcpip.LinkAddress([]byte(nil)), false
}
-// SetOption implements NetworkProtocol.
-func (p *protocol) SetOption(option interface{}) *tcpip.Error {
+// SetOption implements stack.NetworkProtocol.SetOption.
+func (*protocol) SetOption(option interface{}) *tcpip.Error {
return tcpip.ErrUnknownProtocolOption
}
-// Option implements NetworkProtocol.
-func (p *protocol) Option(option interface{}) *tcpip.Error {
+// Option implements stack.NetworkProtocol.Option.
+func (*protocol) Option(option interface{}) *tcpip.Error {
return tcpip.ErrUnknownProtocolOption
}
+// Close implements stack.TransportProtocol.Close.
+func (*protocol) Close() {}
+
+// Wait implements stack.TransportProtocol.Wait.
+func (*protocol) Wait() {}
+
var broadcastMAC = tcpip.LinkAddress([]byte{0xff, 0xff, 0xff, 0xff, 0xff, 0xff})
// NewProtocol returns an ARP network protocol.
diff --git a/pkg/tcpip/network/ipv4/ipv4.go b/pkg/tcpip/network/ipv4/ipv4.go
index 6597e6781..4f1742938 100644
--- a/pkg/tcpip/network/ipv4/ipv4.go
+++ b/pkg/tcpip/network/ipv4/ipv4.go
@@ -473,6 +473,12 @@ func (p *protocol) DefaultTTL() uint8 {
return uint8(atomic.LoadUint32(&p.defaultTTL))
}
+// Close implements stack.TransportProtocol.Close.
+func (*protocol) Close() {}
+
+// Wait implements stack.TransportProtocol.Wait.
+func (*protocol) Wait() {}
+
// calculateMTU calculates the network-layer payload MTU based on the link-layer
// payload mtu.
func calculateMTU(mtu uint32) uint32 {
diff --git a/pkg/tcpip/network/ipv6/ipv6.go b/pkg/tcpip/network/ipv6/ipv6.go
index 180a480fd..9aef5234b 100644
--- a/pkg/tcpip/network/ipv6/ipv6.go
+++ b/pkg/tcpip/network/ipv6/ipv6.go
@@ -265,6 +265,12 @@ func (p *protocol) DefaultTTL() uint8 {
return uint8(atomic.LoadUint32(&p.defaultTTL))
}
+// Close implements stack.TransportProtocol.Close.
+func (*protocol) Close() {}
+
+// Wait implements stack.TransportProtocol.Wait.
+func (*protocol) Wait() {}
+
// calculateMTU calculates the network-layer payload MTU based on the link-layer
// payload mtu.
func calculateMTU(mtu uint32) uint32 {
diff --git a/pkg/tcpip/stack/ndp.go b/pkg/tcpip/stack/ndp.go
index 045409bda..f651871ce 100644
--- a/pkg/tcpip/stack/ndp.go
+++ b/pkg/tcpip/stack/ndp.go
@@ -1148,22 +1148,27 @@ func (ndp *ndpState) cleanupAutoGenAddrResourcesAndNotify(addr tcpip.Address) bo
return true
}
-// cleanupHostOnlyState cleans up any state that is only useful for hosts.
+// cleanupState cleans up ndp's state.
//
-// cleanupHostOnlyState MUST be called when ndp's NIC is transitioning from a
-// host to a router. This function will invalidate all discovered on-link
-// prefixes, discovered routers, and auto-generated addresses as routers do not
-// normally process Router Advertisements to discover default routers and
-// on-link prefixes, and auto-generate addresses via SLAAC.
+// If hostOnly is true, then only host-specific state will be cleaned up.
+//
+// cleanupState MUST be called with hostOnly set to true when ndp's NIC is
+// transitioning from a host to a router. This function will invalidate all
+// discovered on-link prefixes, discovered routers, and auto-generated
+// addresses.
+//
+// If hostOnly is true, then the link-local auto-generated address will not be
+// invalidated as routers are also expected to generate a link-local address.
//
// The NIC that ndp belongs to MUST be locked.
-func (ndp *ndpState) cleanupHostOnlyState() {
+func (ndp *ndpState) cleanupState(hostOnly bool) {
linkLocalSubnet := header.IPv6LinkLocalPrefix.Subnet()
linkLocalAddrs := 0
for addr := range ndp.autoGenAddresses {
// RFC 4862 section 5 states that routers are also expected to generate a
- // link-local address so we do not invalidate them.
- if linkLocalSubnet.Contains(addr) {
+ // link-local address so we do not invalidate them if we are cleaning up
+ // host-only state.
+ if hostOnly && linkLocalSubnet.Contains(addr) {
linkLocalAddrs++
continue
}
@@ -1230,7 +1235,7 @@ func (ndp *ndpState) startSolicitingRouters() {
}
payloadSize := header.ICMPv6HeaderSize + header.NDPRSMinimumSize
- hdr := buffer.NewPrependable(header.IPv6MinimumSize + payloadSize)
+ hdr := buffer.NewPrependable(int(r.MaxHeaderLength()) + payloadSize)
pkt := header.ICMPv6(hdr.Prepend(payloadSize))
pkt.SetType(header.ICMPv6RouterSolicit)
pkt.SetChecksum(header.ICMPv6Checksum(pkt, r.LocalAddress, r.RemoteAddress, buffer.VectorisedView{}))
diff --git a/pkg/tcpip/stack/ndp_test.go b/pkg/tcpip/stack/ndp_test.go
index 1f6f77439..6e9306d09 100644
--- a/pkg/tcpip/stack/ndp_test.go
+++ b/pkg/tcpip/stack/ndp_test.go
@@ -267,6 +267,17 @@ func (n *ndpDispatcher) OnDHCPv6Configuration(nicID tcpip.NICID, configuration s
}
}
+// channelLinkWithHeaderLength is a channel.Endpoint with a configurable
+// header length.
+type channelLinkWithHeaderLength struct {
+ *channel.Endpoint
+ headerLength uint16
+}
+
+func (l *channelLinkWithHeaderLength) MaxHeaderLength() uint16 {
+ return l.headerLength
+}
+
// Check e to make sure that the event is for addr on nic with ID 1, and the
// resolved flag set to resolved with the specified err.
func checkDADEvent(e ndpDADEvent, nicID tcpip.NICID, addr tcpip.Address, resolved bool, err *tcpip.Error) string {
@@ -323,21 +334,46 @@ func TestDADDisabled(t *testing.T) {
// DAD for various values of DupAddrDetectTransmits and RetransmitTimer.
// Included in the subtests is a test to make sure that an invalid
// RetransmitTimer (<1ms) values get fixed to the default RetransmitTimer of 1s.
+// This tests also validates the NDP NS packet that is transmitted.
func TestDADResolve(t *testing.T) {
const nicID = 1
tests := []struct {
name string
+ linkHeaderLen uint16
dupAddrDetectTransmits uint8
retransTimer time.Duration
expectedRetransmitTimer time.Duration
}{
- {"1:1s:1s", 1, time.Second, time.Second},
- {"2:1s:1s", 2, time.Second, time.Second},
- {"1:2s:2s", 1, 2 * time.Second, 2 * time.Second},
+ {
+ name: "1:1s:1s",
+ dupAddrDetectTransmits: 1,
+ retransTimer: time.Second,
+ expectedRetransmitTimer: time.Second,
+ },
+ {
+ name: "2:1s:1s",
+ linkHeaderLen: 1,
+ dupAddrDetectTransmits: 2,
+ retransTimer: time.Second,
+ expectedRetransmitTimer: time.Second,
+ },
+ {
+ name: "1:2s:2s",
+ linkHeaderLen: 2,
+ dupAddrDetectTransmits: 1,
+ retransTimer: 2 * time.Second,
+ expectedRetransmitTimer: 2 * time.Second,
+ },
// 0s is an invalid RetransmitTimer timer and will be fixed to
// the default RetransmitTimer value of 1s.
- {"1:0s:1s", 1, 0, time.Second},
+ {
+ name: "1:0s:1s",
+ linkHeaderLen: 3,
+ dupAddrDetectTransmits: 1,
+ retransTimer: 0,
+ expectedRetransmitTimer: time.Second,
+ },
}
for _, test := range tests {
@@ -356,10 +392,13 @@ func TestDADResolve(t *testing.T) {
opts.NDPConfigs.RetransmitTimer = test.retransTimer
opts.NDPConfigs.DupAddrDetectTransmits = test.dupAddrDetectTransmits
- e := channel.New(int(test.dupAddrDetectTransmits), 1280, linkAddr1)
- e.LinkEPCapabilities |= stack.CapabilityResolutionRequired
+ e := channelLinkWithHeaderLength{
+ Endpoint: channel.New(int(test.dupAddrDetectTransmits), 1280, linkAddr1),
+ headerLength: test.linkHeaderLen,
+ }
+ e.Endpoint.LinkEPCapabilities |= stack.CapabilityResolutionRequired
s := stack.New(opts)
- if err := s.CreateNIC(nicID, e); err != nil {
+ if err := s.CreateNIC(nicID, &e); err != nil {
t.Fatalf("CreateNIC(%d, _) = %s", nicID, err)
}
@@ -445,6 +484,10 @@ func TestDADResolve(t *testing.T) {
checker.NDPNSTargetAddress(addr1),
checker.NDPNSOptions(nil),
))
+
+ if l, want := p.Pkt.Header.AvailableLength(), int(test.linkHeaderLen); l != want {
+ t.Errorf("got p.Pkt.Header.AvailableLength() = %d; want = %d", l, want)
+ }
}
})
}
@@ -592,70 +635,94 @@ func TestDADFail(t *testing.T) {
}
}
-// TestDADStop tests to make sure that the DAD process stops when an address is
-// removed.
func TestDADStop(t *testing.T) {
const nicID = 1
- ndpDisp := ndpDispatcher{
- dadC: make(chan ndpDADEvent, 1),
- }
- ndpConfigs := stack.NDPConfigurations{
- RetransmitTimer: time.Second,
- DupAddrDetectTransmits: 2,
- }
- opts := stack.Options{
- NetworkProtocols: []stack.NetworkProtocol{ipv6.NewProtocol()},
- NDPDisp: &ndpDisp,
- NDPConfigs: ndpConfigs,
- }
+ tests := []struct {
+ name string
+ stopFn func(t *testing.T, s *stack.Stack)
+ }{
+ // Tests to make sure that DAD stops when an address is removed.
+ {
+ name: "Remove address",
+ stopFn: func(t *testing.T, s *stack.Stack) {
+ if err := s.RemoveAddress(nicID, addr1); err != nil {
+ t.Fatalf("RemoveAddress(%d, %s): %s", nicID, addr1, err)
+ }
+ },
+ },
- e := channel.New(0, 1280, linkAddr1)
- s := stack.New(opts)
- if err := s.CreateNIC(nicID, e); err != nil {
- t.Fatalf("CreateNIC(%d, _) = %s", nicID, err)
+ // Tests to make sure that DAD stops when the NIC is disabled.
+ {
+ name: "Disable NIC",
+ stopFn: func(t *testing.T, s *stack.Stack) {
+ if err := s.DisableNIC(nicID); err != nil {
+ t.Fatalf("DisableNIC(%d): %s", nicID, err)
+ }
+ },
+ },
}
- if err := s.AddAddress(nicID, header.IPv6ProtocolNumber, addr1); err != nil {
- t.Fatalf("AddAddress(%d, %d, %s) = %s", nicID, header.IPv6ProtocolNumber, addr1, err)
- }
+ for _, test := range tests {
+ t.Run(test.name, func(t *testing.T) {
+ ndpDisp := ndpDispatcher{
+ dadC: make(chan ndpDADEvent, 1),
+ }
+ ndpConfigs := stack.NDPConfigurations{
+ RetransmitTimer: time.Second,
+ DupAddrDetectTransmits: 2,
+ }
+ opts := stack.Options{
+ NetworkProtocols: []stack.NetworkProtocol{ipv6.NewProtocol()},
+ NDPDisp: &ndpDisp,
+ NDPConfigs: ndpConfigs,
+ }
- // Address should not be considered bound to the NIC yet (DAD ongoing).
- addr, err := s.GetMainNICAddress(nicID, header.IPv6ProtocolNumber)
- if err != nil {
- t.Fatalf("got stack.GetMainNICAddress(%d, %d) = (_, %v), want = (_, nil)", nicID, header.IPv6ProtocolNumber, err)
- }
- if want := (tcpip.AddressWithPrefix{}); addr != want {
- t.Fatalf("got stack.GetMainNICAddress(%d, %d) = (%s, nil), want = (%s, nil)", nicID, header.IPv6ProtocolNumber, addr, want)
- }
+ e := channel.New(0, 1280, linkAddr1)
+ s := stack.New(opts)
+ if err := s.CreateNIC(nicID, e); err != nil {
+ t.Fatalf("CreateNIC(%d, _): %s", nicID, err)
+ }
- // Remove the address. This should stop DAD.
- if err := s.RemoveAddress(nicID, addr1); err != nil {
- t.Fatalf("RemoveAddress(%d, %s) = %s", nicID, addr1, err)
- }
+ if err := s.AddAddress(nicID, header.IPv6ProtocolNumber, addr1); err != nil {
+ t.Fatalf("AddAddress(%d, %d, %s): %s", nicID, header.IPv6ProtocolNumber, addr1, err)
+ }
- // Wait for DAD to fail (since the address was removed during DAD).
- select {
- case <-time.After(time.Duration(ndpConfigs.DupAddrDetectTransmits)*ndpConfigs.RetransmitTimer + time.Second):
- // If we don't get a failure event after the expected resolution
- // time + extra 1s buffer, something is wrong.
- t.Fatal("timed out waiting for DAD failure")
- case e := <-ndpDisp.dadC:
- if diff := checkDADEvent(e, nicID, addr1, false, nil); diff != "" {
- t.Errorf("dad event mismatch (-want +got):\n%s", diff)
- }
- }
- addr, err = s.GetMainNICAddress(nicID, header.IPv6ProtocolNumber)
- if err != nil {
- t.Fatalf("got stack.GetMainNICAddress(%d, %d) = (_, %v), want = (_, nil)", nicID, header.IPv6ProtocolNumber, err)
- }
- if want := (tcpip.AddressWithPrefix{}); addr != want {
- t.Fatalf("got stack.GetMainNICAddress(%d, %d) = (%s, nil), want = (%s, nil)", nicID, header.IPv6ProtocolNumber, addr, want)
- }
+ // Address should not be considered bound to the NIC yet (DAD ongoing).
+ addr, err := s.GetMainNICAddress(nicID, header.IPv6ProtocolNumber)
+ if err != nil {
+ t.Fatalf("got stack.GetMainNICAddress(%d, %d) = (_, %v), want = (_, nil)", nicID, header.IPv6ProtocolNumber, err)
+ }
+ if want := (tcpip.AddressWithPrefix{}); addr != want {
+ t.Fatalf("got stack.GetMainNICAddress(%d, %d) = (%s, nil), want = (%s, nil)", nicID, header.IPv6ProtocolNumber, addr, want)
+ }
+
+ test.stopFn(t, s)
+
+ // Wait for DAD to fail (since the address was removed during DAD).
+ select {
+ case <-time.After(time.Duration(ndpConfigs.DupAddrDetectTransmits)*ndpConfigs.RetransmitTimer + time.Second):
+ // If we don't get a failure event after the expected resolution
+ // time + extra 1s buffer, something is wrong.
+ t.Fatal("timed out waiting for DAD failure")
+ case e := <-ndpDisp.dadC:
+ if diff := checkDADEvent(e, nicID, addr1, false, nil); diff != "" {
+ t.Errorf("dad event mismatch (-want +got):\n%s", diff)
+ }
+ }
+ addr, err = s.GetMainNICAddress(nicID, header.IPv6ProtocolNumber)
+ if err != nil {
+ t.Fatalf("got stack.GetMainNICAddress(%d, %d) = (_, %v), want = (_, nil)", nicID, header.IPv6ProtocolNumber, err)
+ }
+ if want := (tcpip.AddressWithPrefix{}); addr != want {
+ t.Errorf("got stack.GetMainNICAddress(%d, %d) = (%s, nil), want = (%s, nil)", nicID, header.IPv6ProtocolNumber, addr, want)
+ }
- // Should not have sent more than 1 NS message.
- if got := s.Stats().ICMP.V6PacketsSent.NeighborSolicit.Value(); got > 1 {
- t.Fatalf("got NeighborSolicit = %d, want <= 1", got)
+ // Should not have sent more than 1 NS message.
+ if got := s.Stats().ICMP.V6PacketsSent.NeighborSolicit.Value(); got > 1 {
+ t.Errorf("got NeighborSolicit = %d, want <= 1", got)
+ }
+ })
}
}
@@ -2886,17 +2953,16 @@ func TestNDPRecursiveDNSServerDispatch(t *testing.T) {
}
}
-// TestCleanupHostOnlyStateOnBecomingRouter tests that all discovered routers
-// and prefixes, and non-linklocal auto-generated addresses are invalidated when
-// a NIC becomes a router.
-func TestCleanupHostOnlyStateOnBecomingRouter(t *testing.T) {
+// TestCleanupNDPState tests that all discovered routers and prefixes, and
+// auto-generated addresses are invalidated when a NIC becomes a router.
+func TestCleanupNDPState(t *testing.T) {
t.Parallel()
const (
- lifetimeSeconds = 5
- maxEvents = 4
- nicID1 = 1
- nicID2 = 2
+ lifetimeSeconds = 5
+ maxRouterAndPrefixEvents = 4
+ nicID1 = 1
+ nicID2 = 2
)
prefix1, subnet1, e1Addr1 := prefixSubnetAddr(0, linkAddr1)
@@ -2912,254 +2978,308 @@ func TestCleanupHostOnlyStateOnBecomingRouter(t *testing.T) {
PrefixLen: 64,
}
- ndpDisp := ndpDispatcher{
- routerC: make(chan ndpRouterEvent, maxEvents),
- rememberRouter: true,
- prefixC: make(chan ndpPrefixEvent, maxEvents),
- rememberPrefix: true,
- autoGenAddrC: make(chan ndpAutoGenAddrEvent, maxEvents),
- }
- s := stack.New(stack.Options{
- NetworkProtocols: []stack.NetworkProtocol{ipv6.NewProtocol()},
- AutoGenIPv6LinkLocal: true,
- NDPConfigs: stack.NDPConfigurations{
- HandleRAs: true,
- DiscoverDefaultRouters: true,
- DiscoverOnLinkPrefixes: true,
- AutoGenGlobalAddresses: true,
+ tests := []struct {
+ name string
+ cleanupFn func(t *testing.T, s *stack.Stack)
+ keepAutoGenLinkLocal bool
+ maxAutoGenAddrEvents int
+ }{
+ // A NIC should still keep its auto-generated link-local address when
+ // becoming a router.
+ {
+ name: "Forwarding Enable",
+ cleanupFn: func(t *testing.T, s *stack.Stack) {
+ t.Helper()
+ s.SetForwarding(true)
+ },
+ keepAutoGenLinkLocal: true,
+ maxAutoGenAddrEvents: 4,
},
- NDPDisp: &ndpDisp,
- })
- expectRouterEvent := func() (bool, ndpRouterEvent) {
- select {
- case e := <-ndpDisp.routerC:
- return true, e
- default:
- }
+ // A NIC should cleanup all NDP state when it is disabled.
+ {
+ name: "NIC Disable",
+ cleanupFn: func(t *testing.T, s *stack.Stack) {
+ t.Helper()
- return false, ndpRouterEvent{}
+ if err := s.DisableNIC(nicID1); err != nil {
+ t.Fatalf("s.DisableNIC(%d): %s", nicID1, err)
+ }
+ if err := s.DisableNIC(nicID2); err != nil {
+ t.Fatalf("s.DisableNIC(%d): %s", nicID2, err)
+ }
+ },
+ keepAutoGenLinkLocal: false,
+ maxAutoGenAddrEvents: 6,
+ },
}
- expectPrefixEvent := func() (bool, ndpPrefixEvent) {
- select {
- case e := <-ndpDisp.prefixC:
- return true, e
- default:
- }
+ for _, test := range tests {
+ t.Run(test.name, func(t *testing.T) {
+ ndpDisp := ndpDispatcher{
+ routerC: make(chan ndpRouterEvent, maxRouterAndPrefixEvents),
+ rememberRouter: true,
+ prefixC: make(chan ndpPrefixEvent, maxRouterAndPrefixEvents),
+ rememberPrefix: true,
+ autoGenAddrC: make(chan ndpAutoGenAddrEvent, test.maxAutoGenAddrEvents),
+ }
+ s := stack.New(stack.Options{
+ NetworkProtocols: []stack.NetworkProtocol{ipv6.NewProtocol()},
+ AutoGenIPv6LinkLocal: true,
+ NDPConfigs: stack.NDPConfigurations{
+ HandleRAs: true,
+ DiscoverDefaultRouters: true,
+ DiscoverOnLinkPrefixes: true,
+ AutoGenGlobalAddresses: true,
+ },
+ NDPDisp: &ndpDisp,
+ })
- return false, ndpPrefixEvent{}
- }
+ expectRouterEvent := func() (bool, ndpRouterEvent) {
+ select {
+ case e := <-ndpDisp.routerC:
+ return true, e
+ default:
+ }
- expectAutoGenAddrEvent := func() (bool, ndpAutoGenAddrEvent) {
- select {
- case e := <-ndpDisp.autoGenAddrC:
- return true, e
- default:
- }
+ return false, ndpRouterEvent{}
+ }
- return false, ndpAutoGenAddrEvent{}
- }
+ expectPrefixEvent := func() (bool, ndpPrefixEvent) {
+ select {
+ case e := <-ndpDisp.prefixC:
+ return true, e
+ default:
+ }
- e1 := channel.New(0, 1280, linkAddr1)
- if err := s.CreateNIC(nicID1, e1); err != nil {
- t.Fatalf("CreateNIC(%d, _) = %s", nicID1, err)
- }
- // We have other tests that make sure we receive the *correct* events
- // on normal discovery of routers/prefixes, and auto-generated
- // addresses. Here we just make sure we get an event and let other tests
- // handle the correctness check.
- expectAutoGenAddrEvent()
+ return false, ndpPrefixEvent{}
+ }
- e2 := channel.New(0, 1280, linkAddr2)
- if err := s.CreateNIC(nicID2, e2); err != nil {
- t.Fatalf("CreateNIC(%d, _) = %s", nicID2, err)
- }
- expectAutoGenAddrEvent()
+ expectAutoGenAddrEvent := func() (bool, ndpAutoGenAddrEvent) {
+ select {
+ case e := <-ndpDisp.autoGenAddrC:
+ return true, e
+ default:
+ }
- // Receive RAs on NIC(1) and NIC(2) from default routers (llAddr3 and
- // llAddr4) w/ PI (for prefix1 in RA from llAddr3 and prefix2 in RA from
- // llAddr4) to discover multiple routers and prefixes, and auto-gen
- // multiple addresses.
+ return false, ndpAutoGenAddrEvent{}
+ }
- e1.InjectInbound(header.IPv6ProtocolNumber, raBufWithPI(llAddr3, lifetimeSeconds, prefix1, true, true, lifetimeSeconds, lifetimeSeconds))
- if ok, _ := expectRouterEvent(); !ok {
- t.Errorf("expected router event for %s on NIC(%d)", llAddr3, nicID1)
- }
- if ok, _ := expectPrefixEvent(); !ok {
- t.Errorf("expected prefix event for %s on NIC(%d)", prefix1, nicID1)
- }
- if ok, _ := expectAutoGenAddrEvent(); !ok {
- t.Errorf("expected auto-gen addr event for %s on NIC(%d)", e1Addr1, nicID1)
- }
+ e1 := channel.New(0, 1280, linkAddr1)
+ if err := s.CreateNIC(nicID1, e1); err != nil {
+ t.Fatalf("CreateNIC(%d, _) = %s", nicID1, err)
+ }
+ // We have other tests that make sure we receive the *correct* events
+ // on normal discovery of routers/prefixes, and auto-generated
+ // addresses. Here we just make sure we get an event and let other tests
+ // handle the correctness check.
+ expectAutoGenAddrEvent()
- e1.InjectInbound(header.IPv6ProtocolNumber, raBufWithPI(llAddr4, lifetimeSeconds, prefix2, true, true, lifetimeSeconds, lifetimeSeconds))
- if ok, _ := expectRouterEvent(); !ok {
- t.Errorf("expected router event for %s on NIC(%d)", llAddr4, nicID1)
- }
- if ok, _ := expectPrefixEvent(); !ok {
- t.Errorf("expected prefix event for %s on NIC(%d)", prefix2, nicID1)
- }
- if ok, _ := expectAutoGenAddrEvent(); !ok {
- t.Errorf("expected auto-gen addr event for %s on NIC(%d)", e1Addr2, nicID1)
- }
+ e2 := channel.New(0, 1280, linkAddr2)
+ if err := s.CreateNIC(nicID2, e2); err != nil {
+ t.Fatalf("CreateNIC(%d, _) = %s", nicID2, err)
+ }
+ expectAutoGenAddrEvent()
- e2.InjectInbound(header.IPv6ProtocolNumber, raBufWithPI(llAddr3, lifetimeSeconds, prefix1, true, true, lifetimeSeconds, lifetimeSeconds))
- if ok, _ := expectRouterEvent(); !ok {
- t.Errorf("expected router event for %s on NIC(%d)", llAddr3, nicID2)
- }
- if ok, _ := expectPrefixEvent(); !ok {
- t.Errorf("expected prefix event for %s on NIC(%d)", prefix1, nicID2)
- }
- if ok, _ := expectAutoGenAddrEvent(); !ok {
- t.Errorf("expected auto-gen addr event for %s on NIC(%d)", e1Addr2, nicID2)
- }
+ // Receive RAs on NIC(1) and NIC(2) from default routers (llAddr3 and
+ // llAddr4) w/ PI (for prefix1 in RA from llAddr3 and prefix2 in RA from
+ // llAddr4) to discover multiple routers and prefixes, and auto-gen
+ // multiple addresses.
- e2.InjectInbound(header.IPv6ProtocolNumber, raBufWithPI(llAddr4, lifetimeSeconds, prefix2, true, true, lifetimeSeconds, lifetimeSeconds))
- if ok, _ := expectRouterEvent(); !ok {
- t.Errorf("expected router event for %s on NIC(%d)", llAddr4, nicID2)
- }
- if ok, _ := expectPrefixEvent(); !ok {
- t.Errorf("expected prefix event for %s on NIC(%d)", prefix2, nicID2)
- }
- if ok, _ := expectAutoGenAddrEvent(); !ok {
- t.Errorf("expected auto-gen addr event for %s on NIC(%d)", e2Addr2, nicID2)
- }
+ e1.InjectInbound(header.IPv6ProtocolNumber, raBufWithPI(llAddr3, lifetimeSeconds, prefix1, true, true, lifetimeSeconds, lifetimeSeconds))
+ if ok, _ := expectRouterEvent(); !ok {
+ t.Errorf("expected router event for %s on NIC(%d)", llAddr3, nicID1)
+ }
+ if ok, _ := expectPrefixEvent(); !ok {
+ t.Errorf("expected prefix event for %s on NIC(%d)", prefix1, nicID1)
+ }
+ if ok, _ := expectAutoGenAddrEvent(); !ok {
+ t.Errorf("expected auto-gen addr event for %s on NIC(%d)", e1Addr1, nicID1)
+ }
- // We should have the auto-generated addresses added.
- nicinfo := s.NICInfo()
- nic1Addrs := nicinfo[nicID1].ProtocolAddresses
- nic2Addrs := nicinfo[nicID2].ProtocolAddresses
- if !containsV6Addr(nic1Addrs, llAddrWithPrefix1) {
- t.Errorf("missing %s from the list of addresses for NIC(%d): %+v", llAddrWithPrefix1, nicID1, nic1Addrs)
- }
- if !containsV6Addr(nic1Addrs, e1Addr1) {
- t.Errorf("missing %s from the list of addresses for NIC(%d): %+v", e1Addr1, nicID1, nic1Addrs)
- }
- if !containsV6Addr(nic1Addrs, e1Addr2) {
- t.Errorf("missing %s from the list of addresses for NIC(%d): %+v", e1Addr2, nicID1, nic1Addrs)
- }
- if !containsV6Addr(nic2Addrs, llAddrWithPrefix2) {
- t.Errorf("missing %s from the list of addresses for NIC(%d): %+v", llAddrWithPrefix2, nicID2, nic2Addrs)
- }
- if !containsV6Addr(nic2Addrs, e2Addr1) {
- t.Errorf("missing %s from the list of addresses for NIC(%d): %+v", e2Addr1, nicID2, nic2Addrs)
- }
- if !containsV6Addr(nic2Addrs, e2Addr2) {
- t.Errorf("missing %s from the list of addresses for NIC(%d): %+v", e2Addr2, nicID2, nic2Addrs)
- }
+ e1.InjectInbound(header.IPv6ProtocolNumber, raBufWithPI(llAddr4, lifetimeSeconds, prefix2, true, true, lifetimeSeconds, lifetimeSeconds))
+ if ok, _ := expectRouterEvent(); !ok {
+ t.Errorf("expected router event for %s on NIC(%d)", llAddr4, nicID1)
+ }
+ if ok, _ := expectPrefixEvent(); !ok {
+ t.Errorf("expected prefix event for %s on NIC(%d)", prefix2, nicID1)
+ }
+ if ok, _ := expectAutoGenAddrEvent(); !ok {
+ t.Errorf("expected auto-gen addr event for %s on NIC(%d)", e1Addr2, nicID1)
+ }
- // We can't proceed any further if we already failed the test (missing
- // some discovery/auto-generated address events or addresses).
- if t.Failed() {
- t.FailNow()
- }
+ e2.InjectInbound(header.IPv6ProtocolNumber, raBufWithPI(llAddr3, lifetimeSeconds, prefix1, true, true, lifetimeSeconds, lifetimeSeconds))
+ if ok, _ := expectRouterEvent(); !ok {
+ t.Errorf("expected router event for %s on NIC(%d)", llAddr3, nicID2)
+ }
+ if ok, _ := expectPrefixEvent(); !ok {
+ t.Errorf("expected prefix event for %s on NIC(%d)", prefix1, nicID2)
+ }
+ if ok, _ := expectAutoGenAddrEvent(); !ok {
+ t.Errorf("expected auto-gen addr event for %s on NIC(%d)", e1Addr2, nicID2)
+ }
- s.SetForwarding(true)
+ e2.InjectInbound(header.IPv6ProtocolNumber, raBufWithPI(llAddr4, lifetimeSeconds, prefix2, true, true, lifetimeSeconds, lifetimeSeconds))
+ if ok, _ := expectRouterEvent(); !ok {
+ t.Errorf("expected router event for %s on NIC(%d)", llAddr4, nicID2)
+ }
+ if ok, _ := expectPrefixEvent(); !ok {
+ t.Errorf("expected prefix event for %s on NIC(%d)", prefix2, nicID2)
+ }
+ if ok, _ := expectAutoGenAddrEvent(); !ok {
+ t.Errorf("expected auto-gen addr event for %s on NIC(%d)", e2Addr2, nicID2)
+ }
- // Collect invalidation events after becoming a router
- gotRouterEvents := make(map[ndpRouterEvent]int)
- for i := 0; i < maxEvents; i++ {
- ok, e := expectRouterEvent()
- if !ok {
- t.Errorf("expected %d router events after becoming a router; got = %d", maxEvents, i)
- break
- }
- gotRouterEvents[e]++
- }
- gotPrefixEvents := make(map[ndpPrefixEvent]int)
- for i := 0; i < maxEvents; i++ {
- ok, e := expectPrefixEvent()
- if !ok {
- t.Errorf("expected %d prefix events after becoming a router; got = %d", maxEvents, i)
- break
- }
- gotPrefixEvents[e]++
- }
- gotAutoGenAddrEvents := make(map[ndpAutoGenAddrEvent]int)
- for i := 0; i < maxEvents; i++ {
- ok, e := expectAutoGenAddrEvent()
- if !ok {
- t.Errorf("expected %d auto-generated address events after becoming a router; got = %d", maxEvents, i)
- break
- }
- gotAutoGenAddrEvents[e]++
- }
+ // We should have the auto-generated addresses added.
+ nicinfo := s.NICInfo()
+ nic1Addrs := nicinfo[nicID1].ProtocolAddresses
+ nic2Addrs := nicinfo[nicID2].ProtocolAddresses
+ if !containsV6Addr(nic1Addrs, llAddrWithPrefix1) {
+ t.Errorf("missing %s from the list of addresses for NIC(%d): %+v", llAddrWithPrefix1, nicID1, nic1Addrs)
+ }
+ if !containsV6Addr(nic1Addrs, e1Addr1) {
+ t.Errorf("missing %s from the list of addresses for NIC(%d): %+v", e1Addr1, nicID1, nic1Addrs)
+ }
+ if !containsV6Addr(nic1Addrs, e1Addr2) {
+ t.Errorf("missing %s from the list of addresses for NIC(%d): %+v", e1Addr2, nicID1, nic1Addrs)
+ }
+ if !containsV6Addr(nic2Addrs, llAddrWithPrefix2) {
+ t.Errorf("missing %s from the list of addresses for NIC(%d): %+v", llAddrWithPrefix2, nicID2, nic2Addrs)
+ }
+ if !containsV6Addr(nic2Addrs, e2Addr1) {
+ t.Errorf("missing %s from the list of addresses for NIC(%d): %+v", e2Addr1, nicID2, nic2Addrs)
+ }
+ if !containsV6Addr(nic2Addrs, e2Addr2) {
+ t.Errorf("missing %s from the list of addresses for NIC(%d): %+v", e2Addr2, nicID2, nic2Addrs)
+ }
- // No need to proceed any further if we already failed the test (missing
- // some invalidation events).
- if t.Failed() {
- t.FailNow()
- }
+ // We can't proceed any further if we already failed the test (missing
+ // some discovery/auto-generated address events or addresses).
+ if t.Failed() {
+ t.FailNow()
+ }
- expectedRouterEvents := map[ndpRouterEvent]int{
- {nicID: nicID1, addr: llAddr3, discovered: false}: 1,
- {nicID: nicID1, addr: llAddr4, discovered: false}: 1,
- {nicID: nicID2, addr: llAddr3, discovered: false}: 1,
- {nicID: nicID2, addr: llAddr4, discovered: false}: 1,
- }
- if diff := cmp.Diff(expectedRouterEvents, gotRouterEvents); diff != "" {
- t.Errorf("router events mismatch (-want +got):\n%s", diff)
- }
- expectedPrefixEvents := map[ndpPrefixEvent]int{
- {nicID: nicID1, prefix: subnet1, discovered: false}: 1,
- {nicID: nicID1, prefix: subnet2, discovered: false}: 1,
- {nicID: nicID2, prefix: subnet1, discovered: false}: 1,
- {nicID: nicID2, prefix: subnet2, discovered: false}: 1,
- }
- if diff := cmp.Diff(expectedPrefixEvents, gotPrefixEvents); diff != "" {
- t.Errorf("prefix events mismatch (-want +got):\n%s", diff)
- }
- expectedAutoGenAddrEvents := map[ndpAutoGenAddrEvent]int{
- {nicID: nicID1, addr: e1Addr1, eventType: invalidatedAddr}: 1,
- {nicID: nicID1, addr: e1Addr2, eventType: invalidatedAddr}: 1,
- {nicID: nicID2, addr: e2Addr1, eventType: invalidatedAddr}: 1,
- {nicID: nicID2, addr: e2Addr2, eventType: invalidatedAddr}: 1,
- }
- if diff := cmp.Diff(expectedAutoGenAddrEvents, gotAutoGenAddrEvents); diff != "" {
- t.Errorf("auto-generated address events mismatch (-want +got):\n%s", diff)
- }
+ test.cleanupFn(t, s)
- // Make sure the auto-generated addresses got removed.
- nicinfo = s.NICInfo()
- nic1Addrs = nicinfo[nicID1].ProtocolAddresses
- nic2Addrs = nicinfo[nicID2].ProtocolAddresses
- if !containsV6Addr(nic1Addrs, llAddrWithPrefix1) {
- t.Errorf("missing %s from the list of addresses for NIC(%d): %+v", llAddrWithPrefix1, nicID1, nic1Addrs)
- }
- if containsV6Addr(nic1Addrs, e1Addr1) {
- t.Errorf("still have %s in the list of addresses for NIC(%d): %+v", e1Addr1, nicID1, nic1Addrs)
- }
- if containsV6Addr(nic1Addrs, e1Addr2) {
- t.Errorf("still have %s in the list of addresses for NIC(%d): %+v", e1Addr2, nicID1, nic1Addrs)
- }
- if !containsV6Addr(nic2Addrs, llAddrWithPrefix2) {
- t.Errorf("missing %s from the list of addresses for NIC(%d): %+v", llAddrWithPrefix2, nicID2, nic2Addrs)
- }
- if containsV6Addr(nic2Addrs, e2Addr1) {
- t.Errorf("still have %s in the list of addresses for NIC(%d): %+v", e2Addr1, nicID2, nic2Addrs)
- }
- if containsV6Addr(nic2Addrs, e2Addr2) {
- t.Errorf("still have %s in the list of addresses for NIC(%d): %+v", e2Addr2, nicID2, nic2Addrs)
- }
+ // Collect invalidation events after having NDP state cleaned up.
+ gotRouterEvents := make(map[ndpRouterEvent]int)
+ for i := 0; i < maxRouterAndPrefixEvents; i++ {
+ ok, e := expectRouterEvent()
+ if !ok {
+ t.Errorf("expected %d router events after becoming a router; got = %d", maxRouterAndPrefixEvents, i)
+ break
+ }
+ gotRouterEvents[e]++
+ }
+ gotPrefixEvents := make(map[ndpPrefixEvent]int)
+ for i := 0; i < maxRouterAndPrefixEvents; i++ {
+ ok, e := expectPrefixEvent()
+ if !ok {
+ t.Errorf("expected %d prefix events after becoming a router; got = %d", maxRouterAndPrefixEvents, i)
+ break
+ }
+ gotPrefixEvents[e]++
+ }
+ gotAutoGenAddrEvents := make(map[ndpAutoGenAddrEvent]int)
+ for i := 0; i < test.maxAutoGenAddrEvents; i++ {
+ ok, e := expectAutoGenAddrEvent()
+ if !ok {
+ t.Errorf("expected %d auto-generated address events after becoming a router; got = %d", test.maxAutoGenAddrEvents, i)
+ break
+ }
+ gotAutoGenAddrEvents[e]++
+ }
- // Should not get any more events (invalidation timers should have been
- // cancelled when we transitioned into a router).
- time.Sleep(lifetimeSeconds*time.Second + defaultTimeout)
- select {
- case <-ndpDisp.routerC:
- t.Error("unexpected router event")
- default:
- }
- select {
- case <-ndpDisp.prefixC:
- t.Error("unexpected prefix event")
- default:
- }
- select {
- case <-ndpDisp.autoGenAddrC:
- t.Error("unexpected auto-generated address event")
- default:
+ // No need to proceed any further if we already failed the test (missing
+ // some invalidation events).
+ if t.Failed() {
+ t.FailNow()
+ }
+
+ expectedRouterEvents := map[ndpRouterEvent]int{
+ {nicID: nicID1, addr: llAddr3, discovered: false}: 1,
+ {nicID: nicID1, addr: llAddr4, discovered: false}: 1,
+ {nicID: nicID2, addr: llAddr3, discovered: false}: 1,
+ {nicID: nicID2, addr: llAddr4, discovered: false}: 1,
+ }
+ if diff := cmp.Diff(expectedRouterEvents, gotRouterEvents); diff != "" {
+ t.Errorf("router events mismatch (-want +got):\n%s", diff)
+ }
+ expectedPrefixEvents := map[ndpPrefixEvent]int{
+ {nicID: nicID1, prefix: subnet1, discovered: false}: 1,
+ {nicID: nicID1, prefix: subnet2, discovered: false}: 1,
+ {nicID: nicID2, prefix: subnet1, discovered: false}: 1,
+ {nicID: nicID2, prefix: subnet2, discovered: false}: 1,
+ }
+ if diff := cmp.Diff(expectedPrefixEvents, gotPrefixEvents); diff != "" {
+ t.Errorf("prefix events mismatch (-want +got):\n%s", diff)
+ }
+ expectedAutoGenAddrEvents := map[ndpAutoGenAddrEvent]int{
+ {nicID: nicID1, addr: e1Addr1, eventType: invalidatedAddr}: 1,
+ {nicID: nicID1, addr: e1Addr2, eventType: invalidatedAddr}: 1,
+ {nicID: nicID2, addr: e2Addr1, eventType: invalidatedAddr}: 1,
+ {nicID: nicID2, addr: e2Addr2, eventType: invalidatedAddr}: 1,
+ }
+
+ if !test.keepAutoGenLinkLocal {
+ expectedAutoGenAddrEvents[ndpAutoGenAddrEvent{nicID: nicID1, addr: llAddrWithPrefix1, eventType: invalidatedAddr}] = 1
+ expectedAutoGenAddrEvents[ndpAutoGenAddrEvent{nicID: nicID2, addr: llAddrWithPrefix2, eventType: invalidatedAddr}] = 1
+ }
+
+ if diff := cmp.Diff(expectedAutoGenAddrEvents, gotAutoGenAddrEvents); diff != "" {
+ t.Errorf("auto-generated address events mismatch (-want +got):\n%s", diff)
+ }
+
+ // Make sure the auto-generated addresses got removed.
+ nicinfo = s.NICInfo()
+ nic1Addrs = nicinfo[nicID1].ProtocolAddresses
+ nic2Addrs = nicinfo[nicID2].ProtocolAddresses
+ if containsV6Addr(nic1Addrs, llAddrWithPrefix1) != test.keepAutoGenLinkLocal {
+ if test.keepAutoGenLinkLocal {
+ t.Errorf("missing %s from the list of addresses for NIC(%d): %+v", llAddrWithPrefix1, nicID1, nic1Addrs)
+ } else {
+ t.Errorf("still have %s in the list of addresses for NIC(%d): %+v", llAddrWithPrefix1, nicID1, nic1Addrs)
+ }
+ }
+ if containsV6Addr(nic1Addrs, e1Addr1) {
+ t.Errorf("still have %s in the list of addresses for NIC(%d): %+v", e1Addr1, nicID1, nic1Addrs)
+ }
+ if containsV6Addr(nic1Addrs, e1Addr2) {
+ t.Errorf("still have %s in the list of addresses for NIC(%d): %+v", e1Addr2, nicID1, nic1Addrs)
+ }
+ if containsV6Addr(nic2Addrs, llAddrWithPrefix2) != test.keepAutoGenLinkLocal {
+ if test.keepAutoGenLinkLocal {
+ t.Errorf("missing %s from the list of addresses for NIC(%d): %+v", llAddrWithPrefix2, nicID2, nic2Addrs)
+ } else {
+ t.Errorf("still have %s in the list of addresses for NIC(%d): %+v", llAddrWithPrefix2, nicID2, nic2Addrs)
+ }
+ }
+ if containsV6Addr(nic2Addrs, e2Addr1) {
+ t.Errorf("still have %s in the list of addresses for NIC(%d): %+v", e2Addr1, nicID2, nic2Addrs)
+ }
+ if containsV6Addr(nic2Addrs, e2Addr2) {
+ t.Errorf("still have %s in the list of addresses for NIC(%d): %+v", e2Addr2, nicID2, nic2Addrs)
+ }
+
+ // Should not get any more events (invalidation timers should have been
+ // cancelled when the NDP state was cleaned up).
+ time.Sleep(lifetimeSeconds*time.Second + defaultTimeout)
+ select {
+ case <-ndpDisp.routerC:
+ t.Error("unexpected router event")
+ default:
+ }
+ select {
+ case <-ndpDisp.prefixC:
+ t.Error("unexpected prefix event")
+ default:
+ }
+ select {
+ case <-ndpDisp.autoGenAddrC:
+ t.Error("unexpected auto-generated address event")
+ default:
+ }
+ })
}
}
@@ -3259,8 +3379,11 @@ func TestDHCPv6ConfigurationFromNDPDA(t *testing.T) {
func TestRouterSolicitation(t *testing.T) {
t.Parallel()
+ const nicID = 1
+
tests := []struct {
name string
+ linkHeaderLen uint16
maxRtrSolicit uint8
rtrSolicitInt time.Duration
effectiveRtrSolicitInt time.Duration
@@ -3277,6 +3400,7 @@ func TestRouterSolicitation(t *testing.T) {
},
{
name: "Two RS with delay",
+ linkHeaderLen: 1,
maxRtrSolicit: 2,
rtrSolicitInt: time.Second,
effectiveRtrSolicitInt: time.Second,
@@ -3285,6 +3409,7 @@ func TestRouterSolicitation(t *testing.T) {
},
{
name: "Single RS without delay",
+ linkHeaderLen: 2,
maxRtrSolicit: 1,
rtrSolicitInt: time.Second,
effectiveRtrSolicitInt: time.Second,
@@ -3293,6 +3418,7 @@ func TestRouterSolicitation(t *testing.T) {
},
{
name: "Two RS without delay and invalid zero interval",
+ linkHeaderLen: 3,
maxRtrSolicit: 2,
rtrSolicitInt: 0,
effectiveRtrSolicitInt: 4 * time.Second,
@@ -3330,8 +3456,11 @@ func TestRouterSolicitation(t *testing.T) {
t.Run(test.name, func(t *testing.T) {
t.Parallel()
- e := channel.New(int(test.maxRtrSolicit), 1280, linkAddr1)
- e.LinkEPCapabilities |= stack.CapabilityResolutionRequired
+ e := channelLinkWithHeaderLength{
+ Endpoint: channel.New(int(test.maxRtrSolicit), 1280, linkAddr1),
+ headerLength: test.linkHeaderLen,
+ }
+ e.Endpoint.LinkEPCapabilities |= stack.CapabilityResolutionRequired
waitForPkt := func(timeout time.Duration) {
t.Helper()
ctx, _ := context.WithTimeout(context.Background(), timeout)
@@ -3357,6 +3486,10 @@ func TestRouterSolicitation(t *testing.T) {
checker.TTL(header.NDPHopLimit),
checker.NDPRS(),
)
+
+ if l, want := p.Pkt.Header.AvailableLength(), int(test.linkHeaderLen); l != want {
+ t.Errorf("got p.Pkt.Header.AvailableLength() = %d; want = %d", l, want)
+ }
}
waitForNothing := func(timeout time.Duration) {
t.Helper()
@@ -3373,8 +3506,8 @@ func TestRouterSolicitation(t *testing.T) {
MaxRtrSolicitationDelay: test.maxRtrSolicitDelay,
},
})
- if err := s.CreateNIC(1, e); err != nil {
- t.Fatalf("CreateNIC(1) = %s", err)
+ if err := s.CreateNIC(nicID, &e); err != nil {
+ t.Fatalf("CreateNIC(%d, _) = %s", nicID, err)
}
// Make sure each RS got sent at the right
@@ -3406,77 +3539,130 @@ func TestRouterSolicitation(t *testing.T) {
})
}
-// TestStopStartSolicitingRouters tests that when forwarding is enabled or
-// disabled, router solicitations are stopped or started, respecitively.
func TestStopStartSolicitingRouters(t *testing.T) {
t.Parallel()
+ const nicID = 1
const interval = 500 * time.Millisecond
const delay = time.Second
const maxRtrSolicitations = 3
- e := channel.New(maxRtrSolicitations, 1280, linkAddr1)
- waitForPkt := func(timeout time.Duration) {
- t.Helper()
- ctx, _ := context.WithTimeout(context.Background(), timeout)
- p, ok := e.ReadContext(ctx)
- if !ok {
- t.Fatal("timed out waiting for packet")
- return
- }
- if p.Proto != header.IPv6ProtocolNumber {
- t.Fatalf("got Proto = %d, want = %d", p.Proto, header.IPv6ProtocolNumber)
- }
- checker.IPv6(t, p.Pkt.Header.View(),
- checker.SrcAddr(header.IPv6Any),
- checker.DstAddr(header.IPv6AllRoutersMulticastAddress),
- checker.TTL(header.NDPHopLimit),
- checker.NDPRS())
- }
- s := stack.New(stack.Options{
- NetworkProtocols: []stack.NetworkProtocol{ipv6.NewProtocol()},
- NDPConfigs: stack.NDPConfigurations{
- MaxRtrSolicitations: maxRtrSolicitations,
- RtrSolicitationInterval: interval,
- MaxRtrSolicitationDelay: delay,
+ tests := []struct {
+ name string
+ startFn func(t *testing.T, s *stack.Stack)
+ stopFn func(t *testing.T, s *stack.Stack)
+ }{
+ // Tests that when forwarding is enabled or disabled, router solicitations
+ // are stopped or started, respectively.
+ {
+ name: "Forwarding enabled and disabled",
+ startFn: func(t *testing.T, s *stack.Stack) {
+ t.Helper()
+ s.SetForwarding(false)
+ },
+ stopFn: func(t *testing.T, s *stack.Stack) {
+ t.Helper()
+ s.SetForwarding(true)
+ },
},
- })
- if err := s.CreateNIC(1, e); err != nil {
- t.Fatalf("CreateNIC(1) = %s", err)
- }
- // Enable forwarding which should stop router solicitations.
- s.SetForwarding(true)
- ctx, _ := context.WithTimeout(context.Background(), delay+defaultTimeout)
- if _, ok := e.ReadContext(ctx); ok {
- // A single RS may have been sent before forwarding was enabled.
- ctx, _ = context.WithTimeout(context.Background(), interval+defaultTimeout)
- if _, ok = e.ReadContext(ctx); ok {
- t.Fatal("Should not have sent more than one RS message")
- }
- }
+ // Tests that when a NIC is enabled or disabled, router solicitations
+ // are started or stopped, respectively.
+ {
+ name: "NIC disabled and enabled",
+ startFn: func(t *testing.T, s *stack.Stack) {
+ t.Helper()
- // Enabling forwarding again should do nothing.
- s.SetForwarding(true)
- ctx, _ = context.WithTimeout(context.Background(), delay+defaultTimeout)
- if _, ok := e.ReadContext(ctx); ok {
- t.Fatal("unexpectedly got a packet after becoming a router")
- }
+ if err := s.EnableNIC(nicID); err != nil {
+ t.Fatalf("s.EnableNIC(%d): %s", nicID, err)
+ }
+ },
+ stopFn: func(t *testing.T, s *stack.Stack) {
+ t.Helper()
- // Disable forwarding which should start router solicitations.
- s.SetForwarding(false)
- waitForPkt(delay + defaultAsyncEventTimeout)
- waitForPkt(interval + defaultAsyncEventTimeout)
- waitForPkt(interval + defaultAsyncEventTimeout)
- ctx, _ = context.WithTimeout(context.Background(), interval+defaultTimeout)
- if _, ok := e.ReadContext(ctx); ok {
- t.Fatal("unexpectedly got an extra packet after sending out the expected RSs")
+ if err := s.DisableNIC(nicID); err != nil {
+ t.Fatalf("s.DisableNIC(%d): %s", nicID, err)
+ }
+ },
+ },
}
- // Disabling forwarding again should do nothing.
- s.SetForwarding(false)
- ctx, _ = context.WithTimeout(context.Background(), delay+defaultTimeout)
- if _, ok := e.ReadContext(ctx); ok {
- t.Fatal("unexpectedly got a packet after becoming a router")
+ for _, test := range tests {
+ t.Run(test.name, func(t *testing.T) {
+ e := channel.New(maxRtrSolicitations, 1280, linkAddr1)
+ waitForPkt := func(timeout time.Duration) {
+ t.Helper()
+
+ ctx, cancel := context.WithTimeout(context.Background(), timeout)
+ defer cancel()
+ p, ok := e.ReadContext(ctx)
+ if !ok {
+ t.Fatal("timed out waiting for packet")
+ return
+ }
+
+ if p.Proto != header.IPv6ProtocolNumber {
+ t.Fatalf("got Proto = %d, want = %d", p.Proto, header.IPv6ProtocolNumber)
+ }
+ checker.IPv6(t, p.Pkt.Header.View(),
+ checker.SrcAddr(header.IPv6Any),
+ checker.DstAddr(header.IPv6AllRoutersMulticastAddress),
+ checker.TTL(header.NDPHopLimit),
+ checker.NDPRS())
+ }
+ s := stack.New(stack.Options{
+ NetworkProtocols: []stack.NetworkProtocol{ipv6.NewProtocol()},
+ NDPConfigs: stack.NDPConfigurations{
+ MaxRtrSolicitations: maxRtrSolicitations,
+ RtrSolicitationInterval: interval,
+ MaxRtrSolicitationDelay: delay,
+ },
+ })
+ if err := s.CreateNIC(nicID, e); err != nil {
+ t.Fatalf("CreateNIC(%d, _) = %s", nicID, err)
+ }
+
+ // Stop soliciting routers.
+ test.stopFn(t, s)
+ ctx, cancel := context.WithTimeout(context.Background(), delay+defaultTimeout)
+ defer cancel()
+ if _, ok := e.ReadContext(ctx); ok {
+ // A single RS may have been sent before forwarding was enabled.
+ ctx, cancel := context.WithTimeout(context.Background(), interval+defaultTimeout)
+ defer cancel()
+ if _, ok = e.ReadContext(ctx); ok {
+ t.Fatal("should not have sent more than one RS message")
+ }
+ }
+
+ // Stopping router solicitations after it has already been stopped should
+ // do nothing.
+ test.stopFn(t, s)
+ ctx, cancel = context.WithTimeout(context.Background(), delay+defaultTimeout)
+ defer cancel()
+ if _, ok := e.ReadContext(ctx); ok {
+ t.Fatal("unexpectedly got a packet after router solicitation has been stopepd")
+ }
+
+ // Start soliciting routers.
+ test.startFn(t, s)
+ waitForPkt(delay + defaultAsyncEventTimeout)
+ waitForPkt(interval + defaultAsyncEventTimeout)
+ waitForPkt(interval + defaultAsyncEventTimeout)
+ ctx, cancel = context.WithTimeout(context.Background(), interval+defaultTimeout)
+ defer cancel()
+ if _, ok := e.ReadContext(ctx); ok {
+ t.Fatal("unexpectedly got an extra packet after sending out the expected RSs")
+ }
+
+ // Starting router solicitations after it has already completed should do
+ // nothing.
+ test.startFn(t, s)
+ ctx, cancel = context.WithTimeout(context.Background(), delay+defaultTimeout)
+ defer cancel()
+ if _, ok := e.ReadContext(ctx); ok {
+ t.Fatal("unexpectedly got a packet after finishing router solicitations")
+ }
+ })
}
}
diff --git a/pkg/tcpip/stack/nic.go b/pkg/tcpip/stack/nic.go
index ca3a7a07e..46d3a6646 100644
--- a/pkg/tcpip/stack/nic.go
+++ b/pkg/tcpip/stack/nic.go
@@ -27,6 +27,14 @@ import (
"gvisor.dev/gvisor/pkg/tcpip/header"
)
+var ipv4BroadcastAddr = tcpip.ProtocolAddress{
+ Protocol: header.IPv4ProtocolNumber,
+ AddressWithPrefix: tcpip.AddressWithPrefix{
+ Address: header.IPv4Broadcast,
+ PrefixLen: 8 * header.IPv4AddressSize,
+ },
+}
+
// NIC represents a "network interface card" to which the networking stack is
// attached.
type NIC struct {
@@ -132,11 +140,77 @@ func newNIC(stack *Stack, id tcpip.NICID, name string, ep LinkEndpoint, ctx NICC
nic.mu.packetEPs[netProto.Number()] = []PacketEndpoint{}
}
+ nic.linkEP.Attach(nic)
+
return nic
}
-// enable enables the NIC. enable will attach the link to its LinkEndpoint and
-// join the IPv6 All-Nodes Multicast address (ff02::1).
+// enabled returns true if n is enabled.
+func (n *NIC) enabled() bool {
+ n.mu.RLock()
+ enabled := n.mu.enabled
+ n.mu.RUnlock()
+ return enabled
+}
+
+// disable disables n.
+//
+// It undoes the work done by enable.
+func (n *NIC) disable() *tcpip.Error {
+ n.mu.RLock()
+ enabled := n.mu.enabled
+ n.mu.RUnlock()
+ if !enabled {
+ return nil
+ }
+
+ n.mu.Lock()
+ defer n.mu.Unlock()
+
+ if !n.mu.enabled {
+ return nil
+ }
+
+ // TODO(b/147015577): Should Routes that are currently bound to n be
+ // invalidated? Currently, Routes will continue to work when a NIC is enabled
+ // again, and applications may not know that the underlying NIC was ever
+ // disabled.
+
+ if _, ok := n.stack.networkProtocols[header.IPv6ProtocolNumber]; ok {
+ n.mu.ndp.stopSolicitingRouters()
+ n.mu.ndp.cleanupState(false /* hostOnly */)
+
+ // Stop DAD for all the unicast IPv6 endpoints that are in the
+ // permanentTentative state.
+ for _, r := range n.mu.endpoints {
+ if addr := r.ep.ID().LocalAddress; r.getKind() == permanentTentative && header.IsV6UnicastAddress(addr) {
+ n.mu.ndp.stopDuplicateAddressDetection(addr)
+ }
+ }
+
+ // The NIC may have already left the multicast group.
+ if err := n.leaveGroupLocked(header.IPv6AllNodesMulticastAddress); err != nil && err != tcpip.ErrBadLocalAddress {
+ return err
+ }
+ }
+
+ if _, ok := n.stack.networkProtocols[header.IPv4ProtocolNumber]; ok {
+ // The address may have already been removed.
+ if err := n.removePermanentAddressLocked(ipv4BroadcastAddr.AddressWithPrefix.Address); err != nil && err != tcpip.ErrBadLocalAddress {
+ return err
+ }
+ }
+
+ n.mu.enabled = false
+ return nil
+}
+
+// enable enables n.
+//
+// If the stack has IPv6 enabled, enable will join the IPv6 All-Nodes Multicast
+// address (ff02::1), start DAD for permanent addresses, and start soliciting
+// routers if the stack is not operating as a router. If the stack is also
+// configured to auto-generate a link-local address, one will be generated.
func (n *NIC) enable() *tcpip.Error {
n.mu.RLock()
enabled := n.mu.enabled
@@ -154,14 +228,9 @@ func (n *NIC) enable() *tcpip.Error {
n.mu.enabled = true
- n.attachLinkEndpoint()
-
// Create an endpoint to receive broadcast packets on this interface.
if _, ok := n.stack.networkProtocols[header.IPv4ProtocolNumber]; ok {
- if _, err := n.addAddressLocked(tcpip.ProtocolAddress{
- Protocol: header.IPv4ProtocolNumber,
- AddressWithPrefix: tcpip.AddressWithPrefix{header.IPv4Broadcast, 8 * header.IPv4AddressSize},
- }, NeverPrimaryEndpoint, permanent, static, false /* deprecated */); err != nil {
+ if _, err := n.addAddressLocked(ipv4BroadcastAddr, NeverPrimaryEndpoint, permanent, static, false /* deprecated */); err != nil {
return err
}
}
@@ -183,6 +252,14 @@ func (n *NIC) enable() *tcpip.Error {
return nil
}
+ // Join the All-Nodes multicast group before starting DAD as responses to DAD
+ // (NDP NS) messages may be sent to the All-Nodes multicast group if the
+ // source address of the NDP NS is the unspecified address, as per RFC 4861
+ // section 7.2.4.
+ if err := n.joinGroupLocked(header.IPv6ProtocolNumber, header.IPv6AllNodesMulticastAddress); err != nil {
+ return err
+ }
+
// Perform DAD on the all the unicast IPv6 endpoints that are in the permanent
// state.
//
@@ -200,10 +277,6 @@ func (n *NIC) enable() *tcpip.Error {
}
}
- if err := n.joinGroupLocked(header.IPv6ProtocolNumber, header.IPv6AllNodesMulticastAddress); err != nil {
- return err
- }
-
// Do not auto-generate an IPv6 link-local address for loopback devices.
if n.stack.autoGenIPv6LinkLocal && !n.isLoopback() {
// The valid and preferred lifetime is infinite for the auto-generated
@@ -225,6 +298,33 @@ func (n *NIC) enable() *tcpip.Error {
return nil
}
+// remove detaches NIC from the link endpoint, and marks existing referenced
+// network endpoints expired. This guarantees no packets between this NIC and
+// the network stack.
+func (n *NIC) remove() *tcpip.Error {
+ n.mu.Lock()
+ defer n.mu.Unlock()
+
+ // Detach from link endpoint, so no packet comes in.
+ n.linkEP.Attach(nil)
+
+ // Remove permanent and permanentTentative addresses, so no packet goes out.
+ var errs []*tcpip.Error
+ for nid, ref := range n.mu.endpoints {
+ switch ref.getKind() {
+ case permanentTentative, permanent:
+ if err := n.removePermanentAddressLocked(nid.LocalAddress); err != nil {
+ errs = append(errs, err)
+ }
+ }
+ }
+ if len(errs) > 0 {
+ return errs[0]
+ }
+
+ return nil
+}
+
// becomeIPv6Router transitions n into an IPv6 router.
//
// When transitioning into an IPv6 router, host-only state (NDP discovered
@@ -234,7 +334,7 @@ func (n *NIC) becomeIPv6Router() {
n.mu.Lock()
defer n.mu.Unlock()
- n.mu.ndp.cleanupHostOnlyState()
+ n.mu.ndp.cleanupState(true /* hostOnly */)
n.mu.ndp.stopSolicitingRouters()
}
@@ -249,12 +349,6 @@ func (n *NIC) becomeIPv6Host() {
n.mu.ndp.startSolicitingRouters()
}
-// attachLinkEndpoint attaches the NIC to the endpoint, which will enable it
-// to start delivering packets.
-func (n *NIC) attachLinkEndpoint() {
- n.linkEP.Attach(n)
-}
-
// setPromiscuousMode enables or disables promiscuous mode.
func (n *NIC) setPromiscuousMode(enable bool) {
n.mu.Lock()
@@ -712,6 +806,7 @@ func (n *NIC) AllAddresses() []tcpip.ProtocolAddress {
case permanentExpired, temporary:
continue
}
+
addrs = append(addrs, tcpip.ProtocolAddress{
Protocol: ref.protocol,
AddressWithPrefix: tcpip.AddressWithPrefix{
@@ -1009,6 +1104,15 @@ func (n *NIC) leaveGroupLocked(addr tcpip.Address) *tcpip.Error {
return nil
}
+// isInGroup returns true if n has joined the multicast group addr.
+func (n *NIC) isInGroup(addr tcpip.Address) bool {
+ n.mu.RLock()
+ joins := n.mu.mcastJoins[NetworkEndpointID{addr}]
+ n.mu.RUnlock()
+
+ return joins != 0
+}
+
func handlePacket(protocol tcpip.NetworkProtocolNumber, dst, src tcpip.Address, localLinkAddr, remotelinkAddr tcpip.LinkAddress, ref *referencedNetworkEndpoint, pkt tcpip.PacketBuffer) {
r := makeRoute(protocol, dst, src, localLinkAddr, ref, false /* handleLocal */, false /* multicastLoop */)
r.RemoteLinkAddress = remotelinkAddr
@@ -1225,6 +1329,11 @@ func (n *NIC) Stack() *Stack {
return n.stack
}
+// LinkEndpoint returns the link endpoint of n.
+func (n *NIC) LinkEndpoint() LinkEndpoint {
+ return n.linkEP
+}
+
// isAddrTentative returns true if addr is tentative on n.
//
// Note that if addr is not associated with n, then this function will return
@@ -1411,7 +1520,7 @@ func (r *referencedNetworkEndpoint) isValidForOutgoing() bool {
//
// r's NIC must be read locked.
func (r *referencedNetworkEndpoint) isValidForOutgoingRLocked() bool {
- return r.getKind() != permanentExpired || r.nic.mu.spoofing
+ return r.nic.mu.enabled && (r.getKind() != permanentExpired || r.nic.mu.spoofing)
}
// decRef decrements the ref count and cleans up the endpoint once it reaches
diff --git a/pkg/tcpip/stack/registration.go b/pkg/tcpip/stack/registration.go
index d83adf0ec..f9fd8f18f 100644
--- a/pkg/tcpip/stack/registration.go
+++ b/pkg/tcpip/stack/registration.go
@@ -74,10 +74,11 @@ type TransportEndpoint interface {
// HandleControlPacket takes ownership of pkt.
HandleControlPacket(id TransportEndpointID, typ ControlType, extra uint32, pkt tcpip.PacketBuffer)
- // Close puts the endpoint in a closed state and frees all resources
- // associated with it. This cleanup may happen asynchronously. Wait can
- // be used to block on this asynchronous cleanup.
- Close()
+ // Abort initiates an expedited endpoint teardown. It puts the endpoint
+ // in a closed state and frees all resources associated with it. This
+ // cleanup may happen asynchronously. Wait can be used to block on this
+ // asynchronous cleanup.
+ Abort()
// Wait waits for any worker goroutines owned by the endpoint to stop.
//
@@ -160,6 +161,13 @@ type TransportProtocol interface {
// Option returns an error if the option is not supported or the
// provided option value is invalid.
Option(option interface{}) *tcpip.Error
+
+ // Close requests that any worker goroutines owned by the protocol
+ // stop.
+ Close()
+
+ // Wait waits for any worker goroutines owned by the protocol to stop.
+ Wait()
}
// TransportDispatcher contains the methods used by the network stack to deliver
@@ -293,6 +301,13 @@ type NetworkProtocol interface {
// Option returns an error if the option is not supported or the
// provided option value is invalid.
Option(option interface{}) *tcpip.Error
+
+ // Close requests that any worker goroutines owned by the protocol
+ // stop.
+ Close()
+
+ // Wait waits for any worker goroutines owned by the protocol to stop.
+ Wait()
}
// NetworkDispatcher contains the methods used by the network stack to deliver
diff --git a/pkg/tcpip/stack/stack.go b/pkg/tcpip/stack/stack.go
index 6eac16e16..ebb6c5e3b 100644
--- a/pkg/tcpip/stack/stack.go
+++ b/pkg/tcpip/stack/stack.go
@@ -881,6 +881,8 @@ type NICOptions struct {
// CreateNICWithOptions creates a NIC with the provided id, LinkEndpoint, and
// NICOptions. See the documentation on type NICOptions for details on how
// NICs can be configured.
+//
+// LinkEndpoint.Attach will be called to bind ep with a NetworkDispatcher.
func (s *Stack) CreateNICWithOptions(id tcpip.NICID, ep LinkEndpoint, opts NICOptions) *tcpip.Error {
s.mu.Lock()
defer s.mu.Unlock()
@@ -900,7 +902,6 @@ func (s *Stack) CreateNICWithOptions(id tcpip.NICID, ep LinkEndpoint, opts NICOp
}
n := newNIC(s, id, opts.Name, ep, opts.Context)
-
s.nics[id] = n
if !opts.Disabled {
return n.enable()
@@ -910,34 +911,88 @@ func (s *Stack) CreateNICWithOptions(id tcpip.NICID, ep LinkEndpoint, opts NICOp
}
// CreateNIC creates a NIC with the provided id and LinkEndpoint and calls
-// `LinkEndpoint.Attach` to start delivering packets to it.
+// LinkEndpoint.Attach to bind ep with a NetworkDispatcher.
func (s *Stack) CreateNIC(id tcpip.NICID, ep LinkEndpoint) *tcpip.Error {
return s.CreateNICWithOptions(id, ep, NICOptions{})
}
+// GetNICByName gets the NIC specified by name.
+func (s *Stack) GetNICByName(name string) (*NIC, bool) {
+ s.mu.RLock()
+ defer s.mu.RUnlock()
+ for _, nic := range s.nics {
+ if nic.Name() == name {
+ return nic, true
+ }
+ }
+ return nil, false
+}
+
// EnableNIC enables the given NIC so that the link-layer endpoint can start
// delivering packets to it.
func (s *Stack) EnableNIC(id tcpip.NICID) *tcpip.Error {
s.mu.RLock()
defer s.mu.RUnlock()
- nic := s.nics[id]
- if nic == nil {
+ nic, ok := s.nics[id]
+ if !ok {
return tcpip.ErrUnknownNICID
}
return nic.enable()
}
+// DisableNIC disables the given NIC.
+func (s *Stack) DisableNIC(id tcpip.NICID) *tcpip.Error {
+ s.mu.RLock()
+ defer s.mu.RUnlock()
+
+ nic, ok := s.nics[id]
+ if !ok {
+ return tcpip.ErrUnknownNICID
+ }
+
+ return nic.disable()
+}
+
// CheckNIC checks if a NIC is usable.
func (s *Stack) CheckNIC(id tcpip.NICID) bool {
s.mu.RLock()
+ defer s.mu.RUnlock()
+
nic, ok := s.nics[id]
- s.mu.RUnlock()
- if ok {
- return nic.linkEP.IsAttached()
+ if !ok {
+ return false
}
- return false
+
+ return nic.enabled()
+}
+
+// RemoveNIC removes NIC and all related routes from the network stack.
+func (s *Stack) RemoveNIC(id tcpip.NICID) *tcpip.Error {
+ s.mu.Lock()
+ defer s.mu.Unlock()
+
+ nic, ok := s.nics[id]
+ if !ok {
+ return tcpip.ErrUnknownNICID
+ }
+ delete(s.nics, id)
+
+ // Remove routes in-place. n tracks the number of routes written.
+ n := 0
+ for i, r := range s.routeTable {
+ if r.NIC != id {
+ // Keep this route.
+ if i > n {
+ s.routeTable[n] = r
+ }
+ n++
+ }
+ }
+ s.routeTable = s.routeTable[:n]
+
+ return nic.remove()
}
// NICAddressRanges returns a map of NICIDs to their associated subnets.
@@ -989,7 +1044,7 @@ func (s *Stack) NICInfo() map[tcpip.NICID]NICInfo {
for id, nic := range s.nics {
flags := NICStateFlags{
Up: true, // Netstack interfaces are always up.
- Running: nic.linkEP.IsAttached(),
+ Running: nic.enabled(),
Promiscuous: nic.isPromiscuousMode(),
Loopback: nic.isLoopback(),
}
@@ -1151,7 +1206,7 @@ func (s *Stack) FindRoute(id tcpip.NICID, localAddr, remoteAddr tcpip.Address, n
isMulticast := header.IsV4MulticastAddress(remoteAddr) || header.IsV6MulticastAddress(remoteAddr)
needRoute := !(isBroadcast || isMulticast || header.IsV6LinkLocalAddress(remoteAddr))
if id != 0 && !needRoute {
- if nic, ok := s.nics[id]; ok {
+ if nic, ok := s.nics[id]; ok && nic.enabled() {
if ref := s.getRefEP(nic, localAddr, remoteAddr, netProto); ref != nil {
return makeRoute(netProto, ref.ep.ID().LocalAddress, remoteAddr, nic.linkEP.LinkAddress(), ref, s.handleLocal && !nic.isLoopback(), multicastLoop && !nic.isLoopback()), nil
}
@@ -1161,7 +1216,7 @@ func (s *Stack) FindRoute(id tcpip.NICID, localAddr, remoteAddr tcpip.Address, n
if (id != 0 && id != route.NIC) || (len(remoteAddr) != 0 && !route.Destination.Contains(remoteAddr)) {
continue
}
- if nic, ok := s.nics[route.NIC]; ok {
+ if nic, ok := s.nics[route.NIC]; ok && nic.enabled() {
if ref := s.getRefEP(nic, localAddr, remoteAddr, netProto); ref != nil {
if len(remoteAddr) == 0 {
// If no remote address was provided, then the route
@@ -1391,7 +1446,13 @@ func (s *Stack) RestoreCleanupEndpoints(es []TransportEndpoint) {
// Endpoints created or modified during this call may not get closed.
func (s *Stack) Close() {
for _, e := range s.RegisteredEndpoints() {
- e.Close()
+ e.Abort()
+ }
+ for _, p := range s.transportProtocols {
+ p.proto.Close()
+ }
+ for _, p := range s.networkProtocols {
+ p.Close()
}
}
@@ -1409,6 +1470,12 @@ func (s *Stack) Wait() {
for _, e := range s.CleanupEndpoints() {
e.Wait()
}
+ for _, p := range s.transportProtocols {
+ p.proto.Wait()
+ }
+ for _, p := range s.networkProtocols {
+ p.Wait()
+ }
s.mu.RLock()
defer s.mu.RUnlock()
@@ -1614,6 +1681,18 @@ func (s *Stack) LeaveGroup(protocol tcpip.NetworkProtocolNumber, nicID tcpip.NIC
return tcpip.ErrUnknownNICID
}
+// IsInGroup returns true if the NIC with ID nicID has joined the multicast
+// group multicastAddr.
+func (s *Stack) IsInGroup(nicID tcpip.NICID, multicastAddr tcpip.Address) (bool, *tcpip.Error) {
+ s.mu.RLock()
+ defer s.mu.RUnlock()
+
+ if nic, ok := s.nics[nicID]; ok {
+ return nic.isInGroup(multicastAddr), nil
+ }
+ return false, tcpip.ErrUnknownNICID
+}
+
// IPTables returns the stack's iptables.
func (s *Stack) IPTables() iptables.IPTables {
s.tablesMu.RLock()
diff --git a/pkg/tcpip/stack/stack_test.go b/pkg/tcpip/stack/stack_test.go
index 7ba604442..e15db40fb 100644
--- a/pkg/tcpip/stack/stack_test.go
+++ b/pkg/tcpip/stack/stack_test.go
@@ -33,6 +33,7 @@ import (
"gvisor.dev/gvisor/pkg/tcpip/header"
"gvisor.dev/gvisor/pkg/tcpip/link/channel"
"gvisor.dev/gvisor/pkg/tcpip/link/loopback"
+ "gvisor.dev/gvisor/pkg/tcpip/network/ipv4"
"gvisor.dev/gvisor/pkg/tcpip/network/ipv6"
"gvisor.dev/gvisor/pkg/tcpip/stack"
"gvisor.dev/gvisor/pkg/tcpip/transport/udp"
@@ -234,10 +235,33 @@ func (f *fakeNetworkProtocol) Option(option interface{}) *tcpip.Error {
}
}
+// Close implements TransportProtocol.Close.
+func (*fakeNetworkProtocol) Close() {}
+
+// Wait implements TransportProtocol.Wait.
+func (*fakeNetworkProtocol) Wait() {}
+
func fakeNetFactory() stack.NetworkProtocol {
return &fakeNetworkProtocol{}
}
+// linkEPWithMockedAttach is a stack.LinkEndpoint that tests can use to verify
+// that LinkEndpoint.Attach was called.
+type linkEPWithMockedAttach struct {
+ stack.LinkEndpoint
+ attached bool
+}
+
+// Attach implements stack.LinkEndpoint.Attach.
+func (l *linkEPWithMockedAttach) Attach(d stack.NetworkDispatcher) {
+ l.LinkEndpoint.Attach(d)
+ l.attached = true
+}
+
+func (l *linkEPWithMockedAttach) isAttached() bool {
+ return l.attached
+}
+
func TestNetworkReceive(t *testing.T) {
// Create a stack with the fake network protocol, one nic, and two
// addresses attached to it: 1 & 2.
@@ -509,6 +533,296 @@ func testNoRoute(t *testing.T, s *stack.Stack, nic tcpip.NICID, srcAddr, dstAddr
}
}
+// TestAttachToLinkEndpointImmediately tests that a LinkEndpoint is attached to
+// a NetworkDispatcher when the NIC is created.
+func TestAttachToLinkEndpointImmediately(t *testing.T) {
+ const nicID = 1
+
+ tests := []struct {
+ name string
+ nicOpts stack.NICOptions
+ }{
+ {
+ name: "Create enabled NIC",
+ nicOpts: stack.NICOptions{Disabled: false},
+ },
+ {
+ name: "Create disabled NIC",
+ nicOpts: stack.NICOptions{Disabled: true},
+ },
+ }
+
+ for _, test := range tests {
+ t.Run(test.name, func(t *testing.T) {
+ s := stack.New(stack.Options{
+ NetworkProtocols: []stack.NetworkProtocol{fakeNetFactory()},
+ })
+
+ e := linkEPWithMockedAttach{
+ LinkEndpoint: loopback.New(),
+ }
+
+ if err := s.CreateNICWithOptions(nicID, &e, test.nicOpts); err != nil {
+ t.Fatalf("CreateNICWithOptions(%d, _, %+v) = %s", nicID, test.nicOpts, err)
+ }
+ if !e.isAttached() {
+ t.Fatalf("link endpoint not attached to a network disatcher")
+ }
+ })
+ }
+}
+
+func TestDisableUnknownNIC(t *testing.T) {
+ s := stack.New(stack.Options{
+ NetworkProtocols: []stack.NetworkProtocol{fakeNetFactory()},
+ })
+
+ if err := s.DisableNIC(1); err != tcpip.ErrUnknownNICID {
+ t.Fatalf("got s.DisableNIC(1) = %v, want = %s", err, tcpip.ErrUnknownNICID)
+ }
+}
+
+func TestDisabledNICsNICInfoAndCheckNIC(t *testing.T) {
+ const nicID = 1
+
+ s := stack.New(stack.Options{
+ NetworkProtocols: []stack.NetworkProtocol{fakeNetFactory()},
+ })
+
+ e := loopback.New()
+ nicOpts := stack.NICOptions{Disabled: true}
+ if err := s.CreateNICWithOptions(nicID, e, nicOpts); err != nil {
+ t.Fatalf("CreateNICWithOptions(%d, _, %+v) = %s", nicID, nicOpts, err)
+ }
+
+ checkNIC := func(enabled bool) {
+ t.Helper()
+
+ allNICInfo := s.NICInfo()
+ nicInfo, ok := allNICInfo[nicID]
+ if !ok {
+ t.Errorf("entry for %d missing from allNICInfo = %+v", nicID, allNICInfo)
+ } else if nicInfo.Flags.Running != enabled {
+ t.Errorf("got nicInfo.Flags.Running = %t, want = %t", nicInfo.Flags.Running, enabled)
+ }
+
+ if got := s.CheckNIC(nicID); got != enabled {
+ t.Errorf("got s.CheckNIC(%d) = %t, want = %t", nicID, got, enabled)
+ }
+ }
+
+ // NIC should initially report itself as disabled.
+ checkNIC(false)
+
+ if err := s.EnableNIC(nicID); err != nil {
+ t.Fatalf("s.EnableNIC(%d): %s", nicID, err)
+ }
+ checkNIC(true)
+
+ // If the NIC is not reporting a correct enabled status, we cannot trust the
+ // next check so end the test here.
+ if t.Failed() {
+ t.FailNow()
+ }
+
+ if err := s.DisableNIC(nicID); err != nil {
+ t.Fatalf("s.DisableNIC(%d): %s", nicID, err)
+ }
+ checkNIC(false)
+}
+
+func TestRoutesWithDisabledNIC(t *testing.T) {
+ const unspecifiedNIC = 0
+ const nicID1 = 1
+ const nicID2 = 2
+
+ s := stack.New(stack.Options{
+ NetworkProtocols: []stack.NetworkProtocol{fakeNetFactory()},
+ })
+
+ ep1 := channel.New(0, defaultMTU, "")
+ if err := s.CreateNIC(nicID1, ep1); err != nil {
+ t.Fatalf("CreateNIC(%d, _): %s", nicID1, err)
+ }
+
+ addr1 := tcpip.Address("\x01")
+ if err := s.AddAddress(nicID1, fakeNetNumber, addr1); err != nil {
+ t.Fatalf("AddAddress(%d, %d, %s): %s", nicID1, fakeNetNumber, addr1, err)
+ }
+
+ ep2 := channel.New(0, defaultMTU, "")
+ if err := s.CreateNIC(nicID2, ep2); err != nil {
+ t.Fatalf("CreateNIC(%d, _): %s", nicID2, err)
+ }
+
+ addr2 := tcpip.Address("\x02")
+ if err := s.AddAddress(nicID2, fakeNetNumber, addr2); err != nil {
+ t.Fatalf("AddAddress(%d, %d, %s): %s", nicID2, fakeNetNumber, addr2, err)
+ }
+
+ // Set a route table that sends all packets with odd destination
+ // addresses through the first NIC, and all even destination address
+ // through the second one.
+ {
+ subnet0, err := tcpip.NewSubnet("\x00", "\x01")
+ if err != nil {
+ t.Fatal(err)
+ }
+ subnet1, err := tcpip.NewSubnet("\x01", "\x01")
+ if err != nil {
+ t.Fatal(err)
+ }
+ s.SetRouteTable([]tcpip.Route{
+ {Destination: subnet1, Gateway: "\x00", NIC: nicID1},
+ {Destination: subnet0, Gateway: "\x00", NIC: nicID2},
+ })
+ }
+
+ // Test routes to odd address.
+ testRoute(t, s, unspecifiedNIC, "", "\x05", addr1)
+ testRoute(t, s, unspecifiedNIC, addr1, "\x05", addr1)
+ testRoute(t, s, nicID1, addr1, "\x05", addr1)
+
+ // Test routes to even address.
+ testRoute(t, s, unspecifiedNIC, "", "\x06", addr2)
+ testRoute(t, s, unspecifiedNIC, addr2, "\x06", addr2)
+ testRoute(t, s, nicID2, addr2, "\x06", addr2)
+
+ // Disabling NIC1 should result in no routes to odd addresses. Routes to even
+ // addresses should continue to be available as NIC2 is still enabled.
+ if err := s.DisableNIC(nicID1); err != nil {
+ t.Fatalf("s.DisableNIC(%d): %s", nicID1, err)
+ }
+ nic1Dst := tcpip.Address("\x05")
+ testNoRoute(t, s, unspecifiedNIC, "", nic1Dst)
+ testNoRoute(t, s, unspecifiedNIC, addr1, nic1Dst)
+ testNoRoute(t, s, nicID1, addr1, nic1Dst)
+ nic2Dst := tcpip.Address("\x06")
+ testRoute(t, s, unspecifiedNIC, "", nic2Dst, addr2)
+ testRoute(t, s, unspecifiedNIC, addr2, nic2Dst, addr2)
+ testRoute(t, s, nicID2, addr2, nic2Dst, addr2)
+
+ // Disabling NIC2 should result in no routes to even addresses. No route
+ // should be available to any address as routes to odd addresses were made
+ // unavailable by disabling NIC1 above.
+ if err := s.DisableNIC(nicID2); err != nil {
+ t.Fatalf("s.DisableNIC(%d): %s", nicID2, err)
+ }
+ testNoRoute(t, s, unspecifiedNIC, "", nic1Dst)
+ testNoRoute(t, s, unspecifiedNIC, addr1, nic1Dst)
+ testNoRoute(t, s, nicID1, addr1, nic1Dst)
+ testNoRoute(t, s, unspecifiedNIC, "", nic2Dst)
+ testNoRoute(t, s, unspecifiedNIC, addr2, nic2Dst)
+ testNoRoute(t, s, nicID2, addr2, nic2Dst)
+
+ // Enabling NIC1 should make routes to odd addresses available again. Routes
+ // to even addresses should continue to be unavailable as NIC2 is still
+ // disabled.
+ if err := s.EnableNIC(nicID1); err != nil {
+ t.Fatalf("s.EnableNIC(%d): %s", nicID1, err)
+ }
+ testRoute(t, s, unspecifiedNIC, "", nic1Dst, addr1)
+ testRoute(t, s, unspecifiedNIC, addr1, nic1Dst, addr1)
+ testRoute(t, s, nicID1, addr1, nic1Dst, addr1)
+ testNoRoute(t, s, unspecifiedNIC, "", nic2Dst)
+ testNoRoute(t, s, unspecifiedNIC, addr2, nic2Dst)
+ testNoRoute(t, s, nicID2, addr2, nic2Dst)
+}
+
+func TestRouteWritePacketWithDisabledNIC(t *testing.T) {
+ const unspecifiedNIC = 0
+ const nicID1 = 1
+ const nicID2 = 2
+
+ s := stack.New(stack.Options{
+ NetworkProtocols: []stack.NetworkProtocol{fakeNetFactory()},
+ })
+
+ ep1 := channel.New(1, defaultMTU, "")
+ if err := s.CreateNIC(nicID1, ep1); err != nil {
+ t.Fatalf("CreateNIC(%d, _): %s", nicID1, err)
+ }
+
+ addr1 := tcpip.Address("\x01")
+ if err := s.AddAddress(nicID1, fakeNetNumber, addr1); err != nil {
+ t.Fatalf("AddAddress(%d, %d, %s): %s", nicID1, fakeNetNumber, addr1, err)
+ }
+
+ ep2 := channel.New(1, defaultMTU, "")
+ if err := s.CreateNIC(nicID2, ep2); err != nil {
+ t.Fatalf("CreateNIC(%d, _): %s", nicID2, err)
+ }
+
+ addr2 := tcpip.Address("\x02")
+ if err := s.AddAddress(nicID2, fakeNetNumber, addr2); err != nil {
+ t.Fatalf("AddAddress(%d, %d, %s): %s", nicID2, fakeNetNumber, addr2, err)
+ }
+
+ // Set a route table that sends all packets with odd destination
+ // addresses through the first NIC, and all even destination address
+ // through the second one.
+ {
+ subnet0, err := tcpip.NewSubnet("\x00", "\x01")
+ if err != nil {
+ t.Fatal(err)
+ }
+ subnet1, err := tcpip.NewSubnet("\x01", "\x01")
+ if err != nil {
+ t.Fatal(err)
+ }
+ s.SetRouteTable([]tcpip.Route{
+ {Destination: subnet1, Gateway: "\x00", NIC: nicID1},
+ {Destination: subnet0, Gateway: "\x00", NIC: nicID2},
+ })
+ }
+
+ nic1Dst := tcpip.Address("\x05")
+ r1, err := s.FindRoute(nicID1, addr1, nic1Dst, fakeNetNumber, false /* multicastLoop */)
+ if err != nil {
+ t.Errorf("FindRoute(%d, %s, %s, %d, false): %s", nicID1, addr1, nic1Dst, fakeNetNumber, err)
+ }
+ defer r1.Release()
+
+ nic2Dst := tcpip.Address("\x06")
+ r2, err := s.FindRoute(nicID2, addr2, nic2Dst, fakeNetNumber, false /* multicastLoop */)
+ if err != nil {
+ t.Errorf("FindRoute(%d, %s, %s, %d, false): %s", nicID2, addr2, nic2Dst, fakeNetNumber, err)
+ }
+ defer r2.Release()
+
+ // If we failed to get routes r1 or r2, we cannot proceed with the test.
+ if t.Failed() {
+ t.FailNow()
+ }
+
+ buf := buffer.View([]byte{1})
+ testSend(t, r1, ep1, buf)
+ testSend(t, r2, ep2, buf)
+
+ // Writes with Routes that use the disabled NIC1 should fail.
+ if err := s.DisableNIC(nicID1); err != nil {
+ t.Fatalf("s.DisableNIC(%d): %s", nicID1, err)
+ }
+ testFailingSend(t, r1, ep1, buf, tcpip.ErrInvalidEndpointState)
+ testSend(t, r2, ep2, buf)
+
+ // Writes with Routes that use the disabled NIC2 should fail.
+ if err := s.DisableNIC(nicID2); err != nil {
+ t.Fatalf("s.DisableNIC(%d): %s", nicID2, err)
+ }
+ testFailingSend(t, r1, ep1, buf, tcpip.ErrInvalidEndpointState)
+ testFailingSend(t, r2, ep2, buf, tcpip.ErrInvalidEndpointState)
+
+ // Writes with Routes that use the re-enabled NIC1 should succeed.
+ // TODO(b/147015577): Should we instead completely invalidate all Routes that
+ // were bound to a disabled NIC at some point?
+ if err := s.EnableNIC(nicID1); err != nil {
+ t.Fatalf("s.EnableNIC(%d): %s", nicID1, err)
+ }
+ testSend(t, r1, ep1, buf)
+ testFailingSend(t, r2, ep2, buf, tcpip.ErrInvalidEndpointState)
+}
+
func TestRoutes(t *testing.T) {
// Create a stack with the fake network protocol, two nics, and two
// addresses per nic, the first nic has odd address, the second one has
@@ -2173,13 +2487,29 @@ func TestNICAutoGenLinkLocalAddr(t *testing.T) {
e := channel.New(0, 1280, test.linkAddr)
s := stack.New(opts)
- nicOpts := stack.NICOptions{Name: test.nicName}
+ nicOpts := stack.NICOptions{Name: test.nicName, Disabled: true}
if err := s.CreateNICWithOptions(nicID, e, nicOpts); err != nil {
t.Fatalf("CreateNICWithOptions(%d, _, %+v) = %s", nicID, opts, err)
}
- var expectedMainAddr tcpip.AddressWithPrefix
+ // A new disabled NIC should not have any address, even if auto generation
+ // was enabled.
+ allStackAddrs := s.AllAddresses()
+ allNICAddrs, ok := allStackAddrs[nicID]
+ if !ok {
+ t.Fatalf("entry for %d missing from allStackAddrs = %+v", nicID, allStackAddrs)
+ }
+ if l := len(allNICAddrs); l != 0 {
+ t.Fatalf("got len(allNICAddrs) = %d, want = 0", l)
+ }
+
+ // Enabling the NIC should attempt auto-generation of a link-local
+ // address.
+ if err := s.EnableNIC(nicID); err != nil {
+ t.Fatalf("s.EnableNIC(%d): %s", nicID, err)
+ }
+ var expectedMainAddr tcpip.AddressWithPrefix
if test.shouldGen {
expectedMainAddr = tcpip.AddressWithPrefix{
Address: test.expectedAddr,
@@ -2460,13 +2790,14 @@ func TestNewPEBOnPromotionToPermanent(t *testing.T) {
func TestIPv6SourceAddressSelectionScopeAndSameAddress(t *testing.T) {
const (
- linkLocalAddr1 = tcpip.Address("\xfe\x80\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x01")
- linkLocalAddr2 = tcpip.Address("\xfe\x80\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x02")
- uniqueLocalAddr1 = tcpip.Address("\xfc\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x01")
- uniqueLocalAddr2 = tcpip.Address("\xfd\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x02")
- globalAddr1 = tcpip.Address("\xa0\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x01")
- globalAddr2 = tcpip.Address("\xa0\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x02")
- nicID = 1
+ linkLocalAddr1 = tcpip.Address("\xfe\x80\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x01")
+ linkLocalAddr2 = tcpip.Address("\xfe\x80\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x02")
+ linkLocalMulticastAddr = tcpip.Address("\xff\x02\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x01")
+ uniqueLocalAddr1 = tcpip.Address("\xfc\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x01")
+ uniqueLocalAddr2 = tcpip.Address("\xfd\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x02")
+ globalAddr1 = tcpip.Address("\xa0\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x01")
+ globalAddr2 = tcpip.Address("\xa0\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x02")
+ nicID = 1
)
// Rule 3 is not tested here, and is instead tested by NDP's AutoGenAddr test.
@@ -2540,6 +2871,18 @@ func TestIPv6SourceAddressSelectionScopeAndSameAddress(t *testing.T) {
expectedLocalAddr: linkLocalAddr1,
},
{
+ name: "Link Local most preferred for link local multicast (last address)",
+ nicAddrs: []tcpip.Address{globalAddr1, uniqueLocalAddr1, linkLocalAddr1},
+ connectAddr: linkLocalMulticastAddr,
+ expectedLocalAddr: linkLocalAddr1,
+ },
+ {
+ name: "Link Local most preferred for link local multicast (first address)",
+ nicAddrs: []tcpip.Address{linkLocalAddr1, uniqueLocalAddr1, globalAddr1},
+ connectAddr: linkLocalMulticastAddr,
+ expectedLocalAddr: linkLocalAddr1,
+ },
+ {
name: "Unique Local most preferred (last address)",
nicAddrs: []tcpip.Address{uniqueLocalAddr1, globalAddr1, linkLocalAddr1},
connectAddr: uniqueLocalAddr2,
@@ -2609,6 +2952,111 @@ func TestIPv6SourceAddressSelectionScopeAndSameAddress(t *testing.T) {
}
}
+func TestAddRemoveIPv4BroadcastAddressOnNICEnableDisable(t *testing.T) {
+ const nicID = 1
+
+ e := loopback.New()
+ s := stack.New(stack.Options{
+ NetworkProtocols: []stack.NetworkProtocol{ipv4.NewProtocol()},
+ })
+ nicOpts := stack.NICOptions{Disabled: true}
+ if err := s.CreateNICWithOptions(nicID, e, nicOpts); err != nil {
+ t.Fatalf("CreateNIC(%d, _, %+v) = %s", nicID, nicOpts, err)
+ }
+
+ allStackAddrs := s.AllAddresses()
+ allNICAddrs, ok := allStackAddrs[nicID]
+ if !ok {
+ t.Fatalf("entry for %d missing from allStackAddrs = %+v", nicID, allStackAddrs)
+ }
+ if l := len(allNICAddrs); l != 0 {
+ t.Fatalf("got len(allNICAddrs) = %d, want = 0", l)
+ }
+
+ // Enabling the NIC should add the IPv4 broadcast address.
+ if err := s.EnableNIC(nicID); err != nil {
+ t.Fatalf("s.EnableNIC(%d): %s", nicID, err)
+ }
+ allStackAddrs = s.AllAddresses()
+ allNICAddrs, ok = allStackAddrs[nicID]
+ if !ok {
+ t.Fatalf("entry for %d missing from allStackAddrs = %+v", nicID, allStackAddrs)
+ }
+ if l := len(allNICAddrs); l != 1 {
+ t.Fatalf("got len(allNICAddrs) = %d, want = 1", l)
+ }
+ want := tcpip.ProtocolAddress{
+ Protocol: header.IPv4ProtocolNumber,
+ AddressWithPrefix: tcpip.AddressWithPrefix{
+ Address: header.IPv4Broadcast,
+ PrefixLen: 32,
+ },
+ }
+ if allNICAddrs[0] != want {
+ t.Fatalf("got allNICAddrs[0] = %+v, want = %+v", allNICAddrs[0], want)
+ }
+
+ // Disabling the NIC should remove the IPv4 broadcast address.
+ if err := s.DisableNIC(nicID); err != nil {
+ t.Fatalf("s.DisableNIC(%d): %s", nicID, err)
+ }
+ allStackAddrs = s.AllAddresses()
+ allNICAddrs, ok = allStackAddrs[nicID]
+ if !ok {
+ t.Fatalf("entry for %d missing from allStackAddrs = %+v", nicID, allStackAddrs)
+ }
+ if l := len(allNICAddrs); l != 0 {
+ t.Fatalf("got len(allNICAddrs) = %d, want = 0", l)
+ }
+}
+
+func TestJoinLeaveAllNodesMulticastOnNICEnableDisable(t *testing.T) {
+ const nicID = 1
+
+ e := loopback.New()
+ s := stack.New(stack.Options{
+ NetworkProtocols: []stack.NetworkProtocol{ipv6.NewProtocol()},
+ })
+ nicOpts := stack.NICOptions{Disabled: true}
+ if err := s.CreateNICWithOptions(nicID, e, nicOpts); err != nil {
+ t.Fatalf("CreateNIC(%d, _, %+v) = %s", nicID, nicOpts, err)
+ }
+
+ // Should not be in the IPv6 all-nodes multicast group yet because the NIC has
+ // not been enabled yet.
+ isInGroup, err := s.IsInGroup(nicID, header.IPv6AllNodesMulticastAddress)
+ if err != nil {
+ t.Fatalf("IsInGroup(%d, %s): %s", nicID, header.IPv6AllNodesMulticastAddress, err)
+ }
+ if isInGroup {
+ t.Fatalf("got IsInGroup(%d, %s) = true, want = false", nicID, header.IPv6AllNodesMulticastAddress)
+ }
+
+ // The all-nodes multicast group should be joined when the NIC is enabled.
+ if err := s.EnableNIC(nicID); err != nil {
+ t.Fatalf("s.EnableNIC(%d): %s", nicID, err)
+ }
+ isInGroup, err = s.IsInGroup(nicID, header.IPv6AllNodesMulticastAddress)
+ if err != nil {
+ t.Fatalf("IsInGroup(%d, %s): %s", nicID, header.IPv6AllNodesMulticastAddress, err)
+ }
+ if !isInGroup {
+ t.Fatalf("got IsInGroup(%d, %s) = false, want = true", nicID, header.IPv6AllNodesMulticastAddress)
+ }
+
+ // The all-nodes multicast group should be left when the NIC is disabled.
+ if err := s.DisableNIC(nicID); err != nil {
+ t.Fatalf("s.DisableNIC(%d): %s", nicID, err)
+ }
+ isInGroup, err = s.IsInGroup(nicID, header.IPv6AllNodesMulticastAddress)
+ if err != nil {
+ t.Fatalf("IsInGroup(%d, %s): %s", nicID, header.IPv6AllNodesMulticastAddress, err)
+ }
+ if isInGroup {
+ t.Fatalf("got IsInGroup(%d, %s) = true, want = false", nicID, header.IPv6AllNodesMulticastAddress)
+ }
+}
+
// TestDoDADWhenNICEnabled tests that IPv6 endpoints that were added while a NIC
// was disabled have DAD performed on them when the NIC is enabled.
func TestDoDADWhenNICEnabled(t *testing.T) {
diff --git a/pkg/tcpip/stack/transport_demuxer.go b/pkg/tcpip/stack/transport_demuxer.go
index d686e6eb8..778c0a4d6 100644
--- a/pkg/tcpip/stack/transport_demuxer.go
+++ b/pkg/tcpip/stack/transport_demuxer.go
@@ -306,26 +306,6 @@ func (ep *multiPortEndpoint) handlePacketAll(r *Route, id TransportEndpointID, p
ep.mu.RUnlock() // Don't use defer for performance reasons.
}
-// Close implements stack.TransportEndpoint.Close.
-func (ep *multiPortEndpoint) Close() {
- ep.mu.RLock()
- eps := append([]TransportEndpoint(nil), ep.endpointsArr...)
- ep.mu.RUnlock()
- for _, e := range eps {
- e.Close()
- }
-}
-
-// Wait implements stack.TransportEndpoint.Wait.
-func (ep *multiPortEndpoint) Wait() {
- ep.mu.RLock()
- eps := append([]TransportEndpoint(nil), ep.endpointsArr...)
- ep.mu.RUnlock()
- for _, e := range eps {
- e.Wait()
- }
-}
-
// singleRegisterEndpoint tries to add an endpoint to the multiPortEndpoint
// list. The list might be empty already.
func (ep *multiPortEndpoint) singleRegisterEndpoint(t TransportEndpoint, reusePort bool) *tcpip.Error {
diff --git a/pkg/tcpip/stack/transport_test.go b/pkg/tcpip/stack/transport_test.go
index 869c69a6d..5d1da2f8b 100644
--- a/pkg/tcpip/stack/transport_test.go
+++ b/pkg/tcpip/stack/transport_test.go
@@ -61,6 +61,10 @@ func newFakeTransportEndpoint(s *stack.Stack, proto *fakeTransportProtocol, netP
return &fakeTransportEndpoint{stack: s, TransportEndpointInfo: stack.TransportEndpointInfo{NetProto: netProto}, proto: proto, uniqueID: uniqueID}
}
+func (f *fakeTransportEndpoint) Abort() {
+ f.Close()
+}
+
func (f *fakeTransportEndpoint) Close() {
f.route.Release()
}
@@ -272,7 +276,7 @@ func (f *fakeTransportProtocol) NewEndpoint(stack *stack.Stack, netProto tcpip.N
return newFakeTransportEndpoint(stack, f, netProto, stack.UniqueID()), nil
}
-func (f *fakeTransportProtocol) NewRawEndpoint(stack *stack.Stack, netProto tcpip.NetworkProtocolNumber, _ *waiter.Queue) (tcpip.Endpoint, *tcpip.Error) {
+func (*fakeTransportProtocol) NewRawEndpoint(stack *stack.Stack, netProto tcpip.NetworkProtocolNumber, _ *waiter.Queue) (tcpip.Endpoint, *tcpip.Error) {
return nil, tcpip.ErrUnknownProtocol
}
@@ -310,6 +314,15 @@ func (f *fakeTransportProtocol) Option(option interface{}) *tcpip.Error {
}
}
+// Abort implements TransportProtocol.Abort.
+func (*fakeTransportProtocol) Abort() {}
+
+// Close implements tcpip.Endpoint.Close.
+func (*fakeTransportProtocol) Close() {}
+
+// Wait implements TransportProtocol.Wait.
+func (*fakeTransportProtocol) Wait() {}
+
func fakeTransFactory() stack.TransportProtocol {
return &fakeTransportProtocol{}
}
diff --git a/pkg/tcpip/tcpip.go b/pkg/tcpip/tcpip.go
index ce5527391..3dc5d87d6 100644
--- a/pkg/tcpip/tcpip.go
+++ b/pkg/tcpip/tcpip.go
@@ -341,9 +341,15 @@ type ControlMessages struct {
// networking stack.
type Endpoint interface {
// Close puts the endpoint in a closed state and frees all resources
- // associated with it.
+ // associated with it. Close initiates the teardown process, the
+ // Endpoint may not be fully closed when Close returns.
Close()
+ // Abort initiates an expedited endpoint teardown. As compared to
+ // Close, Abort prioritizes closing the Endpoint quickly over cleanly.
+ // Abort is best effort; implementing Abort with Close is acceptable.
+ Abort()
+
// Read reads data from the endpoint and optionally returns the sender.
//
// This method does not block if there is no data pending. It will also
diff --git a/pkg/tcpip/time_unsafe.go b/pkg/tcpip/time_unsafe.go
index 48764b978..2f98a996f 100644
--- a/pkg/tcpip/time_unsafe.go
+++ b/pkg/tcpip/time_unsafe.go
@@ -25,6 +25,8 @@ import (
)
// StdClock implements Clock with the time package.
+//
+// +stateify savable
type StdClock struct{}
var _ Clock = (*StdClock)(nil)
diff --git a/pkg/tcpip/transport/icmp/endpoint.go b/pkg/tcpip/transport/icmp/endpoint.go
index 42afb3f5b..426da1ee6 100644
--- a/pkg/tcpip/transport/icmp/endpoint.go
+++ b/pkg/tcpip/transport/icmp/endpoint.go
@@ -96,6 +96,11 @@ func (e *endpoint) UniqueID() uint64 {
return e.uniqueID
}
+// Abort implements stack.TransportEndpoint.Abort.
+func (e *endpoint) Abort() {
+ e.Close()
+}
+
// Close puts the endpoint in a closed state and frees all resources
// associated with it.
func (e *endpoint) Close() {
diff --git a/pkg/tcpip/transport/icmp/protocol.go b/pkg/tcpip/transport/icmp/protocol.go
index 9ce500e80..113d92901 100644
--- a/pkg/tcpip/transport/icmp/protocol.go
+++ b/pkg/tcpip/transport/icmp/protocol.go
@@ -104,20 +104,26 @@ func (p *protocol) ParsePorts(v buffer.View) (src, dst uint16, err *tcpip.Error)
// HandleUnknownDestinationPacket handles packets targeted at this protocol but
// that don't match any existing endpoint.
-func (p *protocol) HandleUnknownDestinationPacket(*stack.Route, stack.TransportEndpointID, tcpip.PacketBuffer) bool {
+func (*protocol) HandleUnknownDestinationPacket(*stack.Route, stack.TransportEndpointID, tcpip.PacketBuffer) bool {
return true
}
-// SetOption implements TransportProtocol.SetOption.
-func (p *protocol) SetOption(option interface{}) *tcpip.Error {
+// SetOption implements stack.TransportProtocol.SetOption.
+func (*protocol) SetOption(option interface{}) *tcpip.Error {
return tcpip.ErrUnknownProtocolOption
}
-// Option implements TransportProtocol.Option.
-func (p *protocol) Option(option interface{}) *tcpip.Error {
+// Option implements stack.TransportProtocol.Option.
+func (*protocol) Option(option interface{}) *tcpip.Error {
return tcpip.ErrUnknownProtocolOption
}
+// Close implements stack.TransportProtocol.Close.
+func (*protocol) Close() {}
+
+// Wait implements stack.TransportProtocol.Wait.
+func (*protocol) Wait() {}
+
// NewProtocol4 returns an ICMPv4 transport protocol.
func NewProtocol4() stack.TransportProtocol {
return &protocol{ProtocolNumber4}
diff --git a/pkg/tcpip/transport/packet/endpoint.go b/pkg/tcpip/transport/packet/endpoint.go
index fc5bc69fa..5722815e9 100644
--- a/pkg/tcpip/transport/packet/endpoint.go
+++ b/pkg/tcpip/transport/packet/endpoint.go
@@ -98,6 +98,11 @@ func NewEndpoint(s *stack.Stack, cooked bool, netProto tcpip.NetworkProtocolNumb
return ep, nil
}
+// Abort implements stack.TransportEndpoint.Abort.
+func (e *endpoint) Abort() {
+ e.Close()
+}
+
// Close implements tcpip.Endpoint.Close.
func (ep *endpoint) Close() {
ep.mu.Lock()
diff --git a/pkg/tcpip/transport/raw/endpoint.go b/pkg/tcpip/transport/raw/endpoint.go
index ee9c4c58b..2ef5fac76 100644
--- a/pkg/tcpip/transport/raw/endpoint.go
+++ b/pkg/tcpip/transport/raw/endpoint.go
@@ -121,6 +121,11 @@ func newEndpoint(s *stack.Stack, netProto tcpip.NetworkProtocolNumber, transProt
return e, nil
}
+// Abort implements stack.TransportEndpoint.Abort.
+func (e *endpoint) Abort() {
+ e.Close()
+}
+
// Close implements tcpip.Endpoint.Close.
func (e *endpoint) Close() {
e.mu.Lock()
diff --git a/pkg/tcpip/transport/tcp/accept.go b/pkg/tcpip/transport/tcp/accept.go
index 08afb7c17..13e383ffc 100644
--- a/pkg/tcpip/transport/tcp/accept.go
+++ b/pkg/tcpip/transport/tcp/accept.go
@@ -299,6 +299,13 @@ func (l *listenContext) createEndpointAndPerformHandshake(s *segment, opts *head
h := newPassiveHandshake(ep, seqnum.Size(ep.initialReceiveWindow()), isn, irs, opts, deferAccept)
if err := h.execute(); err != nil {
ep.Close()
+ // Wake up any waiters. This is strictly not required normally
+ // as a socket that was never accepted can't really have any
+ // registered waiters except when stack.Wait() is called which
+ // waits for all registered endpoints to stop and expects an
+ // EventHUp.
+ ep.waiterQueue.Notify(waiter.EventHUp | waiter.EventErr | waiter.EventIn | waiter.EventOut)
+
if l.listenEP != nil {
l.removePendingEndpoint(ep)
}
@@ -607,7 +614,7 @@ func (e *endpoint) protocolListenLoop(rcvWnd seqnum.Size) *tcpip.Error {
e.mu.Unlock()
// Notify waiters that the endpoint is shutdown.
- e.waiterQueue.Notify(waiter.EventIn | waiter.EventOut)
+ e.waiterQueue.Notify(waiter.EventIn | waiter.EventOut | waiter.EventHUp | waiter.EventErr)
}()
s := sleep.Sleeper{}
diff --git a/pkg/tcpip/transport/tcp/connect.go b/pkg/tcpip/transport/tcp/connect.go
index 5c5397823..7730e6445 100644
--- a/pkg/tcpip/transport/tcp/connect.go
+++ b/pkg/tcpip/transport/tcp/connect.go
@@ -1372,7 +1372,7 @@ func (e *endpoint) protocolMainLoop(handshake bool, wakerInitDone chan<- struct{
e.snd.updateMaxPayloadSize(mtu, count)
}
- if n&notifyReset != 0 {
+ if n&notifyReset != 0 || n&notifyAbort != 0 {
return tcpip.ErrConnectionAborted
}
@@ -1655,7 +1655,7 @@ func (e *endpoint) doTimeWait() (twReuse func()) {
}
case notification:
n := e.fetchNotifications()
- if n&notifyClose != 0 {
+ if n&notifyClose != 0 || n&notifyAbort != 0 {
return nil
}
if n&notifyDrain != 0 {
diff --git a/pkg/tcpip/transport/tcp/dispatcher.go b/pkg/tcpip/transport/tcp/dispatcher.go
index e18012ac0..d792b07d6 100644
--- a/pkg/tcpip/transport/tcp/dispatcher.go
+++ b/pkg/tcpip/transport/tcp/dispatcher.go
@@ -68,17 +68,28 @@ func (q *epQueue) empty() bool {
type processor struct {
epQ epQueue
newEndpointWaker sleep.Waker
+ closeWaker sleep.Waker
id int
+ wg sync.WaitGroup
}
func newProcessor(id int) *processor {
p := &processor{
id: id,
}
+ p.wg.Add(1)
go p.handleSegments()
return p
}
+func (p *processor) close() {
+ p.closeWaker.Assert()
+}
+
+func (p *processor) wait() {
+ p.wg.Wait()
+}
+
func (p *processor) queueEndpoint(ep *endpoint) {
// Queue an endpoint for processing by the processor goroutine.
p.epQ.enqueue(ep)
@@ -87,11 +98,17 @@ func (p *processor) queueEndpoint(ep *endpoint) {
func (p *processor) handleSegments() {
const newEndpointWaker = 1
+ const closeWaker = 2
s := sleep.Sleeper{}
s.AddWaker(&p.newEndpointWaker, newEndpointWaker)
+ s.AddWaker(&p.closeWaker, closeWaker)
defer s.Done()
for {
- s.Fetch(true)
+ id, ok := s.Fetch(true)
+ if ok && id == closeWaker {
+ p.wg.Done()
+ return
+ }
for ep := p.epQ.dequeue(); ep != nil; ep = p.epQ.dequeue() {
if ep.segmentQueue.empty() {
continue
@@ -160,6 +177,18 @@ func newDispatcher(nProcessors int) *dispatcher {
}
}
+func (d *dispatcher) close() {
+ for _, p := range d.processors {
+ p.close()
+ }
+}
+
+func (d *dispatcher) wait() {
+ for _, p := range d.processors {
+ p.wait()
+ }
+}
+
func (d *dispatcher) queuePacket(r *stack.Route, stackEP stack.TransportEndpoint, id stack.TransportEndpointID, pkt tcpip.PacketBuffer) {
ep := stackEP.(*endpoint)
s := newSegment(r, id, pkt)
diff --git a/pkg/tcpip/transport/tcp/endpoint.go b/pkg/tcpip/transport/tcp/endpoint.go
index f2be0e651..f1ad19dac 100644
--- a/pkg/tcpip/transport/tcp/endpoint.go
+++ b/pkg/tcpip/transport/tcp/endpoint.go
@@ -121,6 +121,8 @@ const (
notifyDrain
notifyReset
notifyResetByPeer
+ // notifyAbort is a request for an expedited teardown.
+ notifyAbort
notifyKeepaliveChanged
notifyMSSChanged
// notifyTickleWorker is used to tickle the protocol main loop during a
@@ -785,6 +787,24 @@ func (e *endpoint) notifyProtocolGoroutine(n uint32) {
}
}
+// Abort implements stack.TransportEndpoint.Abort.
+func (e *endpoint) Abort() {
+ // The abort notification is not processed synchronously, so no
+ // synchronization is needed.
+ //
+ // If the endpoint becomes connected after this check, we still close
+ // the endpoint. This worst case results in a slower abort.
+ //
+ // If the endpoint disconnected after the check, nothing needs to be
+ // done, so sending a notification which will potentially be ignored is
+ // fine.
+ if e.EndpointState().connected() {
+ e.notifyProtocolGoroutine(notifyAbort)
+ return
+ }
+ e.Close()
+}
+
// Close puts the endpoint in a closed state and frees all resources associated
// with it. It must be called only once and with no other concurrent calls to
// the endpoint.
@@ -829,9 +849,18 @@ func (e *endpoint) closeNoShutdown() {
// Either perform the local cleanup or kick the worker to make sure it
// knows it needs to cleanup.
tcpip.AddDanglingEndpoint(e)
- if !e.workerRunning {
+ switch e.EndpointState() {
+ // Sockets in StateSynRecv state(passive connections) are closed when
+ // the handshake fails or if the listening socket is closed while
+ // handshake was in progress. In such cases the handshake goroutine
+ // is already gone by the time Close is called and we need to cleanup
+ // here.
+ case StateInitial, StateBound, StateSynRecv:
e.cleanupLocked()
- } else {
+ e.setEndpointState(StateClose)
+ case StateError, StateClose:
+ // do nothing.
+ default:
e.workerCleanup = true
e.notifyProtocolGoroutine(notifyClose)
}
diff --git a/pkg/tcpip/transport/tcp/protocol.go b/pkg/tcpip/transport/tcp/protocol.go
index 958c06fa7..73098d904 100644
--- a/pkg/tcpip/transport/tcp/protocol.go
+++ b/pkg/tcpip/transport/tcp/protocol.go
@@ -194,7 +194,7 @@ func replyWithReset(s *segment) {
sendTCP(&s.route, s.id, buffer.VectorisedView{}, s.route.DefaultTTL(), stack.DefaultTOS, flags, seq, ack, 0 /* rcvWnd */, nil /* options */, nil /* gso */)
}
-// SetOption implements TransportProtocol.SetOption.
+// SetOption implements stack.TransportProtocol.SetOption.
func (p *protocol) SetOption(option interface{}) *tcpip.Error {
switch v := option.(type) {
case SACKEnabled:
@@ -269,7 +269,7 @@ func (p *protocol) SetOption(option interface{}) *tcpip.Error {
}
}
-// Option implements TransportProtocol.Option.
+// Option implements stack.TransportProtocol.Option.
func (p *protocol) Option(option interface{}) *tcpip.Error {
switch v := option.(type) {
case *SACKEnabled:
@@ -331,6 +331,16 @@ func (p *protocol) Option(option interface{}) *tcpip.Error {
}
}
+// Close implements stack.TransportProtocol.Close.
+func (p *protocol) Close() {
+ p.dispatcher.close()
+}
+
+// Wait implements stack.TransportProtocol.Wait.
+func (p *protocol) Wait() {
+ p.dispatcher.wait()
+}
+
// NewProtocol returns a TCP transport protocol.
func NewProtocol() stack.TransportProtocol {
return &protocol{
diff --git a/pkg/tcpip/transport/tcp/tcp_test.go b/pkg/tcpip/transport/tcp/tcp_test.go
index cc118c993..5b2b16afa 100644
--- a/pkg/tcpip/transport/tcp/tcp_test.go
+++ b/pkg/tcpip/transport/tcp/tcp_test.go
@@ -543,8 +543,9 @@ func TestCurrentConnectedIncrement(t *testing.T) {
),
)
- // Wait for the TIME-WAIT state to transition to CLOSED.
- time.Sleep(1 * time.Second)
+ // Wait for a little more than the TIME-WAIT duration for the socket to
+ // transition to CLOSED state.
+ time.Sleep(1200 * time.Millisecond)
if got := c.Stack().Stats().TCP.CurrentEstablished.Value(); got != 0 {
t.Errorf("got stats.TCP.CurrentEstablished.Value() = %v, want = 0", got)
diff --git a/pkg/tcpip/transport/udp/endpoint.go b/pkg/tcpip/transport/udp/endpoint.go
index eff7f3600..1c6a600b8 100644
--- a/pkg/tcpip/transport/udp/endpoint.go
+++ b/pkg/tcpip/transport/udp/endpoint.go
@@ -186,6 +186,11 @@ func (e *endpoint) UniqueID() uint64 {
return e.uniqueID
}
+// Abort implements stack.TransportEndpoint.Abort.
+func (e *endpoint) Abort() {
+ e.Close()
+}
+
// Close puts the endpoint in a closed state and frees all resources
// associated with it.
func (e *endpoint) Close() {
diff --git a/pkg/tcpip/transport/udp/protocol.go b/pkg/tcpip/transport/udp/protocol.go
index 259c3072a..8df089d22 100644
--- a/pkg/tcpip/transport/udp/protocol.go
+++ b/pkg/tcpip/transport/udp/protocol.go
@@ -180,16 +180,22 @@ func (p *protocol) HandleUnknownDestinationPacket(r *stack.Route, id stack.Trans
return true
}
-// SetOption implements TransportProtocol.SetOption.
-func (p *protocol) SetOption(option interface{}) *tcpip.Error {
+// SetOption implements stack.TransportProtocol.SetOption.
+func (*protocol) SetOption(option interface{}) *tcpip.Error {
return tcpip.ErrUnknownProtocolOption
}
-// Option implements TransportProtocol.Option.
-func (p *protocol) Option(option interface{}) *tcpip.Error {
+// Option implements stack.TransportProtocol.Option.
+func (*protocol) Option(option interface{}) *tcpip.Error {
return tcpip.ErrUnknownProtocolOption
}
+// Close implements stack.TransportProtocol.Close.
+func (*protocol) Close() {}
+
+// Wait implements stack.TransportProtocol.Wait.
+func (*protocol) Wait() {}
+
// NewProtocol returns a UDP transport protocol.
func NewProtocol() stack.TransportProtocol {
return &protocol{}
diff --git a/pkg/usermem/BUILD b/pkg/usermem/BUILD
index ff8b9e91a..6c9ada9c7 100644
--- a/pkg/usermem/BUILD
+++ b/pkg/usermem/BUILD
@@ -25,7 +25,6 @@ go_library(
"bytes_io_unsafe.go",
"usermem.go",
"usermem_arm64.go",
- "usermem_unsafe.go",
"usermem_x86.go",
],
visibility = ["//:sandbox"],
@@ -33,6 +32,7 @@ go_library(
"//pkg/atomicbitops",
"//pkg/binary",
"//pkg/context",
+ "//pkg/gohacks",
"//pkg/log",
"//pkg/safemem",
"//pkg/syserror",
diff --git a/pkg/usermem/usermem.go b/pkg/usermem/usermem.go
index 71fd4e155..d2f4403b0 100644
--- a/pkg/usermem/usermem.go
+++ b/pkg/usermem/usermem.go
@@ -23,6 +23,7 @@ import (
"gvisor.dev/gvisor/pkg/binary"
"gvisor.dev/gvisor/pkg/context"
+ "gvisor.dev/gvisor/pkg/gohacks"
"gvisor.dev/gvisor/pkg/safemem"
"gvisor.dev/gvisor/pkg/syserror"
)
@@ -251,7 +252,7 @@ func CopyStringIn(ctx context.Context, uio IO, addr Addr, maxlen int, opts IOOpt
}
end, ok := addr.AddLength(uint64(readlen))
if !ok {
- return stringFromImmutableBytes(buf[:done]), syserror.EFAULT
+ return gohacks.StringFromImmutableBytes(buf[:done]), syserror.EFAULT
}
// Shorten the read to avoid crossing page boundaries, since faulting
// in a page unnecessarily is expensive. This also ensures that partial
@@ -272,16 +273,16 @@ func CopyStringIn(ctx context.Context, uio IO, addr Addr, maxlen int, opts IOOpt
// Look for the terminating zero byte, which may have occurred before
// hitting err.
if i := bytes.IndexByte(buf[done:done+n], byte(0)); i >= 0 {
- return stringFromImmutableBytes(buf[:done+i]), nil
+ return gohacks.StringFromImmutableBytes(buf[:done+i]), nil
}
done += n
if err != nil {
- return stringFromImmutableBytes(buf[:done]), err
+ return gohacks.StringFromImmutableBytes(buf[:done]), err
}
addr = end
}
- return stringFromImmutableBytes(buf), syserror.ENAMETOOLONG
+ return gohacks.StringFromImmutableBytes(buf), syserror.ENAMETOOLONG
}
// CopyOutVec copies bytes from src to the memory mapped at ars in uio. The
diff --git a/runsc/boot/BUILD b/runsc/boot/BUILD
index ae4dd102a..26f68fe3d 100644
--- a/runsc/boot/BUILD
+++ b/runsc/boot/BUILD
@@ -19,7 +19,6 @@ go_library(
"loader_amd64.go",
"loader_arm64.go",
"network.go",
- "pprof.go",
"strace.go",
"user.go",
],
@@ -91,6 +90,7 @@ go_library(
"//pkg/usermem",
"//runsc/boot/filter",
"//runsc/boot/platforms",
+ "//runsc/boot/pprof",
"//runsc/specutils",
"@com_github_golang_protobuf//proto:go_default_library",
"@com_github_opencontainers_runtime-spec//specs-go:go_default_library",
diff --git a/runsc/boot/controller.go b/runsc/boot/controller.go
index 9c9e94864..17e774e0c 100644
--- a/runsc/boot/controller.go
+++ b/runsc/boot/controller.go
@@ -32,6 +32,7 @@ import (
"gvisor.dev/gvisor/pkg/sentry/watchdog"
"gvisor.dev/gvisor/pkg/tcpip/stack"
"gvisor.dev/gvisor/pkg/urpc"
+ "gvisor.dev/gvisor/runsc/boot/pprof"
"gvisor.dev/gvisor/runsc/specutils"
)
@@ -142,7 +143,7 @@ func newController(fd int, l *Loader) (*controller, error) {
}
srv.Register(manager)
- if eps, ok := l.k.NetworkStack().(*netstack.Stack); ok {
+ if eps, ok := l.k.RootNetworkNamespace().Stack().(*netstack.Stack); ok {
net := &Network{
Stack: eps.Stack,
}
@@ -341,7 +342,7 @@ func (cm *containerManager) Restore(o *RestoreOpts, _ *struct{}) error {
return fmt.Errorf("creating memory file: %v", err)
}
k.SetMemoryFile(mf)
- networkStack := cm.l.k.NetworkStack()
+ networkStack := cm.l.k.RootNetworkNamespace().Stack()
cm.l.k = k
// Set up the restore environment.
@@ -365,9 +366,9 @@ func (cm *containerManager) Restore(o *RestoreOpts, _ *struct{}) error {
}
if cm.l.conf.ProfileEnable {
- // initializePProf opens /proc/self/maps, so has to be
- // called before installing seccomp filters.
- initializePProf()
+ // pprof.Initialize opens /proc/self/maps, so has to be called before
+ // installing seccomp filters.
+ pprof.Initialize()
}
// Seccomp filters have to be applied before parsing the state file.
diff --git a/runsc/boot/filter/config.go b/runsc/boot/filter/config.go
index c69f4c602..a4627905e 100644
--- a/runsc/boot/filter/config.go
+++ b/runsc/boot/filter/config.go
@@ -229,7 +229,9 @@ var allowedSyscalls = seccomp.SyscallRules{
syscall.SYS_NANOSLEEP: {},
syscall.SYS_PPOLL: {},
syscall.SYS_PREAD64: {},
+ syscall.SYS_PREADV: {},
syscall.SYS_PWRITE64: {},
+ syscall.SYS_PWRITEV: {},
syscall.SYS_READ: {},
syscall.SYS_RECVMSG: []seccomp.Rule{
{
diff --git a/runsc/boot/loader.go b/runsc/boot/loader.go
index eef43b9df..e7ca98134 100644
--- a/runsc/boot/loader.go
+++ b/runsc/boot/loader.go
@@ -49,6 +49,7 @@ import (
"gvisor.dev/gvisor/pkg/sentry/watchdog"
"gvisor.dev/gvisor/pkg/sync"
"gvisor.dev/gvisor/pkg/tcpip"
+ "gvisor.dev/gvisor/pkg/tcpip/link/loopback"
"gvisor.dev/gvisor/pkg/tcpip/link/sniffer"
"gvisor.dev/gvisor/pkg/tcpip/network/arp"
"gvisor.dev/gvisor/pkg/tcpip/network/ipv4"
@@ -60,6 +61,7 @@ import (
"gvisor.dev/gvisor/pkg/tcpip/transport/udp"
"gvisor.dev/gvisor/runsc/boot/filter"
_ "gvisor.dev/gvisor/runsc/boot/platforms" // register all platforms.
+ "gvisor.dev/gvisor/runsc/boot/pprof"
"gvisor.dev/gvisor/runsc/specutils"
// Include supported socket providers.
@@ -230,11 +232,8 @@ func New(args Args) (*Loader, error) {
return nil, fmt.Errorf("enabling strace: %v", err)
}
- // Create an empty network stack because the network namespace may be empty at
- // this point. Netns is configured before Run() is called. Netstack is
- // configured using a control uRPC message. Host network is configured inside
- // Run().
- networkStack, err := newEmptyNetworkStack(args.Conf, k, k)
+ // Create root network namespace/stack.
+ netns, err := newRootNetworkNamespace(args.Conf, k, k)
if err != nil {
return nil, fmt.Errorf("creating network: %v", err)
}
@@ -277,7 +276,7 @@ func New(args Args) (*Loader, error) {
FeatureSet: cpuid.HostFeatureSet(),
Timekeeper: tk,
RootUserNamespace: creds.UserNamespace,
- NetworkStack: networkStack,
+ RootNetworkNamespace: netns,
ApplicationCores: uint(args.NumCPU),
Vdso: vdso,
RootUTSNamespace: kernel.NewUTSNamespace(args.Spec.Hostname, args.Spec.Hostname, creds.UserNamespace),
@@ -466,7 +465,7 @@ func (l *Loader) run() error {
// Delay host network configuration to this point because network namespace
// is configured after the loader is created and before Run() is called.
log.Debugf("Configuring host network")
- stack := l.k.NetworkStack().(*hostinet.Stack)
+ stack := l.k.RootNetworkNamespace().Stack().(*hostinet.Stack)
if err := stack.Configure(); err != nil {
return err
}
@@ -485,7 +484,7 @@ func (l *Loader) run() error {
// l.restore is set by the container manager when a restore call is made.
if !l.restore {
if l.conf.ProfileEnable {
- initializePProf()
+ pprof.Initialize()
}
// Finally done with all configuration. Setup filters before user code
@@ -908,48 +907,92 @@ func (l *Loader) WaitExit() kernel.ExitStatus {
return l.k.GlobalInit().ExitStatus()
}
-func newEmptyNetworkStack(conf *Config, clock tcpip.Clock, uniqueID stack.UniqueID) (inet.Stack, error) {
+func newRootNetworkNamespace(conf *Config, clock tcpip.Clock, uniqueID stack.UniqueID) (*inet.Namespace, error) {
+ // Create an empty network stack because the network namespace may be empty at
+ // this point. Netns is configured before Run() is called. Netstack is
+ // configured using a control uRPC message. Host network is configured inside
+ // Run().
switch conf.Network {
case NetworkHost:
- return hostinet.NewStack(), nil
+ // No network namespacing support for hostinet yet, hence creator is nil.
+ return inet.NewRootNamespace(hostinet.NewStack(), nil), nil
case NetworkNone, NetworkSandbox:
- // NetworkNone sets up loopback using netstack.
- netProtos := []stack.NetworkProtocol{ipv4.NewProtocol(), ipv6.NewProtocol(), arp.NewProtocol()}
- transProtos := []stack.TransportProtocol{tcp.NewProtocol(), udp.NewProtocol(), icmp.NewProtocol4()}
- s := netstack.Stack{stack.New(stack.Options{
- NetworkProtocols: netProtos,
- TransportProtocols: transProtos,
- Clock: clock,
- Stats: netstack.Metrics,
- HandleLocal: true,
- // Enable raw sockets for users with sufficient
- // privileges.
- RawFactory: raw.EndpointFactory{},
- UniqueID: uniqueID,
- })}
-
- // Enable SACK Recovery.
- if err := s.Stack.SetTransportProtocolOption(tcp.ProtocolNumber, tcp.SACKEnabled(true)); err != nil {
- return nil, fmt.Errorf("failed to enable SACK: %v", err)
+ s, err := newEmptySandboxNetworkStack(clock, uniqueID)
+ if err != nil {
+ return nil, err
}
+ creator := &sandboxNetstackCreator{
+ clock: clock,
+ uniqueID: uniqueID,
+ }
+ return inet.NewRootNamespace(s, creator), nil
- // Set default TTLs as required by socket/netstack.
- s.Stack.SetNetworkProtocolOption(ipv4.ProtocolNumber, tcpip.DefaultTTLOption(netstack.DefaultTTL))
- s.Stack.SetNetworkProtocolOption(ipv6.ProtocolNumber, tcpip.DefaultTTLOption(netstack.DefaultTTL))
+ default:
+ panic(fmt.Sprintf("invalid network configuration: %v", conf.Network))
+ }
- // Enable Receive Buffer Auto-Tuning.
- if err := s.Stack.SetTransportProtocolOption(tcp.ProtocolNumber, tcpip.ModerateReceiveBufferOption(true)); err != nil {
- return nil, fmt.Errorf("SetTransportProtocolOption failed: %v", err)
- }
+}
- s.FillDefaultIPTables()
+func newEmptySandboxNetworkStack(clock tcpip.Clock, uniqueID stack.UniqueID) (inet.Stack, error) {
+ netProtos := []stack.NetworkProtocol{ipv4.NewProtocol(), ipv6.NewProtocol(), arp.NewProtocol()}
+ transProtos := []stack.TransportProtocol{tcp.NewProtocol(), udp.NewProtocol(), icmp.NewProtocol4()}
+ s := netstack.Stack{stack.New(stack.Options{
+ NetworkProtocols: netProtos,
+ TransportProtocols: transProtos,
+ Clock: clock,
+ Stats: netstack.Metrics,
+ HandleLocal: true,
+ // Enable raw sockets for users with sufficient
+ // privileges.
+ RawFactory: raw.EndpointFactory{},
+ UniqueID: uniqueID,
+ })}
- return &s, nil
+ // Enable SACK Recovery.
+ if err := s.Stack.SetTransportProtocolOption(tcp.ProtocolNumber, tcp.SACKEnabled(true)); err != nil {
+ return nil, fmt.Errorf("failed to enable SACK: %v", err)
+ }
- default:
- panic(fmt.Sprintf("invalid network configuration: %v", conf.Network))
+ // Set default TTLs as required by socket/netstack.
+ s.Stack.SetNetworkProtocolOption(ipv4.ProtocolNumber, tcpip.DefaultTTLOption(netstack.DefaultTTL))
+ s.Stack.SetNetworkProtocolOption(ipv6.ProtocolNumber, tcpip.DefaultTTLOption(netstack.DefaultTTL))
+
+ // Enable Receive Buffer Auto-Tuning.
+ if err := s.Stack.SetTransportProtocolOption(tcp.ProtocolNumber, tcpip.ModerateReceiveBufferOption(true)); err != nil {
+ return nil, fmt.Errorf("SetTransportProtocolOption failed: %v", err)
+ }
+
+ s.FillDefaultIPTables()
+
+ return &s, nil
+}
+
+// sandboxNetstackCreator implements kernel.NetworkStackCreator.
+//
+// +stateify savable
+type sandboxNetstackCreator struct {
+ clock tcpip.Clock
+ uniqueID stack.UniqueID
+}
+
+// CreateStack implements kernel.NetworkStackCreator.CreateStack.
+func (f *sandboxNetstackCreator) CreateStack() (inet.Stack, error) {
+ s, err := newEmptySandboxNetworkStack(f.clock, f.uniqueID)
+ if err != nil {
+ return nil, err
}
+
+ // Setup loopback.
+ n := &Network{Stack: s.(*netstack.Stack).Stack}
+ nicID := tcpip.NICID(f.uniqueID.UniqueID())
+ link := DefaultLoopbackLink
+ linkEP := loopback.New()
+ if err := n.createNICWithAddrs(nicID, link.Name, linkEP, link.Addresses); err != nil {
+ return nil, err
+ }
+
+ return s, nil
}
// signal sends a signal to one or more processes in a container. If PID is 0,
diff --git a/runsc/boot/network.go b/runsc/boot/network.go
index 6a8765ec8..bee6ee336 100644
--- a/runsc/boot/network.go
+++ b/runsc/boot/network.go
@@ -17,6 +17,7 @@ package boot
import (
"fmt"
"net"
+ "strings"
"syscall"
"gvisor.dev/gvisor/pkg/log"
@@ -31,6 +32,32 @@ import (
"gvisor.dev/gvisor/pkg/urpc"
)
+var (
+ // DefaultLoopbackLink contains IP addresses and routes of "127.0.0.1/8" and
+ // "::1/8" on "lo" interface.
+ DefaultLoopbackLink = LoopbackLink{
+ Name: "lo",
+ Addresses: []net.IP{
+ net.IP("\x7f\x00\x00\x01"),
+ net.IPv6loopback,
+ },
+ Routes: []Route{
+ {
+ Destination: net.IPNet{
+ IP: net.IPv4(0x7f, 0, 0, 0),
+ Mask: net.IPv4Mask(0xff, 0, 0, 0),
+ },
+ },
+ {
+ Destination: net.IPNet{
+ IP: net.IPv6loopback,
+ Mask: net.IPMask(strings.Repeat("\xff", net.IPv6len)),
+ },
+ },
+ },
+ }
+)
+
// Network exposes methods that can be used to configure a network stack.
type Network struct {
Stack *stack.Stack
diff --git a/runsc/boot/pprof/BUILD b/runsc/boot/pprof/BUILD
new file mode 100644
index 000000000..29cb42b2f
--- /dev/null
+++ b/runsc/boot/pprof/BUILD
@@ -0,0 +1,11 @@
+load("//tools:defs.bzl", "go_library")
+
+package(licenses = ["notice"])
+
+go_library(
+ name = "pprof",
+ srcs = ["pprof.go"],
+ visibility = [
+ "//runsc:__subpackages__",
+ ],
+)
diff --git a/runsc/boot/pprof.go b/runsc/boot/pprof/pprof.go
index 463362f02..1ded20dee 100644
--- a/runsc/boot/pprof.go
+++ b/runsc/boot/pprof/pprof.go
@@ -12,7 +12,9 @@
// See the License for the specific language governing permissions and
// limitations under the License.
-package boot
+// Package pprof provides a stub to initialize custom profilers.
+package pprof
-func initializePProf() {
+// Initialize will be called at boot for initializing custom profilers.
+func Initialize() {
}
diff --git a/runsc/container/container_test.go b/runsc/container/container_test.go
index 04a7dc237..bdd65b498 100644
--- a/runsc/container/container_test.go
+++ b/runsc/container/container_test.go
@@ -71,6 +71,7 @@ func waitForProcessCount(cont *Container, want int) error {
return &backoff.PermanentError{Err: err}
}
if got := len(pss); got != want {
+ log.Infof("Waiting for process count to reach %d. Current: %d", want, got)
return fmt.Errorf("wrong process count, got: %d, want: %d", got, want)
}
return nil
diff --git a/runsc/sandbox/network.go b/runsc/sandbox/network.go
index 99e143696..bc093fba5 100644
--- a/runsc/sandbox/network.go
+++ b/runsc/sandbox/network.go
@@ -21,7 +21,6 @@ import (
"path/filepath"
"runtime"
"strconv"
- "strings"
"syscall"
specs "github.com/opencontainers/runtime-spec/specs-go"
@@ -75,30 +74,8 @@ func setupNetwork(conn *urpc.Client, pid int, spec *specs.Spec, conf *boot.Confi
}
func createDefaultLoopbackInterface(conn *urpc.Client) error {
- link := boot.LoopbackLink{
- Name: "lo",
- Addresses: []net.IP{
- net.IP("\x7f\x00\x00\x01"),
- net.IPv6loopback,
- },
- Routes: []boot.Route{
- {
- Destination: net.IPNet{
-
- IP: net.IPv4(0x7f, 0, 0, 0),
- Mask: net.IPv4Mask(0xff, 0, 0, 0),
- },
- },
- {
- Destination: net.IPNet{
- IP: net.IPv6loopback,
- Mask: net.IPMask(strings.Repeat("\xff", net.IPv6len)),
- },
- },
- },
- }
if err := conn.Call(boot.NetworkCreateLinksAndRoutes, &boot.CreateLinksAndRoutesArgs{
- LoopbackLinks: []boot.LoopbackLink{link},
+ LoopbackLinks: []boot.LoopbackLink{boot.DefaultLoopbackLink},
}, nil); err != nil {
return fmt.Errorf("creating loopback link and routes: %v", err)
}
diff --git a/scripts/common_build.sh b/scripts/common_build.sh
index ae8b67383..3be0bb21c 100755
--- a/scripts/common_build.sh
+++ b/scripts/common_build.sh
@@ -70,7 +70,9 @@ function collect_logs() {
for d in `find -L "bazel-testlogs" -name 'shard_*_of_*' | xargs dirname | sort | uniq`; do
junitparser merge `find $d -name test.xml` $d/test.xml
cat $d/shard_*_of_*/test.log > $d/test.log
- ls -l $d/shard_*_of_*/test.outputs/outputs.zip && zip -r -1 $d/outputs.zip $d/shard_*_of_*/test.outputs/outputs.zip
+ if ls -l $d/shard_*_of_*/test.outputs/outputs.zip 2>/dev/null; then
+ zip -r -1 "$d/outputs.zip" $d/shard_*_of_*/test.outputs/outputs.zip
+ fi
done
find -L "bazel-testlogs" -name 'shard_*_of_*' | xargs rm -rf
# Move test logs to Kokoro directory. tar is used to conveniently perform
@@ -90,7 +92,13 @@ function collect_logs() {
echo " gsutil cp gs://gvisor/logs/${KOKORO_BUILD_ARTIFACTS_SUBDIR}/${archive} /tmp"
echo " https://storage.cloud.google.com/gvisor/logs/${KOKORO_BUILD_ARTIFACTS_SUBDIR}/${archive}"
fi
- tar --create --gzip --file="${KOKORO_ARTIFACTS_DIR}/${archive}" -C "${RUNSC_LOGS_DIR}" .
+ time tar \
+ --verbose \
+ --create \
+ --gzip \
+ --file="${KOKORO_ARTIFACTS_DIR}/${archive}" \
+ --directory "${RUNSC_LOGS_DIR}" \
+ .
fi
fi
fi
diff --git a/test/perf/BUILD b/test/perf/BUILD
index 7a2bf10ed..346a28e16 100644
--- a/test/perf/BUILD
+++ b/test/perf/BUILD
@@ -29,7 +29,7 @@ syscall_test(
)
syscall_test(
- size = "large",
+ size = "enormous",
test = "//test/perf/linux:getdents_benchmark",
)
@@ -39,7 +39,7 @@ syscall_test(
)
syscall_test(
- size = "large",
+ size = "enormous",
test = "//test/perf/linux:gettid_benchmark",
)
@@ -87,7 +87,7 @@ syscall_test(
)
syscall_test(
- size = "large",
+ size = "enormous",
test = "//test/perf/linux:signal_benchmark",
)
@@ -102,7 +102,7 @@ syscall_test(
)
syscall_test(
- size = "large",
+ size = "enormous",
add_overlay = True,
test = "//test/perf/linux:unlink_benchmark",
)
diff --git a/test/perf/linux/getdents_benchmark.cc b/test/perf/linux/getdents_benchmark.cc
index 0e03975b4..afc599ad2 100644
--- a/test/perf/linux/getdents_benchmark.cc
+++ b/test/perf/linux/getdents_benchmark.cc
@@ -141,7 +141,7 @@ void BM_GetdentsNewFD(benchmark::State& state) {
state.SetItemsProcessed(state.iterations());
}
-BENCHMARK(BM_GetdentsNewFD)->Range(1, 1 << 16)->UseRealTime();
+BENCHMARK(BM_GetdentsNewFD)->Range(1, 1 << 12)->UseRealTime();
} // namespace
diff --git a/test/runner/defs.bzl b/test/runner/defs.bzl
index 5e97c1867..56743a526 100644
--- a/test/runner/defs.bzl
+++ b/test/runner/defs.bzl
@@ -1,6 +1,6 @@
"""Defines a rule for syscall test targets."""
-load("//tools:defs.bzl", "loopback")
+load("//tools:defs.bzl", "default_platform", "loopback", "platforms")
def _runner_test_impl(ctx):
# Generate a runner binary.
@@ -94,19 +94,6 @@ def _syscall_test(
# Disable off-host networking.
tags.append("requires-net:loopback")
- # Add tag to prevent the tests from running in a Bazel sandbox.
- # TODO(b/120560048): Make the tests run without this tag.
- tags.append("no-sandbox")
-
- # TODO(b/112165693): KVM tests are tagged "manual" to until the platform is
- # more stable.
- if platform == "kvm":
- tags.append("manual")
- tags.append("requires-kvm")
-
- # TODO(b/112165693): Remove when tests pass reliably.
- tags.append("notap")
-
runner_args = [
# Arguments are passed directly to runner binary.
"--platform=" + platform,
@@ -149,6 +136,8 @@ def syscall_test(
add_hostinet: add a hostinet test.
tags: starting test tags.
"""
+ if not tags:
+ tags = []
_syscall_test(
test = test,
@@ -160,35 +149,26 @@ def syscall_test(
tags = tags,
)
- _syscall_test(
- test = test,
- shard_count = shard_count,
- size = size,
- platform = "kvm",
- use_tmpfs = use_tmpfs,
- add_uds_tree = add_uds_tree,
- tags = tags,
- )
-
- _syscall_test(
- test = test,
- shard_count = shard_count,
- size = size,
- platform = "ptrace",
- use_tmpfs = use_tmpfs,
- add_uds_tree = add_uds_tree,
- tags = tags,
- )
+ for (platform, platform_tags) in platforms.items():
+ _syscall_test(
+ test = test,
+ shard_count = shard_count,
+ size = size,
+ platform = platform,
+ use_tmpfs = use_tmpfs,
+ add_uds_tree = add_uds_tree,
+ tags = platform_tags + tags,
+ )
if add_overlay:
_syscall_test(
test = test,
shard_count = shard_count,
size = size,
- platform = "ptrace",
+ platform = default_platform,
use_tmpfs = False, # overlay is adding a writable tmpfs on top of root.
add_uds_tree = add_uds_tree,
- tags = tags,
+ tags = platforms[default_platform] + tags,
overlay = True,
)
@@ -198,10 +178,10 @@ def syscall_test(
test = test,
shard_count = shard_count,
size = size,
- platform = "ptrace",
+ platform = default_platform,
use_tmpfs = use_tmpfs,
add_uds_tree = add_uds_tree,
- tags = tags,
+ tags = platforms[default_platform] + tags,
file_access = "shared",
)
@@ -210,9 +190,9 @@ def syscall_test(
test = test,
shard_count = shard_count,
size = size,
- platform = "ptrace",
+ platform = default_platform,
use_tmpfs = use_tmpfs,
network = "host",
add_uds_tree = add_uds_tree,
- tags = tags,
+ tags = platforms[default_platform] + tags,
)
diff --git a/test/runner/gtest/gtest.go b/test/runner/gtest/gtest.go
index 23bf7b5f6..f96e2415e 100644
--- a/test/runner/gtest/gtest.go
+++ b/test/runner/gtest/gtest.go
@@ -43,6 +43,10 @@ type TestCase struct {
// Name is the name of this individual test.
Name string
+ // all indicates that this will run without flags. This takes
+ // precendence over benchmark below.
+ all bool
+
// benchmark indicates that this is a benchmark. In this case, the
// suite will be empty, and we will use the appropriate test and
// benchmark flags.
@@ -57,6 +61,9 @@ func (tc TestCase) FullName() string {
// Args returns arguments to be passed when invoking the test.
func (tc TestCase) Args() []string {
+ if tc.all {
+ return []string{} // No arguments.
+ }
if tc.benchmark {
return []string{
fmt.Sprintf("%s=^$", filterTestFlag),
@@ -81,11 +88,16 @@ func ParseTestCases(testBin string, benchmarks bool, extraArgs ...string) ([]Tes
cmd := exec.Command(testBin, args...)
out, err := cmd.Output()
if err != nil {
- exitErr, ok := err.(*exec.ExitError)
- if !ok {
- return nil, fmt.Errorf("could not enumerate gtest tests: %v", err)
- }
- return nil, fmt.Errorf("could not enumerate gtest tests: %v\nstderr:\n%s", err, exitErr.Stderr)
+ // We failed to list tests with the given flags. Just
+ // return something that will run the binary with no
+ // flags, which should execute all tests.
+ return []TestCase{
+ TestCase{
+ Suite: "Default",
+ Name: "All",
+ all: true,
+ },
+ }, nil
}
// Parse test output.
@@ -114,7 +126,6 @@ func ParseTestCases(testBin string, benchmarks bool, extraArgs ...string) ([]Tes
Suite: suite,
Name: name,
})
-
}
// Finished?
@@ -127,6 +138,8 @@ func ParseTestCases(testBin string, benchmarks bool, extraArgs ...string) ([]Tes
cmd = exec.Command(testBin, args...)
out, err = cmd.Output()
if err != nil {
+ // We were able to enumerate tests above, but not benchmarks?
+ // We requested them, so we return an error in this case.
exitErr, ok := err.(*exec.ExitError)
if !ok {
return nil, fmt.Errorf("could not enumerate gtest benchmarks: %v", err)
diff --git a/test/syscalls/BUILD b/test/syscalls/BUILD
index d69ac8356..3518e862d 100644
--- a/test/syscalls/BUILD
+++ b/test/syscalls/BUILD
@@ -258,6 +258,8 @@ syscall_test(
syscall_test(test = "//test/syscalls/linux:munmap_test")
+syscall_test(test = "//test/syscalls/linux:network_namespace_test")
+
syscall_test(
add_overlay = True,
test = "//test/syscalls/linux:open_create_test",
@@ -676,6 +678,8 @@ syscall_test(
test = "//test/syscalls/linux:truncate_test",
)
+syscall_test(test = "//test/syscalls/linux:tuntap_test")
+
syscall_test(test = "//test/syscalls/linux:udp_bind_test")
syscall_test(
diff --git a/test/syscalls/linux/32bit.cc b/test/syscalls/linux/32bit.cc
index c47a05181..3c825477c 100644
--- a/test/syscalls/linux/32bit.cc
+++ b/test/syscalls/linux/32bit.cc
@@ -74,7 +74,7 @@ void ExitGroup32(const char instruction[2], int code) {
"int $3\n"
:
: [ code ] "m"(code), [ ip ] "d"(m.ptr())
- : "rax", "rbx", "rsp");
+ : "rax", "rbx");
}
constexpr int kExitCode = 42;
diff --git a/test/syscalls/linux/BUILD b/test/syscalls/linux/BUILD
index 05a818795..704bae17b 100644
--- a/test/syscalls/linux/BUILD
+++ b/test/syscalls/linux/BUILD
@@ -132,6 +132,17 @@ cc_library(
)
cc_library(
+ name = "socket_netlink_route_util",
+ testonly = 1,
+ srcs = ["socket_netlink_route_util.cc"],
+ hdrs = ["socket_netlink_route_util.h"],
+ deps = [
+ ":socket_netlink_util",
+ "@com_google_absl//absl/types:optional",
+ ],
+)
+
+cc_library(
name = "socket_test_util",
testonly = 1,
srcs = [
@@ -3430,6 +3441,25 @@ cc_binary(
],
)
+cc_binary(
+ name = "tuntap_test",
+ testonly = 1,
+ srcs = ["tuntap.cc"],
+ linkstatic = 1,
+ deps = [
+ ":socket_test_util",
+ gtest,
+ "//test/syscalls/linux:socket_netlink_route_util",
+ "//test/util:capability_util",
+ "//test/util:file_descriptor",
+ "//test/util:fs_util",
+ "//test/util:posix_error",
+ "//test/util:test_main",
+ "//test/util:test_util",
+ "@com_google_absl//absl/strings",
+ ],
+)
+
cc_library(
name = "udp_socket_test_cases",
testonly = 1,
@@ -3640,6 +3670,23 @@ cc_binary(
)
cc_binary(
+ name = "network_namespace_test",
+ testonly = 1,
+ srcs = ["network_namespace.cc"],
+ linkstatic = 1,
+ deps = [
+ ":socket_test_util",
+ gtest,
+ "//test/util:capability_util",
+ "//test/util:memory_util",
+ "//test/util:test_main",
+ "//test/util:test_util",
+ "//test/util:thread_util",
+ "@com_google_absl//absl/synchronization",
+ ],
+)
+
+cc_binary(
name = "semaphore_test",
testonly = 1,
srcs = ["semaphore.cc"],
diff --git a/test/syscalls/linux/dev.cc b/test/syscalls/linux/dev.cc
index 4dd302eed..4e473268c 100644
--- a/test/syscalls/linux/dev.cc
+++ b/test/syscalls/linux/dev.cc
@@ -153,6 +153,13 @@ TEST(DevTest, TTYExists) {
EXPECT_EQ(statbuf.st_mode, S_IFCHR | 0666);
}
+TEST(DevTest, NetTunExists) {
+ struct stat statbuf = {};
+ ASSERT_THAT(stat("/dev/net/tun", &statbuf), SyscallSucceeds());
+ // Check that it's a character device with rw-rw-rw- permissions.
+ EXPECT_EQ(statbuf.st_mode, S_IFCHR | 0666);
+}
+
} // namespace
} // namespace testing
diff --git a/test/syscalls/linux/network_namespace.cc b/test/syscalls/linux/network_namespace.cc
new file mode 100644
index 000000000..6ea48c263
--- /dev/null
+++ b/test/syscalls/linux/network_namespace.cc
@@ -0,0 +1,121 @@
+// Copyright 2020 The gVisor Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+#include <net/if.h>
+#include <sched.h>
+#include <sys/ioctl.h>
+#include <sys/socket.h>
+#include <sys/types.h>
+
+#include "gmock/gmock.h"
+#include "gtest/gtest.h"
+#include "absl/synchronization/notification.h"
+#include "test/syscalls/linux/socket_test_util.h"
+#include "test/util/capability_util.h"
+#include "test/util/memory_util.h"
+#include "test/util/test_util.h"
+#include "test/util/thread_util.h"
+
+namespace gvisor {
+namespace testing {
+
+namespace {
+
+using TestFunc = std::function<PosixError()>;
+using RunFunc = std::function<PosixError(TestFunc)>;
+
+struct NamespaceStrategy {
+ RunFunc run;
+
+ static NamespaceStrategy Of(RunFunc run) {
+ NamespaceStrategy s;
+ s.run = run;
+ return s;
+ }
+};
+
+PosixError RunWithUnshare(TestFunc fn) {
+ PosixError err = PosixError(-1, "function did not return a value");
+ ScopedThread t([&] {
+ if (unshare(CLONE_NEWNET) != 0) {
+ err = PosixError(errno);
+ return;
+ }
+ err = fn();
+ });
+ t.Join();
+ return err;
+}
+
+PosixError RunWithClone(TestFunc fn) {
+ struct Args {
+ absl::Notification n;
+ TestFunc fn;
+ PosixError err;
+ };
+ Args args;
+ args.fn = fn;
+ args.err = PosixError(-1, "function did not return a value");
+
+ ASSIGN_OR_RETURN_ERRNO(
+ Mapping child_stack,
+ MmapAnon(kPageSize, PROT_READ | PROT_WRITE, MAP_PRIVATE));
+ pid_t child = clone(
+ +[](void *arg) {
+ Args *args = reinterpret_cast<Args *>(arg);
+ args->err = args->fn();
+ args->n.Notify();
+ syscall(SYS_exit, 0); // Exit manually. No return address on stack.
+ return 0;
+ },
+ reinterpret_cast<void *>(child_stack.addr() + kPageSize),
+ CLONE_NEWNET | CLONE_THREAD | CLONE_SIGHAND | CLONE_VM, &args);
+ if (child < 0) {
+ return PosixError(errno, "clone() failed");
+ }
+ args.n.WaitForNotification();
+ return args.err;
+}
+
+class NetworkNamespaceTest
+ : public ::testing::TestWithParam<NamespaceStrategy> {};
+
+TEST_P(NetworkNamespaceTest, LoopbackExists) {
+ SKIP_IF(!ASSERT_NO_ERRNO_AND_VALUE(HaveCapability(CAP_NET_ADMIN)));
+
+ EXPECT_NO_ERRNO(GetParam().run([]() {
+ // TODO(gvisor.dev/issue/1833): Update this to test that only "lo" exists.
+ // Check loopback device exists.
+ int sock = socket(AF_INET, SOCK_DGRAM, 0);
+ if (sock < 0) {
+ return PosixError(errno, "socket() failed");
+ }
+ struct ifreq ifr;
+ snprintf(ifr.ifr_name, IFNAMSIZ, "lo");
+ if (ioctl(sock, SIOCGIFINDEX, &ifr) < 0) {
+ return PosixError(errno, "ioctl() failed, lo cannot be found");
+ }
+ return NoError();
+ }));
+}
+
+INSTANTIATE_TEST_SUITE_P(
+ AllNetworkNamespaceTest, NetworkNamespaceTest,
+ ::testing::Values(NamespaceStrategy::Of(RunWithUnshare),
+ NamespaceStrategy::Of(RunWithClone)));
+
+} // namespace
+
+} // namespace testing
+} // namespace gvisor
diff --git a/test/syscalls/linux/rseq/uapi.h b/test/syscalls/linux/rseq/uapi.h
index e3ff0579a..ca1d67691 100644
--- a/test/syscalls/linux/rseq/uapi.h
+++ b/test/syscalls/linux/rseq/uapi.h
@@ -15,14 +15,9 @@
#ifndef GVISOR_TEST_SYSCALLS_LINUX_RSEQ_UAPI_H_
#define GVISOR_TEST_SYSCALLS_LINUX_RSEQ_UAPI_H_
-// User-kernel ABI for restartable sequences.
+#include <stdint.h>
-// Standard types.
-//
-// N.B. This header will be included in targets that do have the standard
-// library, so we can't shadow the standard type names.
-using __u32 = __UINT32_TYPE__;
-using __u64 = __UINT64_TYPE__;
+// User-kernel ABI for restartable sequences.
#ifdef __x86_64__
// Syscall numbers.
@@ -32,20 +27,20 @@ constexpr int kRseqSyscall = 334;
#endif // __x86_64__
struct rseq_cs {
- __u32 version;
- __u32 flags;
- __u64 start_ip;
- __u64 post_commit_offset;
- __u64 abort_ip;
-} __attribute__((aligned(4 * sizeof(__u64))));
+ uint32_t version;
+ uint32_t flags;
+ uint64_t start_ip;
+ uint64_t post_commit_offset;
+ uint64_t abort_ip;
+} __attribute__((aligned(4 * sizeof(uint64_t))));
// N.B. alignment is enforced by the kernel.
struct rseq {
- __u32 cpu_id_start;
- __u32 cpu_id;
+ uint32_t cpu_id_start;
+ uint32_t cpu_id;
struct rseq_cs* rseq_cs;
- __u32 flags;
-} __attribute__((aligned(4 * sizeof(__u64))));
+ uint32_t flags;
+} __attribute__((aligned(4 * sizeof(uint64_t))));
constexpr int kRseqFlagUnregister = 1 << 0;
diff --git a/test/syscalls/linux/socket_generic.cc b/test/syscalls/linux/socket_generic.cc
index e8f24a59e..f7d6139f1 100644
--- a/test/syscalls/linux/socket_generic.cc
+++ b/test/syscalls/linux/socket_generic.cc
@@ -447,6 +447,60 @@ TEST_P(AllSocketPairTest, RecvTimeoutRecvmsgSucceeds) {
SyscallFailsWithErrno(EAGAIN));
}
+TEST_P(AllSocketPairTest, SendTimeoutDefault) {
+ auto sockets = ASSERT_NO_ERRNO_AND_VALUE(NewSocketPair());
+
+ timeval actual_tv = {.tv_sec = -1, .tv_usec = -1};
+ socklen_t len = sizeof(actual_tv);
+ EXPECT_THAT(getsockopt(sockets->first_fd(), SOL_SOCKET, SO_SNDTIMEO,
+ &actual_tv, &len),
+ SyscallSucceeds());
+ EXPECT_EQ(actual_tv.tv_sec, 0);
+ EXPECT_EQ(actual_tv.tv_usec, 0);
+}
+
+TEST_P(AllSocketPairTest, SetGetSendTimeout) {
+ auto sockets = ASSERT_NO_ERRNO_AND_VALUE(NewSocketPair());
+
+ timeval tv = {.tv_sec = 89, .tv_usec = 42000};
+ EXPECT_THAT(
+ setsockopt(sockets->first_fd(), SOL_SOCKET, SO_SNDTIMEO, &tv, sizeof(tv)),
+ SyscallSucceeds());
+
+ timeval actual_tv = {};
+ socklen_t len = sizeof(actual_tv);
+ EXPECT_THAT(getsockopt(sockets->first_fd(), SOL_SOCKET, SO_SNDTIMEO,
+ &actual_tv, &len),
+ SyscallSucceeds());
+ EXPECT_EQ(actual_tv.tv_sec, 89);
+ EXPECT_EQ(actual_tv.tv_usec, 42000);
+}
+
+TEST_P(AllSocketPairTest, SetGetSendTimeoutLargerArg) {
+ auto sockets = ASSERT_NO_ERRNO_AND_VALUE(NewSocketPair());
+
+ struct timeval_with_extra {
+ struct timeval tv;
+ int64_t extra_data;
+ } ABSL_ATTRIBUTE_PACKED;
+
+ timeval_with_extra tv_extra = {
+ .tv = {.tv_sec = 0, .tv_usec = 123000},
+ };
+
+ EXPECT_THAT(setsockopt(sockets->first_fd(), SOL_SOCKET, SO_SNDTIMEO,
+ &tv_extra, sizeof(tv_extra)),
+ SyscallSucceeds());
+
+ timeval_with_extra actual_tv = {};
+ socklen_t len = sizeof(actual_tv);
+ EXPECT_THAT(getsockopt(sockets->first_fd(), SOL_SOCKET, SO_SNDTIMEO,
+ &actual_tv, &len),
+ SyscallSucceeds());
+ EXPECT_EQ(actual_tv.tv.tv_sec, 0);
+ EXPECT_EQ(actual_tv.tv.tv_usec, 123000);
+}
+
TEST_P(AllSocketPairTest, SendTimeoutAllowsWrite) {
auto sockets = ASSERT_NO_ERRNO_AND_VALUE(NewSocketPair());
@@ -491,18 +545,36 @@ TEST_P(AllSocketPairTest, SendTimeoutAllowsSendmsg) {
ASSERT_NO_FATAL_FAILURE(SendNullCmsg(sockets->first_fd(), buf, sizeof(buf)));
}
-TEST_P(AllSocketPairTest, SoRcvTimeoIsSet) {
+TEST_P(AllSocketPairTest, RecvTimeoutDefault) {
auto sockets = ASSERT_NO_ERRNO_AND_VALUE(NewSocketPair());
- struct timeval tv {
- .tv_sec = 0, .tv_usec = 35
- };
+ timeval actual_tv = {.tv_sec = -1, .tv_usec = -1};
+ socklen_t len = sizeof(actual_tv);
+ EXPECT_THAT(getsockopt(sockets->first_fd(), SOL_SOCKET, SO_RCVTIMEO,
+ &actual_tv, &len),
+ SyscallSucceeds());
+ EXPECT_EQ(actual_tv.tv_sec, 0);
+ EXPECT_EQ(actual_tv.tv_usec, 0);
+}
+
+TEST_P(AllSocketPairTest, SetGetRecvTimeout) {
+ auto sockets = ASSERT_NO_ERRNO_AND_VALUE(NewSocketPair());
+
+ timeval tv = {.tv_sec = 123, .tv_usec = 456000};
EXPECT_THAT(
setsockopt(sockets->first_fd(), SOL_SOCKET, SO_RCVTIMEO, &tv, sizeof(tv)),
SyscallSucceeds());
+
+ timeval actual_tv = {};
+ socklen_t len = sizeof(actual_tv);
+ EXPECT_THAT(getsockopt(sockets->first_fd(), SOL_SOCKET, SO_RCVTIMEO,
+ &actual_tv, &len),
+ SyscallSucceeds());
+ EXPECT_EQ(actual_tv.tv_sec, 123);
+ EXPECT_EQ(actual_tv.tv_usec, 456000);
}
-TEST_P(AllSocketPairTest, SoRcvTimeoIsSetLargerArg) {
+TEST_P(AllSocketPairTest, SetGetRecvTimeoutLargerArg) {
auto sockets = ASSERT_NO_ERRNO_AND_VALUE(NewSocketPair());
struct timeval_with_extra {
@@ -510,13 +582,21 @@ TEST_P(AllSocketPairTest, SoRcvTimeoIsSetLargerArg) {
int64_t extra_data;
} ABSL_ATTRIBUTE_PACKED;
- timeval_with_extra tv_extra;
- tv_extra.tv.tv_sec = 0;
- tv_extra.tv.tv_usec = 25;
+ timeval_with_extra tv_extra = {
+ .tv = {.tv_sec = 0, .tv_usec = 432000},
+ };
EXPECT_THAT(setsockopt(sockets->first_fd(), SOL_SOCKET, SO_RCVTIMEO,
&tv_extra, sizeof(tv_extra)),
SyscallSucceeds());
+
+ timeval_with_extra actual_tv = {};
+ socklen_t len = sizeof(actual_tv);
+ EXPECT_THAT(getsockopt(sockets->first_fd(), SOL_SOCKET, SO_RCVTIMEO,
+ &actual_tv, &len),
+ SyscallSucceeds());
+ EXPECT_EQ(actual_tv.tv.tv_sec, 0);
+ EXPECT_EQ(actual_tv.tv.tv_usec, 432000);
}
TEST_P(AllSocketPairTest, RecvTimeoutRecvmsgOneSecondSucceeds) {
diff --git a/test/syscalls/linux/socket_netlink_route_util.cc b/test/syscalls/linux/socket_netlink_route_util.cc
new file mode 100644
index 000000000..53eb3b6b2
--- /dev/null
+++ b/test/syscalls/linux/socket_netlink_route_util.cc
@@ -0,0 +1,163 @@
+// Copyright 2020 The gVisor Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+#include "test/syscalls/linux/socket_netlink_route_util.h"
+
+#include <linux/if.h>
+#include <linux/netlink.h>
+#include <linux/rtnetlink.h>
+
+#include "absl/types/optional.h"
+#include "test/syscalls/linux/socket_netlink_util.h"
+
+namespace gvisor {
+namespace testing {
+namespace {
+
+constexpr uint32_t kSeq = 12345;
+
+} // namespace
+
+PosixError DumpLinks(
+ const FileDescriptor& fd, uint32_t seq,
+ const std::function<void(const struct nlmsghdr* hdr)>& fn) {
+ struct request {
+ struct nlmsghdr hdr;
+ struct ifinfomsg ifm;
+ };
+
+ struct request req = {};
+ req.hdr.nlmsg_len = sizeof(req);
+ req.hdr.nlmsg_type = RTM_GETLINK;
+ req.hdr.nlmsg_flags = NLM_F_REQUEST | NLM_F_DUMP;
+ req.hdr.nlmsg_seq = seq;
+ req.ifm.ifi_family = AF_UNSPEC;
+
+ return NetlinkRequestResponse(fd, &req, sizeof(req), fn, false);
+}
+
+PosixErrorOr<std::vector<Link>> DumpLinks() {
+ ASSIGN_OR_RETURN_ERRNO(FileDescriptor fd, NetlinkBoundSocket(NETLINK_ROUTE));
+
+ std::vector<Link> links;
+ RETURN_IF_ERRNO(DumpLinks(fd, kSeq, [&](const struct nlmsghdr* hdr) {
+ if (hdr->nlmsg_type != RTM_NEWLINK ||
+ hdr->nlmsg_len < NLMSG_SPACE(sizeof(struct ifinfomsg))) {
+ return;
+ }
+ const struct ifinfomsg* msg =
+ reinterpret_cast<const struct ifinfomsg*>(NLMSG_DATA(hdr));
+ const auto* rta = FindRtAttr(hdr, msg, IFLA_IFNAME);
+ if (rta == nullptr) {
+ // Ignore links that do not have a name.
+ return;
+ }
+
+ links.emplace_back();
+ links.back().index = msg->ifi_index;
+ links.back().type = msg->ifi_type;
+ links.back().name =
+ std::string(reinterpret_cast<const char*>(RTA_DATA(rta)));
+ }));
+ return links;
+}
+
+PosixErrorOr<absl::optional<Link>> FindLoopbackLink() {
+ ASSIGN_OR_RETURN_ERRNO(auto links, DumpLinks());
+ for (const auto& link : links) {
+ if (link.type == ARPHRD_LOOPBACK) {
+ return absl::optional<Link>(link);
+ }
+ }
+ return absl::optional<Link>();
+}
+
+PosixError LinkAddLocalAddr(int index, int family, int prefixlen,
+ const void* addr, int addrlen) {
+ ASSIGN_OR_RETURN_ERRNO(FileDescriptor fd, NetlinkBoundSocket(NETLINK_ROUTE));
+
+ struct request {
+ struct nlmsghdr hdr;
+ struct ifaddrmsg ifaddr;
+ char attrbuf[512];
+ };
+
+ struct request req = {};
+ req.hdr.nlmsg_len = NLMSG_LENGTH(sizeof(req.ifaddr));
+ req.hdr.nlmsg_type = RTM_NEWADDR;
+ req.hdr.nlmsg_flags = NLM_F_REQUEST | NLM_F_ACK;
+ req.hdr.nlmsg_seq = kSeq;
+ req.ifaddr.ifa_index = index;
+ req.ifaddr.ifa_family = family;
+ req.ifaddr.ifa_prefixlen = prefixlen;
+
+ struct rtattr* rta = reinterpret_cast<struct rtattr*>(
+ reinterpret_cast<int8_t*>(&req) + NLMSG_ALIGN(req.hdr.nlmsg_len));
+ rta->rta_type = IFA_LOCAL;
+ rta->rta_len = RTA_LENGTH(addrlen);
+ req.hdr.nlmsg_len = NLMSG_ALIGN(req.hdr.nlmsg_len) + RTA_LENGTH(addrlen);
+ memcpy(RTA_DATA(rta), addr, addrlen);
+
+ return NetlinkRequestAckOrError(fd, kSeq, &req, req.hdr.nlmsg_len);
+}
+
+PosixError LinkChangeFlags(int index, unsigned int flags, unsigned int change) {
+ ASSIGN_OR_RETURN_ERRNO(FileDescriptor fd, NetlinkBoundSocket(NETLINK_ROUTE));
+
+ struct request {
+ struct nlmsghdr hdr;
+ struct ifinfomsg ifinfo;
+ char pad[NLMSG_ALIGNTO];
+ };
+
+ struct request req = {};
+ req.hdr.nlmsg_len = NLMSG_LENGTH(sizeof(req.ifinfo));
+ req.hdr.nlmsg_type = RTM_NEWLINK;
+ req.hdr.nlmsg_flags = NLM_F_REQUEST | NLM_F_ACK;
+ req.hdr.nlmsg_seq = kSeq;
+ req.ifinfo.ifi_index = index;
+ req.ifinfo.ifi_flags = flags;
+ req.ifinfo.ifi_change = change;
+
+ return NetlinkRequestAckOrError(fd, kSeq, &req, req.hdr.nlmsg_len);
+}
+
+PosixError LinkSetMacAddr(int index, const void* addr, int addrlen) {
+ ASSIGN_OR_RETURN_ERRNO(FileDescriptor fd, NetlinkBoundSocket(NETLINK_ROUTE));
+
+ struct request {
+ struct nlmsghdr hdr;
+ struct ifinfomsg ifinfo;
+ char attrbuf[512];
+ };
+
+ struct request req = {};
+ req.hdr.nlmsg_len = NLMSG_LENGTH(sizeof(req.ifinfo));
+ req.hdr.nlmsg_type = RTM_NEWLINK;
+ req.hdr.nlmsg_flags = NLM_F_REQUEST | NLM_F_ACK;
+ req.hdr.nlmsg_seq = kSeq;
+ req.ifinfo.ifi_index = index;
+
+ struct rtattr* rta = reinterpret_cast<struct rtattr*>(
+ reinterpret_cast<int8_t*>(&req) + NLMSG_ALIGN(req.hdr.nlmsg_len));
+ rta->rta_type = IFLA_ADDRESS;
+ rta->rta_len = RTA_LENGTH(addrlen);
+ req.hdr.nlmsg_len = NLMSG_ALIGN(req.hdr.nlmsg_len) + RTA_LENGTH(addrlen);
+ memcpy(RTA_DATA(rta), addr, addrlen);
+
+ return NetlinkRequestAckOrError(fd, kSeq, &req, req.hdr.nlmsg_len);
+}
+
+} // namespace testing
+} // namespace gvisor
diff --git a/test/syscalls/linux/socket_netlink_route_util.h b/test/syscalls/linux/socket_netlink_route_util.h
new file mode 100644
index 000000000..2c018e487
--- /dev/null
+++ b/test/syscalls/linux/socket_netlink_route_util.h
@@ -0,0 +1,55 @@
+// Copyright 2020 The gVisor Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+#ifndef GVISOR_TEST_SYSCALLS_LINUX_SOCKET_NETLINK_ROUTE_UTIL_H_
+#define GVISOR_TEST_SYSCALLS_LINUX_SOCKET_NETLINK_ROUTE_UTIL_H_
+
+#include <linux/netlink.h>
+#include <linux/rtnetlink.h>
+
+#include <vector>
+
+#include "absl/types/optional.h"
+#include "test/syscalls/linux/socket_netlink_util.h"
+
+namespace gvisor {
+namespace testing {
+
+struct Link {
+ int index;
+ int16_t type;
+ std::string name;
+};
+
+PosixError DumpLinks(const FileDescriptor& fd, uint32_t seq,
+ const std::function<void(const struct nlmsghdr* hdr)>& fn);
+
+PosixErrorOr<std::vector<Link>> DumpLinks();
+
+PosixErrorOr<absl::optional<Link>> FindLoopbackLink();
+
+// LinkAddLocalAddr sets IFA_LOCAL attribute on the interface.
+PosixError LinkAddLocalAddr(int index, int family, int prefixlen,
+ const void* addr, int addrlen);
+
+// LinkChangeFlags changes interface flags. E.g. IFF_UP.
+PosixError LinkChangeFlags(int index, unsigned int flags, unsigned int change);
+
+// LinkSetMacAddr sets IFLA_ADDRESS attribute of the interface.
+PosixError LinkSetMacAddr(int index, const void* addr, int addrlen);
+
+} // namespace testing
+} // namespace gvisor
+
+#endif // GVISOR_TEST_SYSCALLS_LINUX_SOCKET_NETLINK_ROUTE_UTIL_H_
diff --git a/test/syscalls/linux/tuntap.cc b/test/syscalls/linux/tuntap.cc
new file mode 100644
index 000000000..f6ac9d7b8
--- /dev/null
+++ b/test/syscalls/linux/tuntap.cc
@@ -0,0 +1,346 @@
+// Copyright 2019 The gVisor Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+#include <arpa/inet.h>
+#include <linux/capability.h>
+#include <linux/if_arp.h>
+#include <linux/if_ether.h>
+#include <linux/if_tun.h>
+#include <netinet/ip.h>
+#include <netinet/ip_icmp.h>
+#include <sys/ioctl.h>
+#include <sys/socket.h>
+#include <sys/types.h>
+
+#include "gmock/gmock.h"
+#include "gtest/gtest.h"
+#include "absl/strings/ascii.h"
+#include "absl/strings/str_split.h"
+#include "test/syscalls/linux/socket_netlink_route_util.h"
+#include "test/syscalls/linux/socket_test_util.h"
+#include "test/util/capability_util.h"
+#include "test/util/file_descriptor.h"
+#include "test/util/fs_util.h"
+#include "test/util/posix_error.h"
+#include "test/util/test_util.h"
+
+namespace gvisor {
+namespace testing {
+namespace {
+
+constexpr int kIPLen = 4;
+
+constexpr const char kDevNetTun[] = "/dev/net/tun";
+constexpr const char kTapName[] = "tap0";
+
+constexpr const uint8_t kMacA[ETH_ALEN] = {0xAA, 0xAA, 0xAA, 0xAA, 0xAA, 0xAA};
+constexpr const uint8_t kMacB[ETH_ALEN] = {0xBB, 0xBB, 0xBB, 0xBB, 0xBB, 0xBB};
+
+PosixErrorOr<std::set<std::string>> DumpLinkNames() {
+ ASSIGN_OR_RETURN_ERRNO(auto links, DumpLinks());
+ std::set<std::string> names;
+ for (const auto& link : links) {
+ names.emplace(link.name);
+ }
+ return names;
+}
+
+PosixErrorOr<absl::optional<Link>> GetLinkByName(const std::string& name) {
+ ASSIGN_OR_RETURN_ERRNO(auto links, DumpLinks());
+ for (const auto& link : links) {
+ if (link.name == name) {
+ return absl::optional<Link>(link);
+ }
+ }
+ return absl::optional<Link>();
+}
+
+struct pihdr {
+ uint16_t pi_flags;
+ uint16_t pi_protocol;
+} __attribute__((packed));
+
+struct ping_pkt {
+ pihdr pi;
+ struct ethhdr eth;
+ struct iphdr ip;
+ struct icmphdr icmp;
+ char payload[64];
+} __attribute__((packed));
+
+ping_pkt CreatePingPacket(const uint8_t srcmac[ETH_ALEN], const char* srcip,
+ const uint8_t dstmac[ETH_ALEN], const char* dstip) {
+ ping_pkt pkt = {};
+
+ pkt.pi.pi_protocol = htons(ETH_P_IP);
+
+ memcpy(pkt.eth.h_dest, dstmac, sizeof(pkt.eth.h_dest));
+ memcpy(pkt.eth.h_source, srcmac, sizeof(pkt.eth.h_source));
+ pkt.eth.h_proto = htons(ETH_P_IP);
+
+ pkt.ip.ihl = 5;
+ pkt.ip.version = 4;
+ pkt.ip.tos = 0;
+ pkt.ip.tot_len = htons(sizeof(struct iphdr) + sizeof(struct icmphdr) +
+ sizeof(pkt.payload));
+ pkt.ip.id = 1;
+ pkt.ip.frag_off = 1 << 6; // Do not fragment
+ pkt.ip.ttl = 64;
+ pkt.ip.protocol = IPPROTO_ICMP;
+ inet_pton(AF_INET, dstip, &pkt.ip.daddr);
+ inet_pton(AF_INET, srcip, &pkt.ip.saddr);
+ pkt.ip.check = IPChecksum(pkt.ip);
+
+ pkt.icmp.type = ICMP_ECHO;
+ pkt.icmp.code = 0;
+ pkt.icmp.checksum = 0;
+ pkt.icmp.un.echo.sequence = 1;
+ pkt.icmp.un.echo.id = 1;
+
+ strncpy(pkt.payload, "abcd", sizeof(pkt.payload));
+ pkt.icmp.checksum = ICMPChecksum(pkt.icmp, pkt.payload, sizeof(pkt.payload));
+
+ return pkt;
+}
+
+struct arp_pkt {
+ pihdr pi;
+ struct ethhdr eth;
+ struct arphdr arp;
+ uint8_t arp_sha[ETH_ALEN];
+ uint8_t arp_spa[kIPLen];
+ uint8_t arp_tha[ETH_ALEN];
+ uint8_t arp_tpa[kIPLen];
+} __attribute__((packed));
+
+std::string CreateArpPacket(const uint8_t srcmac[ETH_ALEN], const char* srcip,
+ const uint8_t dstmac[ETH_ALEN], const char* dstip) {
+ std::string buffer;
+ buffer.resize(sizeof(arp_pkt));
+
+ arp_pkt* pkt = reinterpret_cast<arp_pkt*>(&buffer[0]);
+ {
+ pkt->pi.pi_protocol = htons(ETH_P_ARP);
+
+ memcpy(pkt->eth.h_dest, kMacA, sizeof(pkt->eth.h_dest));
+ memcpy(pkt->eth.h_source, kMacB, sizeof(pkt->eth.h_source));
+ pkt->eth.h_proto = htons(ETH_P_ARP);
+
+ pkt->arp.ar_hrd = htons(ARPHRD_ETHER);
+ pkt->arp.ar_pro = htons(ETH_P_IP);
+ pkt->arp.ar_hln = ETH_ALEN;
+ pkt->arp.ar_pln = kIPLen;
+ pkt->arp.ar_op = htons(ARPOP_REPLY);
+
+ memcpy(pkt->arp_sha, srcmac, sizeof(pkt->arp_sha));
+ inet_pton(AF_INET, srcip, pkt->arp_spa);
+ memcpy(pkt->arp_tha, dstmac, sizeof(pkt->arp_tha));
+ inet_pton(AF_INET, dstip, pkt->arp_tpa);
+ }
+ return buffer;
+}
+
+} // namespace
+
+class TuntapTest : public ::testing::Test {
+ protected:
+ void TearDown() override {
+ if (ASSERT_NO_ERRNO_AND_VALUE(HaveCapability(CAP_NET_ADMIN))) {
+ // Bring back capability if we had dropped it in test case.
+ ASSERT_NO_ERRNO(SetCapability(CAP_NET_ADMIN, true));
+ }
+ }
+};
+
+TEST_F(TuntapTest, CreateInterfaceNoCap) {
+ SKIP_IF(!ASSERT_NO_ERRNO_AND_VALUE(HaveCapability(CAP_NET_ADMIN)));
+
+ ASSERT_NO_ERRNO(SetCapability(CAP_NET_ADMIN, false));
+
+ FileDescriptor fd = ASSERT_NO_ERRNO_AND_VALUE(Open(kDevNetTun, O_RDWR));
+
+ struct ifreq ifr = {};
+ ifr.ifr_flags = IFF_TAP;
+ strncpy(ifr.ifr_name, kTapName, IFNAMSIZ);
+
+ EXPECT_THAT(ioctl(fd.get(), TUNSETIFF, &ifr), SyscallFailsWithErrno(EPERM));
+}
+
+TEST_F(TuntapTest, CreateFixedNameInterface) {
+ SKIP_IF(!ASSERT_NO_ERRNO_AND_VALUE(HaveCapability(CAP_NET_ADMIN)));
+
+ FileDescriptor fd = ASSERT_NO_ERRNO_AND_VALUE(Open(kDevNetTun, O_RDWR));
+
+ struct ifreq ifr_set = {};
+ ifr_set.ifr_flags = IFF_TAP;
+ strncpy(ifr_set.ifr_name, kTapName, IFNAMSIZ);
+ EXPECT_THAT(ioctl(fd.get(), TUNSETIFF, &ifr_set),
+ SyscallSucceedsWithValue(0));
+
+ struct ifreq ifr_get = {};
+ EXPECT_THAT(ioctl(fd.get(), TUNGETIFF, &ifr_get),
+ SyscallSucceedsWithValue(0));
+
+ struct ifreq ifr_expect = ifr_set;
+ // See __tun_chr_ioctl() in net/drivers/tun.c.
+ ifr_expect.ifr_flags |= IFF_NOFILTER;
+
+ EXPECT_THAT(DumpLinkNames(),
+ IsPosixErrorOkAndHolds(::testing::Contains(kTapName)));
+ EXPECT_THAT(memcmp(&ifr_expect, &ifr_get, sizeof(ifr_get)), ::testing::Eq(0));
+}
+
+TEST_F(TuntapTest, CreateInterface) {
+ SKIP_IF(!ASSERT_NO_ERRNO_AND_VALUE(HaveCapability(CAP_NET_ADMIN)));
+
+ FileDescriptor fd = ASSERT_NO_ERRNO_AND_VALUE(Open(kDevNetTun, O_RDWR));
+
+ struct ifreq ifr = {};
+ ifr.ifr_flags = IFF_TAP;
+ // Empty ifr.ifr_name. Let kernel assign.
+
+ EXPECT_THAT(ioctl(fd.get(), TUNSETIFF, &ifr), SyscallSucceedsWithValue(0));
+
+ struct ifreq ifr_get = {};
+ EXPECT_THAT(ioctl(fd.get(), TUNGETIFF, &ifr_get),
+ SyscallSucceedsWithValue(0));
+
+ std::string ifname = ifr_get.ifr_name;
+ EXPECT_THAT(ifname, ::testing::StartsWith("tap"));
+ EXPECT_THAT(DumpLinkNames(),
+ IsPosixErrorOkAndHolds(::testing::Contains(ifname)));
+}
+
+TEST_F(TuntapTest, InvalidReadWrite) {
+ SKIP_IF(!ASSERT_NO_ERRNO_AND_VALUE(HaveCapability(CAP_NET_ADMIN)));
+
+ FileDescriptor fd = ASSERT_NO_ERRNO_AND_VALUE(Open(kDevNetTun, O_RDWR));
+
+ char buf[128] = {};
+ EXPECT_THAT(read(fd.get(), buf, sizeof(buf)), SyscallFailsWithErrno(EBADFD));
+ EXPECT_THAT(write(fd.get(), buf, sizeof(buf)), SyscallFailsWithErrno(EBADFD));
+}
+
+TEST_F(TuntapTest, WriteToDownDevice) {
+ SKIP_IF(!ASSERT_NO_ERRNO_AND_VALUE(HaveCapability(CAP_NET_ADMIN)));
+
+ // FIXME: gVisor always creates enabled/up'd interfaces.
+ SKIP_IF(IsRunningOnGvisor());
+
+ FileDescriptor fd = ASSERT_NO_ERRNO_AND_VALUE(Open(kDevNetTun, O_RDWR));
+
+ // Device created should be down by default.
+ struct ifreq ifr = {};
+ ifr.ifr_flags = IFF_TAP;
+ EXPECT_THAT(ioctl(fd.get(), TUNSETIFF, &ifr), SyscallSucceedsWithValue(0));
+
+ char buf[128] = {};
+ EXPECT_THAT(write(fd.get(), buf, sizeof(buf)), SyscallFailsWithErrno(EIO));
+}
+
+// This test sets up a TAP device and pings kernel by sending ICMP echo request.
+//
+// It works as the following:
+// * Open /dev/net/tun, and create kTapName interface.
+// * Use rtnetlink to do initial setup of the interface:
+// * Assign IP address 10.0.0.1/24 to kernel.
+// * MAC address: kMacA
+// * Bring up the interface.
+// * Send an ICMP echo reqest (ping) packet from 10.0.0.2 (kMacB) to kernel.
+// * Loop to receive packets from TAP device/fd:
+// * If packet is an ICMP echo reply, it stops and passes the test.
+// * If packet is an ARP request, it responds with canned reply and resends
+// the
+// ICMP request packet.
+TEST_F(TuntapTest, PingKernel) {
+ SKIP_IF(!ASSERT_NO_ERRNO_AND_VALUE(HaveCapability(CAP_NET_ADMIN)));
+
+ // Interface creation.
+ FileDescriptor fd = ASSERT_NO_ERRNO_AND_VALUE(Open(kDevNetTun, O_RDWR));
+
+ struct ifreq ifr_set = {};
+ ifr_set.ifr_flags = IFF_TAP;
+ strncpy(ifr_set.ifr_name, kTapName, IFNAMSIZ);
+ EXPECT_THAT(ioctl(fd.get(), TUNSETIFF, &ifr_set),
+ SyscallSucceedsWithValue(0));
+
+ absl::optional<Link> link =
+ ASSERT_NO_ERRNO_AND_VALUE(GetLinkByName(kTapName));
+ ASSERT_TRUE(link.has_value());
+
+ // Interface setup.
+ struct in_addr addr;
+ inet_pton(AF_INET, "10.0.0.1", &addr);
+ EXPECT_NO_ERRNO(LinkAddLocalAddr(link->index, AF_INET, /*prefixlen=*/24,
+ &addr, sizeof(addr)));
+
+ if (!IsRunningOnGvisor()) {
+ // FIXME: gVisor doesn't support setting MAC address on interfaces yet.
+ EXPECT_NO_ERRNO(LinkSetMacAddr(link->index, kMacA, sizeof(kMacA)));
+
+ // FIXME: gVisor always creates enabled/up'd interfaces.
+ EXPECT_NO_ERRNO(LinkChangeFlags(link->index, IFF_UP, IFF_UP));
+ }
+
+ ping_pkt ping_req = CreatePingPacket(kMacB, "10.0.0.2", kMacA, "10.0.0.1");
+ std::string arp_rep = CreateArpPacket(kMacB, "10.0.0.2", kMacA, "10.0.0.1");
+
+ // Send ping, this would trigger an ARP request on Linux.
+ EXPECT_THAT(write(fd.get(), &ping_req, sizeof(ping_req)),
+ SyscallSucceedsWithValue(sizeof(ping_req)));
+
+ // Receive loop to process inbound packets.
+ struct inpkt {
+ union {
+ pihdr pi;
+ ping_pkt ping;
+ arp_pkt arp;
+ };
+ };
+ while (1) {
+ inpkt r = {};
+ int n = read(fd.get(), &r, sizeof(r));
+ EXPECT_THAT(n, SyscallSucceeds());
+
+ if (n < sizeof(pihdr)) {
+ std::cerr << "Ignored packet, protocol: " << r.pi.pi_protocol
+ << " len: " << n << std::endl;
+ continue;
+ }
+
+ // Process ARP packet.
+ if (n >= sizeof(arp_pkt) && r.pi.pi_protocol == htons(ETH_P_ARP)) {
+ // Respond with canned ARP reply.
+ EXPECT_THAT(write(fd.get(), arp_rep.data(), arp_rep.size()),
+ SyscallSucceedsWithValue(arp_rep.size()));
+ // First ping request might have been dropped due to mac address not in
+ // ARP cache. Send it again.
+ EXPECT_THAT(write(fd.get(), &ping_req, sizeof(ping_req)),
+ SyscallSucceedsWithValue(sizeof(ping_req)));
+ }
+
+ // Process ping response packet.
+ if (n >= sizeof(ping_pkt) && r.pi.pi_protocol == ping_req.pi.pi_protocol &&
+ r.ping.ip.protocol == ping_req.ip.protocol &&
+ !memcmp(&r.ping.ip.saddr, &ping_req.ip.daddr, kIPLen) &&
+ !memcmp(&r.ping.ip.daddr, &ping_req.ip.saddr, kIPLen) &&
+ r.ping.icmp.type == 0 && r.ping.icmp.code == 0) {
+ // Ends and passes the test.
+ break;
+ }
+ }
+}
+
+} // namespace testing
+} // namespace gvisor
diff --git a/test/syscalls/linux/udp_socket_test_cases.cc b/test/syscalls/linux/udp_socket_test_cases.cc
index 57b1a357c..740c7986d 100644
--- a/test/syscalls/linux/udp_socket_test_cases.cc
+++ b/test/syscalls/linux/udp_socket_test_cases.cc
@@ -21,6 +21,10 @@
#include <sys/socket.h>
#include <sys/types.h>
+#ifndef SIOCGSTAMP
+#include <linux/sockios.h>
+#endif
+
#include "gtest/gtest.h"
#include "absl/base/macros.h"
#include "absl/time/clock.h"
diff --git a/tools/bazeldefs/defs.bzl b/tools/bazeldefs/defs.bzl
index 6f091d759..905b16d41 100644
--- a/tools/bazeldefs/defs.bzl
+++ b/tools/bazeldefs/defs.bzl
@@ -8,7 +8,6 @@ load("@rules_pkg//:pkg.bzl", _pkg_deb = "pkg_deb", _pkg_tar = "pkg_tar")
load("@io_bazel_rules_docker//go:image.bzl", _go_image = "go_image")
load("@io_bazel_rules_docker//container:container.bzl", _container_image = "container_image")
load("@pydeps//:requirements.bzl", _py_requirement = "requirement")
-load("//tools/bazeldefs:tags.bzl", _go_suffixes = "go_suffixes")
container_image = _container_image
cc_binary = _cc_binary
@@ -19,7 +18,6 @@ cc_test = _cc_test
cc_toolchain = "@bazel_tools//tools/cpp:current_cc_toolchain"
go_image = _go_image
go_embed_data = _go_embed_data
-go_suffixes = _go_suffixes
gtest = "@com_google_googletest//:gtest"
gbenchmark = "@com_google_benchmark//:benchmark"
loopback = "//tools/bazeldefs:loopback"
diff --git a/tools/bazeldefs/platforms.bzl b/tools/bazeldefs/platforms.bzl
new file mode 100644
index 000000000..92b0b5fc0
--- /dev/null
+++ b/tools/bazeldefs/platforms.bzl
@@ -0,0 +1,17 @@
+"""List of platforms."""
+
+# Platform to associated tags.
+platforms = {
+ "ptrace": [
+ # TODO(b/120560048): Make the tests run without this tag.
+ "no-sandbox",
+ ],
+ "kvm": [
+ "manual",
+ "local",
+ # TODO(b/120560048): Make the tests run without this tag.
+ "no-sandbox",
+ ],
+}
+
+default_platform = "ptrace"
diff --git a/tools/defs.bzl b/tools/defs.bzl
index 4eece2d83..15a310403 100644
--- a/tools/defs.bzl
+++ b/tools/defs.bzl
@@ -7,7 +7,9 @@ change for Google-internal and bazel-compatible rules.
load("//tools/go_stateify:defs.bzl", "go_stateify")
load("//tools/go_marshal:defs.bzl", "go_marshal", "marshal_deps", "marshal_test_deps")
-load("//tools/bazeldefs:defs.bzl", "go_suffixes", _cc_binary = "cc_binary", _cc_flags_supplier = "cc_flags_supplier", _cc_library = "cc_library", _cc_proto_library = "cc_proto_library", _cc_test = "cc_test", _cc_toolchain = "cc_toolchain", _container_image = "container_image", _default_installer = "default_installer", _default_net_util = "default_net_util", _gbenchmark = "gbenchmark", _go_binary = "go_binary", _go_embed_data = "go_embed_data", _go_image = "go_image", _go_library = "go_library", _go_proto_library = "go_proto_library", _go_test = "go_test", _go_tool_library = "go_tool_library", _gtest = "gtest", _loopback = "loopback", _pkg_deb = "pkg_deb", _pkg_tar = "pkg_tar", _proto_library = "proto_library", _py_binary = "py_binary", _py_library = "py_library", _py_requirement = "py_requirement", _py_test = "py_test", _select_arch = "select_arch", _select_system = "select_system")
+load("//tools/bazeldefs:defs.bzl", _cc_binary = "cc_binary", _cc_flags_supplier = "cc_flags_supplier", _cc_library = "cc_library", _cc_proto_library = "cc_proto_library", _cc_test = "cc_test", _cc_toolchain = "cc_toolchain", _container_image = "container_image", _default_installer = "default_installer", _default_net_util = "default_net_util", _gbenchmark = "gbenchmark", _go_binary = "go_binary", _go_embed_data = "go_embed_data", _go_image = "go_image", _go_library = "go_library", _go_proto_library = "go_proto_library", _go_test = "go_test", _go_tool_library = "go_tool_library", _gtest = "gtest", _loopback = "loopback", _pkg_deb = "pkg_deb", _pkg_tar = "pkg_tar", _proto_library = "proto_library", _py_binary = "py_binary", _py_library = "py_library", _py_requirement = "py_requirement", _py_test = "py_test", _select_arch = "select_arch", _select_system = "select_system")
+load("//tools/bazeldefs:platforms.bzl", _default_platform = "default_platform", _platforms = "platforms")
+load("//tools/bazeldefs:tags.bzl", "go_suffixes")
# Delegate directly.
cc_binary = _cc_binary
@@ -33,6 +35,8 @@ select_system = _select_system
loopback = _loopback
default_installer = _default_installer
default_net_util = _default_net_util
+platforms = _platforms
+default_platform = _default_platform
def go_binary(name, **kwargs):
"""Wraps the standard go_binary.
@@ -84,7 +88,7 @@ def go_imports(name, src, out):
cmd = ("$(location @org_golang_x_tools//cmd/goimports:goimports) $(SRCS) > $@"),
)
-def go_library(name, srcs, deps = [], imports = [], stateify = True, marshal = False, **kwargs):
+def go_library(name, srcs, deps = [], imports = [], stateify = True, marshal = False, marshal_debug = False, **kwargs):
"""Wraps the standard go_library and does stateification and marshalling.
The recommended way is to use this rule with mostly identical configuration as the native
@@ -107,6 +111,7 @@ def go_library(name, srcs, deps = [], imports = [], stateify = True, marshal = F
imports: imports required for stateify.
stateify: whether statify is enabled (default: true).
marshal: whether marshal is enabled (default: false).
+ marshal_debug: whether the gomarshal tools emits debugging output (default: false).
**kwargs: standard go_library arguments.
"""
all_srcs = srcs
@@ -145,7 +150,10 @@ def go_library(name, srcs, deps = [], imports = [], stateify = True, marshal = F
go_marshal(
name = name + suffix + "_abi_autogen",
srcs = src_subset,
- debug = False,
+ debug = select({
+ "//tools/go_marshal:marshal_config_verbose": True,
+ "//conditions:default": marshal_debug,
+ }),
imports = imports,
package = name,
)
diff --git a/tools/go_marshal/BUILD b/tools/go_marshal/BUILD
index 80d9c0504..be49cf9c8 100644
--- a/tools/go_marshal/BUILD
+++ b/tools/go_marshal/BUILD
@@ -12,3 +12,8 @@ go_binary(
"//tools/go_marshal/gomarshal",
],
)
+
+config_setting(
+ name = "marshal_config_verbose",
+ values = {"define": "gomarshal=verbose"},
+)
diff --git a/tools/go_marshal/gomarshal/generator.go b/tools/go_marshal/gomarshal/generator.go
index 0fa868415..d365a1f3c 100644
--- a/tools/go_marshal/gomarshal/generator.go
+++ b/tools/go_marshal/gomarshal/generator.go
@@ -44,7 +44,8 @@ const (
// All recievers are single letters, so we don't allow import aliases to be a
// single letter.
var badIdents = []string{
- "addr", "blk", "buf", "dst", "dsts", "err", "hdr", "len", "ptr", "src", "srcs", "task", "val",
+ "addr", "blk", "buf", "dst", "dsts", "err", "hdr", "idx", "inner", "len",
+ "ptr", "src", "srcs", "task", "val",
// All single-letter identifiers.
}
@@ -193,9 +194,9 @@ func (g *Generator) parse() ([]*ast.File, []*token.FileSet, error) {
return files, fsets, nil
}
-// collectMarshallabeTypes walks the parsed AST and collects a list of type
+// collectMarshallableTypes walks the parsed AST and collects a list of type
// declarations for which we need to generate the Marshallable interface.
-func (g *Generator) collectMarshallabeTypes(a *ast.File, f *token.FileSet) []*ast.TypeSpec {
+func (g *Generator) collectMarshallableTypes(a *ast.File, f *token.FileSet) []*ast.TypeSpec {
var types []*ast.TypeSpec
for _, decl := range a.Decls {
gdecl, ok := decl.(*ast.GenDecl)
@@ -222,14 +223,22 @@ func (g *Generator) collectMarshallabeTypes(a *ast.File, f *token.FileSet) []*as
continue
}
for _, spec := range gdecl.Specs {
- // We already confirmed we're in a type declaration earlier.
+ // We already confirmed we're in a type declaration earlier, so this
+ // cast will succeed.
t := spec.(*ast.TypeSpec)
- if _, ok := t.Type.(*ast.StructType); ok {
- debugfAt(f.Position(t.Pos()), "Collected marshallable type %s.\n", t.Name.Name)
+ switch t.Type.(type) {
+ case *ast.StructType:
+ debugfAt(f.Position(t.Pos()), "Collected marshallable struct %s.\n", t.Name.Name)
+ types = append(types, t)
+ continue
+ case *ast.Ident: // Newtype on primitive.
+ debugfAt(f.Position(t.Pos()), "Collected marshallable newtype on primitive %s.\n", t.Name.Name)
types = append(types, t)
continue
}
- debugf("Skipping declaration %v since it's not a struct declaration.\n", gdecl)
+ // A user specifically requested marshalling on this type, but we
+ // don't support it.
+ abortAt(f.Position(t.Pos()), fmt.Sprintf("Marshalling codegen was requested on type '%s', but go-marshal doesn't support this kind of declaration.\n", t.Name))
}
}
return types
@@ -269,12 +278,20 @@ func (g *Generator) collectImports(a *ast.File, f *token.FileSet) map[string]imp
}
func (g *Generator) generateOne(t *ast.TypeSpec, fset *token.FileSet) *interfaceGenerator {
- // We're guaranteed to have only struct type specs by now. See
- // Generator.collectMarshallabeTypes.
i := newInterfaceGenerator(t, fset)
- i.validate()
- i.emitMarshallable()
- return i
+ switch ty := t.Type.(type) {
+ case *ast.StructType:
+ i.validateStruct()
+ i.emitMarshallableForStruct()
+ return i
+ case *ast.Ident:
+ i.validatePrimitiveNewtype(ty)
+ i.emitMarshallableForPrimitiveNewtype()
+ return i
+ default:
+ // This should've been filtered out by collectMarshallabeTypes.
+ panic(fmt.Sprintf("Unexpected type %+v", ty))
+ }
}
// generateOneTestSuite generates a test suite for the automatically generated
@@ -320,7 +337,7 @@ func (g *Generator) Run() error {
for i, a := range asts {
// Collect type declarations marked for code generation and generate
// Marshallable interfaces.
- for _, t := range g.collectMarshallabeTypes(a, fsets[i]) {
+ for _, t := range g.collectMarshallableTypes(a, fsets[i]) {
impl := g.generateOne(t, fsets[i])
// Collect Marshallable types referenced by the generated code.
for ref, _ := range impl.ms {
diff --git a/tools/go_marshal/gomarshal/generator_interfaces.go b/tools/go_marshal/gomarshal/generator_interfaces.go
index 3aa299ccd..ea1af998e 100644
--- a/tools/go_marshal/gomarshal/generator_interfaces.go
+++ b/tools/go_marshal/gomarshal/generator_interfaces.go
@@ -55,9 +55,6 @@ func (g *interfaceGenerator) typeName() string {
// newinterfaceGenerator creates a new interface generator.
func newInterfaceGenerator(t *ast.TypeSpec, fset *token.FileSet) *interfaceGenerator {
- if _, ok := t.Type.(*ast.StructType); !ok {
- panic(fmt.Sprintf("Attempting to generate code for a not struct type %v", t))
- }
g := &interfaceGenerator{
t: t,
r: receiverName(t),
@@ -103,9 +100,31 @@ func (g *interfaceGenerator) abortAt(p token.Pos, msg string) {
abortAt(g.f.Position(p), msg)
}
-// validate ensures the type we're working with can be marshalled. These checks
-// are done ahead of time and in one place so we can make assumptions later.
-func (g *interfaceGenerator) validate() {
+func (g *interfaceGenerator) validatePrimitiveNewtype(t *ast.Ident) {
+ switch t.Name {
+ case "int8", "uint8", "byte", "int16", "uint16", "int32", "uint32", "int64", "uint64":
+ // These are the only primitive types we're allow. Below, we provide
+ // suggestions for some disallowed types and reject them, then attempt
+ // to marshal any remaining types by invoking the marshal.Marshallable
+ // interface on them. If these types don't actually implement
+ // marshal.Marshallable, compilation of the generated code will fail
+ // with an appropriate error message.
+ return
+ case "int":
+ g.abortAt(t.Pos(), "Type 'int' has ambiguous width, use int32 or int64")
+ case "uint":
+ g.abortAt(t.Pos(), "Type 'uint' has ambiguous width, use uint32 or uint64")
+ case "string":
+ g.abortAt(t.Pos(), "Type 'string' is dynamically-sized and cannot be marshalled, use a fixed size byte array '[...]byte' instead")
+ default:
+ debugfAt(g.f.Position(t.Pos()), fmt.Sprintf("Found derived type '%s', will attempt dispatch via marshal.Marshallable.\n", t.Name))
+ }
+}
+
+// validateStruct ensures the type we're working with can be marshalled. These
+// checks are done ahead of time and in one place so we can make assumptions
+// later.
+func (g *interfaceGenerator) validateStruct() {
g.forEachField(func(f *ast.Field) {
if len(f.Names) == 0 {
g.abortAt(f.Pos(), "Cannot marshal structs with embedded fields, give the field a name; use '_' for anonymous fields such as padding fields")
@@ -115,25 +134,7 @@ func (g *interfaceGenerator) validate() {
g.forEachField(func(f *ast.Field) {
fieldDispatcher{
primitive: func(_, t *ast.Ident) {
- switch t.Name {
- case "int8", "uint8", "byte", "int16", "uint16", "int32", "uint32", "int64", "uint64":
- // These are the only primitive types we're allow. Below, we
- // provide suggestions for some disallowed types and reject
- // them, then attempt to marshal any remaining types by
- // invoking the marshal.Marshallable interface on them. If
- // these types don't actually implement
- // marshal.Marshallable, compilation of the generated code
- // will fail with an appropriate error message.
- return
- case "int":
- g.abortAt(f.Pos(), "Type 'int' has ambiguous width, use int32 or int64")
- case "uint":
- g.abortAt(f.Pos(), "Type 'uint' has ambiguous width, use uint32 or uint64")
- case "string":
- g.abortAt(f.Pos(), "Type 'string' is dynamically-sized and cannot be marshalled, use a fixed size byte array '[...]byte' instead")
- default:
- debugfAt(g.f.Position(f.Pos()), fmt.Sprintf("Found derived type '%s', will attempt dispatch via marshal.Marshallable.\n", t.Name))
- }
+ g.validatePrimitiveNewtype(t)
},
selector: func(_, _, _ *ast.Ident) {
// No validation to perform on selector fields. However this
@@ -190,7 +191,8 @@ func (g *interfaceGenerator) shiftDynamic(bufVar, name string) {
g.emit("%s = %s[%s.SizeBytes():]\n", bufVar, bufVar, name)
}
-func (g *interfaceGenerator) marshalScalar(accessor, typ string, bufVar string) {
+// marshalStructFieldScalar writes a single scalar field from a struct to a byte slice.
+func (g *interfaceGenerator) marshalStructFieldScalar(accessor, typ, bufVar string) {
switch typ {
case "int8", "uint8", "byte":
g.emit("%s[0] = byte(%s)\n", bufVar, accessor)
@@ -213,43 +215,27 @@ func (g *interfaceGenerator) marshalScalar(accessor, typ string, bufVar string)
}
}
-func (g *interfaceGenerator) unmarshalScalar(accessor, typ string, bufVar string) {
+// unmarshalStructFieldScalar reads a single scalar field from a struct, from a
+// byte slice.
+func (g *interfaceGenerator) unmarshalStructFieldScalar(accessor, typ, bufVar string) {
switch typ {
- case "int8":
- g.emit("%s = int8(%s[0])\n", accessor, bufVar)
- g.shift(bufVar, 1)
- case "uint8":
- g.emit("%s = uint8(%s[0])\n", accessor, bufVar)
- g.shift(bufVar, 1)
case "byte":
g.emit("%s = %s[0]\n", accessor, bufVar)
g.shift(bufVar, 1)
-
- case "int16":
- g.recordUsedImport("usermem")
- g.emit("%s = int16(usermem.ByteOrder.Uint16(%s[:2]))\n", accessor, bufVar)
- g.shift(bufVar, 2)
- case "uint16":
+ case "int8", "uint8":
+ g.emit("%s = %s(%s[0])\n", accessor, typ, bufVar)
+ g.shift(bufVar, 1)
+ case "int16", "uint16":
g.recordUsedImport("usermem")
- g.emit("%s = usermem.ByteOrder.Uint16(%s[:2])\n", accessor, bufVar)
+ g.emit("%s = %s(usermem.ByteOrder.Uint16(%s[:2]))\n", accessor, typ, bufVar)
g.shift(bufVar, 2)
-
- case "int32":
- g.recordUsedImport("usermem")
- g.emit("%s = int32(usermem.ByteOrder.Uint32(%s[:4]))\n", accessor, bufVar)
- g.shift(bufVar, 4)
- case "uint32":
+ case "int32", "uint32":
g.recordUsedImport("usermem")
- g.emit("%s = usermem.ByteOrder.Uint32(%s[:4])\n", accessor, bufVar)
+ g.emit("%s = %s(usermem.ByteOrder.Uint32(%s[:4]))\n", accessor, typ, bufVar)
g.shift(bufVar, 4)
-
- case "int64":
- g.recordUsedImport("usermem")
- g.emit("%s = int64(usermem.ByteOrder.Uint64(%s[:8]))\n", accessor, bufVar)
- g.shift(bufVar, 8)
- case "uint64":
+ case "int64", "uint64":
g.recordUsedImport("usermem")
- g.emit("%s = usermem.ByteOrder.Uint64(%s[:8])\n", accessor, bufVar)
+ g.emit("%s = %s(usermem.ByteOrder.Uint64(%s[:8]))\n", accessor, typ, bufVar)
g.shift(bufVar, 8)
default:
g.emit("%s.UnmarshalBytes(%s[:%s.SizeBytes()])\n", accessor, bufVar, accessor)
@@ -258,6 +244,49 @@ func (g *interfaceGenerator) unmarshalScalar(accessor, typ string, bufVar string
}
}
+// marshalPrimitiveScalar writes a single primitive variable to a byte slice.
+func (g *interfaceGenerator) marshalPrimitiveScalar(accessor, typ, bufVar string) {
+ switch typ {
+ case "int8", "uint8", "byte":
+ g.emit("%s[0] = byte(*%s)\n", bufVar, accessor)
+ case "int16", "uint16":
+ g.recordUsedImport("usermem")
+ g.emit("usermem.ByteOrder.PutUint16(%s[:2], uint16(*%s))\n", bufVar, accessor)
+ case "int32", "uint32":
+ g.recordUsedImport("usermem")
+ g.emit("usermem.ByteOrder.PutUint32(%s[:4], uint32(*%s))\n", bufVar, accessor)
+ case "int64", "uint64":
+ g.recordUsedImport("usermem")
+ g.emit("usermem.ByteOrder.PutUint64(%s[:8], uint64(*%s))\n", bufVar, accessor)
+ default:
+ g.emit("inner := (*%s)(%s)\n", typ, accessor)
+ g.emit("inner.MarshalBytes(%s[:%s.SizeBytes()])\n", bufVar, accessor)
+ }
+}
+
+// unmarshalPrimitiveScalar read a single primitive variable from a byte slice.
+func (g *interfaceGenerator) unmarshalPrimitiveScalar(accessor, typ, bufVar, typeCast string) {
+ switch typ {
+ case "byte":
+ g.emit("*%s = %s(%s[0])\n", accessor, typeCast, bufVar)
+ case "int8", "uint8":
+ g.emit("*%s = %s(%s(%s[0]))\n", accessor, typeCast, typ, bufVar)
+ case "int16", "uint16":
+ g.recordUsedImport("usermem")
+ g.emit("*%s = %s(%s(usermem.ByteOrder.Uint16(%s[:2])))\n", accessor, typeCast, typ, bufVar)
+ case "int32", "uint32":
+ g.recordUsedImport("usermem")
+ g.emit("*%s = %s(%s(usermem.ByteOrder.Uint32(%s[:4])))\n", accessor, typeCast, typ, bufVar)
+
+ case "int64", "uint64":
+ g.recordUsedImport("usermem")
+ g.emit("*%s = %s(%s(usermem.ByteOrder.Uint64(%s[:8])))\n", accessor, typeCast, typ, bufVar)
+ default:
+ g.emit("inner := (*%s)(%s)\n", typ, accessor)
+ g.emit("inner.UnmarshalBytes(%s[:%s.SizeBytes()])\n", bufVar, accessor)
+ }
+}
+
// areFieldsPackedExpression returns a go expression checking whether g.t's fields are
// packed. Returns "", false if g.t has no fields that may be potentially
// packed, otherwise returns <clause>, true, where <clause> is an expression
@@ -274,7 +303,7 @@ func (g *interfaceGenerator) areFieldsPackedExpression() (string, bool) {
return strings.Join(cs, " && "), true
}
-func (g *interfaceGenerator) emitMarshallable() {
+func (g *interfaceGenerator) emitMarshallableForStruct() {
// Is g.t a packed struct without consideing field types?
thisPacked := true
g.forEachField(func(f *ast.Field) {
@@ -357,10 +386,10 @@ func (g *interfaceGenerator) emitMarshallable() {
}
return
}
- g.marshalScalar(g.fieldAccessor(n), t.Name, "dst")
+ g.marshalStructFieldScalar(g.fieldAccessor(n), t.Name, "dst")
},
selector: func(n, tX, tSel *ast.Ident) {
- g.marshalScalar(g.fieldAccessor(n), fmt.Sprintf("%s.%s", tX.Name, tSel.Name), "dst")
+ g.marshalStructFieldScalar(g.fieldAccessor(n), fmt.Sprintf("%s.%s", tX.Name, tSel.Name), "dst")
},
array: func(n, t *ast.Ident, size int) {
if n.Name == "_" {
@@ -377,9 +406,9 @@ func (g *interfaceGenerator) emitMarshallable() {
return
}
- g.emit("for i := 0; i < %d; i++ {\n", size)
+ g.emit("for idx := 0; idx < %d; idx++ {\n", size)
g.inIndent(func() {
- g.marshalScalar(fmt.Sprintf("%s[i]", g.fieldAccessor(n)), t.Name, "dst")
+ g.marshalStructFieldScalar(fmt.Sprintf("%s[idx]", g.fieldAccessor(n)), t.Name, "dst")
})
g.emit("}\n")
},
@@ -406,10 +435,10 @@ func (g *interfaceGenerator) emitMarshallable() {
}
return
}
- g.unmarshalScalar(g.fieldAccessor(n), t.Name, "src")
+ g.unmarshalStructFieldScalar(g.fieldAccessor(n), t.Name, "src")
},
selector: func(n, tX, tSel *ast.Ident) {
- g.unmarshalScalar(g.fieldAccessor(n), fmt.Sprintf("%s.%s", tX.Name, tSel.Name), "src")
+ g.unmarshalStructFieldScalar(g.fieldAccessor(n), fmt.Sprintf("%s.%s", tX.Name, tSel.Name), "src")
},
array: func(n, t *ast.Ident, size int) {
if n.Name == "_" {
@@ -426,9 +455,9 @@ func (g *interfaceGenerator) emitMarshallable() {
return
}
- g.emit("for i := 0; i < %d; i++ {\n", size)
+ g.emit("for idx := 0; idx < %d; idx++ {\n", size)
g.inIndent(func() {
- g.unmarshalScalar(fmt.Sprintf("%s[i]", g.fieldAccessor(n)), t.Name, "src")
+ g.unmarshalStructFieldScalar(fmt.Sprintf("%s[idx]", g.fieldAccessor(n)), t.Name, "src")
})
g.emit("}\n")
},
@@ -507,13 +536,14 @@ func (g *interfaceGenerator) emitMarshallable() {
g.emit("// CopyOut implements marshal.Marshallable.CopyOut.\n")
g.recordUsedImport("marshal")
g.recordUsedImport("usermem")
- g.emit("func (%s *%s) CopyOut(task marshal.Task, addr usermem.Addr) (int, error) {\n", g.r, g.typeName())
+ g.emit("func (%s *%s) CopyOut(task marshal.Task, addr usermem.Addr) error {\n", g.r, g.typeName())
g.inIndent(func() {
fallback := func() {
g.emit("// Type %s doesn't have a packed layout in memory, fall back to MarshalBytes.\n", g.typeName())
g.emit("buf := task.CopyScratchBuffer(%s.SizeBytes())\n", g.r)
g.emit("%s.MarshalBytes(buf)\n", g.r)
- g.emit("return task.CopyOutBytes(addr, buf)\n")
+ g.emit("_, err := task.CopyOutBytes(addr, buf)\n")
+ g.emit("return err\n")
}
if thisPacked {
g.recordUsedImport("reflect")
@@ -539,11 +569,11 @@ func (g *interfaceGenerator) emitMarshallable() {
g.emit("hdr.Len = %s.SizeBytes()\n", g.r)
g.emit("hdr.Cap = %s.SizeBytes()\n\n", g.r)
- g.emit("len, err := task.CopyOutBytes(addr, buf)\n")
+ g.emit("_, err := task.CopyOutBytes(addr, buf)\n")
g.emit("// Since we bypassed the compiler's escape analysis, indicate that %s\n", g.r)
g.emit("// must live until after the CopyOutBytes.\n")
g.emit("runtime.KeepAlive(%s)\n", g.r)
- g.emit("return len, err\n")
+ g.emit("return err\n")
} else {
fallback()
}
@@ -553,20 +583,20 @@ func (g *interfaceGenerator) emitMarshallable() {
g.emit("// CopyIn implements marshal.Marshallable.CopyIn.\n")
g.recordUsedImport("marshal")
g.recordUsedImport("usermem")
- g.emit("func (%s *%s) CopyIn(task marshal.Task, addr usermem.Addr) (int, error) {\n", g.r, g.typeName())
+ g.emit("func (%s *%s) CopyIn(task marshal.Task, addr usermem.Addr) error {\n", g.r, g.typeName())
g.inIndent(func() {
fallback := func() {
g.emit("// Type %s doesn't have a packed layout in memory, fall back to UnmarshalBytes.\n", g.typeName())
g.emit("buf := task.CopyScratchBuffer(%s.SizeBytes())\n", g.r)
- g.emit("n, err := task.CopyInBytes(addr, buf)\n")
+ g.emit("_, err := task.CopyInBytes(addr, buf)\n")
g.emit("if err != nil {\n")
g.inIndent(func() {
- g.emit("return n, err\n")
+ g.emit("return err\n")
})
g.emit("}\n")
g.emit("%s.UnmarshalBytes(buf)\n", g.r)
- g.emit("return n, nil\n")
+ g.emit("return nil\n")
}
if thisPacked {
g.recordUsedImport("reflect")
@@ -592,11 +622,11 @@ func (g *interfaceGenerator) emitMarshallable() {
g.emit("hdr.Len = %s.SizeBytes()\n", g.r)
g.emit("hdr.Cap = %s.SizeBytes()\n\n", g.r)
- g.emit("len, err := task.CopyInBytes(addr, buf)\n")
+ g.emit("_, err := task.CopyInBytes(addr, buf)\n")
g.emit("// Since we bypassed the compiler's escape analysis, indicate that %s\n", g.r)
g.emit("// must live until after the CopyInBytes.\n")
g.emit("runtime.KeepAlive(%s)\n", g.r)
- g.emit("return len, err\n")
+ g.emit("return err\n")
} else {
fallback()
}
@@ -649,3 +679,144 @@ func (g *interfaceGenerator) emitMarshallable() {
})
g.emit("}\n\n")
}
+
+// emitMarshallableForPrimitiveNewtype outputs code to implement the
+// marshal.Marshallable interface for a newtype on a primitive. Primitive
+// newtypes are always packed, so we can omit the various fallbacks required for
+// non-packed structs.
+func (g *interfaceGenerator) emitMarshallableForPrimitiveNewtype() {
+ g.recordUsedImport("io")
+ g.recordUsedImport("marshal")
+ g.recordUsedImport("reflect")
+ g.recordUsedImport("runtime")
+ g.recordUsedImport("safecopy")
+ g.recordUsedImport("unsafe")
+ g.recordUsedImport("usermem")
+
+ nt := g.t.Type.(*ast.Ident)
+
+ g.emit("// SizeBytes implements marshal.Marshallable.SizeBytes.\n")
+ g.emit("func (%s *%s) SizeBytes() int {\n", g.r, g.typeName())
+ g.inIndent(func() {
+ if size, dynamic := g.scalarSize(nt); !dynamic {
+ g.emit("return %d\n", size)
+ } else {
+ g.emit("return (*%s)(nil).SizeBytes()\n", nt.Name)
+ }
+ })
+ g.emit("}\n\n")
+
+ g.emit("// MarshalBytes implements marshal.Marshallable.MarshalBytes.\n")
+ g.emit("func (%s *%s) MarshalBytes(dst []byte) {\n", g.r, g.typeName())
+ g.inIndent(func() {
+ g.marshalPrimitiveScalar(g.r, nt.Name, "dst")
+ })
+ g.emit("}\n\n")
+
+ g.emit("// UnmarshalBytes implements marshal.Marshallable.UnmarshalBytes.\n")
+ g.emit("func (%s *%s) UnmarshalBytes(src []byte) {\n", g.r, g.typeName())
+ g.inIndent(func() {
+ g.unmarshalPrimitiveScalar(g.r, nt.Name, "src", g.typeName())
+ })
+ g.emit("}\n\n")
+
+ g.emit("// Packed implements marshal.Marshallable.Packed.\n")
+ g.emit("func (%s *%s) Packed() bool {\n", g.r, g.typeName())
+ g.inIndent(func() {
+ g.emit("// Scalar newtypes are always packed.\n")
+ g.emit("return true\n")
+ })
+ g.emit("}\n\n")
+
+ g.emit("// MarshalUnsafe implements marshal.Marshallable.MarshalUnsafe.\n")
+ g.emit("func (%s *%s) MarshalUnsafe(dst []byte) {\n", g.r, g.typeName())
+ g.inIndent(func() {
+ g.emit("safecopy.CopyIn(dst, unsafe.Pointer(%s))\n", g.r)
+ })
+ g.emit("}\n\n")
+
+ g.emit("// UnmarshalUnsafe implements marshal.Marshallable.UnmarshalUnsafe.\n")
+ g.emit("func (%s *%s) UnmarshalUnsafe(src []byte) {\n", g.r, g.typeName())
+ g.inIndent(func() {
+ g.emit("safecopy.CopyOut(unsafe.Pointer(%s), src)\n", g.r)
+ })
+ g.emit("}\n\n")
+
+ g.emit("// CopyOut implements marshal.Marshallable.CopyOut.\n")
+ g.emit("func (%s *%s) CopyOut(task marshal.Task, addr usermem.Addr) error {\n", g.r, g.typeName())
+ g.inIndent(func() {
+ // Fast serialization.
+ g.emit("// Bypass escape analysis on %s. The no-op arithmetic operation on the\n", g.r)
+ g.emit("// pointer makes the compiler think val doesn't depend on %s.\n", g.r)
+ g.emit("// See src/runtime/stubs.go:noescape() in the golang toolchain.\n")
+ g.emit("ptr := unsafe.Pointer(%s)\n", g.r)
+ g.emit("val := uintptr(ptr)\n")
+ g.emit("val = val^0\n\n")
+
+ g.emit("// Construct a slice backed by %s's underlying memory.\n", g.r)
+ g.emit("var buf []byte\n")
+ g.emit("hdr := (*reflect.SliceHeader)(unsafe.Pointer(&buf))\n")
+ g.emit("hdr.Data = val\n")
+ g.emit("hdr.Len = %s.SizeBytes()\n", g.r)
+ g.emit("hdr.Cap = %s.SizeBytes()\n\n", g.r)
+
+ g.emit("_, err := task.CopyOutBytes(addr, buf)\n")
+ g.emit("// Since we bypassed the compiler's escape analysis, indicate that %s\n", g.r)
+ g.emit("// must live until after the CopyOutBytes.\n")
+ g.emit("runtime.KeepAlive(%s)\n", g.r)
+ g.emit("return err\n")
+ })
+ g.emit("}\n\n")
+
+ g.emit("// CopyIn implements marshal.Marshallable.CopyIn.\n")
+ g.emit("func (%s *%s) CopyIn(task marshal.Task, addr usermem.Addr) error {\n", g.r, g.typeName())
+ g.inIndent(func() {
+ g.emit("// Bypass escape analysis on %s. The no-op arithmetic operation on the\n", g.r)
+ g.emit("// pointer makes the compiler think val doesn't depend on %s.\n", g.r)
+ g.emit("// See src/runtime/stubs.go:noescape() in the golang toolchain.\n")
+ g.emit("ptr := unsafe.Pointer(%s)\n", g.r)
+ g.emit("val := uintptr(ptr)\n")
+ g.emit("val = val^0\n\n")
+
+ g.emit("// Construct a slice backed by %s's underlying memory.\n", g.r)
+ g.emit("var buf []byte\n")
+ g.emit("hdr := (*reflect.SliceHeader)(unsafe.Pointer(&buf))\n")
+ g.emit("hdr.Data = val\n")
+ g.emit("hdr.Len = %s.SizeBytes()\n", g.r)
+ g.emit("hdr.Cap = %s.SizeBytes()\n\n", g.r)
+
+ g.emit("_, err := task.CopyInBytes(addr, buf)\n")
+ g.emit("// Since we bypassed the compiler's escape analysis, indicate that %s\n", g.r)
+ g.emit("// must live until after the CopyInBytes.\n")
+ g.emit("runtime.KeepAlive(%s)\n", g.r)
+ g.emit("return err\n")
+ })
+ g.emit("}\n\n")
+
+ g.emit("// WriteTo implements io.WriterTo.WriteTo.\n")
+ g.emit("func (%s *%s) WriteTo(w io.Writer) (int64, error) {\n", g.r, g.typeName())
+ g.inIndent(func() {
+ g.emit("// Bypass escape analysis on %s. The no-op arithmetic operation on the\n", g.r)
+ g.emit("// pointer makes the compiler think val doesn't depend on %s.\n", g.r)
+ g.emit("// See src/runtime/stubs.go:noescape() in the golang toolchain.\n")
+ g.emit("ptr := unsafe.Pointer(%s)\n", g.r)
+ g.emit("val := uintptr(ptr)\n")
+ g.emit("val = val^0\n\n")
+
+ g.emit("// Construct a slice backed by %s's underlying memory.\n", g.r)
+ g.emit("var buf []byte\n")
+ g.emit("hdr := (*reflect.SliceHeader)(unsafe.Pointer(&buf))\n")
+ g.emit("hdr.Data = val\n")
+ g.emit("hdr.Len = %s.SizeBytes()\n", g.r)
+ g.emit("hdr.Cap = %s.SizeBytes()\n\n", g.r)
+
+ g.emit("len, err := w.Write(buf)\n")
+ g.emit("// Since we bypassed the compiler's escape analysis, indicate that %s\n", g.r)
+ g.emit("// must live until after the Write.\n")
+ g.emit("runtime.KeepAlive(%s)\n", g.r)
+ g.emit("return int64(len), err\n")
+
+ })
+ g.emit("}\n\n")
+
+}
diff --git a/tools/go_marshal/gomarshal/generator_tests.go b/tools/go_marshal/gomarshal/generator_tests.go
index 8c28b00d0..8ba47eb67 100644
--- a/tools/go_marshal/gomarshal/generator_tests.go
+++ b/tools/go_marshal/gomarshal/generator_tests.go
@@ -49,9 +49,6 @@ type testGenerator struct {
}
func newTestGenerator(t *ast.TypeSpec) *testGenerator {
- if _, ok := t.Type.(*ast.StructType); !ok {
- panic(fmt.Sprintf("Attempting to generate code for a not struct type %v", t))
- }
g := &testGenerator{
t: t,
r: receiverName(t),
@@ -69,14 +66,6 @@ func (g *testGenerator) typeName() string {
return g.t.Name.Name
}
-func (g *testGenerator) forEachField(fn func(f *ast.Field)) {
- // This is guaranteed to succeed because g.t is always a struct.
- st := g.t.Type.(*ast.StructType)
- for _, field := range st.Fields.List {
- fn(field)
- }
-}
-
func (g *testGenerator) testFuncName(base string) string {
return fmt.Sprintf("%s%s", base, strings.Title(g.t.Name.Name))
}
@@ -89,10 +78,10 @@ func (g *testGenerator) inTestFunction(name string, body func()) {
func (g *testGenerator) emitTestNonZeroSize() {
g.inTestFunction("TestSizeNonZero", func() {
- g.emit("x := &%s{}\n", g.typeName())
+ g.emit("var x %v\n", g.typeName())
g.emit("if x.SizeBytes() == 0 {\n")
g.inIndent(func() {
- g.emit("t.Fatal(\"Marshallable.Size() should not return zero\")\n")
+ g.emit("t.Fatal(\"Marshallable.SizeBytes() should not return zero\")\n")
})
g.emit("}\n")
})
@@ -100,7 +89,7 @@ func (g *testGenerator) emitTestNonZeroSize() {
func (g *testGenerator) emitTestSuspectAlignment() {
g.inTestFunction("TestSuspectAlignment", func() {
- g.emit("x := %s{}\n", g.typeName())
+ g.emit("var x %v\n", g.typeName())
g.emit("analysis.AlignmentCheck(t, reflect.TypeOf(x))\n")
})
}
diff --git a/tools/go_marshal/marshal/marshal.go b/tools/go_marshal/marshal/marshal.go
index 20353850d..f129788e0 100644
--- a/tools/go_marshal/marshal/marshal.go
+++ b/tools/go_marshal/marshal/marshal.go
@@ -91,12 +91,12 @@ type Marshallable interface {
// marshalled does not escape. The implementation should avoid creating
// extra copies in memory by directly deserializing to the object's
// underlying memory.
- CopyIn(task Task, addr usermem.Addr) (int, error)
+ CopyIn(task Task, addr usermem.Addr) error
// CopyOut serializes a Marshallable type to a task's memory. This may only
// be called from a task goroutine. This is more efficient than calling
// MarshalUnsafe on Marshallable.Packed types, as the type being serialized
// does not escape. The implementation should avoid creating extra copies in
// memory by directly serializing from the object's underlying memory.
- CopyOut(task Task, addr usermem.Addr) (int, error)
+ CopyOut(task Task, addr usermem.Addr) error
}
diff --git a/tools/go_marshal/test/test.go b/tools/go_marshal/test/test.go
index 8de02d707..93229dedb 100644
--- a/tools/go_marshal/test/test.go
+++ b/tools/go_marshal/test/test.go
@@ -103,3 +103,13 @@ type Stat struct {
CTime Timespec
_ [3]int64
}
+
+// SignalSet is an example marshallable newtype on a primitive.
+//
+// +marshal
+type SignalSet uint64
+
+// SignalSetAlias is an example newtype on another marshallable type.
+//
+// +marshal
+type SignalSetAlias SignalSet
diff --git a/tools/installers/master.sh b/tools/installers/master.sh
index 7b1956454..52f9734a6 100755
--- a/tools/installers/master.sh
+++ b/tools/installers/master.sh
@@ -15,6 +15,21 @@
# limitations under the License.
# Install runsc from the master branch.
+set -e
+
curl -fsSL https://gvisor.dev/archive.key | sudo apt-key add -
add-apt-repository "deb https://storage.googleapis.com/gvisor/releases release main"
-apt-get update && apt-get install -y runsc
+while true; do
+ if apt-get update; then
+ apt-get install -y runsc
+ break
+ fi
+ result=$?
+ # Check if apt update failed to aquire the file lock.
+ if [[ $result -ne 100 ]]; then
+ exit $result
+ fi
+done
+runsc install
+service docker restart
+