diff options
756 files changed, 33932 insertions, 13328 deletions
diff --git a/.github/ISSUE_TEMPLATE/bug_report.md b/.github/ISSUE_TEMPLATE/bug_report.md index 49a1ba697..50d187633 100644 --- a/.github/ISSUE_TEMPLATE/bug_report.md +++ b/.github/ISSUE_TEMPLATE/bug_report.md @@ -23,7 +23,7 @@ reproduced with software that is publicly available. Please include the following details of your environment: -* `runsc -v` +* `runsc -version` * `docker version` or `docker info` (if available) * `kubectl version` and `kubectl get nodes` (if using Kubernetes) * `uname -a` diff --git a/.github/workflows/build.yml b/.github/workflows/build.yml index cf782a580..1e1677ab5 100644 --- a/.github/workflows/build.yml +++ b/.github/workflows/build.yml @@ -3,9 +3,11 @@ on: push: branches: - master + - feature/** pull_request: branches: - master + - feature/** jobs: default: diff --git a/.github/workflows/go.yml b/.github/workflows/go.yml index 4da3853b2..3a6a592d1 100644 --- a/.github/workflows/go.yml +++ b/.github/workflows/go.yml @@ -6,6 +6,7 @@ on: pull_request: branches: - master + - feature/** jobs: generate: @@ -49,7 +50,12 @@ jobs: key: ${{ runner.os }}-bazel-${{ hashFiles('WORKSPACE') }} restore-keys: | ${{ runner.os }}-bazel- + # Create gopath to merge the changes. The first execution will create + # symlinks to the cache, e.g. bazel-bin. Once the cache is setup, delete + # old gopath files that may exist from previous runs (and could contain + # files that are now deleted). Then run gopath again for good. - run: | + make build TARGETS="//:gopath" rm -rf bazel-bin/gopath make build TARGETS="//:gopath" - run: tools/go_branch.sh @@ -114,7 +114,7 @@ runsc: ## Builds the runsc binary. .PHONY: runsc debian: ## Builds the debian packages. - @$(call submake,build OPTIONS="-c opt" TARGETS="//runsc:runsc-debian") + @$(call submake,build OPTIONS="-c opt" TARGETS="//debian:debian") .PHONY: debian smoke-tests: ## Runs a simple smoke test after build runsc. @@ -129,10 +129,9 @@ tests: unit-tests @$(call submake,test TARGETS="test/syscalls/...") .PHONY: tests - integration-tests: ## Run all standard integration tests. integration-tests: docker-tests overlay-tests hostnet-tests swgso-tests -integration-tests: do-tests kvm-tests root-tests containerd-tests +integration-tests: do-tests kvm-tests containerd-test-1.3.4 .PHONY: integration-tests network-tests: ## Run all networking integration tests. @@ -157,6 +156,10 @@ syscall-tests: syscall-ptrace-tests syscall-kvm-tests syscall-native-tests @$(call submake,install-test-runtime) @$(call submake,test-runtime OPTIONS="--test_timeout=10800" TARGETS="//test/runtimes:$*") +%-runtime-tests_vfs2: load-runtimes_% + @$(call submake,install-test-runtime RUNTIME="vfs2" ARGS="--vfs2") + @$(call submake,test-runtime RUNTIME="vfs2" OPTIONS="--test_timeout=10800" TARGETS="//test/runtimes:$*") + do-tests: runsc @$(call submake,run TARGETS="//runsc" ARGS="--rootless do true") @$(call submake,run TARGETS="//runsc" ARGS="--rootless -network=none do true") @@ -182,6 +185,7 @@ swgso-tests: load-basic-images @$(call submake,install-test-runtime RUNTIME="swgso" ARGS="--software-gso=true --gso=false") @$(call submake,test-runtime RUNTIME="swgso" TARGETS="$(INTEGRATION_TARGETS)") .PHONY: swgso-tests + hostnet-tests: load-basic-images @$(call submake,install-test-runtime RUNTIME="hostnet" ARGS="--network=host") @$(call submake,test-runtime RUNTIME="hostnet" OPTIONS="--test_arg=-checkpoint=false" TARGETS="$(INTEGRATION_TARGETS)") @@ -196,6 +200,8 @@ kvm-tests: load-basic-images .PHONY: kvm-tests iptables-tests: load-iptables + @sudo modprobe iptable_filter + @sudo modprobe ip6table_filter @$(call submake,test-runtime RUNTIME="runc" TARGETS="//test/iptables:iptables_test") @$(call submake,install-test-runtime RUNTIME="iptables" ARGS="--net-raw") @$(call submake,test-runtime RUNTIME="iptables" TARGETS="//test/iptables:iptables_test") @@ -207,21 +213,18 @@ packetdrill-tests: load-packetdrill .PHONY: packetdrill-tests packetimpact-tests: load-packetimpact - @sudo modprobe iptable_filter ip6table_filter + @sudo modprobe iptable_filter + @sudo modprobe ip6table_filter @$(call submake,install-test-runtime RUNTIME="packetimpact") @$(call submake,test-runtime OPTIONS="--jobs=HOST_CPUS*3 --local_test_jobs=HOST_CPUS*3" RUNTIME="packetimpact" TARGETS="$(shell $(MAKE) query TARGETS='attr(tags, packetimpact, tests(//...))')") .PHONY: packetimpact-tests -root-tests: load-basic-images - @$(call submake,install-test-runtime) - @$(call submake,sudo TARGETS="//test/root:root_test" ARGS="-test.v") -.PHONY: root-tests - # Specific containerd version tests. -containerd-test-%: load-basic_alpine load-basic_python load-basic_busybox load-basic_resolv load-basic_httpd install-test-runtime +containerd-test-%: load-basic_alpine load-basic_python load-basic_busybox load-basic_resolv load-basic_httpd load-basic_ubuntu + @$(call submake,install-test-runtime RUNTIME="root") @CONTAINERD_VERSION=$* $(MAKE) sudo TARGETS="tools/installers:containerd" @$(MAKE) sudo TARGETS="tools/installers:shim" - @$(MAKE) sudo TARGETS="test/root:root_test" ARGS="-test.v" + @$(MAKE) sudo TARGETS="test/root:root_test" ARGS="--runtime=root -test.v" # Note that we can't run containerd-test-1.1.8 tests here. # @@ -294,15 +297,17 @@ $(RELEASE_KEY): echo Name-Email: test@example.com >> $$C && \ echo Expire-Date: 0 >> $$C && \ echo %commit >> $$C && \ - gpg --batch $(GPG_TEST_OPTIONS) --passphrase '' --no-default-keyring --keyring $$T --no-tty --gen-key $$C && \ - gpg --batch $(GPG_TEST_OPTIONS) --export-secret-keys --no-default-keyring --keyring $$T --secret-keyring $$T > $@; \ + gpg --batch $(GPG_TEST_OPTIONS) --passphrase '' --no-default-keyring --secret-keyring $$T --no-tty --gen-key $$C && \ + gpg --batch $(GPG_TEST_OPTIONS) --export-secret-keys --no-default-keyring --secret-keyring $$T > $@; \ rc=$$?; rm -f $$T $$C; exit $$rc release: $(RELEASE_KEY) ## Builds a release. @mkdir -p $(RELEASE_ROOT) @T=$$(mktemp -d /tmp/release.XXXXXX); \ - $(call submake,copy TARGETS="runsc" DESTINATION=$$T) && \ - $(call submake,copy TARGETS="runsc:runsc-debian" DESTINATION=$$T) && \ + $(call submake,copy TARGETS="//runsc:runsc" DESTINATION=$$T) && \ + $(call submake,copy TARGETS="//shim/v1:gvisor-containerd-shim" DESTINATION=$$T) && \ + $(call submake,copy TARGETS="//shim/v2:containerd-shim-runsc-v1" DESTINATION=$$T) && \ + $(call submake,copy TARGETS="//debian:debian" DESTINATION=$$T) && \ NIGHTLY=$(RELEASE_NIGHTLY) tools/make_release.sh $(RELEASE_KEY) $(RELEASE_ROOT) $$T/*; \ rc=$$?; rm -rf $$T; exit $$rc .PHONY: release @@ -2,6 +2,7 @@ ![](https://github.com/google/gvisor/workflows/Build/badge.svg) [![gVisor chat](https://badges.gitter.im/gvisor/community.png)](https://gitter.im/gvisor/community) +[![code search](https://img.shields.io/badge/code-search-blue)](https://cs.opensource.google/gvisor/gvisor) ## What is gVisor? diff --git a/debian/BUILD b/debian/BUILD new file mode 100644 index 000000000..331f44a5c --- /dev/null +++ b/debian/BUILD @@ -0,0 +1,59 @@ +load("//tools:defs.bzl", "pkg_deb", "pkg_tar") + +package(licenses = ["notice"]) + +pkg_tar( + name = "debian-bin", + srcs = [ + "//runsc", + "//shim/v1:gvisor-containerd-shim", + "//shim/v2:containerd-shim-runsc-v1", + ], + mode = "0755", + package_dir = "/usr/bin", +) + +pkg_tar( + name = "debian-data", + extension = "tar.gz", + deps = [ + ":debian-bin", + "//shim:config", + ], +) + +genrule( + name = "debian-version", + # Note that runsc must appear in the srcs parameter and not the tools + # parameter, otherwise it will not be stamped. This is reasonable, as tools + # may be encoded differently in the build graph (cached more aggressively + # because they are assumes to be hermetic). + srcs = ["//runsc"], + outs = ["version.txt"], + # Note that the little dance here is necessary because files in the $(SRCS) + # attribute are not executable by default, and we can't touch in place. + cmd = "cp $(location //runsc:runsc) $(@D)/runsc && \ + chmod a+x $(@D)/runsc && \ + $(@D)/runsc -version | grep version | sed 's/^[^0-9]*//' > $@ && \ + rm -f $(@D)/runsc", + stamp = 1, +) + +pkg_deb( + name = "debian", + architecture = "amd64", + data = ":debian-data", + # Note that the description_file will be flatten (all newlines removed), + # and therefore it is kept to a simple one-line description. The expected + # format for debian packages is "short summary\nLonger explanation of + # tool." and this is impossible with the flattening. + description_file = "description", + homepage = "https://gvisor.dev/", + maintainer = "The gVisor Authors <gvisor-dev@googlegroups.com>", + package = "runsc", + postinst = "postinst.sh", + version_file = ":version.txt", + visibility = [ + "//visibility:public", + ], +) diff --git a/runsc/debian/description b/debian/description index 9e8e08805..9e8e08805 100644 --- a/runsc/debian/description +++ b/debian/description diff --git a/runsc/debian/postinst.sh b/debian/postinst.sh index d1e28e17b..6a326f823 100755 --- a/runsc/debian/postinst.sh +++ b/debian/postinst.sh @@ -21,7 +21,7 @@ fi # Update docker configuration. if [ -f /etc/docker/daemon.json ]; then runsc install - if systemctl status docker 2>/dev/null; then + if systemctl is-active -q docker; then systemctl restart docker || echo "unable to restart docker; you must do so manually." >&2 fi fi diff --git a/g3doc/architecture_guide/security.md b/g3doc/architecture_guide/security.md index b99b86332..9363d834c 100644 --- a/g3doc/architecture_guide/security.md +++ b/g3doc/architecture_guide/security.md @@ -104,7 +104,7 @@ interactions with a guest operating system and a set of virtualized hardware devices. These hardware devices are then implemented via the host System API by a Virtual Machine Monitor (VMM). The Sentry similarly prevents direct interactions by providing its own implementation of the System API that the -application must interact with. Applications are not able to to directly craft +application must interact with. Applications are not able to directly craft specific arguments or flags for the host System API, or interact directly with host primitives. diff --git a/g3doc/style.md b/g3doc/style.md index d10549fe9..8258b0233 100644 --- a/g3doc/style.md +++ b/g3doc/style.md @@ -46,6 +46,15 @@ protected. Each field or variable protected by a mutex should state as such in a comment on the field or variable declaration. +### Function comments + +Functions with special entry conditions (e.g., a lock must be held) should state +these conditions in a `Preconditions:` comment block. One condition per line; +multiple conditions are specified with a bullet (`*`). + +Functions with notable exit conditions (e.g., a `Done` function must eventually +be called by the caller) can similarly have a `Postconditions:` block. + ### Unused returns Unused returns should be explicitly ignored with underscores. If there is a diff --git a/g3doc/user_guide/FAQ.md b/g3doc/user_guide/FAQ.md index 89df65e99..514fe3918 100644 --- a/g3doc/user_guide/FAQ.md +++ b/g3doc/user_guide/FAQ.md @@ -74,11 +74,10 @@ directories. ### I'm getting an error like: `panic: unable to attach: operation not permitted` or `fork/exec /proc/self/exe: invalid argument: unknown` {#runsc-perms} -Make sure that permissions and the owner is correct on the `runsc` binary. +Make sure that permissions is correct on the `runsc` binary. ```bash -sudo chown root:root /usr/local/bin/runsc -sudo chmod 0755 /usr/local/bin/runsc +sudo chmod a+rx /usr/local/bin/runsc ``` ### I'm getting an error like `mount submount "/etc/hostname": creating mount with source ".../hostname": input/output error: unknown.` {#memlock} @@ -96,6 +95,30 @@ containerd. See [issue #1765](https://gvisor.dev/issue/1765) for more details. +### I'm getting an error like `RuntimeHandler "runsc" not supported` {#runtime-handler} + +This error indicates that the Kubernetes CRI runtime was not set up to handle +`runsc` as a runtime handler. Please ensure that containerd configuration has +been created properly and containerd has been restarted. See the +[containerd quick start](containerd/quick_start.md) for more details. + +If you have ensured that containerd has been set up properly and you used +kubeadm to create your cluster please check if Docker is also installed on that +system. Kubeadm prefers using Docker if both Docker and containerd are +installed. + +Please recreate your cluster and set the `--cni-socket` option on kubeadm +commands. For example: + +```bash +kubeadm init --cni-socket=/var/run/containerd/containerd.sock` ... +``` + +To fix an existing cluster edit the `/var/lib/kubelet/kubeadm-flags.env` file +and set the `--container-runtime` flag to `remote` and set the +`--container-runtime-endpoint` flag to point to the containerd socket. e.g. +`/var/run/containerd/containerd.sock`. + ### My container cannot resolve another container's name when using Docker user defined bridge {#docker-bridge} This is normally indicated by errors like `bad address 'container-name'` when diff --git a/g3doc/user_guide/install.md b/g3doc/user_guide/install.md index 9afdd264d..abb9e8582 100644 --- a/g3doc/user_guide/install.md +++ b/g3doc/user_guide/install.md @@ -5,6 +5,68 @@ > Note: gVisor supports only x86\_64 and requires Linux 4.14.77+ > ([older Linux](./networking.md#gso)). +## Install latest release {#install-latest} + +To download and install the latest release manually follow these steps: + +```bash +( + set -e + URL=https://storage.googleapis.com/gvisor/releases/release/latest + wget ${URL}/runsc ${URL}/runsc.sha512 + sha512sum -c runsc.sha512 + rm -f runsc.sha512 + sudo mv runsc /usr/local/bin + sudo chmod a+rx /usr/local/bin/runsc +) +``` + +To install gVisor with Docker, run the following commands: + +```bash +/usr/local/bin/runsc install +sudo systemctl restart docker +docker run --rm --runtime=runsc hello-world +``` + +For more details about using gVisor with Docker, see +[Docker Quick Start](./quick_start/docker.md) + +Note: It is important to copy `runsc` to a location that is readable and +executable to all users, since `runsc` executes itself as user `nobody` to avoid +unnecessary privileges. The `/usr/local/bin` directory is a good place to put +the `runsc` binary. + +## Install from an `apt` repository + +First, appropriate dependencies must be installed to allow `apt` to install +packages via https: + +```bash +sudo apt-get update && \ +sudo apt-get install -y \ + apt-transport-https \ + ca-certificates \ + curl \ + gnupg-agent \ + software-properties-common +``` + +Next, the configure the key used to sign archives and the repository: + +```bash +curl -fsSL https://gvisor.dev/archive.key | sudo apt-key add - +sudo add-apt-repository "deb https://storage.googleapis.com/gvisor/releases release main" +``` + +Now the runsc package can be installed: + +```bash +sudo apt-get update && sudo apt-get install -y runsc +``` + +If you have Docker installed, it will be automatically configured. + ## Versions The `runsc` binaries and repositories are available in multiple versions and @@ -21,12 +83,16 @@ Binaries are available for every commit on the `master` branch, and are available at the following URL: `https://storage.googleapis.com/gvisor/releases/master/latest/runsc` +`https://storage.googleapis.com/gvisor/releases/master/latest/runsc.sha512` -Checksums for the release binary are at: +You can use this link with the steps described in +[Install latest release](#install-latest). -`https://storage.googleapis.com/gvisor/releases/master/latest/runsc.sha512` +For `apt` installation, use the `master` to configure the repository: -For `apt` installation, use the `master` as the `${DIST}` below. +```bash +sudo add-apt-repository "deb https://storage.googleapis.com/gvisor/releases master main" +``` ### Nightly @@ -34,18 +100,22 @@ Nightly releases are built most nights from the master branch, and are available at the following URL: `https://storage.googleapis.com/gvisor/releases/nightly/latest/runsc` - -Checksums for the release binary are at: - `https://storage.googleapis.com/gvisor/releases/nightly/latest/runsc.sha512` +You can use this link with the steps described in +[Install latest release](#install-latest). + Specific nightly releases can be found at: `https://storage.googleapis.com/gvisor/releases/nightly/${yyyy-mm-dd}/runsc` Note that a release may not be available for every day. -For `apt` installation, use the `nightly` as the `${DIST}` below. +For `apt` installation, use the `nightly` to configure the repository: + +```bash +sudo add-apt-repository "deb https://storage.googleapis.com/gvisor/releases nightly main" +``` ### Latest release @@ -53,105 +123,47 @@ The latest official release is available at the following URL: `https://storage.googleapis.com/gvisor/releases/release/latest` -For `apt` installation, use the `release` as the `${DIST}` below. - -### Specific release - -A given release release is available at the following URL: - -`https://storage.googleapis.com/gvisor/releases/release/${yyyymmdd}` - -See the [releases][releases] page for information about specific releases. - -For `apt` installation of a specific release, which may include point updates, -use the date of the release, e.g. `${yyyymmdd}`, as the `${DIST}` below. - -> Note: only newer releases may be available as `apt` repositories. - -### Point release - -A given point release is available at the following URL: - -`https://storage.googleapis.com/gvisor/releases/release/${yyyymmdd}.${rc}` - -Note that `apt` installation of a specific point release is not supported. - -## Install from an `apt` repository +You can use this link with the steps described in +[Install latest release](#install-latest). -First, appropriate dependencies must be installed to allow `apt` to install -packages via https: +For `apt` installation, use the `release` to configure the repository: ```bash -sudo apt-get update && \ -sudo apt-get install -y \ - apt-transport-https \ - ca-certificates \ - curl \ - gnupg-agent \ - software-properties-common +sudo add-apt-repository "deb https://storage.googleapis.com/gvisor/releases release main" ``` -Next, the key used to sign archives should be added to your `apt` keychain: - -```bash -curl -fsSL https://gvisor.dev/archive.key | sudo apt-key add - -``` +### Specific release -Based on the release type, you will need to substitute `${DIST}` below, using -one of: +A given release release is available at the following URL: -* `master`: For HEAD. -* `nightly`: For nightly releases. -* `release`: For the latest release. -* `${yyyymmdd}`: For a specific releases (see above). +`https://storage.googleapis.com/gvisor/releases/release/${yyyymmdd}` -The repository for the release you wish to install should be added: +You can use this link with the steps described in +[Install latest release](#install-latest). -```bash -sudo add-apt-repository "deb https://storage.googleapis.com/gvisor/releases ${DIST} main" -``` +See the [releases](https://github.com/google/gvisor/releases) page for +information about specific releases. -For example, to install the latest official release, you can use: +For `apt` installation of a specific release, which may include point updates, +use the date of the release for repository, e.g. `${yyyymmdd}`. ```bash -sudo add-apt-repository "deb https://storage.googleapis.com/gvisor/releases release main" +sudo add-apt-repository "deb https://storage.googleapis.com/gvisor/releases yyyymmdd main" ``` -Now the runsc package can be installed: - -```bash -sudo apt-get update && sudo apt-get install -y runsc -``` +> Note: only newer releases may be available as `apt` repositories. -If you have Docker installed, it will be automatically configured. +### Point release -## Install directly +A given point release is available at the following URL: -The binary URLs provided above can be used to install directly. For example, the -latest nightly binary can be downloaded, validated, and placed in an appropriate -location by running: +`https://storage.googleapis.com/gvisor/releases/release/${yyyymmdd}.${rc}` -```bash -( - set -e - URL=https://storage.googleapis.com/gvisor/releases/nightly/latest - wget ${URL}/runsc - wget ${URL}/runsc.sha512 - sha512sum -c runsc.sha512 - rm -f runsc.sha512 - sudo mv runsc /usr/local/bin - sudo chown root:root /usr/local/bin/runsc - sudo chmod 0755 /usr/local/bin/runsc -) -``` +You can use this link with the steps described in +[Install latest release](#install-latest). -**It is important to copy this binary to a location that is accessible to all -users, and ensure it is executable by all users**, since `runsc` executes itself -as user `nobody` to avoid unnecessary privileges. The `/usr/local/bin` directory -is a good place to put the `runsc` binary. +Note that `apt` installation of a specific point release is not supported. After installation, try out `runsc` by following the [Docker Quick Start](./quick_start/docker.md) or [OCI Quick Start](./quick_start/oci.md). - -[releases]: https://github.com/google/gvisor/releases diff --git a/g3doc/user_guide/quick_start/docker.md b/g3doc/user_guide/quick_start/docker.md index 6ad594ecc..ee842e453 100644 --- a/g3doc/user_guide/quick_start/docker.md +++ b/g3doc/user_guide/quick_start/docker.md @@ -22,18 +22,6 @@ named "runsc" by default. sudo runsc install ``` -You may also wish to install a runtime entry for debugging. The `runsc install` -command can accept options that will be passed to the runtime when it is invoked -by Docker. - -```bash -sudo runsc install --runtime runsc-debug -- \ - --debug \ - --debug-log=/tmp/runsc-debug.log \ - --strace \ - --log-packets -``` - You must restart the Docker daemon after installing the runtime. Typically this is done via `systemd`: @@ -85,6 +73,21 @@ $ docker run --runtime=runsc -it ubuntu dmesg Note that this is easily replicated by an attacker so applications should never use `dmesg` to verify the runtime in a security sensitive context. +## Options + +You may also wish to install a runtime entry with different options. The `runsc +install` command can accept flags that will be passed to the runtime when it is +invoked by Docker. For example, to install a runtime with debugging enabled, run +the following: + +```bash +sudo runsc install --runtime runsc-debug -- \ + --debug \ + --debug-log=/tmp/runsc-debug.log \ + --strace \ + --log-packets +``` + Next, look at the different options available for gVisor: [platform][platforms], [network][networking], [filesystem][filesystem]. diff --git a/g3doc/user_guide/tutorials/BUILD b/g3doc/user_guide/tutorials/BUILD index 405026a33..f405349b3 100644 --- a/g3doc/user_guide/tutorials/BUILD +++ b/g3doc/user_guide/tutorials/BUILD @@ -15,6 +15,15 @@ doc( ) doc( + name = "docker_compose", + src = "docker-compose.md", + category = "User Guide", + permalink = "/docs/tutorials/docker-compose/", + subcategory = "Tutorials", + weight = "20", +) + +doc( name = "kubernetes", src = "kubernetes.md", category = "User Guide", @@ -24,7 +33,7 @@ doc( ], permalink = "/docs/tutorials/kubernetes/", subcategory = "Tutorials", - weight = "20", + weight = "30", ) doc( @@ -33,5 +42,5 @@ doc( category = "User Guide", permalink = "/docs/tutorials/cni/", subcategory = "Tutorials", - weight = "30", + weight = "40", ) diff --git a/g3doc/user_guide/tutorials/cni.md b/g3doc/user_guide/tutorials/cni.md index ce2fd09a8..a3507c25b 100644 --- a/g3doc/user_guide/tutorials/cni.md +++ b/g3doc/user_guide/tutorials/cni.md @@ -47,7 +47,7 @@ sudo mkdir -p /etc/cni/net.d sudo sh -c 'cat > /etc/cni/net.d/10-bridge.conf << EOF { - "cniVersion": "0.4.0", + "cniVersion": "0.3.1", "name": "mynet", "type": "bridge", "bridge": "cni0", @@ -65,7 +65,7 @@ EOF' sudo sh -c 'cat > /etc/cni/net.d/99-loopback.conf << EOF { - "cniVersion": "0.4.0", + "cniVersion": "0.3.1", "name": "lo", "type": "loopback" } diff --git a/g3doc/user_guide/tutorials/docker-compose.md b/g3doc/user_guide/tutorials/docker-compose.md new file mode 100644 index 000000000..3284231f8 --- /dev/null +++ b/g3doc/user_guide/tutorials/docker-compose.md @@ -0,0 +1,100 @@ +# Wordpress with Docker Compose + +This page shows you how to deploy a sample [WordPress][wordpress] site using +[Docker Compose][docker-compose]. + +### Before you begin + +[Follow these instructions][docker-install] to install runsc with Docker. This +document assumes that Docker and Docker Compose are installed and the runtime +name chosen for gVisor is `runsc`. + +### Configuration + +We'll start by creating the `docker-compose.yaml` file to specify our services. +We will specify two services, a `wordpress` service for the Wordpress Apache +server, and a `db` service for MySQL. We will configure Wordpress to connect to +MySQL via the `db` service host name. + +> **Note:** Docker Compose uses it's own network by default and allows services +> to communicate using their service name. Docker Compose does this by setting +> up a DNS server at IP address 127.0.0.11 and configuring containers to use it +> via [resolv.conf][resolv.conf]. This IP is not addressable inside a gVisor +> sandbox so it's important that we set the DNS IP address to the alternative +> `8.8.8.8` and use a network that allows routing to it. See +> [Networking in Compose][compose-networking] for more details. + +> **Note:** The `runtime` field was removed from services in the 3.x version of +> the API in versions of docker-compose < 1.27.0. You will need to write your +> `docker-compose.yaml` file using the 2.x format or use docker-compose >= +> 1.27.0. See this [issue](https://github.com/docker/compose/issues/6239) for +> more details. + +```yaml +version: '2.3' + +services: + db: + image: mysql:5.7 + volumes: + - db_data:/var/lib/mysql + restart: always + environment: + MYSQL_ROOT_PASSWORD: somewordpress + MYSQL_DATABASE: wordpress + MYSQL_USER: wordpress + MYSQL_PASSWORD: wordpress + # All services must be on the same network to communicate. + network_mode: "bridge" + + wordpress: + depends_on: + - db + # When using the "bridge" network specify links. + links: + - db + image: wordpress:latest + ports: + - "8080:80" + restart: always + environment: + WORDPRESS_DB_HOST: db:3306 + WORDPRESS_DB_USER: wordpress + WORDPRESS_DB_PASSWORD: wordpress + WORDPRESS_DB_NAME: wordpress + # Specify the dns address if needed. + dns: + - 8.8.8.8 + # All services must be on the same network to communicate. + network_mode: "bridge" + # Specify the runtime used by Docker. Must be set up in + # /etc/docker/daemon.json. + runtime: "runsc" + +volumes: + db_data: {} +``` + +Once you have a `docker-compose.yaml` in the current directory you can start the +containers: + +```bash +docker-compose up +``` + +Once the containers have started you can access wordpress at +http://localhost:8080. + +Congrats! You now how a working wordpress site up and running using Docker +Compose. + +### What's next + +Learn how to deploy [WordPress with Kubernetes][wordpress-k8s]. + +[docker-compose]: https://docs.docker.com/compose/ +[docker-install]: ../quick_start/docker.md +[wordpress]: https://wordpress.com/ +[resolv.conf]: https://man7.org/linux/man-pages/man5/resolv.conf.5.html +[wordpress-k8s]: kubernetes.md +[compose-networking]: https://docs.docker.com/compose/networking/ diff --git a/g3doc/user_guide/tutorials/docker.md b/g3doc/user_guide/tutorials/docker.md index 705560038..9ca01da2a 100644 --- a/g3doc/user_guide/tutorials/docker.md +++ b/g3doc/user_guide/tutorials/docker.md @@ -60,9 +60,11 @@ Congratulations! You have just deployed a WordPress site using Docker. ### What's next -[Learn how to deploy WordPress with Kubernetes][wordpress-k8s]. +Learn how to deploy WordPress with [Kubernetes][wordpress-k8s] or +[Docker Compose][wordpress-compose]. [docker]: https://www.docker.com/ -[docker-install]: /docs/user_guide/quick_start/docker/ +[docker-install]: ../quick_start/docker.md [wordpress]: https://wordpress.com/ -[wordpress-k8s]: /docs/tutorials/kubernetes/ +[wordpress-k8s]: kubernetes.md +[wordpress-compose]: docker-compose.md diff --git a/g3doc/user_guide/tutorials/kubernetes.md b/g3doc/user_guide/tutorials/kubernetes.md index d2a94b1b7..1ec6e71e9 100644 --- a/g3doc/user_guide/tutorials/kubernetes.md +++ b/g3doc/user_guide/tutorials/kubernetes.md @@ -23,12 +23,12 @@ gcloud beta container node-pools create sandbox-pool --cluster=${CLUSTER_NAME} - If you prefer to use the console, select your cluster and select the **ADD NODE POOL** button: -![+ ADD NODE POOL](./node-pool-button.png) +![+ ADD NODE POOL](node-pool-button.png) Then select the **Image type** with **Containerd** and select **Enable sandbox with gVisor** option. Select other options as you like: -![+ NODE POOL](./add-node-pool.png) +![+ NODE POOL](add-node-pool.png) ### Check that gVisor is enabled @@ -57,47 +57,149 @@ curl -LO https://k8s.io/examples/application/wordpress/mysql-deployment.yaml Add a **spec.template.spec.runtimeClassName** set to **gvisor** to both files, as shown below: -**wordpress-deployment.yaml:** ```yaml apiVersion: v1 kind: Service metadata: -name: wordpress labels: app: wordpress spec: ports: - port: 80 selector: app: -wordpress tier: frontend - -## type: LoadBalancer - -apiVersion: v1 kind: PersistentVolumeClaim metadata: name: wp-pv-claim labels: -app: wordpress spec: accessModes: - ReadWriteOnce resources: requests: - -## storage: 20Gi - -apiVersion: apps/v1 kind: Deployment metadata: name: wordpress labels: app: -wordpress spec: selector: matchLabels: app: wordpress tier: frontend strategy: -type: Recreate template: metadata: labels: app: wordpress tier: frontend spec: -runtimeClassName: gvisor # ADD THIS LINE containers: - image: -wordpress:4.8-apache name: wordpress env: - name: WORDPRESS_DB_HOST value: -wordpress-mysql - name: WORDPRESS_DB_PASSWORD valueFrom: secretKeyRef: name: -mysql-pass key: password ports: - containerPort: 80 name: wordpress -volumeMounts: - name: wordpress-persistent-storage mountPath: /var/www/html -volumes: - name: wordpress-persistent-storage persistentVolumeClaim: claimName: -wp-pv-claim ``` - -**mysql-deployment.yaml:** ```yaml apiVersion: v1 kind: Service metadata: name: -wordpress-mysql labels: app: wordpress spec: ports: - port: 3306 selector: app: -wordpress tier: mysql - -## clusterIP: None - -apiVersion: v1 kind: PersistentVolumeClaim metadata: name: mysql-pv-claim -labels: app: wordpress spec: accessModes: - ReadWriteOnce resources: requests: - -## storage: 20Gi +**wordpress-deployment.yaml:** + +```yaml +apiVersion: v1 +kind: Service +metadata: + name: wordpress + labels: + app: wordpress +spec: + ports: + - port: 80 + selector: + app: wordpress + tier: frontend + type: LoadBalancer +--- +apiVersion: v1 +kind: PersistentVolumeClaim +metadata: + name: wp-pv-claim + labels: + app: wordpress +spec: + accessModes: + - ReadWriteOnce + resources: + requests: + storage: 20Gi +--- +apiVersion: apps/v1 +kind: Deployment +metadata: + name: wordpress + labels: + app: wordpress +spec: + selector: + matchLabels: + app: wordpress + tier: frontend + strategy: + type: Recreate + template: + metadata: + labels: + app: wordpress + tier: frontend + spec: + runtimeClassName: gvisor # ADD THIS LINE + containers: + - image: wordpress:4.8-apache + name: wordpress + env: + - name: WORDPRESS_DB_HOST + value: wordpress-mysql + - name: WORDPRESS_DB_PASSWORD + valueFrom: + secretKeyRef: + name: mysql-pass + key: password + ports: + - containerPort: 80 + name: wordpress + volumeMounts: + - name: wordpress-persistent-storage + mountPath: /var/www/html + volumes: + - name: wordpress-persistent-storage + persistentVolumeClaim: + claimName: wp-pv-claim +``` -apiVersion: apps/v1 kind: Deployment metadata: name: wordpress-mysql labels: -app: wordpress spec: selector: matchLabels: app: wordpress tier: mysql strategy: -type: Recreate template: metadata: labels: app: wordpress tier: mysql spec: -runtimeClassName: gvisor # ADD THIS LINE containers: - image: mysql:5.6 name: -mysql env: - name: MYSQL_ROOT_PASSWORD valueFrom: secretKeyRef: name: mysql-pass -key: password ports: - containerPort: 3306 name: mysql volumeMounts: - name: -mysql-persistent-storage mountPath: /var/lib/mysql volumes: - name: -mysql-persistent-storage persistentVolumeClaim: claimName: mysql-pv-claim ``` +**mysql-deployment.yaml:** + +```yaml +apiVersion: v1 +kind: Service +metadata: + name: wordpress-mysql + labels: + app: wordpress +spec: + ports: + - port: 3306 + selector: + app: wordpress + tier: mysql + clusterIP: None +--- +apiVersion: v1 +kind: PersistentVolumeClaim +metadata: + name: mysql-pv-claim + labels: + app: wordpress +spec: + accessModes: + - ReadWriteOnce + resources: + requests: + storage: 20Gi +--- +apiVersion: apps/v1 +kind: Deployment +metadata: + name: wordpress-mysql + labels: + app: wordpress +spec: + selector: + matchLabels: + app: wordpress + tier: mysql + strategy: + type: Recreate + template: + metadata: + labels: + app: wordpress + tier: mysql + spec: + runtimeClassName: gvisor # ADD THIS LINE + containers: + - image: mysql:5.6 + name: mysql + env: + - name: MYSQL_ROOT_PASSWORD + valueFrom: + secretKeyRef: + name: mysql-pass + key: password + ports: + - containerPort: 3306 + name: mysql + volumeMounts: + - name: mysql-persistent-storage + mountPath: /var/lib/mysql + volumes: + - name: mysql-persistent-storage + persistentVolumeClaim: + claimName: mysql-pv-claim +``` Note that apart from `runtimeClassName: gvisor`, nothing else about the Deployment has is changed. diff --git a/images/Makefile b/images/Makefile index 278dec02f..d4b6524ba 100644 --- a/images/Makefile +++ b/images/Makefile @@ -59,9 +59,9 @@ local_image = $(LOCAL_IMAGE_PREFIX)/$(subst _,/,$(1)) # we need to explicitly repull the base layer in order to ensure that the # architecture is correct. Note that we use the term "rebuild" here to avoid # conflicting with the bazel "build" terminology, which is used elsewhere. -rebuild-%: FROM=$(shell grep FROM $(call path,$*)/Dockerfile } cut -d' ' -f2) +rebuild-%: FROM=$(shell grep FROM $(call path,$*)/Dockerfile | cut -d' ' -f2) rebuild-%: register-cross - $(foreach IMAGE,$(FROM),docker $(DOCKER_PLATFORM_ARGS) $(IMAGE); &&) true + $(foreach IMAGE,$(FROM),docker pull $(DOCKER_PLATFORM_ARGS) $(IMAGE) &&) \ T=$$(mktemp -d) && cp -a $(call path,$*)/* $$T && \ docker build $(DOCKER_PLATFORM_ARGS) -t $(call remote_image,$*) $$T && \ rm -rf $$T @@ -73,10 +73,10 @@ pull-%: docker pull $(DOCKER_PLATFORM_ARGS) $(call remote_image,$*) # load will either pull the "remote" or build it locally. This is the preferred -# entrypoint, as it should never file. The local tag should always be set after +# entrypoint, as it should never fail. The local tag should always be set after # this returns (either by the pull or the build). load-%: - docker inspect $(call remote_image,$*) >/dev/null 2>&1 || $(MAKE) pull-$* || $(MAKE) rebuild-$* + $(MAKE) pull-$* || $(MAKE) rebuild-$* docker tag $(call remote_image,$*) $(call local_image,$*) # push pushes the remote image, after either pulling (to validate that the tag diff --git a/images/benchmarks/httpd/Dockerfile b/images/benchmarks/httpd/Dockerfile index b72406012..e95538a40 100644 --- a/images/benchmarks/httpd/Dockerfile +++ b/images/benchmarks/httpd/Dockerfile @@ -8,7 +8,7 @@ RUN set -x \ # Generate a bunch of relevant files. RUN mkdir -p /local && \ - for size in 1 10 100 1000 1024 10240; do \ + for size in 1 10 100 1024 10240; do \ dd if=/dev/zero of=/local/latin${size}k.txt count=${size} bs=1024; \ done diff --git a/images/benchmarks/nginx/Dockerfile b/images/benchmarks/nginx/Dockerfile index b64eb52ae..2444d04b1 100644 --- a/images/benchmarks/nginx/Dockerfile +++ b/images/benchmarks/nginx/Dockerfile @@ -1 +1,11 @@ FROM nginx:1.15.10 + +# Generate a bunch of relevant files. +RUN mkdir -p /local && \ + for size in 1 10 100 1024 10240; do \ + dd if=/dev/zero of=/local/latin${size}k.txt count=${size} bs=1024; \ + done + +RUN touch /local/index.html + +COPY ./nginx.conf /etc/nginx/nginx.conf diff --git a/images/benchmarks/nginx/nginx.conf b/images/benchmarks/nginx/nginx.conf new file mode 100644 index 000000000..2c43c0cda --- /dev/null +++ b/images/benchmarks/nginx/nginx.conf @@ -0,0 +1,19 @@ +user nginx; +worker_processes 1; +daemon off; + +error_log /var/log/nginx/error.log warn; +pid /var/run/nginx.pid; + +events { + worker_connections 1024; +} + + +http { + server { + location / { + root /tmp/html; + } + } +} diff --git a/images/jekyll/Dockerfile b/images/jekyll/Dockerfile index ba039ba15..ae19f3bfc 100644 --- a/images/jekyll/Dockerfile +++ b/images/jekyll/Dockerfile @@ -1,5 +1,6 @@ FROM jekyll/jekyll:4.0.0 USER root + RUN gem install \ html-proofer:3.10.2 \ nokogiri:1.10.1 \ @@ -10,5 +11,9 @@ RUN gem install \ jekyll-relative-links:0.6.1 \ jekyll-feed:0.13.0 \ jekyll-sitemap:1.4.0 + +# checks.rb is used with html-proofer for presubmit checks. COPY checks.rb /checks.rb -CMD ["/usr/gem/gems/jekyll-4.0.0/exe/jekyll", "build", "-t", "-s", "/input", "-d", "/output"] + +COPY build.sh /build.sh +CMD ["/build.sh"] diff --git a/images/jekyll/build.sh b/images/jekyll/build.sh new file mode 100755 index 000000000..010972ea6 --- /dev/null +++ b/images/jekyll/build.sh @@ -0,0 +1,22 @@ +#!/bin/bash + +# Copyright 2020 The gVisor Authors. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +set -euxo pipefail + +# Generate the syntax highlighting css file. +/usr/gem/bin/rougify style github >/input/_sass/syntax.css +# Build website including pages irrespective of date. +/usr/gem/bin/jekyll build --future -t -s /input -d /output diff --git a/images/packetdrill/Dockerfile b/images/packetdrill/Dockerfile index 01296dbaf..b4cd73006 100644 --- a/images/packetdrill/Dockerfile +++ b/images/packetdrill/Dockerfile @@ -1,8 +1,8 @@ FROM ubuntu:bionic RUN apt-get update && apt-get install -y net-tools git iptables iputils-ping \ netcat tcpdump jq tar bison flex make +# Pick up updated git. RUN hash -r RUN git clone --depth 1 --branch packetdrill-v2.0 \ https://github.com/google/packetdrill.git RUN cd packetdrill/gtests/net/packetdrill && ./configure && make -CMD /bin/bash diff --git a/images/packetimpact/Dockerfile b/images/packetimpact/Dockerfile index 87aa99ef2..906d5cdd6 100644 --- a/images/packetimpact/Dockerfile +++ b/images/packetimpact/Dockerfile @@ -1,4 +1,4 @@ -FROM ubuntu:bionic +FROM ubuntu:focal RUN apt-get update && DEBIAN_FRONTEND=noninteractive apt-get install -y \ # iptables to disable OS native packet processing. iptables \ @@ -11,6 +11,8 @@ RUN apt-get update && DEBIAN_FRONTEND=noninteractive apt-get install -y \ # tshark to log verbose packet sniffing. tshark \ # killall for cleanup. - psmisc -RUN hash -r -CMD /bin/bash + psmisc \ + # qemu-system-x86 to emulate fuchsia. + qemu-system-x86 \ + # sha1sum to generate entropy. + libdigest-sha-perl diff --git a/pkg/abi/linux/BUILD b/pkg/abi/linux/BUILD index b5c5cc20b..cdcaa8c73 100644 --- a/pkg/abi/linux/BUILD +++ b/pkg/abi/linux/BUILD @@ -74,9 +74,9 @@ go_library( "//pkg/abi", "//pkg/binary", "//pkg/bits", + "//pkg/marshal", + "//pkg/marshal/primitive", "//pkg/usermem", - "//tools/go_marshal/marshal", - "//tools/go_marshal/primitive", ], ) diff --git a/pkg/abi/linux/aio.go b/pkg/abi/linux/aio.go index 86ee3f8b5..5fc099892 100644 --- a/pkg/abi/linux/aio.go +++ b/pkg/abi/linux/aio.go @@ -42,6 +42,8 @@ const ( // // The priority field is currently ignored in the implementation below. Also // note that the IOCB_FLAG_RESFD feature is not supported. +// +// +marshal type IOCallback struct { Data uint64 Key uint32 @@ -64,6 +66,7 @@ type IOCallback struct { // IOEvent describes an I/O result. // +// +marshal // +stateify savable type IOEvent struct { Data uint64 diff --git a/pkg/abi/linux/bpf.go b/pkg/abi/linux/bpf.go index aa3d3ce70..9422fcf69 100644 --- a/pkg/abi/linux/bpf.go +++ b/pkg/abi/linux/bpf.go @@ -16,6 +16,7 @@ package linux // BPFInstruction is a raw BPF virtual machine instruction. // +// +marshal slice:BPFInstructionSlice // +stateify savable type BPFInstruction struct { // OpCode is the operation to execute. diff --git a/pkg/abi/linux/capability.go b/pkg/abi/linux/capability.go index 965f74663..afd16cc27 100644 --- a/pkg/abi/linux/capability.go +++ b/pkg/abi/linux/capability.go @@ -177,12 +177,16 @@ const ( ) // CapUserHeader is equivalent to Linux's cap_user_header_t. +// +// +marshal type CapUserHeader struct { Version uint32 Pid int32 } // CapUserData is equivalent to Linux's cap_user_data_t. +// +// +marshal slice:CapUserDataSlice type CapUserData struct { Effective uint32 Permitted uint32 diff --git a/pkg/abi/linux/dev.go b/pkg/abi/linux/dev.go index 192e2093b..7771650b3 100644 --- a/pkg/abi/linux/dev.go +++ b/pkg/abi/linux/dev.go @@ -54,9 +54,9 @@ const ( // Unix98 PTY masters. UNIX98_PTY_MASTER_MAJOR = 128 - // UNIX98_PTY_SLAVE_MAJOR is the initial major device number for - // Unix98 PTY slaves. - UNIX98_PTY_SLAVE_MAJOR = 136 + // UNIX98_PTY_REPLICA_MAJOR is the initial major device number for + // Unix98 PTY replicas. + UNIX98_PTY_REPLICA_MAJOR = 136 ) // Minor device numbers for TTYAUX_MAJOR. diff --git a/pkg/abi/linux/fcntl.go b/pkg/abi/linux/fcntl.go index 9242e80a5..cc3571fad 100644 --- a/pkg/abi/linux/fcntl.go +++ b/pkg/abi/linux/fcntl.go @@ -45,6 +45,8 @@ const ( ) // Flock is the lock structure for F_SETLK. +// +// +marshal type Flock struct { Type int16 Whence int16 @@ -63,6 +65,8 @@ const ( ) // FOwnerEx is the owner structure for F_SETOWN_EX and F_GETOWN_EX. +// +// +marshal type FOwnerEx struct { Type int32 PID int32 diff --git a/pkg/abi/linux/fs.go b/pkg/abi/linux/fs.go index 158d2db5b..0d921ed6f 100644 --- a/pkg/abi/linux/fs.go +++ b/pkg/abi/linux/fs.go @@ -29,6 +29,7 @@ const ( SYSFS_MAGIC = 0x62656572 TMPFS_MAGIC = 0x01021994 V9FS_MAGIC = 0x01021997 + FUSE_SUPER_MAGIC = 0x65735546 ) // Filesystem path limits, from uapi/linux/limits.h. @@ -44,17 +45,18 @@ type Statfs struct { // Type is one of the filesystem magic values, defined above. Type uint64 - // BlockSize is the data block size. + // BlockSize is the optimal transfer block size in bytes. BlockSize int64 - // Blocks is the number of data blocks in use. + // Blocks is the maximum number of data blocks the filesystem may store, in + // units of BlockSize. Blocks uint64 - // BlocksFree is the number of free blocks. + // BlocksFree is the number of free data blocks, in units of BlockSize. BlocksFree uint64 - // BlocksAvailable is the number of blocks free for use by - // unprivileged users. + // BlocksAvailable is the number of data blocks free for use by + // unprivileged users, in units of BlockSize. BlocksAvailable uint64 // Files is the number of used file nodes on the filesystem. diff --git a/pkg/abi/linux/fuse.go b/pkg/abi/linux/fuse.go index 7e30483ee..d91c97a64 100644 --- a/pkg/abi/linux/fuse.go +++ b/pkg/abi/linux/fuse.go @@ -14,12 +14,20 @@ package linux +import ( + "gvisor.dev/gvisor/pkg/marshal" + "gvisor.dev/gvisor/pkg/marshal/primitive" +) + // +marshal type FUSEOpcode uint32 // +marshal type FUSEOpID uint64 +// FUSE_ROOT_ID is the id of root inode. +const FUSE_ROOT_ID = 1 + // Opcodes for FUSE operations. Analogous to the opcodes in include/linux/fuse.h. const ( FUSE_LOOKUP FUSEOpcode = 1 @@ -116,61 +124,28 @@ type FUSEHeaderOut struct { Unique FUSEOpID } -// FUSEWriteIn is the header written by a daemon when it makes a -// write request to the FUSE filesystem. -// -// +marshal -type FUSEWriteIn struct { - // Fh specifies the file handle that is being written to. - Fh uint64 - - // Offset is the offset of the write. - Offset uint64 - - // Size is the size of data being written. - Size uint32 - - // WriteFlags is the flags used during the write. - WriteFlags uint32 - - // LockOwner is the ID of the lock owner. - LockOwner uint64 - - // Flags is the flags for the request. - Flags uint32 - - _ uint32 -} - // FUSE_INIT flags, consistent with the ones in include/uapi/linux/fuse.h. +// Our taget version is 7.23 but we have few implemented in advance. const ( - FUSE_ASYNC_READ = 1 << 0 - FUSE_POSIX_LOCKS = 1 << 1 - FUSE_FILE_OPS = 1 << 2 - FUSE_ATOMIC_O_TRUNC = 1 << 3 - FUSE_EXPORT_SUPPORT = 1 << 4 - FUSE_BIG_WRITES = 1 << 5 - FUSE_DONT_MASK = 1 << 6 - FUSE_SPLICE_WRITE = 1 << 7 - FUSE_SPLICE_MOVE = 1 << 8 - FUSE_SPLICE_READ = 1 << 9 - FUSE_FLOCK_LOCKS = 1 << 10 - FUSE_HAS_IOCTL_DIR = 1 << 11 - FUSE_AUTO_INVAL_DATA = 1 << 12 - FUSE_DO_READDIRPLUS = 1 << 13 - FUSE_READDIRPLUS_AUTO = 1 << 14 - FUSE_ASYNC_DIO = 1 << 15 - FUSE_WRITEBACK_CACHE = 1 << 16 - FUSE_NO_OPEN_SUPPORT = 1 << 17 - FUSE_PARALLEL_DIROPS = 1 << 18 - FUSE_HANDLE_KILLPRIV = 1 << 19 - FUSE_POSIX_ACL = 1 << 20 - FUSE_ABORT_ERROR = 1 << 21 - FUSE_MAX_PAGES = 1 << 22 - FUSE_CACHE_SYMLINKS = 1 << 23 - FUSE_NO_OPENDIR_SUPPORT = 1 << 24 - FUSE_EXPLICIT_INVAL_DATA = 1 << 25 - FUSE_MAP_ALIGNMENT = 1 << 26 + FUSE_ASYNC_READ = 1 << 0 + FUSE_POSIX_LOCKS = 1 << 1 + FUSE_FILE_OPS = 1 << 2 + FUSE_ATOMIC_O_TRUNC = 1 << 3 + FUSE_EXPORT_SUPPORT = 1 << 4 + FUSE_BIG_WRITES = 1 << 5 + FUSE_DONT_MASK = 1 << 6 + FUSE_SPLICE_WRITE = 1 << 7 + FUSE_SPLICE_MOVE = 1 << 8 + FUSE_SPLICE_READ = 1 << 9 + FUSE_FLOCK_LOCKS = 1 << 10 + FUSE_HAS_IOCTL_DIR = 1 << 11 + FUSE_AUTO_INVAL_DATA = 1 << 12 + FUSE_DO_READDIRPLUS = 1 << 13 + FUSE_READDIRPLUS_AUTO = 1 << 14 + FUSE_ASYNC_DIO = 1 << 15 + FUSE_WRITEBACK_CACHE = 1 << 16 + FUSE_NO_OPEN_SUPPORT = 1 << 17 + FUSE_MAX_PAGES = 1 << 22 // From FUSE 7.28 ) // currently supported FUSE protocol version numbers. @@ -179,6 +154,13 @@ const ( FUSE_KERNEL_MINOR_VERSION = 31 ) +// Constants relevant to FUSE operations. +const ( + FUSE_NAME_MAX = 1024 + FUSE_PAGE_SIZE = 4096 + FUSE_DIRENT_ALIGN = 8 +) + // FUSEInitIn is the request sent by the kernel to the daemon, // to negotiate the version and flags. // @@ -199,7 +181,7 @@ type FUSEInitIn struct { } // FUSEInitOut is the reply sent by the daemon to the kernel -// for FUSEInitIn. +// for FUSEInitIn. We target FUSE 7.23; this struct supports 7.28. // // +marshal type FUSEInitOut struct { @@ -240,13 +222,16 @@ type FUSEInitOut struct { // if the value from daemon is too large. MaxPages uint16 - // MapAlignment is an unknown field and not used by this package at this moment. - // Use as a placeholder to be consistent with the FUSE protocol. - MapAlignment uint16 + _ uint16 _ [8]uint32 } +// FUSE_GETATTR_FH is currently the only flag of FUSEGetAttrIn.GetAttrFlags. +// If it is set, the file handle (FUSEGetAttrIn.Fh) is used to indicate the +// object instead of the node id attribute in the request header. +const FUSE_GETATTR_FH = (1 << 0) + // FUSEGetAttrIn is the request sent by the kernel to the daemon, // to get the attribute of a inode. // @@ -267,22 +252,52 @@ type FUSEGetAttrIn struct { // // +marshal type FUSEAttr struct { - Ino uint64 - Size uint64 - Blocks uint64 - Atime uint64 - Mtime uint64 - Ctime uint64 + // Ino is the inode number of this file. + Ino uint64 + + // Size is the size of this file. + Size uint64 + + // Blocks is the number of the 512B blocks allocated by this file. + Blocks uint64 + + // Atime is the time of last access. + Atime uint64 + + // Mtime is the time of last modification. + Mtime uint64 + + // Ctime is the time of last status change. + Ctime uint64 + + // AtimeNsec is the nano second part of Atime. AtimeNsec uint32 + + // MtimeNsec is the nano second part of Mtime. MtimeNsec uint32 + + // CtimeNsec is the nano second part of Ctime. CtimeNsec uint32 - Mode uint32 - Nlink uint32 - UID uint32 - GID uint32 - Rdev uint32 - BlkSize uint32 - _ uint32 + + // Mode contains the file type and mode. + Mode uint32 + + // Nlink is the number of the hard links. + Nlink uint32 + + // UID is user ID of the owner. + UID uint32 + + // GID is group ID of the owner. + GID uint32 + + // Rdev is the device ID if this is a special file. + Rdev uint32 + + // BlkSize is the block size for filesystem I/O. + BlkSize uint32 + + _ uint32 } // FUSEGetAttrOut is the reply sent by the daemon to the kernel @@ -301,3 +316,558 @@ type FUSEGetAttrOut struct { // Attr contains the metadata returned from the FUSE server Attr FUSEAttr } + +// FUSEEntryOut is the reply sent by the daemon to the kernel +// for FUSE_MKNOD, FUSE_MKDIR, FUSE_SYMLINK, FUSE_LINK and +// FUSE_LOOKUP. +// +// +marshal +type FUSEEntryOut struct { + // NodeID is the ID for current inode. + NodeID uint64 + + // Generation is the generation number of inode. + // Used to identify an inode that have different ID at different time. + Generation uint64 + + // EntryValid indicates timeout for an entry. + EntryValid uint64 + + // AttrValid indicates timeout for an entry's attributes. + AttrValid uint64 + + // EntryValidNsec indicates timeout for an entry in nanosecond. + EntryValidNSec uint32 + + // AttrValidNsec indicates timeout for an entry's attributes in nanosecond. + AttrValidNSec uint32 + + // Attr contains the attributes of an entry. + Attr FUSEAttr +} + +// FUSELookupIn is the request sent by the kernel to the daemon +// to look up a file name. +// +// Dynamically-sized objects cannot be marshalled. +type FUSELookupIn struct { + marshal.StubMarshallable + + // Name is a file name to be looked up. + Name string +} + +// MarshalBytes serializes r.name to the dst buffer. +func (r *FUSELookupIn) MarshalBytes(buf []byte) { + copy(buf, r.Name) +} + +// SizeBytes is the size of the memory representation of FUSELookupIn. +// 1 extra byte for null-terminated string. +func (r *FUSELookupIn) SizeBytes() int { + return len(r.Name) + 1 +} + +// MAX_NON_LFS indicates the maximum offset without large file support. +const MAX_NON_LFS = ((1 << 31) - 1) + +// flags returned by OPEN request. +const ( + // FOPEN_DIRECT_IO indicates bypassing page cache for this opened file. + FOPEN_DIRECT_IO = 1 << 0 + // FOPEN_KEEP_CACHE avoids invalidate of data cache on open. + FOPEN_KEEP_CACHE = 1 << 1 + // FOPEN_NONSEEKABLE indicates the file cannot be seeked. + FOPEN_NONSEEKABLE = 1 << 2 +) + +// FUSEOpenIn is the request sent by the kernel to the daemon, +// to negotiate flags and get file handle. +// +// +marshal +type FUSEOpenIn struct { + // Flags of this open request. + Flags uint32 + + _ uint32 +} + +// FUSEOpenOut is the reply sent by the daemon to the kernel +// for FUSEOpenIn. +// +// +marshal +type FUSEOpenOut struct { + // Fh is the file handler for opened file. + Fh uint64 + + // OpenFlag for the opened file. + OpenFlag uint32 + + _ uint32 +} + +// FUSE_READ flags, consistent with the ones in include/uapi/linux/fuse.h. +const ( + FUSE_READ_LOCKOWNER = 1 << 1 +) + +// FUSEReadIn is the request sent by the kernel to the daemon +// for FUSE_READ. +// +// +marshal +type FUSEReadIn struct { + // Fh is the file handle in userspace. + Fh uint64 + + // Offset is the read offset. + Offset uint64 + + // Size is the number of bytes to read. + Size uint32 + + // ReadFlags for this FUSE_READ request. + // Currently only contains FUSE_READ_LOCKOWNER. + ReadFlags uint32 + + // LockOwner is the id of the lock owner if there is one. + LockOwner uint64 + + // Flags for the underlying file. + Flags uint32 + + _ uint32 +} + +// FUSEWriteIn is the first part of the payload of the +// request sent by the kernel to the daemon +// for FUSE_WRITE (struct for FUSE version >= 7.9). +// +// The second part of the payload is the +// binary bytes of the data to be written. +// +// +marshal +type FUSEWriteIn struct { + // Fh is the file handle in userspace. + Fh uint64 + + // Offset is the write offset. + Offset uint64 + + // Size is the number of bytes to write. + Size uint32 + + // ReadFlags for this FUSE_WRITE request. + WriteFlags uint32 + + // LockOwner is the id of the lock owner if there is one. + LockOwner uint64 + + // Flags for the underlying file. + Flags uint32 + + _ uint32 +} + +// FUSEWriteOut is the payload of the reply sent by the daemon to the kernel +// for a FUSE_WRITE request. +// +// +marshal +type FUSEWriteOut struct { + // Size is the number of bytes written. + Size uint32 + + _ uint32 +} + +// FUSEReleaseIn is the request sent by the kernel to the daemon +// when there is no more reference to a file. +// +// +marshal +type FUSEReleaseIn struct { + // Fh is the file handler for the file to be released. + Fh uint64 + + // Flags of the file. + Flags uint32 + + // ReleaseFlags of this release request. + ReleaseFlags uint32 + + // LockOwner is the id of the lock owner if there is one. + LockOwner uint64 +} + +// FUSECreateMeta contains all the static fields of FUSECreateIn, +// which is used for FUSE_CREATE. +// +// +marshal +type FUSECreateMeta struct { + // Flags of the creating file. + Flags uint32 + + // Mode is the mode of the creating file. + Mode uint32 + + // Umask is the current file mode creation mask. + Umask uint32 + _ uint32 +} + +// FUSECreateIn contains all the arguments sent by the kernel to the daemon, to +// atomically create and open a new regular file. +// +// Dynamically-sized objects cannot be marshalled. +type FUSECreateIn struct { + marshal.StubMarshallable + + // CreateMeta contains mode, rdev and umash field for FUSE_MKNODS. + CreateMeta FUSECreateMeta + + // Name is the name of the node to create. + Name string +} + +// MarshalBytes serializes r.CreateMeta and r.Name to the dst buffer. +func (r *FUSECreateIn) MarshalBytes(buf []byte) { + r.CreateMeta.MarshalBytes(buf[:r.CreateMeta.SizeBytes()]) + copy(buf[r.CreateMeta.SizeBytes():], r.Name) +} + +// SizeBytes is the size of the memory representation of FUSECreateIn. +// 1 extra byte for null-terminated string. +func (r *FUSECreateIn) SizeBytes() int { + return r.CreateMeta.SizeBytes() + len(r.Name) + 1 +} + +// FUSEMknodMeta contains all the static fields of FUSEMknodIn, +// which is used for FUSE_MKNOD. +// +// +marshal +type FUSEMknodMeta struct { + // Mode of the inode to create. + Mode uint32 + + // Rdev encodes device major and minor information. + Rdev uint32 + + // Umask is the current file mode creation mask. + Umask uint32 + + _ uint32 +} + +// FUSEMknodIn contains all the arguments sent by the kernel +// to the daemon, to create a new file node. +// +// Dynamically-sized objects cannot be marshalled. +type FUSEMknodIn struct { + marshal.StubMarshallable + + // MknodMeta contains mode, rdev and umash field for FUSE_MKNODS. + MknodMeta FUSEMknodMeta + + // Name is the name of the node to create. + Name string +} + +// MarshalBytes serializes r.MknodMeta and r.Name to the dst buffer. +func (r *FUSEMknodIn) MarshalBytes(buf []byte) { + r.MknodMeta.MarshalBytes(buf[:r.MknodMeta.SizeBytes()]) + copy(buf[r.MknodMeta.SizeBytes():], r.Name) +} + +// SizeBytes is the size of the memory representation of FUSEMknodIn. +// 1 extra byte for null-terminated string. +func (r *FUSEMknodIn) SizeBytes() int { + return r.MknodMeta.SizeBytes() + len(r.Name) + 1 +} + +// FUSESymLinkIn is the request sent by the kernel to the daemon, +// to create a symbolic link. +// +// Dynamically-sized objects cannot be marshalled. +type FUSESymLinkIn struct { + marshal.StubMarshallable + + // Name of symlink to create. + Name string + + // Target of the symlink. + Target string +} + +// MarshalBytes serializes r.Name and r.Target to the dst buffer. +// Left null-termination at end of r.Name and r.Target. +func (r *FUSESymLinkIn) MarshalBytes(buf []byte) { + copy(buf, r.Name) + copy(buf[len(r.Name)+1:], r.Target) +} + +// SizeBytes is the size of the memory representation of FUSESymLinkIn. +// 2 extra bytes for null-terminated string. +func (r *FUSESymLinkIn) SizeBytes() int { + return len(r.Name) + len(r.Target) + 2 +} + +// FUSEEmptyIn is used by operations without request body. +type FUSEEmptyIn struct{ marshal.StubMarshallable } + +// MarshalBytes do nothing for marshal. +func (r *FUSEEmptyIn) MarshalBytes(buf []byte) {} + +// SizeBytes is 0 for empty request. +func (r *FUSEEmptyIn) SizeBytes() int { + return 0 +} + +// FUSEMkdirMeta contains all the static fields of FUSEMkdirIn, +// which is used for FUSE_MKDIR. +// +// +marshal +type FUSEMkdirMeta struct { + // Mode of the directory of create. + Mode uint32 + + // Umask is the user file creation mask. + Umask uint32 +} + +// FUSEMkdirIn contains all the arguments sent by the kernel +// to the daemon, to create a new directory. +// +// Dynamically-sized objects cannot be marshalled. +type FUSEMkdirIn struct { + marshal.StubMarshallable + + // MkdirMeta contains Mode and Umask of the directory to create. + MkdirMeta FUSEMkdirMeta + + // Name of the directory to create. + Name string +} + +// MarshalBytes serializes r.MkdirMeta and r.Name to the dst buffer. +func (r *FUSEMkdirIn) MarshalBytes(buf []byte) { + r.MkdirMeta.MarshalBytes(buf[:r.MkdirMeta.SizeBytes()]) + copy(buf[r.MkdirMeta.SizeBytes():], r.Name) +} + +// SizeBytes is the size of the memory representation of FUSEMkdirIn. +// 1 extra byte for null-terminated Name string. +func (r *FUSEMkdirIn) SizeBytes() int { + return r.MkdirMeta.SizeBytes() + len(r.Name) + 1 +} + +// FUSERmDirIn is the request sent by the kernel to the daemon +// when trying to remove a directory. +// +// Dynamically-sized objects cannot be marshalled. +type FUSERmDirIn struct { + marshal.StubMarshallable + + // Name is a directory name to be removed. + Name string +} + +// MarshalBytes serializes r.name to the dst buffer. +func (r *FUSERmDirIn) MarshalBytes(buf []byte) { + copy(buf, r.Name) +} + +// SizeBytes is the size of the memory representation of FUSERmDirIn. +func (r *FUSERmDirIn) SizeBytes() int { + return len(r.Name) + 1 +} + +// FUSEDirents is a list of Dirents received from the FUSE daemon server. +// It is used for FUSE_READDIR. +// +// Dynamically-sized objects cannot be marshalled. +type FUSEDirents struct { + marshal.StubMarshallable + + Dirents []*FUSEDirent +} + +// FUSEDirent is a Dirent received from the FUSE daemon server. +// It is used for FUSE_READDIR. +// +// Dynamically-sized objects cannot be marshalled. +type FUSEDirent struct { + marshal.StubMarshallable + + // Meta contains all the static fields of FUSEDirent. + Meta FUSEDirentMeta + + // Name is the filename of the dirent. + Name string +} + +// FUSEDirentMeta contains all the static fields of FUSEDirent. +// It is used for FUSE_READDIR. +// +// +marshal +type FUSEDirentMeta struct { + // Inode of the dirent. + Ino uint64 + + // Offset of the dirent. + Off uint64 + + // NameLen is the length of the dirent name. + NameLen uint32 + + // Type of the dirent. + Type uint32 +} + +// SizeBytes is the size of the memory representation of FUSEDirents. +func (r *FUSEDirents) SizeBytes() int { + var sizeBytes int + for _, dirent := range r.Dirents { + sizeBytes += dirent.SizeBytes() + } + + return sizeBytes +} + +// UnmarshalBytes deserializes FUSEDirents from the src buffer. +func (r *FUSEDirents) UnmarshalBytes(src []byte) { + for { + if len(src) <= (*FUSEDirentMeta)(nil).SizeBytes() { + break + } + + // Its unclear how many dirents there are in src. Each dirent is dynamically + // sized and so we can't make assumptions about how many dirents we can allocate. + if r.Dirents == nil { + r.Dirents = make([]*FUSEDirent, 0) + } + + // We have to allocate a struct for each dirent - there must be a better way + // to do this. Linux allocates 1 page to store all the dirents and then + // simply reads them from the page. + var dirent FUSEDirent + dirent.UnmarshalBytes(src) + r.Dirents = append(r.Dirents, &dirent) + + src = src[dirent.SizeBytes():] + } +} + +// SizeBytes is the size of the memory representation of FUSEDirent. +func (r *FUSEDirent) SizeBytes() int { + dataSize := r.Meta.SizeBytes() + len(r.Name) + + // Each Dirent must be padded such that its size is a multiple + // of FUSE_DIRENT_ALIGN. Similar to the fuse dirent alignment + // in linux/fuse.h. + return (dataSize + (FUSE_DIRENT_ALIGN - 1)) & ^(FUSE_DIRENT_ALIGN - 1) +} + +// UnmarshalBytes deserializes FUSEDirent from the src buffer. +func (r *FUSEDirent) UnmarshalBytes(src []byte) { + r.Meta.UnmarshalBytes(src) + src = src[r.Meta.SizeBytes():] + + if r.Meta.NameLen > FUSE_NAME_MAX { + // The name is too long and therefore invalid. We don't + // need to unmarshal the name since it'll be thrown away. + return + } + + buf := make([]byte, r.Meta.NameLen) + name := primitive.ByteSlice(buf) + name.UnmarshalBytes(src[:r.Meta.NameLen]) + r.Name = string(name) +} + +// FATTR_* consts are the attribute flags defined in include/uapi/linux/fuse.h. +// These should be or-ed together for setattr to know what has been changed. +const ( + FATTR_MODE = (1 << 0) + FATTR_UID = (1 << 1) + FATTR_GID = (1 << 2) + FATTR_SIZE = (1 << 3) + FATTR_ATIME = (1 << 4) + FATTR_MTIME = (1 << 5) + FATTR_FH = (1 << 6) + FATTR_ATIME_NOW = (1 << 7) + FATTR_MTIME_NOW = (1 << 8) + FATTR_LOCKOWNER = (1 << 9) + FATTR_CTIME = (1 << 10) +) + +// FUSESetAttrIn is the request sent by the kernel to the daemon, +// to set the attribute(s) of a file. +// +// +marshal +type FUSESetAttrIn struct { + // Valid indicates which attributes are modified by this request. + Valid uint32 + + _ uint32 + + // Fh is used to identify the file if FATTR_FH is set in Valid. + Fh uint64 + + // Size is the size that the request wants to change to. + Size uint64 + + // LockOwner is the owner of the lock that the request wants to change to. + LockOwner uint64 + + // Atime is the access time that the request wants to change to. + Atime uint64 + + // Mtime is the modification time that the request wants to change to. + Mtime uint64 + + // Ctime is the status change time that the request wants to change to. + Ctime uint64 + + // AtimeNsec is the nano second part of Atime. + AtimeNsec uint32 + + // MtimeNsec is the nano second part of Mtime. + MtimeNsec uint32 + + // CtimeNsec is the nano second part of Ctime. + CtimeNsec uint32 + + // Mode is the file mode that the request wants to change to. + Mode uint32 + + _ uint32 + + // UID is the user ID of the owner that the request wants to change to. + UID uint32 + + // GID is the group ID of the owner that the request wants to change to. + GID uint32 + + _ uint32 +} + +// FUSEUnlinkIn is the request sent by the kernel to the daemon +// when trying to unlink a node. +// +// Dynamically-sized objects cannot be marshalled. +type FUSEUnlinkIn struct { + marshal.StubMarshallable + + // Name of the node to unlink. + Name string +} + +// MarshalBytes serializes r.name to the dst buffer, which should +// have size len(r.Name) + 1 and last byte set to 0. +func (r *FUSEUnlinkIn) MarshalBytes(buf []byte) { + copy(buf, r.Name) +} + +// SizeBytes is the size of the memory representation of FUSEUnlinkIn. +// 1 extra byte for null-terminated Name string. +func (r *FUSEUnlinkIn) SizeBytes() int { + return len(r.Name) + 1 +} diff --git a/pkg/abi/linux/ioctl.go b/pkg/abi/linux/ioctl.go index 2c5e56ae5..3356a2b4a 100644 --- a/pkg/abi/linux/ioctl.go +++ b/pkg/abi/linux/ioctl.go @@ -113,7 +113,39 @@ const ( _IOC_DIRSHIFT = _IOC_SIZESHIFT + _IOC_SIZEBITS ) +// Constants from uapi/linux/fs.h. +const ( + FS_IOC_GETFLAGS = 2147771905 + FS_VERITY_FL = 1048576 +) + +// Constants from uapi/linux/fsverity.h. +const ( + FS_IOC_ENABLE_VERITY = 1082156677 +) + // IOC outputs the result of _IOC macro in asm-generic/ioctl.h. func IOC(dir, typ, nr, size uint32) uint32 { return uint32(dir)<<_IOC_DIRSHIFT | typ<<_IOC_TYPESHIFT | nr<<_IOC_NRSHIFT | size<<_IOC_SIZESHIFT } + +// Kcov ioctls from kernel/kcov.h. +var ( + KCOV_INIT_TRACE = IOC(_IOC_READ, 'c', 1, 8) + KCOV_ENABLE = IOC(_IOC_NONE, 'c', 100, 0) + KCOV_DISABLE = IOC(_IOC_NONE, 'c', 101, 0) +) + +// Kcov trace types from kernel/kcov.h. +const ( + KCOV_TRACE_PC = 0 + KCOV_TRACE_CMP = 1 +) + +// Kcov state constants from kernel/kcov.h. +const ( + KCOV_MODE_DISABLED = 0 + KCOV_MODE_INIT = 1 + KCOV_MODE_TRACE_PC = 2 + KCOV_MODE_TRACE_CMP = 3 +) diff --git a/pkg/abi/linux/ipc.go b/pkg/abi/linux/ipc.go index 22acd2d43..c6e65df62 100644 --- a/pkg/abi/linux/ipc.go +++ b/pkg/abi/linux/ipc.go @@ -37,6 +37,8 @@ const IPC_PRIVATE = 0 // features like 32-bit UIDs. // IPCPerm is equivalent to struct ipc64_perm. +// +// +marshal type IPCPerm struct { Key uint32 UID uint32 diff --git a/pkg/abi/linux/linux.go b/pkg/abi/linux/linux.go index 281acdbde..3b4abece1 100644 --- a/pkg/abi/linux/linux.go +++ b/pkg/abi/linux/linux.go @@ -12,7 +12,8 @@ // See the License for the specific language governing permissions and // limitations under the License. -// Package linux contains the constants and types needed to interface with a Linux kernel. +// Package linux contains the constants and types needed to interface with a +// Linux kernel. package linux // NumSoftIRQ is the number of software IRQs, exposed via /proc/stat. @@ -21,6 +22,8 @@ package linux const NumSoftIRQ = 10 // Sysinfo is the structure provided by sysinfo on linux versions > 2.3.48. +// +// +marshal type Sysinfo struct { Uptime int64 Loads [3]uint64 @@ -34,6 +37,6 @@ type Sysinfo struct { _ [6]byte // Pad Procs to 64bits. TotalHigh uint64 FreeHigh uint64 - Unit uint32 - /* The _f field in the glibc version of Sysinfo has size 0 on AMD64 */ + Unit uint32 `marshal:"unaligned"` // Struct ends mid-64-bit-word. + // The _f field in the glibc version of Sysinfo has size 0 on AMD64. } diff --git a/pkg/abi/linux/netfilter.go b/pkg/abi/linux/netfilter.go index 91e35366f..1c5b34711 100644 --- a/pkg/abi/linux/netfilter.go +++ b/pkg/abi/linux/netfilter.go @@ -17,9 +17,9 @@ package linux import ( "io" + "gvisor.dev/gvisor/pkg/marshal" + "gvisor.dev/gvisor/pkg/marshal/primitive" "gvisor.dev/gvisor/pkg/usermem" - "gvisor.dev/gvisor/tools/go_marshal/marshal" - "gvisor.dev/gvisor/tools/go_marshal/primitive" ) // This file contains structures required to support netfilter, specifically @@ -450,9 +450,9 @@ func (ke *KernelIPTGetEntries) UnmarshalUnsafe(src []byte) { } // CopyIn implements marshal.Marshallable.CopyIn. -func (ke *KernelIPTGetEntries) CopyIn(task marshal.Task, addr usermem.Addr) (int, error) { - buf := task.CopyScratchBuffer(ke.SizeBytes()) // escapes: okay. - length, err := task.CopyInBytes(addr, buf) // escapes: okay. +func (ke *KernelIPTGetEntries) CopyIn(cc marshal.CopyContext, addr usermem.Addr) (int, error) { + buf := cc.CopyScratchBuffer(ke.SizeBytes()) // escapes: okay. + length, err := cc.CopyInBytes(addr, buf) // escapes: okay. // Unmarshal unconditionally. If we had a short copy-in, this results in a // partially unmarshalled struct. ke.UnmarshalBytes(buf) // escapes: fallback. @@ -460,21 +460,21 @@ func (ke *KernelIPTGetEntries) CopyIn(task marshal.Task, addr usermem.Addr) (int } // CopyOut implements marshal.Marshallable.CopyOut. -func (ke *KernelIPTGetEntries) CopyOut(task marshal.Task, addr usermem.Addr) (int, error) { +func (ke *KernelIPTGetEntries) CopyOut(cc marshal.CopyContext, addr usermem.Addr) (int, error) { // Type KernelIPTGetEntries doesn't have a packed layout in memory, fall // back to MarshalBytes. - return task.CopyOutBytes(addr, ke.marshalAll(task)) + return cc.CopyOutBytes(addr, ke.marshalAll(cc)) } // CopyOutN implements marshal.Marshallable.CopyOutN. -func (ke *KernelIPTGetEntries) CopyOutN(task marshal.Task, addr usermem.Addr, limit int) (int, error) { +func (ke *KernelIPTGetEntries) CopyOutN(cc marshal.CopyContext, addr usermem.Addr, limit int) (int, error) { // Type KernelIPTGetEntries doesn't have a packed layout in memory, fall // back to MarshalBytes. - return task.CopyOutBytes(addr, ke.marshalAll(task)[:limit]) + return cc.CopyOutBytes(addr, ke.marshalAll(cc)[:limit]) } -func (ke *KernelIPTGetEntries) marshalAll(task marshal.Task) []byte { - buf := task.CopyScratchBuffer(ke.SizeBytes()) +func (ke *KernelIPTGetEntries) marshalAll(cc marshal.CopyContext) []byte { + buf := cc.CopyScratchBuffer(ke.SizeBytes()) ke.MarshalBytes(buf) return buf } diff --git a/pkg/abi/linux/netfilter_ipv6.go b/pkg/abi/linux/netfilter_ipv6.go index 9bb9efb10..a137940b6 100644 --- a/pkg/abi/linux/netfilter_ipv6.go +++ b/pkg/abi/linux/netfilter_ipv6.go @@ -17,9 +17,9 @@ package linux import ( "io" + "gvisor.dev/gvisor/pkg/marshal" + "gvisor.dev/gvisor/pkg/marshal/primitive" "gvisor.dev/gvisor/pkg/usermem" - "gvisor.dev/gvisor/tools/go_marshal/marshal" - "gvisor.dev/gvisor/tools/go_marshal/primitive" ) // This file contains structures required to support IPv6 netfilter and @@ -128,9 +128,9 @@ func (ke *KernelIP6TGetEntries) UnmarshalUnsafe(src []byte) { } // CopyIn implements marshal.Marshallable.CopyIn. -func (ke *KernelIP6TGetEntries) CopyIn(task marshal.Task, addr usermem.Addr) (int, error) { - buf := task.CopyScratchBuffer(ke.SizeBytes()) // escapes: okay. - length, err := task.CopyInBytes(addr, buf) // escapes: okay. +func (ke *KernelIP6TGetEntries) CopyIn(cc marshal.CopyContext, addr usermem.Addr) (int, error) { + buf := cc.CopyScratchBuffer(ke.SizeBytes()) // escapes: okay. + length, err := cc.CopyInBytes(addr, buf) // escapes: okay. // Unmarshal unconditionally. If we had a short copy-in, this results // in a partially unmarshalled struct. ke.UnmarshalBytes(buf) // escapes: fallback. @@ -138,21 +138,21 @@ func (ke *KernelIP6TGetEntries) CopyIn(task marshal.Task, addr usermem.Addr) (in } // CopyOut implements marshal.Marshallable.CopyOut. -func (ke *KernelIP6TGetEntries) CopyOut(task marshal.Task, addr usermem.Addr) (int, error) { +func (ke *KernelIP6TGetEntries) CopyOut(cc marshal.CopyContext, addr usermem.Addr) (int, error) { // Type KernelIP6TGetEntries doesn't have a packed layout in memory, // fall back to MarshalBytes. - return task.CopyOutBytes(addr, ke.marshalAll(task)) + return cc.CopyOutBytes(addr, ke.marshalAll(cc)) } // CopyOutN implements marshal.Marshallable.CopyOutN. -func (ke *KernelIP6TGetEntries) CopyOutN(task marshal.Task, addr usermem.Addr, limit int) (int, error) { +func (ke *KernelIP6TGetEntries) CopyOutN(cc marshal.CopyContext, addr usermem.Addr, limit int) (int, error) { // Type KernelIP6TGetEntries doesn't have a packed layout in memory, fall // back to MarshalBytes. - return task.CopyOutBytes(addr, ke.marshalAll(task)[:limit]) + return cc.CopyOutBytes(addr, ke.marshalAll(cc)[:limit]) } -func (ke *KernelIP6TGetEntries) marshalAll(task marshal.Task) []byte { - buf := task.CopyScratchBuffer(ke.SizeBytes()) +func (ke *KernelIP6TGetEntries) marshalAll(cc marshal.CopyContext) []byte { + buf := cc.CopyScratchBuffer(ke.SizeBytes()) ke.MarshalBytes(buf) return buf } @@ -290,6 +290,19 @@ type IP6TIP struct { const SizeOfIP6TIP = 136 +// Flags in IP6TIP.Flags. Corresponding constants are in +// include/uapi/linux/netfilter_ipv6/ip6_tables.h. +const ( + // Whether to check the Protocol field. + IP6T_F_PROTO = 0x01 + // Whether to match the TOS field. + IP6T_F_TOS = 0x02 + // Indicates that the jump target is an aboslute GOTO, not an offset. + IP6T_F_GOTO = 0x04 + // Enables all flags. + IP6T_F_MASK = 0x07 +) + // Flags in IP6TIP.InverseFlags. Corresponding constants are in // include/uapi/linux/netfilter_ipv6/ip6_tables.h. const ( diff --git a/pkg/abi/linux/netlink.go b/pkg/abi/linux/netlink.go index 0ba086c76..b41f94a69 100644 --- a/pkg/abi/linux/netlink.go +++ b/pkg/abi/linux/netlink.go @@ -40,6 +40,8 @@ const ( ) // SockAddrNetlink is struct sockaddr_nl, from uapi/linux/netlink.h. +// +// +marshal type SockAddrNetlink struct { Family uint16 _ uint16 diff --git a/pkg/abi/linux/poll.go b/pkg/abi/linux/poll.go index c04d26e4c..3443a5768 100644 --- a/pkg/abi/linux/poll.go +++ b/pkg/abi/linux/poll.go @@ -15,6 +15,8 @@ package linux // PollFD is struct pollfd, used by poll(2)/ppoll(2), from uapi/asm-generic/poll.h. +// +// +marshal slice:PollFDSlice type PollFD struct { FD int32 Events int16 diff --git a/pkg/abi/linux/rusage.go b/pkg/abi/linux/rusage.go index d8302dc85..e29d0ac7e 100644 --- a/pkg/abi/linux/rusage.go +++ b/pkg/abi/linux/rusage.go @@ -26,6 +26,8 @@ const ( ) // Rusage represents the Linux struct rusage. +// +// +marshal type Rusage struct { UTime Timeval STime Timeval diff --git a/pkg/abi/linux/seccomp.go b/pkg/abi/linux/seccomp.go index d0607e256..b07cafe12 100644 --- a/pkg/abi/linux/seccomp.go +++ b/pkg/abi/linux/seccomp.go @@ -34,11 +34,11 @@ type BPFAction uint32 const ( SECCOMP_RET_KILL_PROCESS BPFAction = 0x80000000 - SECCOMP_RET_KILL_THREAD = 0x00000000 - SECCOMP_RET_TRAP = 0x00030000 - SECCOMP_RET_ERRNO = 0x00050000 - SECCOMP_RET_TRACE = 0x7ff00000 - SECCOMP_RET_ALLOW = 0x7fff0000 + SECCOMP_RET_KILL_THREAD BPFAction = 0x00000000 + SECCOMP_RET_TRAP BPFAction = 0x00030000 + SECCOMP_RET_ERRNO BPFAction = 0x00050000 + SECCOMP_RET_TRACE BPFAction = 0x7ff00000 + SECCOMP_RET_ALLOW BPFAction = 0x7fff0000 ) func (a BPFAction) String() string { @@ -64,6 +64,19 @@ func (a BPFAction) Data() uint16 { return uint16(a & SECCOMP_RET_DATA) } +// WithReturnCode sets the lower 16 bits of the SECCOMP_RET_ERRNO or +// SECCOMP_RET_TRACE actions to the provided return code, overwriting the previous +// action, and returns a new BPFAction. If not SECCOMP_RET_ERRNO or +// SECCOMP_RET_TRACE then this panics. +func (a BPFAction) WithReturnCode(code uint16) BPFAction { + // mask out the previous return value + baseAction := a & SECCOMP_RET_ACTION_FULL + if baseAction == SECCOMP_RET_ERRNO || baseAction == SECCOMP_RET_TRACE { + return BPFAction(uint32(baseAction) | uint32(code)) + } + panic("WithReturnCode only valid for SECCOMP_RET_ERRNO and SECCOMP_RET_TRACE") +} + // SockFprog is sock_fprog taken from <linux/filter.h>. type SockFprog struct { Len uint16 diff --git a/pkg/abi/linux/sem.go b/pkg/abi/linux/sem.go index de422c519..487a626cc 100644 --- a/pkg/abi/linux/sem.go +++ b/pkg/abi/linux/sem.go @@ -35,6 +35,8 @@ const ( const SEM_UNDO = 0x1000 // SemidDS is equivalent to struct semid64_ds. +// +// +marshal type SemidDS struct { SemPerm IPCPerm SemOTime TimeT @@ -45,6 +47,8 @@ type SemidDS struct { } // Sembuf is equivalent to struct sembuf. +// +// +marshal slice:SembufSlice type Sembuf struct { SemNum uint16 SemOp int16 diff --git a/pkg/abi/linux/shm.go b/pkg/abi/linux/shm.go index e45aadb10..274b1e847 100644 --- a/pkg/abi/linux/shm.go +++ b/pkg/abi/linux/shm.go @@ -51,6 +51,8 @@ const ( // ShmidDS is equivalent to struct shmid64_ds. Source: // include/uapi/asm-generic/shmbuf.h +// +// +marshal type ShmidDS struct { ShmPerm IPCPerm ShmSegsz uint64 @@ -66,6 +68,8 @@ type ShmidDS struct { } // ShmParams is equivalent to struct shminfo. Source: include/uapi/linux/shm.h +// +// +marshal type ShmParams struct { ShmMax uint64 ShmMin uint64 @@ -75,6 +79,8 @@ type ShmParams struct { } // ShmInfo is equivalent to struct shm_info. Source: include/uapi/linux/shm.h +// +// +marshal type ShmInfo struct { UsedIDs int32 // Number of currently existing segments. _ [4]byte diff --git a/pkg/abi/linux/signal.go b/pkg/abi/linux/signal.go index 1c330e763..6ca57ffbb 100644 --- a/pkg/abi/linux/signal.go +++ b/pkg/abi/linux/signal.go @@ -214,6 +214,8 @@ const ( ) // Sigevent represents struct sigevent. +// +// +marshal type Sigevent struct { Value uint64 // union sigval {int, void*} Signo int32 diff --git a/pkg/abi/linux/socket.go b/pkg/abi/linux/socket.go index e37c8727d..d156d41e4 100644 --- a/pkg/abi/linux/socket.go +++ b/pkg/abi/linux/socket.go @@ -14,7 +14,10 @@ package linux -import "gvisor.dev/gvisor/pkg/binary" +import ( + "gvisor.dev/gvisor/pkg/binary" + "gvisor.dev/gvisor/pkg/marshal" +) // Address families, from linux/socket.h. const ( @@ -265,6 +268,8 @@ type InetMulticastRequestWithNIC struct { type Inet6Addr [16]byte // SockAddrInet6 is struct sockaddr_in6, from uapi/linux/in6.h. +// +// +marshal type SockAddrInet6 struct { Family uint16 Port uint16 @@ -274,6 +279,8 @@ type SockAddrInet6 struct { } // SockAddrLink is a struct sockaddr_ll, from uapi/linux/if_packet.h. +// +// +marshal type SockAddrLink struct { Family uint16 Protocol uint16 @@ -290,6 +297,8 @@ type SockAddrLink struct { const UnixPathMax = 108 // SockAddrUnix is struct sockaddr_un, from uapi/linux/un.h. +// +// +marshal type SockAddrUnix struct { Family uint16 Path [UnixPathMax]int8 @@ -299,6 +308,8 @@ type SockAddrUnix struct { // equivalent to struct sockaddr. SockAddr ensures that a well-defined set of // types can be used as socket addresses. type SockAddr interface { + marshal.Marshallable + // implementsSockAddr exists purely to allow a type to indicate that they // implement this interface. This method is a no-op and shouldn't be called. implementsSockAddr() diff --git a/pkg/abi/linux/time.go b/pkg/abi/linux/time.go index e6860ed49..206f5af7e 100644 --- a/pkg/abi/linux/time.go +++ b/pkg/abi/linux/time.go @@ -93,6 +93,8 @@ const ( const maxSecInDuration = math.MaxInt64 / int64(time.Second) // TimeT represents time_t in <time.h>. It represents time in seconds. +// +// +marshal type TimeT int64 // NsecToTimeT translates nanoseconds to TimeT (seconds). @@ -102,7 +104,7 @@ func NsecToTimeT(nsec int64) TimeT { // Timespec represents struct timespec in <time.h>. // -// +marshal +// +marshal slice:TimespecSlice type Timespec struct { Sec int64 Nsec int64 @@ -158,7 +160,7 @@ const SizeOfTimeval = 16 // Timeval represents struct timeval in <time.h>. // -// +marshal +// +marshal slice:TimevalSlice type Timeval struct { Sec int64 Usec int64 @@ -196,6 +198,8 @@ func DurationToTimeval(dur time.Duration) Timeval { } // Itimerspec represents struct itimerspec in <time.h>. +// +// +marshal type Itimerspec struct { Interval Timespec Value Timespec @@ -206,12 +210,16 @@ type Itimerspec struct { // struct timeval it_interval; /* next value */ // struct timeval it_value; /* current value */ // }; +// +// +marshal type ItimerVal struct { Interval Timeval Value Timeval } // ClockT represents type clock_t. +// +// +marshal type ClockT int64 // ClockTFromDuration converts time.Duration to clock_t. @@ -220,6 +228,8 @@ func ClockTFromDuration(d time.Duration) ClockT { } // Tms represents struct tms, used by times(2). +// +// +marshal type Tms struct { UTime ClockT STime ClockT @@ -229,6 +239,8 @@ type Tms struct { // TimerID represents type timer_t, which identifies a POSIX per-process // interval timer. +// +// +marshal type TimerID int32 // StatxTimestamp represents struct statx_timestamp. diff --git a/pkg/abi/linux/tty.go b/pkg/abi/linux/tty.go index 8ac02aee8..47e65d9fb 100644 --- a/pkg/abi/linux/tty.go +++ b/pkg/abi/linux/tty.go @@ -23,6 +23,8 @@ const ( ) // Winsize is struct winsize, defined in uapi/asm-generic/termios.h. +// +// +marshal type Winsize struct { Row uint16 Col uint16 @@ -31,6 +33,8 @@ type Winsize struct { } // Termios is struct termios, defined in uapi/asm-generic/termbits.h. +// +// +marshal type Termios struct { InputFlags uint32 OutputFlags uint32 @@ -321,9 +325,9 @@ var MasterTermios = KernelTermios{ OutputSpeed: 38400, } -// DefaultSlaveTermios is the default terminal configuration of the slave end -// of a Unix98 pseudoterminal. -var DefaultSlaveTermios = KernelTermios{ +// DefaultReplicaTermios is the default terminal configuration of the replica +// end of a Unix98 pseudoterminal. +var DefaultReplicaTermios = KernelTermios{ InputFlags: ICRNL | IXON, OutputFlags: OPOST | ONLCR, ControlFlags: B38400 | CS8 | CREAD, @@ -337,6 +341,7 @@ var DefaultSlaveTermios = KernelTermios{ // include/uapi/asm-generic/termios.h. // // +stateify savable +// +marshal type WindowSize struct { Rows uint16 Cols uint16 diff --git a/pkg/abi/linux/utsname.go b/pkg/abi/linux/utsname.go index 60f220a67..cb7c95437 100644 --- a/pkg/abi/linux/utsname.go +++ b/pkg/abi/linux/utsname.go @@ -26,6 +26,8 @@ const ( ) // UtsName represents struct utsname, the struct returned by uname(2). +// +// +marshal type UtsName struct { Sysname [UTSLen + 1]byte Nodename [UTSLen + 1]byte diff --git a/pkg/abi/linux/xattr.go b/pkg/abi/linux/xattr.go index 99180b208..8ef837f27 100644 --- a/pkg/abi/linux/xattr.go +++ b/pkg/abi/linux/xattr.go @@ -23,6 +23,9 @@ const ( XATTR_CREATE = 1 XATTR_REPLACE = 2 + XATTR_TRUSTED_PREFIX = "trusted." + XATTR_TRUSTED_PREFIX_LEN = len(XATTR_TRUSTED_PREFIX) + XATTR_USER_PREFIX = "user." XATTR_USER_PREFIX_LEN = len(XATTR_USER_PREFIX) ) diff --git a/pkg/amutex/BUILD b/pkg/amutex/BUILD index ffc918846..bd3a5cce9 100644 --- a/pkg/amutex/BUILD +++ b/pkg/amutex/BUILD @@ -6,7 +6,10 @@ go_library( name = "amutex", srcs = ["amutex.go"], visibility = ["//:sandbox"], - deps = ["//pkg/syserror"], + deps = [ + "//pkg/context", + "//pkg/syserror", + ], ) go_test( diff --git a/pkg/amutex/amutex.go b/pkg/amutex/amutex.go index a078a31db..d7acc1d9f 100644 --- a/pkg/amutex/amutex.go +++ b/pkg/amutex/amutex.go @@ -19,41 +19,17 @@ package amutex import ( "sync/atomic" + "gvisor.dev/gvisor/pkg/context" "gvisor.dev/gvisor/pkg/syserror" ) // Sleeper must be implemented by users of the abortable mutex to allow for // cancellation of waits. -type Sleeper interface { - // SleepStart is called by the AbortableMutex.Lock() function when the - // mutex is contended and the goroutine is about to sleep. - // - // A channel can be returned that causes the sleep to be canceled if - // it's readable. If no cancellation is desired, nil can be returned. - SleepStart() <-chan struct{} - - // SleepFinish is called by AbortableMutex.Lock() once a contended mutex - // is acquired or the wait is aborted. - SleepFinish(success bool) - - // Interrupted returns true if the wait is aborted. - Interrupted() bool -} +type Sleeper = context.ChannelSleeper // NoopSleeper is a stateless no-op implementation of Sleeper for anonymous // embedding in other types that do not support cancelation. -type NoopSleeper struct{} - -// SleepStart implements Sleeper.SleepStart. -func (NoopSleeper) SleepStart() <-chan struct{} { - return nil -} - -// SleepFinish implements Sleeper.SleepFinish. -func (NoopSleeper) SleepFinish(success bool) {} - -// Interrupted implements Sleeper.Interrupted. -func (NoopSleeper) Interrupted() bool { return false } +type NoopSleeper = context.Context // Block blocks until either receiving from ch succeeds (in which case it // returns nil) or sleeper is interrupted (in which case it returns diff --git a/pkg/bpf/decoder.go b/pkg/bpf/decoder.go index c8ee0c3b1..069d0395d 100644 --- a/pkg/bpf/decoder.go +++ b/pkg/bpf/decoder.go @@ -21,10 +21,15 @@ import ( "gvisor.dev/gvisor/pkg/abi/linux" ) -// DecodeProgram translates an array of BPF instructions into text format. -func DecodeProgram(program []linux.BPFInstruction) (string, error) { +// DecodeProgram translates a compiled BPF program into text format. +func DecodeProgram(p Program) (string, error) { + return DecodeInstructions(p.instructions) +} + +// DecodeInstructions translates an array of BPF instructions into text format. +func DecodeInstructions(instns []linux.BPFInstruction) (string, error) { var ret bytes.Buffer - for line, s := range program { + for line, s := range instns { ret.WriteString(fmt.Sprintf("%v: ", line)) if err := decode(s, line, &ret); err != nil { return "", err @@ -34,7 +39,7 @@ func DecodeProgram(program []linux.BPFInstruction) (string, error) { return ret.String(), nil } -// Decode translates BPF instruction into text format. +// Decode translates a single BPF instruction into text format. func Decode(inst linux.BPFInstruction) (string, error) { var ret bytes.Buffer err := decode(inst, -1, &ret) diff --git a/pkg/bpf/decoder_test.go b/pkg/bpf/decoder_test.go index 6a023f0c0..bb971ce21 100644 --- a/pkg/bpf/decoder_test.go +++ b/pkg/bpf/decoder_test.go @@ -93,7 +93,7 @@ func TestDecode(t *testing.T) { } } -func TestDecodeProgram(t *testing.T) { +func TestDecodeInstructions(t *testing.T) { for _, test := range []struct { name string program []linux.BPFInstruction @@ -126,7 +126,7 @@ func TestDecodeProgram(t *testing.T) { program: []linux.BPFInstruction{Stmt(Ld+Abs+W, 10), Stmt(Ld+Len+Mem, 0)}, fail: true}, } { - got, err := DecodeProgram(test.program) + got, err := DecodeInstructions(test.program) if test.fail { if err == nil { t.Errorf("%s: Decode(...) failed, expected: 'error', got: %q", test.name, got) diff --git a/pkg/bpf/program_builder.go b/pkg/bpf/program_builder.go index 7992044d0..caaf99c83 100644 --- a/pkg/bpf/program_builder.go +++ b/pkg/bpf/program_builder.go @@ -32,13 +32,21 @@ type ProgramBuilder struct { // Maps label names to label objects. labels map[string]*label + // unusableLabels are labels that are added before being referenced in a + // jump. Any labels added this way cannot be referenced later in order to + // avoid backwards references. + unusableLabels map[string]bool + // Array of BPF instructions that makes up the program. instructions []linux.BPFInstruction } // NewProgramBuilder creates a new ProgramBuilder instance. func NewProgramBuilder() *ProgramBuilder { - return &ProgramBuilder{labels: map[string]*label{}} + return &ProgramBuilder{ + labels: map[string]*label{}, + unusableLabels: map[string]bool{}, + } } // label contains information to resolve a label to an offset. @@ -108,9 +116,12 @@ func (b *ProgramBuilder) AddJumpLabels(code uint16, k uint32, jtLabel, jfLabel s func (b *ProgramBuilder) AddLabel(name string) error { l, ok := b.labels[name] if !ok { - // This is done to catch jump backwards cases, but it's not strictly wrong - // to have unused labels. - return fmt.Errorf("Adding a label that hasn't been used is not allowed: %v", name) + if _, ok = b.unusableLabels[name]; ok { + return fmt.Errorf("label %q already set", name) + } + // Mark the label as unusable. This is done to catch backwards jumps. + b.unusableLabels[name] = true + return nil } if l.target != -1 { return fmt.Errorf("label %q target already set: %v", name, l.target) @@ -141,6 +152,10 @@ func (b *ProgramBuilder) addLabelSource(labelName string, t jmpType) { func (b *ProgramBuilder) resolveLabels() error { for key, v := range b.labels { + if _, ok := b.unusableLabels[key]; ok { + return fmt.Errorf("backwards reference detected for label: %q", key) + } + if v.target == -1 { return fmt.Errorf("label target not set: %v", key) } diff --git a/pkg/bpf/program_builder_test.go b/pkg/bpf/program_builder_test.go index 92ca5f4c3..37f684f25 100644 --- a/pkg/bpf/program_builder_test.go +++ b/pkg/bpf/program_builder_test.go @@ -26,16 +26,16 @@ func validate(p *ProgramBuilder, expected []linux.BPFInstruction) error { if err != nil { return fmt.Errorf("Instructions() failed: %v", err) } - got, err := DecodeProgram(instructions) + got, err := DecodeInstructions(instructions) if err != nil { - return fmt.Errorf("DecodeProgram('instructions') failed: %v", err) + return fmt.Errorf("DecodeInstructions('instructions') failed: %v", err) } - expectedDecoded, err := DecodeProgram(expected) + expectedDecoded, err := DecodeInstructions(expected) if err != nil { - return fmt.Errorf("DecodeProgram('expected') failed: %v", err) + return fmt.Errorf("DecodeInstructions('expected') failed: %v", err) } if got != expectedDecoded { - return fmt.Errorf("DecodeProgram() failed, expected: %q, got: %q", expectedDecoded, got) + return fmt.Errorf("DecodeInstructions() failed, expected: %q, got: %q", expectedDecoded, got) } return nil } @@ -124,10 +124,38 @@ func TestProgramBuilderLabelWithNoInstruction(t *testing.T) { } } +// TestProgramBuilderUnusedLabel tests that adding an unused label doesn't +// cause program generation to fail. func TestProgramBuilderUnusedLabel(t *testing.T) { p := NewProgramBuilder() - if err := p.AddLabel("unused"); err == nil { - t.Errorf("AddLabel(unused) should have failed") + p.AddStmt(Ld+Abs+W, 10) + p.AddJump(Jmp+Ja, 10, 0, 0) + + expected := []linux.BPFInstruction{ + Stmt(Ld+Abs+W, 10), + Jump(Jmp+Ja, 10, 0, 0), + } + + if err := p.AddLabel("unused"); err != nil { + t.Errorf("AddLabel(unused) should have succeeded") + } + + if err := validate(p, expected); err != nil { + t.Errorf("Validate() failed: %v", err) + } +} + +// TestProgramBuilderBackwardsReference tests that including a backwards +// reference to a label in a program causes a failure. +func TestProgramBuilderBackwardsReference(t *testing.T) { + p := NewProgramBuilder() + if err := p.AddLabel("bw_label"); err != nil { + t.Errorf("failed to add label") + } + p.AddStmt(Ld+Abs+W, 10) + p.AddJumpTrueLabel(Jmp+Jeq+K, 10, "bw_label", 0) + if _, err := p.Instructions(); err == nil { + t.Errorf("Instructions() should have failed") } } diff --git a/pkg/buffer/BUILD b/pkg/buffer/BUILD index dcd086298..b03d46d18 100644 --- a/pkg/buffer/BUILD +++ b/pkg/buffer/BUILD @@ -26,8 +26,10 @@ go_library( ], visibility = ["//visibility:public"], deps = [ + "//pkg/context", "//pkg/log", "//pkg/safemem", + "//pkg/usermem", ], ) diff --git a/pkg/context/BUILD b/pkg/context/BUILD index 239f31149..f33e23bf7 100644 --- a/pkg/context/BUILD +++ b/pkg/context/BUILD @@ -7,7 +7,6 @@ go_library( srcs = ["context.go"], visibility = ["//:sandbox"], deps = [ - "//pkg/amutex", "//pkg/log", ], ) diff --git a/pkg/context/context.go b/pkg/context/context.go index 5319b6d8d..2613bc752 100644 --- a/pkg/context/context.go +++ b/pkg/context/context.go @@ -26,7 +26,6 @@ import ( "context" "time" - "gvisor.dev/gvisor/pkg/amutex" "gvisor.dev/gvisor/pkg/log" ) @@ -68,9 +67,10 @@ func ThreadGroupIDFromContext(ctx Context) (tgid int32, ok bool) { // In both cases, values extracted from the Context should be used instead. type Context interface { log.Logger - amutex.Sleeper context.Context + ChannelSleeper + // UninterruptibleSleepStart indicates the beginning of an uninterruptible // sleep state (equivalent to Linux's TASK_UNINTERRUPTIBLE). If deactivate // is true and the Context represents a Task, the Task's AddressSpace is @@ -85,29 +85,60 @@ type Context interface { UninterruptibleSleepFinish(activate bool) } -// NoopSleeper is a noop implementation of amutex.Sleeper and UninterruptibleSleep -// methods for anonymous embedding in other types that do not implement sleeps. -type NoopSleeper struct { - amutex.NoopSleeper +// A ChannelSleeper represents a goroutine that may sleep interruptibly, where +// interruption is indicated by a channel becoming readable. +type ChannelSleeper interface { + // SleepStart is called before going to sleep interruptibly. If SleepStart + // returns a non-nil channel and that channel becomes ready for receiving + // while the goroutine is sleeping, the goroutine should be woken, and + // SleepFinish(false) should be called. Otherwise, SleepFinish(true) should + // be called after the goroutine stops sleeping. + SleepStart() <-chan struct{} + + // SleepFinish is called after an interruptibly-sleeping goroutine stops + // sleeping, as documented by SleepStart. + SleepFinish(success bool) + + // Interrupted returns true if the channel returned by SleepStart is + // ready for receiving. + Interrupted() bool +} + +// NoopSleeper is a noop implementation of ChannelSleeper and +// Context.UninterruptibleSleep* methods for anonymous embedding in other types +// that do not implement special behavior around sleeps. +type NoopSleeper struct{} + +// SleepStart implements ChannelSleeper.SleepStart. +func (NoopSleeper) SleepStart() <-chan struct{} { + return nil +} + +// SleepFinish implements ChannelSleeper.SleepFinish. +func (NoopSleeper) SleepFinish(success bool) {} + +// Interrupted implements ChannelSleeper.Interrupted. +func (NoopSleeper) Interrupted() bool { + return false } -// UninterruptibleSleepStart does nothing. -func (NoopSleeper) UninterruptibleSleepStart(bool) {} +// UninterruptibleSleepStart implements Context.UninterruptibleSleepStart. +func (NoopSleeper) UninterruptibleSleepStart(deactivate bool) {} -// UninterruptibleSleepFinish does nothing. -func (NoopSleeper) UninterruptibleSleepFinish(bool) {} +// UninterruptibleSleepFinish implements Context.UninterruptibleSleepFinish. +func (NoopSleeper) UninterruptibleSleepFinish(activate bool) {} -// Deadline returns zero values, meaning no deadline. +// Deadline implements context.Context.Deadline. func (NoopSleeper) Deadline() (time.Time, bool) { return time.Time{}, false } -// Done returns nil. +// Done implements context.Context.Done. func (NoopSleeper) Done() <-chan struct{} { return nil } -// Err returns nil. +// Err returns context.Context.Err. func (NoopSleeper) Err() error { return nil } diff --git a/pkg/coverage/BUILD b/pkg/coverage/BUILD new file mode 100644 index 000000000..a198e8028 --- /dev/null +++ b/pkg/coverage/BUILD @@ -0,0 +1,14 @@ +load("//tools:defs.bzl", "go_library") + +package(licenses = ["notice"]) + +go_library( + name = "coverage", + srcs = ["coverage.go"], + visibility = ["//:sandbox"], + deps = [ + "//pkg/sync", + "//pkg/usermem", + "@io_bazel_rules_go//go/tools/coverdata", + ], +) diff --git a/pkg/coverage/coverage.go b/pkg/coverage/coverage.go new file mode 100644 index 000000000..6831adcce --- /dev/null +++ b/pkg/coverage/coverage.go @@ -0,0 +1,175 @@ +// Copyright 2020 The gVisor Authors. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +// Package coverage provides an interface through which Go coverage data can +// be collected, converted to kcov format, and exposed to userspace. +// +// Coverage can be enabled by calling bazel {build,test} with +// --collect_coverage_data and --instrumentation_filter with the desired +// coverage surface. This causes bazel to use the Go cover tool manually to +// generate instrumented files. It injects a hook that registers all coverage +// data with the coverdata package. +package coverage + +import ( + "fmt" + "io" + "sort" + "sync/atomic" + + "gvisor.dev/gvisor/pkg/sync" + "gvisor.dev/gvisor/pkg/usermem" + + "github.com/bazelbuild/rules_go/go/tools/coverdata" +) + +// KcovAvailable returns whether the kcov coverage interface is available. It is +// available as long as coverage is enabled for some files. +func KcovAvailable() bool { + return len(coverdata.Cover.Blocks) > 0 +} + +// coverageMu must be held while accessing coverdata.Cover. This prevents +// concurrent reads/writes from multiple threads collecting coverage data. +var coverageMu sync.RWMutex + +// once ensures that globalData is only initialized once. +var once sync.Once + +var globalData struct { + // files is the set of covered files sorted by filename. It is calculated at + // startup. + files []string + + // syntheticPCs are a set of PCs calculated at startup, where the PC + // at syntheticPCs[i][j] corresponds to file i, block j. + syntheticPCs [][]uint64 +} + +// ClearCoverageData clears existing coverage data. +func ClearCoverageData() { + coverageMu.Lock() + defer coverageMu.Unlock() + for _, counters := range coverdata.Cover.Counters { + for index := 0; index < len(counters); index++ { + atomic.StoreUint32(&counters[index], 0) + } + } +} + +var coveragePool = sync.Pool{ + New: func() interface{} { + return make([]byte, 0) + }, +} + +// ConsumeCoverageData builds and writes the collection of covered PCs. It +// returns the number of bytes written. +// +// In Linux, a kernel configuration is set that compiles the kernel with a +// custom function that is called at the beginning of every basic block, which +// updates the memory-mapped coverage information. The Go coverage tool does not +// allow us to inject arbitrary instructions into basic blocks, but it does +// provide data that we can convert to a kcov-like format and transfer them to +// userspace through a memory mapping. +// +// Note that this is not a strict implementation of kcov, which is especially +// tricky to do because we do not have the same coverage tools available in Go +// that that are available for the actual Linux kernel. In Linux, a kernel +// configuration is set that compiles the kernel with a custom function that is +// called at the beginning of every basic block to write program counters to the +// kcov memory mapping. In Go, however, coverage tools only give us a count of +// basic blocks as they are executed. Every time we return to userspace, we +// collect the coverage information and write out PCs for each block that was +// executed, providing userspace with the illusion that the kcov data is always +// up to date. For convenience, we also generate a unique synthetic PC for each +// block instead of using actual PCs. Finally, we do not provide thread-specific +// coverage data (each kcov instance only contains PCs executed by the thread +// owning it); instead, we will supply data for any file specified by -- +// instrumentation_filter. +// +// Note that we "consume", i.e. clear, coverdata when this function is run, to +// ensure that each event is only reported once. +// +// TODO(b/160639712): evaluate whether it is ok to reset the global coverage +// data every time this function is run. We could technically have each thread +// store a local snapshot against which we compare the most recent coverdata so +// that separate threads do not affect each other's view of the data. +func ConsumeCoverageData(w io.Writer) int { + once.Do(initCoverageData) + + coverageMu.Lock() + defer coverageMu.Unlock() + + total := 0 + var pcBuffer [8]byte + for fileIndex, file := range globalData.files { + counters := coverdata.Cover.Counters[file] + for index := 0; index < len(counters); index++ { + val := atomic.SwapUint32(&counters[index], 0) + if val != 0 { + // Calculate the synthetic PC. + pc := globalData.syntheticPCs[fileIndex][index] + + usermem.ByteOrder.PutUint64(pcBuffer[:], pc) + n, err := w.Write(pcBuffer[:]) + if err != nil { + if err == io.EOF { + // Simply stop writing if we encounter EOF; it's ok if we attempted to + // write more than we can hold. + return total + n + } + panic(fmt.Sprintf("Internal error writing PCs to kcov area: %v", err)) + } + total += n + } + } + } + + if total == 0 { + // An empty profile indicates that coverage is not enabled, in which case + // there shouldn't be any task work registered. + panic("kcov task work is registered, but no coverage data was found") + } + return total +} + +// initCoverageData initializes globalData. It should only be called once, +// before any kcov data is written. +func initCoverageData() { + // First, order all files. Then calculate synthetic PCs for every block + // (using the well-defined ordering for files as well). + for file := range coverdata.Cover.Blocks { + globalData.files = append(globalData.files, file) + } + sort.Strings(globalData.files) + + // nextSyntheticPC is the first PC that we generate for a block. + // + // This uses a standard-looking kernel range for simplicity. + // + // FIXME(b/160639712): This is only necessary because syzkaller requires + // addresses in the kernel range. If we can remove this constraint, then we + // should be able to use the actual addresses. + var nextSyntheticPC uint64 = 0xffffffff80000000 + for _, file := range globalData.files { + blocks := coverdata.Cover.Blocks[file] + thisFile := make([]uint64, 0, len(blocks)) + for range blocks { + thisFile = append(thisFile, nextSyntheticPC) + nextSyntheticPC++ // Advance. + } + globalData.syntheticPCs = append(globalData.syntheticPCs, thisFile) + } +} diff --git a/pkg/cpuid/cpuid_parse_x86_test.go b/pkg/cpuid/cpuid_parse_x86_test.go index c9bd40e1b..e4ae0d689 100644 --- a/pkg/cpuid/cpuid_parse_x86_test.go +++ b/pkg/cpuid/cpuid_parse_x86_test.go @@ -32,27 +32,27 @@ func kernelVersion() (int, int, error) { return 0, 0, err } - var r string + var sb strings.Builder for _, b := range u.Release { if b == 0 { break } - r += string(b) + sb.WriteByte(byte(b)) } - s := strings.Split(r, ".") + s := strings.Split(sb.String(), ".") if len(s) < 2 { - return 0, 0, fmt.Errorf("kernel release missing major and minor component: %s", r) + return 0, 0, fmt.Errorf("kernel release missing major and minor component: %s", sb.String()) } major, err := strconv.Atoi(s[0]) if err != nil { - return 0, 0, fmt.Errorf("error parsing major version %q in %q: %v", s[0], r, err) + return 0, 0, fmt.Errorf("error parsing major version %q in %q: %w", s[0], sb.String(), err) } minor, err := strconv.Atoi(s[1]) if err != nil { - return 0, 0, fmt.Errorf("error parsing minor version %q in %q: %v", s[1], r, err) + return 0, 0, fmt.Errorf("error parsing minor version %q in %q: %w", s[1], sb.String(), err) } return major, minor, nil diff --git a/pkg/fd/fd.go b/pkg/fd/fd.go index 83bcfe220..cc6b0cdf1 100644 --- a/pkg/fd/fd.go +++ b/pkg/fd/fd.go @@ -49,7 +49,7 @@ func fixCount(n int, err error) (int, error) { // Read implements io.Reader. func (r *ReadWriter) Read(b []byte) (int, error) { - c, err := fixCount(syscall.Read(int(atomic.LoadInt64(&r.fd)), b)) + c, err := fixCount(syscall.Read(r.FD(), b)) if c == 0 && len(b) > 0 && err == nil { return 0, io.EOF } @@ -62,7 +62,7 @@ func (r *ReadWriter) Read(b []byte) (int, error) { func (r *ReadWriter) ReadAt(b []byte, off int64) (c int, err error) { for len(b) > 0 { var m int - m, err = fixCount(syscall.Pread(int(atomic.LoadInt64(&r.fd)), b, off)) + m, err = fixCount(syscall.Pread(r.FD(), b, off)) if m == 0 && err == nil { return c, io.EOF } @@ -82,7 +82,7 @@ func (r *ReadWriter) Write(b []byte) (int, error) { var n, remaining int for remaining = len(b); remaining > 0; { woff := len(b) - remaining - n, err = syscall.Write(int(atomic.LoadInt64(&r.fd)), b[woff:]) + n, err = syscall.Write(r.FD(), b[woff:]) if n > 0 { // syscall.Write wrote some bytes. This is the common case. @@ -110,7 +110,7 @@ func (r *ReadWriter) Write(b []byte) (int, error) { func (r *ReadWriter) WriteAt(b []byte, off int64) (c int, err error) { for len(b) > 0 { var m int - m, err = fixCount(syscall.Pwrite(int(atomic.LoadInt64(&r.fd)), b, off)) + m, err = fixCount(syscall.Pwrite(r.FD(), b, off)) if err != nil { break } @@ -121,6 +121,16 @@ func (r *ReadWriter) WriteAt(b []byte, off int64) (c int, err error) { return } +// FD returns the owned file descriptor. Ownership remains unchanged. +func (r *ReadWriter) FD() int { + return int(atomic.LoadInt64(&r.fd)) +} + +// String implements Stringer.String(). +func (r *ReadWriter) String() string { + return fmt.Sprintf("FD: %d", r.FD()) +} + // FD owns a host file descriptor. // // It is similar to os.File, with a few important distinctions: @@ -167,6 +177,23 @@ func NewFromFile(file *os.File) (*FD, error) { return New(fd), nil } +// NewFromFiles creates new FDs for each file in the slice. +func NewFromFiles(files []*os.File) ([]*FD, error) { + rv := make([]*FD, 0, len(files)) + for _, f := range files { + new, err := NewFromFile(f) + if err != nil { + // Cleanup on error. + for _, fd := range rv { + fd.Close() + } + return nil, err + } + rv = append(rv, new) + } + return rv, nil +} + // Open is equivalent to open(2). func Open(path string, openmode int, perm uint32) (*FD, error) { f, err := syscall.Open(path, openmode|syscall.O_LARGEFILE, perm) @@ -204,11 +231,6 @@ func (f *FD) Release() int { return int(atomic.SwapInt64(&f.fd, -1)) } -// FD returns the file descriptor owned by FD. FD retains ownership. -func (f *FD) FD() int { - return int(atomic.LoadInt64(&f.fd)) -} - // File converts the FD to an os.File. // // FD does not transfer ownership of the file descriptor (it will be @@ -219,7 +241,7 @@ func (f *FD) FD() int { // This operation is somewhat expensive, so care should be taken to minimize // its use. func (f *FD) File() (*os.File, error) { - fd, err := syscall.Dup(int(atomic.LoadInt64(&f.fd))) + fd, err := syscall.Dup(f.FD()) if err != nil { return nil, err } diff --git a/pkg/fdnotifier/poll_unsafe.go b/pkg/fdnotifier/poll_unsafe.go index 4225b04dd..ec2f997a2 100644 --- a/pkg/fdnotifier/poll_unsafe.go +++ b/pkg/fdnotifier/poll_unsafe.go @@ -65,8 +65,7 @@ func NonBlockingPoll(fd int32, mask waiter.EventMask) waiter.EventMask { // epollWait performs a blocking wait on epfd. // -// Preconditions: -// * len(events) > 0 +// Preconditions: len(events) > 0 func epollWait(epfd int, events []syscall.EpollEvent, msec int) (int, error) { if len(events) == 0 { panic("Empty events passed to EpollWait") diff --git a/pkg/flipcall/flipcall.go b/pkg/flipcall/flipcall.go index ec742c091..c4a3366ce 100644 --- a/pkg/flipcall/flipcall.go +++ b/pkg/flipcall/flipcall.go @@ -179,8 +179,10 @@ const ( // Connect blocks until the peer Endpoint has called Endpoint.RecvFirst(). // -// Preconditions: ep is a client Endpoint. ep.Connect(), ep.RecvFirst(), -// ep.SendRecv(), and ep.SendLast() have never been called. +// Preconditions: +// * ep is a client Endpoint. +// * ep.Connect(), ep.RecvFirst(), ep.SendRecv(), and ep.SendLast() have never +// been called. func (ep *Endpoint) Connect() error { err := ep.ctrlConnect() if err == nil { @@ -192,8 +194,9 @@ func (ep *Endpoint) Connect() error { // RecvFirst blocks until the peer Endpoint calls Endpoint.SendRecv(), then // returns the datagram length specified by that call. // -// Preconditions: ep is a server Endpoint. ep.SendRecv(), ep.RecvFirst(), and -// ep.SendLast() have never been called. +// Preconditions: +// * ep is a server Endpoint. +// * ep.SendRecv(), ep.RecvFirst(), and ep.SendLast() have never been called. func (ep *Endpoint) RecvFirst() (uint32, error) { if err := ep.ctrlWaitFirst(); err != nil { return 0, err @@ -211,10 +214,12 @@ func (ep *Endpoint) RecvFirst() (uint32, error) { // datagram length, then blocks until the peer Endpoint calls // Endpoint.SendRecv() or Endpoint.SendLast(). // -// Preconditions: dataLen <= ep.DataCap(). No previous call to ep.SendRecv() or -// ep.RecvFirst() has returned an error. ep.SendLast() has never been called. -// If ep is a client Endpoint, ep.Connect() has previously been called and -// returned nil. +// Preconditions: +// * dataLen <= ep.DataCap(). +// * No previous call to ep.SendRecv() or ep.RecvFirst() has returned an error. +// * ep.SendLast() has never been called. +// * If ep is a client Endpoint, ep.Connect() has previously been called and +// returned nil. func (ep *Endpoint) SendRecv(dataLen uint32) (uint32, error) { if dataLen > ep.dataCap { panic(fmt.Sprintf("attempting to send packet with datagram length %d (maximum %d)", dataLen, ep.dataCap)) @@ -240,10 +245,12 @@ func (ep *Endpoint) SendRecv(dataLen uint32) (uint32, error) { // SendLast causes the peer Endpoint's call to Endpoint.SendRecv() or // Endpoint.RecvFirst() to return with the given datagram length. // -// Preconditions: dataLen <= ep.DataCap(). No previous call to ep.SendRecv() or -// ep.RecvFirst() has returned an error. ep.SendLast() has never been called. -// If ep is a client Endpoint, ep.Connect() has previously been called and -// returned nil. +// Preconditions: +// * dataLen <= ep.DataCap(). +// * No previous call to ep.SendRecv() or ep.RecvFirst() has returned an error. +// * ep.SendLast() has never been called. +// * If ep is a client Endpoint, ep.Connect() has previously been called and +// returned nil. func (ep *Endpoint) SendLast(dataLen uint32) error { if dataLen > ep.dataCap { panic(fmt.Sprintf("attempting to send packet with datagram length %d (maximum %d)", dataLen, ep.dataCap)) diff --git a/pkg/lisafs/README.md b/pkg/lisafs/README.md new file mode 100644 index 000000000..51d0d40e5 --- /dev/null +++ b/pkg/lisafs/README.md @@ -0,0 +1,363 @@ +# Replacing 9P + +## Background + +The Linux filesystem model consists of the following key aspects (modulo mounts, +which are outside the scope of this discussion): + +- A `struct inode` represents a "filesystem object", such as a directory or a + regular file. "Filesystem object" is most precisely defined by the practical + properties of an inode, such as an immutable type (regular file, directory, + symbolic link, etc.) and its independence from the path originally used to + obtain it. + +- A `struct dentry` represents a node in a filesystem tree. Semantically, each + dentry is immutably associated with an inode representing the filesystem + object at that position. (Linux implements optimizations involving reuse of + unreferenced dentries, which allows their associated inodes to change, but + this is outside the scope of this discussion.) + +- A `struct file` represents an open file description (hereafter FD) and is + needed to perform I/O. Each FD is immutably associated with the dentry + through which it was opened. + +The current gVisor virtual filesystem implementation (hereafter VFS1) closely +imitates the Linux design: + +- `struct inode` => `fs.Inode` + +- `struct dentry` => `fs.Dirent` + +- `struct file` => `fs.File` + +gVisor accesses most external filesystems through a variant of the 9P2000.L +protocol, including extensions for performance (`walkgetattr`) and for features +not supported by vanilla 9P2000.L (`flushf`, `lconnect`). The 9P protocol family +is inode-based; 9P fids represent a file (equivalently "file system object"), +and the protocol is structured around alternatively obtaining fids to represent +files (with `walk` and, in gVisor, `walkgetattr`) and performing operations on +those fids. + +In the sections below, a **shared** filesystem is a filesystem that is *mutably* +accessible by multiple concurrent clients, such that a **non-shared** filesystem +is a filesystem that is either read-only or accessible by only a single client. + +## Problems + +### Serialization of Path Component RPCs + +Broadly speaking, VFS1 traverses each path component in a pathname, alternating +between verifying that each traversed dentry represents an inode that represents +a searchable directory and moving to the next dentry in the path. + +In the context of a remote filesystem, the structure of this traversal means +that - modulo caching - a path involving N components requires at least N-1 +*sequential* RPCs to obtain metadata for intermediate directories, incurring +significant latency. (In vanilla 9P2000.L, 2(N-1) RPCs are required: N-1 `walk` +and N-1 `getattr`. We added the `walkgetattr` RPC to reduce this overhead.) On +non-shared filesystems, this overhead is primarily significant during +application startup; caching mitigates much of this overhead at steady state. On +shared filesystems, where correct caching requires revalidation (requiring RPCs +for each revalidated directory anyway), this overhead is consistently ruinous. + +### Inefficient RPCs + +9P is not exceptionally economical with RPCs in general. In addition to the +issue described above: + +- Opening an existing file in 9P involves at least 2 RPCs: `walk` to produce + an unopened fid representing the file, and `lopen` to open the fid. + +- Creating a file also involves at least 2 RPCs: `walk` to produce an unopened + fid representing the parent directory, and `lcreate` to create the file and + convert the fid to an open fid representing the created file. In practice, + both the Linux and gVisor 9P clients expect to have an unopened fid for the + created file (necessitating an additional `walk`), as well as attributes for + the created file (necessitating an additional `getattr`), for a total of 4 + RPCs. (In a shared filesystem, where whether a file already exists can + change between RPCs, a correct implementation of `open(O_CREAT)` would have + to alternate between these two paths (plus `clunk`ing the temporary fid + between alternations, since the nature of the `fid` differs between the two + paths). Neither Linux nor gVisor implement the required alternation, so + `open(O_CREAT)` without `O_EXCL` can spuriously fail with `EEXIST` on both.) + +- Closing (`clunk`ing) a fid requires an RPC. VFS1 issues this RPC + asynchronously in an attempt to reduce critical path latency, but scheduling + overhead makes this not clearly advantageous in practice. + +- `read` and `readdir` can return partial reads without a way to indicate EOF, + necessitating an additional final read to detect EOF. + +- Operations that affect filesystem state do not consistently return updated + filesystem state. In gVisor, the client implementation attempts to handle + this by tracking what it thinks updated state "should" be; this is complex, + and especially brittle for timestamps (which are often not arbitrarily + settable). In Linux, the client implemtation invalidates cached metadata + whenever it performs such an operation, and reloads it when a dentry + corresponding to an inode with no valid cached metadata is revalidated; this + is simple, but necessitates an additional `getattr`. + +### Dentry/Inode Ambiguity + +As noted above, 9P's documentation tends to imply that unopened fids represent +an inode. In practice, most filesystem APIs present very limited interfaces for +working with inodes at best, such that the interpretation of unopened fids +varies: + +- Linux's 9P client associates unopened fids with (dentry, uid) pairs. When + caching is enabled, it also associates each inode with the first fid opened + writably that references that inode, in order to support page cache + writeback. + +- gVisor's 9P client associates unopened fids with inodes, and also caches + opened fids in inodes in a manner similar to Linux. + +- The runsc fsgofer associates unopened fids with both "dentries" (host + filesystem paths) and "inodes" (host file descriptors); which is used + depends on the operation invoked on the fid. + +For non-shared filesystems, this confusion has resulted in correctness issues +that are (in gVisor) currently handled by a number of coarse-grained locks that +serialize renames with all other filesystem operations. For shared filesystems, +this means inconsistent behavior in the presence of concurrent mutation. + +## Design + +Almost all Linux filesystem syscalls describe filesystem resources in one of two +ways: + +- Path-based: A filesystem position is described by a combination of a + starting position and a sequence of path components relative to that + position, where the starting position is one of: + + - The VFS root (defined by mount namespace and chroot), for absolute paths + + - The VFS position of an existing FD, for relative paths passed to `*at` + syscalls (e.g. `statat`) + + - The current working directory, for relative paths passed to non-`*at` + syscalls and `*at` syscalls with `AT_FDCWD` + +- File-description-based: A filesystem object is described by an existing FD, + passed to a `f*` syscall (e.g. `fstat`). + +Many of our issues with 9P arise from its (and VFS') interposition of a model +based on inodes between the filesystem syscall API and filesystem +implementations. We propose to replace 9P with a protocol that does not feature +inodes at all, and instead closely follows the filesystem syscall API by +featuring only path-based and FD-based operations, with minimal deviations as +necessary to ameliorate deficiencies in the syscall interface (see below). This +approach addresses the issues described above: + +- Even on shared filesystems, most application filesystem syscalls are + translated to a single RPC (possibly excepting special cases described + below), which is a logical lower bound. + +- The behavior of application syscalls on shared filesystems is + straightforwardly predictable: path-based syscalls are translated to + path-based RPCs, which will re-lookup the file at that path, and FD-based + syscalls are translated to FD-based RPCs, which use an existing open file + without performing another lookup. (This is at least true on gofers that + proxy the host local filesystem; other filesystems that lack support for + e.g. certain operations on FDs may have different behavior, but this + divergence is at least still predictable and inherent to the underlying + filesystem implementation.) + +Note that this approach is only feasible in gVisor's next-generation virtual +filesystem (VFS2), which does not assume the existence of inodes and allows the +remote filesystem client to translate whole path-based syscalls into RPCs. Thus +one of the unavoidable tradeoffs associated with such a protocol vs. 9P is the +inability to construct a Linux client that is performance-competitive with +gVisor. + +### File Permissions + +Many filesystem operations are side-effectual, such that file permissions must +be checked before such operations take effect. The simplest approach to file +permission checking is for the sentry to obtain permissions from the remote +filesystem, then apply permission checks in the sentry before performing the +application-requested operation. However, this requires an additional RPC per +application syscall (which can't be mitigated by caching on shared filesystems). +Alternatively, we may delegate file permission checking to gofers. In general, +file permission checks depend on the following properties of the accessor: + +- Filesystem UID/GID + +- Supplementary GIDs + +- Effective capabilities in the accessor's user namespace (i.e. the accessor's + effective capability set) + +- All UIDs and GIDs mapped in the accessor's user namespace (which determine + if the accessor's capabilities apply to accessed files) + +We may choose to delay implementation of file permission checking delegation, +although this is potentially costly since it doubles the number of required RPCs +for most operations on shared filesystems. We may also consider compromise +options, such as only delegating file permission checks for accessors in the +root user namespace. + +### Symbolic Links + +gVisor usually interprets symbolic link targets in its VFS rather than on the +filesystem containing the symbolic link; thus e.g. a symlink to +"/proc/self/maps" on a remote filesystem resolves to said file in the sentry's +procfs rather than the host's. This implies that: + +- Remote filesystem servers that proxy filesystems supporting symlinks must + check if each path component is a symlink during path traversal. + +- Absolute symlinks require that the sentry restart the operation at its + contextual VFS root (which is task-specific and may not be on a remote + filesystem at all), so if a remote filesystem server encounters an absolute + symlink during path traversal on behalf of a path-based operation, it must + terminate path traversal and return the symlink target. + +- Relative symlinks begin target resolution in the parent directory of the + symlink, so in theory most relative symlinks can be handled automatically + during the path traversal that encounters the symlink, provided that said + traversal is supplied with the number of remaining symlinks before `ELOOP`. + However, the new path traversed by the symlink target may cross VFS mount + boundaries, such that it's only safe for remote filesystem servers to + speculatively follow relative symlinks for side-effect-free operations such + as `stat` (where the sentry can simply ignore results that are inapplicable + due to crossing mount boundaries). We may choose to delay implementation of + this feature, at the cost of an additional RPC per relative symlink (note + that even if the symlink target crosses a mount boundary, the sentry will + need to `stat` the path to the mount boundary to confirm that each traversed + component is an accessible directory); until it is implemented, relative + symlinks may be handled like absolute symlinks, by terminating path + traversal and returning the symlink target. + +The possibility of symlinks (and the possibility of a compromised sentry) means +that the sentry may issue RPCs with paths that, in the absence of symlinks, +would traverse beyond the root of the remote filesystem. For example, the sentry +may issue an RPC with a path like "/foo/../..", on the premise that if "/foo" is +a symlink then the resulting path may be elsewhere on the remote filesystem. To +handle this, path traversal must also track its current depth below the remote +filesystem root, and terminate path traversal if it would ascend beyond this +point. + +### Path Traversal + +Since path-based VFS operations will translate to path-based RPCs, filesystem +servers will need to handle path traversal. From the perspective of a given +filesystem implementation in the server, there are two basic approaches to path +traversal: + +- Inode-walk: For each path component, obtain a handle to the underlying + filesystem object (e.g. with `open(O_PATH)`), check if that object is a + symlink (as described above) and that that object is accessible by the + caller (e.g. with `fstat()`), then continue to the next path component (e.g. + with `openat()`). This ensures that the checked filesystem object is the one + used to obtain the next object in the traversal, which is intuitively + appealing. However, while this approach works for host local filesystems, it + requires features that are not widely supported by other filesystems. + +- Path-walk: For each path component, use a path-based operation to determine + if the filesystem object currently referred to by that path component is a + symlink / is accessible. This is highly portable, but suffers from quadratic + behavior (at the level of the underlying filesystem implementation, the + first path component will be traversed a number of times equal to the number + of path components in the path). + +The implementation should support either option by delegating path traversal to +filesystem implementations within the server (like VFS and the remote filesystem +protocol itself), as inode-walking is still safe, efficient, amenable to FD +caching, and implementable on non-shared host local filesystems (a sufficiently +common case as to be worth considering in the design). + +Both approaches are susceptible to race conditions that may permit sandboxed +filesystem escapes: + +- Under inode-walk, a malicious application may cause a directory to be moved + (with `rename`) during path traversal, such that the filesystem + implementation incorrectly determines whether subsequent inodes are located + in paths that should be visible to sandboxed applications. + +- Under path-walk, a malicious application may cause a non-symlink file to be + replaced with a symlink during path traversal, such that following path + operations will incorrectly follow the symlink. + +Both race conditions can, to some extent, be mitigated in filesystem server +implementations by synchronizing path traversal with the hazardous operations in +question. However, shared filesystems are frequently used to share data between +sandboxed and unsandboxed applications in a controlled way, and in some cases a +malicious sandboxed application may be able to take advantage of a hazardous +filesystem operation performed by an unsandboxed application. In some cases, +filesystem features may be available to ensure safety even in such cases (e.g. +[the new openat2() syscall](https://man7.org/linux/man-pages/man2/openat2.2.html)), +but it is not clear how to solve this problem in general. (Note that this issue +is not specific to our design; rather, it is a fundamental limitation of +filesystem sandboxing.) + +### Filesystem Multiplexing + +A given sentry may need to access multiple distinct remote filesystems (e.g. +different volumes for a given container). In many cases, there is no advantage +to serving these filesystems from distinct filesystem servers, or accessing them +through distinct connections (factors such as maximum RPC concurrency should be +based on available host resources). Therefore, the protocol should support +multiplexing of distinct filesystem trees within a single session. 9P supports +this by allowing multiple calls to the `attach` RPC to produce fids representing +distinct filesystem trees, but this is somewhat clunky; we propose a much +simpler mechanism wherein each message that conveys a path also conveys a +numeric filesystem ID that identifies a filesystem tree. + +## Alternatives Considered + +### Additional Extensions to 9P + +There are at least three conceptual aspects to 9P: + +- Wire format: messages with a 4-byte little-endian size prefix, strings with + a 2-byte little-endian size prefix, etc. Whether the wire format is worth + retaining is unclear; in particular, it's unclear that the 9P wire format + has a significant advantage over protobufs, which are substantially easier + to extend. Note that the official Go protobuf implementation is widely known + to suffer from a significant number of performance deficiencies, so if we + choose to switch to protobuf, we may need to use an alternative toolchain + such as `gogo/protobuf` (which is also widely used in the Go ecosystem, e.g. + by Kubernetes). + +- Filesystem model: fids, qids, etc. Discarding this is one of the motivations + for this proposal. + +- RPCs: Twalk, Tlopen, etc. In addition to previously-described + inefficiencies, most of these are dependent on the filesystem model and + therefore must be discarded. + +### FUSE + +The FUSE (Filesystem in Userspace) protocol is frequently used to provide +arbitrary userspace filesystem implementations to a host Linux kernel. +Unfortunately, FUSE is also inode-based, and therefore doesn't address any of +the problems we have with 9P. + +### virtio-fs + +virtio-fs is an ongoing project aimed at improving Linux VM filesystem +performance when accessing Linux host filesystems (vs. virtio-9p). In brief, it +is based on: + +- Using a FUSE client in the guest that communicates over virtio with a FUSE + server in the host. + +- Using DAX to map the host page cache into the guest. + +- Using a file metadata table in shared memory to avoid VM exits for metadata + updates. + +None of these improvements seem applicable to gVisor: + +- As explained above, FUSE is still inode-based, so it is still susceptible to + most of the problems we have with 9P. + +- Our use of host file descriptors already allows us to leverage the host page + cache for file contents. + +- Our need for shared filesystem coherence is usually based on a user + requirement that an out-of-sandbox filesystem mutation is guaranteed to be + visible by all subsequent observations from within the sandbox, or vice + versa; it's not clear that this can be guaranteed without a synchronous + signaling mechanism like an RPC. diff --git a/tools/go_marshal/marshal/BUILD b/pkg/marshal/BUILD index 4aec98218..4aec98218 100644 --- a/tools/go_marshal/marshal/BUILD +++ b/pkg/marshal/BUILD diff --git a/tools/go_marshal/marshal/marshal.go b/pkg/marshal/marshal.go index 85b196f08..d8cb44b40 100644 --- a/tools/go_marshal/marshal/marshal.go +++ b/pkg/marshal/marshal.go @@ -26,9 +26,10 @@ import ( "gvisor.dev/gvisor/pkg/usermem" ) -// Task provides a subset of kernel.Task, used in marshalling. We don't import -// the kernel package directly to avoid circular dependency. -type Task interface { +// CopyContext defines the memory operations required to marshal to and from +// user memory. Typically, kernel.Task is used to provide implementations for +// these operations. +type CopyContext interface { // CopyScratchBuffer provides a task goroutine-local scratch buffer. See // kernel.CopyScratchBuffer. CopyScratchBuffer(size int) []byte @@ -107,7 +108,7 @@ type Marshallable interface { // If the copy-in from the task memory is only partially successful, CopyIn // should still attempt to deserialize as much data as possible. See comment // for UnmarshalBytes. - CopyIn(task Task, addr usermem.Addr) (int, error) + CopyIn(cc CopyContext, addr usermem.Addr) (int, error) // CopyOut serializes a Marshallable type to a task's memory. This may only // be called from a task goroutine. This is more efficient than calling @@ -118,7 +119,7 @@ type Marshallable interface { // The copy-out to the task memory may be partially successful, in which // case CopyOut returns how much data was serialized. See comment for // MarshalBytes for implications. - CopyOut(task Task, addr usermem.Addr) (int, error) + CopyOut(cc CopyContext, addr usermem.Addr) (int, error) // CopyOutN is like CopyOut, but explicitly requests a partial // copy-out. Note that this may yield unexpected results for non-packed @@ -126,7 +127,7 @@ type Marshallable interface { // comment on MarshalBytes. // // The limit must be less than or equal to SizeBytes(). - CopyOutN(task Task, addr usermem.Addr, limit int) (int, error) + CopyOutN(cc CopyContext, addr usermem.Addr, limit int) (int, error) } // go-marshal generates additional functions for a type based on additional @@ -156,10 +157,10 @@ type Marshallable interface { // func UnmarshalUnsafeFooSlice(dst []Foo, src []byte) (int, error) { ... } // // // CopyFooSliceIn copies in a slice of Foo objects from the task's memory. -// func CopyFooSliceIn(task marshal.Task, addr usermem.Addr, dst []Foo) (int, error) { ... } +// func CopyFooSliceIn(cc marshal.CopyContext, addr usermem.Addr, dst []Foo) (int, error) { ... } // // // CopyFooSliceIn copies out a slice of Foo objects to the task's memory. -// func CopyFooSliceOut(task marshal.Task, addr usermem.Addr, src []Foo) (int, error) { ... } +// func CopyFooSliceOut(cc marshal.CopyContext, addr usermem.Addr, src []Foo) (int, error) { ... } // // The name of the functions are of the format "Copy%sIn" and "Copy%sOut", where // %s is the first argument to the slice clause. This directive is not supported @@ -174,10 +175,10 @@ type Marshallable interface { // This is only valid on newtypes on primitives, and causes the generated // functions to accept slices of the inner type instead: // -// func CopyInt32SliceIn(task marshal.Task, addr usermem.Addr, dst []int32) (int, error) { ... } +// func CopyInt32SliceIn(cc marshal.CopyContext, addr usermem.Addr, dst []int32) (int, error) { ... } // // Without "inner", they would instead be: // -// func CopyInt32SliceIn(task marshal.Task, addr usermem.Addr, dst []Int32) (int, error) { ... } +// func CopyInt32SliceIn(cc marshal.CopyContext, addr usermem.Addr, dst []Int32) (int, error) { ... } // // This may help avoid a cast depending on how the generated functions are used. diff --git a/tools/go_marshal/marshal/marshal_impl_util.go b/pkg/marshal/marshal_impl_util.go index 89c7d3575..ea75e09f2 100644 --- a/tools/go_marshal/marshal/marshal_impl_util.go +++ b/pkg/marshal/marshal_impl_util.go @@ -44,7 +44,7 @@ func (StubMarshallable) MarshalBytes(dst []byte) { // UnmarshalBytes implements Marshallable.UnmarshalBytes. func (StubMarshallable) UnmarshalBytes(src []byte) { - panic("Please implement your own UnMarshalBytes function") + panic("Please implement your own UnmarshalBytes function") } // Packed implements Marshallable.Packed. @@ -63,16 +63,16 @@ func (StubMarshallable) UnmarshalUnsafe(src []byte) { } // CopyIn implements Marshallable.CopyIn. -func (StubMarshallable) CopyIn(task Task, addr usermem.Addr) (int, error) { +func (StubMarshallable) CopyIn(cc CopyContext, addr usermem.Addr) (int, error) { panic("Please implement your own CopyIn function") } // CopyOut implements Marshallable.CopyOut. -func (StubMarshallable) CopyOut(task Task, addr usermem.Addr) (int, error) { +func (StubMarshallable) CopyOut(cc CopyContext, addr usermem.Addr) (int, error) { panic("Please implement your own CopyOut function") } // CopyOutN implements Marshallable.CopyOutN. -func (StubMarshallable) CopyOutN(task Task, addr usermem.Addr, limit int) (int, error) { +func (StubMarshallable) CopyOutN(cc CopyContext, addr usermem.Addr, limit int) (int, error) { panic("Please implement your own CopyOutN function") } diff --git a/tools/go_marshal/primitive/BUILD b/pkg/marshal/primitive/BUILD index cc08ba63a..06741e6d1 100644 --- a/tools/go_marshal/primitive/BUILD +++ b/pkg/marshal/primitive/BUILD @@ -12,7 +12,7 @@ go_library( "//:sandbox", ], deps = [ + "//pkg/marshal", "//pkg/usermem", - "//tools/go_marshal/marshal", ], ) diff --git a/tools/go_marshal/primitive/primitive.go b/pkg/marshal/primitive/primitive.go index d93edda8b..dfdae5d60 100644 --- a/tools/go_marshal/primitive/primitive.go +++ b/pkg/marshal/primitive/primitive.go @@ -19,8 +19,8 @@ package primitive import ( "io" + "gvisor.dev/gvisor/pkg/marshal" "gvisor.dev/gvisor/pkg/usermem" - "gvisor.dev/gvisor/tools/go_marshal/marshal" ) // Int8 is a marshal.Marshallable implementation for int8. @@ -101,18 +101,18 @@ func (b *ByteSlice) UnmarshalUnsafe(src []byte) { } // CopyIn implements marshal.Marshallable.CopyIn. -func (b *ByteSlice) CopyIn(task marshal.Task, addr usermem.Addr) (int, error) { - return task.CopyInBytes(addr, *b) +func (b *ByteSlice) CopyIn(cc marshal.CopyContext, addr usermem.Addr) (int, error) { + return cc.CopyInBytes(addr, *b) } // CopyOut implements marshal.Marshallable.CopyOut. -func (b *ByteSlice) CopyOut(task marshal.Task, addr usermem.Addr) (int, error) { - return task.CopyOutBytes(addr, *b) +func (b *ByteSlice) CopyOut(cc marshal.CopyContext, addr usermem.Addr) (int, error) { + return cc.CopyOutBytes(addr, *b) } // CopyOutN implements marshal.Marshallable.CopyOutN. -func (b *ByteSlice) CopyOutN(task marshal.Task, addr usermem.Addr, limit int) (int, error) { - return task.CopyOutBytes(addr, (*b)[:limit]) +func (b *ByteSlice) CopyOutN(cc marshal.CopyContext, addr usermem.Addr, limit int) (int, error) { + return cc.CopyOutBytes(addr, (*b)[:limit]) } // WriteTo implements io.WriterTo.WriteTo. @@ -130,9 +130,9 @@ var _ marshal.Marshallable = (*ByteSlice)(nil) // CopyInt16In is a convenient wrapper for copying in an int16 from the task's // memory. -func CopyInt16In(task marshal.Task, addr usermem.Addr, dst *int16) (int, error) { +func CopyInt16In(cc marshal.CopyContext, addr usermem.Addr, dst *int16) (int, error) { var buf Int16 - n, err := buf.CopyIn(task, addr) + n, err := buf.CopyIn(cc, addr) if err != nil { return n, err } @@ -142,16 +142,16 @@ func CopyInt16In(task marshal.Task, addr usermem.Addr, dst *int16) (int, error) // CopyInt16Out is a convenient wrapper for copying out an int16 to the task's // memory. -func CopyInt16Out(task marshal.Task, addr usermem.Addr, src int16) (int, error) { +func CopyInt16Out(cc marshal.CopyContext, addr usermem.Addr, src int16) (int, error) { srcP := Int16(src) - return srcP.CopyOut(task, addr) + return srcP.CopyOut(cc, addr) } // CopyUint16In is a convenient wrapper for copying in a uint16 from the task's // memory. -func CopyUint16In(task marshal.Task, addr usermem.Addr, dst *uint16) (int, error) { +func CopyUint16In(cc marshal.CopyContext, addr usermem.Addr, dst *uint16) (int, error) { var buf Uint16 - n, err := buf.CopyIn(task, addr) + n, err := buf.CopyIn(cc, addr) if err != nil { return n, err } @@ -161,18 +161,18 @@ func CopyUint16In(task marshal.Task, addr usermem.Addr, dst *uint16) (int, error // CopyUint16Out is a convenient wrapper for copying out a uint16 to the task's // memory. -func CopyUint16Out(task marshal.Task, addr usermem.Addr, src uint16) (int, error) { +func CopyUint16Out(cc marshal.CopyContext, addr usermem.Addr, src uint16) (int, error) { srcP := Uint16(src) - return srcP.CopyOut(task, addr) + return srcP.CopyOut(cc, addr) } // 32-bit integers // CopyInt32In is a convenient wrapper for copying in an int32 from the task's // memory. -func CopyInt32In(task marshal.Task, addr usermem.Addr, dst *int32) (int, error) { +func CopyInt32In(cc marshal.CopyContext, addr usermem.Addr, dst *int32) (int, error) { var buf Int32 - n, err := buf.CopyIn(task, addr) + n, err := buf.CopyIn(cc, addr) if err != nil { return n, err } @@ -182,16 +182,16 @@ func CopyInt32In(task marshal.Task, addr usermem.Addr, dst *int32) (int, error) // CopyInt32Out is a convenient wrapper for copying out an int32 to the task's // memory. -func CopyInt32Out(task marshal.Task, addr usermem.Addr, src int32) (int, error) { +func CopyInt32Out(cc marshal.CopyContext, addr usermem.Addr, src int32) (int, error) { srcP := Int32(src) - return srcP.CopyOut(task, addr) + return srcP.CopyOut(cc, addr) } // CopyUint32In is a convenient wrapper for copying in a uint32 from the task's // memory. -func CopyUint32In(task marshal.Task, addr usermem.Addr, dst *uint32) (int, error) { +func CopyUint32In(cc marshal.CopyContext, addr usermem.Addr, dst *uint32) (int, error) { var buf Uint32 - n, err := buf.CopyIn(task, addr) + n, err := buf.CopyIn(cc, addr) if err != nil { return n, err } @@ -201,18 +201,18 @@ func CopyUint32In(task marshal.Task, addr usermem.Addr, dst *uint32) (int, error // CopyUint32Out is a convenient wrapper for copying out a uint32 to the task's // memory. -func CopyUint32Out(task marshal.Task, addr usermem.Addr, src uint32) (int, error) { +func CopyUint32Out(cc marshal.CopyContext, addr usermem.Addr, src uint32) (int, error) { srcP := Uint32(src) - return srcP.CopyOut(task, addr) + return srcP.CopyOut(cc, addr) } // 64-bit integers // CopyInt64In is a convenient wrapper for copying in an int64 from the task's // memory. -func CopyInt64In(task marshal.Task, addr usermem.Addr, dst *int64) (int, error) { +func CopyInt64In(cc marshal.CopyContext, addr usermem.Addr, dst *int64) (int, error) { var buf Int64 - n, err := buf.CopyIn(task, addr) + n, err := buf.CopyIn(cc, addr) if err != nil { return n, err } @@ -222,16 +222,16 @@ func CopyInt64In(task marshal.Task, addr usermem.Addr, dst *int64) (int, error) // CopyInt64Out is a convenient wrapper for copying out an int64 to the task's // memory. -func CopyInt64Out(task marshal.Task, addr usermem.Addr, src int64) (int, error) { +func CopyInt64Out(cc marshal.CopyContext, addr usermem.Addr, src int64) (int, error) { srcP := Int64(src) - return srcP.CopyOut(task, addr) + return srcP.CopyOut(cc, addr) } // CopyUint64In is a convenient wrapper for copying in a uint64 from the task's // memory. -func CopyUint64In(task marshal.Task, addr usermem.Addr, dst *uint64) (int, error) { +func CopyUint64In(cc marshal.CopyContext, addr usermem.Addr, dst *uint64) (int, error) { var buf Uint64 - n, err := buf.CopyIn(task, addr) + n, err := buf.CopyIn(cc, addr) if err != nil { return n, err } @@ -241,7 +241,7 @@ func CopyUint64In(task marshal.Task, addr usermem.Addr, dst *uint64) (int, error // CopyUint64Out is a convenient wrapper for copying out a uint64 to the task's // memory. -func CopyUint64Out(task marshal.Task, addr usermem.Addr, src uint64) (int, error) { +func CopyUint64Out(cc marshal.CopyContext, addr usermem.Addr, src uint64) (int, error) { srcP := Uint64(src) - return srcP.CopyOut(task, addr) + return srcP.CopyOut(cc, addr) } diff --git a/pkg/merkletree/merkletree.go b/pkg/merkletree/merkletree.go index 955c9c473..4b4f9bd52 100644 --- a/pkg/merkletree/merkletree.go +++ b/pkg/merkletree/merkletree.go @@ -29,6 +29,12 @@ const ( sha256DigestSize = 32 ) +// DigestSize returns the size (in bytes) of a digest. +// TODO(b/156980949): Allow config other hash methods (SHA384/SHA512). +func DigestSize() int { + return sha256DigestSize +} + // Layout defines the scale of a Merkle tree. type Layout struct { // blockSize is the size of a data block to be hashed. @@ -45,12 +51,25 @@ type Layout struct { // InitLayout initializes and returns a new Layout object describing the structure // of a tree. dataSize specifies the size of input data in bytes. -func InitLayout(dataSize int64) Layout { +func InitLayout(dataSize int64, dataAndTreeInSameFile bool) Layout { layout := Layout{ blockSize: usermem.PageSize, // TODO(b/156980949): Allow config other hash methods (SHA384/SHA512). digestSize: sha256DigestSize, } + + // treeStart is the offset (in bytes) of the first level of the tree in + // the file. If data and tree are in different files, treeStart should + // be zero. If data is in the same file as the tree, treeStart points + // to the block after the last data block (which may be zero-padded). + var treeStart int64 + if dataAndTreeInSameFile { + treeStart = dataSize + if dataSize%layout.blockSize != 0 { + treeStart += layout.blockSize - dataSize%layout.blockSize + } + } + numBlocks := (dataSize + layout.blockSize - 1) / layout.blockSize level := 0 offset := int64(0) @@ -60,14 +79,15 @@ func InitLayout(dataSize int64) Layout { // contain the hashes of the data blocks, while level numLevels - 1 is // the root. for numBlocks > 1 { - layout.levelOffset = append(layout.levelOffset, offset*layout.blockSize) + layout.levelOffset = append(layout.levelOffset, treeStart+offset*layout.blockSize) // Round numBlocks up to fill up a block. numBlocks += (layout.hashesPerBlock() - numBlocks%layout.hashesPerBlock()) % layout.hashesPerBlock() offset += numBlocks / layout.hashesPerBlock() numBlocks = numBlocks / layout.hashesPerBlock() level++ } - layout.levelOffset = append(layout.levelOffset, offset*layout.blockSize) + layout.levelOffset = append(layout.levelOffset, treeStart+offset*layout.blockSize) + return layout } @@ -107,11 +127,44 @@ func (layout Layout) blockOffset(level int, index int64) int64 { // written to treeWriter. The treeReader should be able to read the tree after // it has been written. That is, treeWriter and treeReader should point to the // same underlying data but have separate cursors. -func Generate(data io.Reader, dataSize int64, treeReader io.Reader, treeWriter io.Writer) ([]byte, error) { - layout := InitLayout(dataSize) +// Generate will modify the cursor for data, but always restores it to its +// original position upon exit. The cursor for tree is modified and not +// restored. +func Generate(data io.ReadSeeker, dataSize int64, treeReader io.ReadSeeker, treeWriter io.WriteSeeker, dataAndTreeInSameFile bool) ([]byte, error) { + layout := InitLayout(dataSize, dataAndTreeInSameFile) numBlocks := (dataSize + layout.blockSize - 1) / layout.blockSize + // If the data is in the same file as the tree, zero pad the last data + // block. + bytesInLastBlock := dataSize % layout.blockSize + if dataAndTreeInSameFile && bytesInLastBlock != 0 { + zeroBuf := make([]byte, layout.blockSize-bytesInLastBlock) + if _, err := treeWriter.Seek(0, io.SeekEnd); err != nil && err != io.EOF { + return nil, err + } + if _, err := treeWriter.Write(zeroBuf); err != nil { + return nil, err + } + } + + // Store the current offset, so we can set it back once verification + // finishes. + origOffset, err := data.Seek(0, io.SeekCurrent) + if err != nil { + return nil, err + } + defer data.Seek(origOffset, io.SeekStart) + + // Read from the beginning of both data and treeReader. + if _, err := data.Seek(0, io.SeekStart); err != nil && err != io.EOF { + return nil, err + } + + if _, err := treeReader.Seek(0, io.SeekStart); err != nil && err != io.EOF { + return nil, err + } + var root []byte for level := 0; level < layout.numLevels(); level++ { for i := int64(0); i < numBlocks; i++ { @@ -172,11 +225,11 @@ func Generate(data io.Reader, dataSize int64, treeReader io.Reader, treeWriter i // Verify will modify the cursor for data, but always restores it to its // original position upon exit. The cursor for tree is modified and not // restored. -func Verify(w io.Writer, data, tree io.ReadSeeker, dataSize int64, readOffset int64, readSize int64, expectedRoot []byte) error { +func Verify(w io.Writer, data, tree io.ReadSeeker, dataSize int64, readOffset int64, readSize int64, expectedRoot []byte, dataAndTreeInSameFile bool) (int64, error) { if readSize <= 0 { - return fmt.Errorf("Unexpected read size: %d", readSize) + return 0, fmt.Errorf("Unexpected read size: %d", readSize) } - layout := InitLayout(int64(dataSize)) + layout := InitLayout(int64(dataSize), dataAndTreeInSameFile) // Calculate the index of blocks that includes the target range in input // data. @@ -187,29 +240,30 @@ func Verify(w io.Writer, data, tree io.ReadSeeker, dataSize int64, readOffset in // finishes. origOffset, err := data.Seek(0, io.SeekCurrent) if err != nil { - return fmt.Errorf("Find current data offset failed: %v", err) + return 0, fmt.Errorf("Find current data offset failed: %v", err) } defer data.Seek(origOffset, io.SeekStart) // Move to the first block that contains target data. if _, err := data.Seek(firstDataBlock*layout.blockSize, io.SeekStart); err != nil { - return fmt.Errorf("Seek to datablock start failed: %v", err) + return 0, fmt.Errorf("Seek to datablock start failed: %v", err) } buf := make([]byte, layout.blockSize) var readErr error - bytesRead := 0 + total := int64(0) for i := firstDataBlock; i <= lastDataBlock; i++ { // Read a block that includes all or part of target range in // input data. - bytesRead, readErr = data.Read(buf) + bytesRead, err := data.Read(buf) + readErr = err // If at the end of input data and all previous blocks are // verified, return the verified input data and EOF. if readErr == io.EOF && bytesRead == 0 { break } if readErr != nil && readErr != io.EOF { - return fmt.Errorf("Read from data failed: %v", err) + return 0, fmt.Errorf("Read from data failed: %v", err) } // If this is the end of file, zero the remaining bytes in buf, // otherwise they are still from the previous block. @@ -221,7 +275,7 @@ func Verify(w io.Writer, data, tree io.ReadSeeker, dataSize int64, readOffset in } } if err := verifyBlock(tree, layout, buf, i, expectedRoot); err != nil { - return err + return 0, err } // startOff is the beginning of the read range within the // current data block. Note that for all blocks other than the @@ -245,10 +299,14 @@ func Verify(w io.Writer, data, tree io.ReadSeeker, dataSize int64, readOffset in if endOff > int64(bytesRead) { endOff = int64(bytesRead) } - w.Write(buf[startOff:endOff]) + n, err := w.Write(buf[startOff:endOff]) + if err != nil { + return total, err + } + total += int64(n) } - return readErr + return total, readErr } // verifyBlock verifies a block against tree. index is the number of block in diff --git a/pkg/merkletree/merkletree_test.go b/pkg/merkletree/merkletree_test.go index 911f61df9..daaca759a 100644 --- a/pkg/merkletree/merkletree_test.go +++ b/pkg/merkletree/merkletree_test.go @@ -27,80 +27,58 @@ import ( func TestLayout(t *testing.T) { testCases := []struct { - dataSize int64 - expectedLevelOffset []int64 + dataSize int64 + dataAndTreeInSameFile bool + expectedLevelOffset []int64 }{ { - dataSize: 100, - expectedLevelOffset: []int64{0}, + dataSize: 100, + dataAndTreeInSameFile: false, + expectedLevelOffset: []int64{0}, }, { - dataSize: 1000000, - expectedLevelOffset: []int64{0, 2 * usermem.PageSize, 3 * usermem.PageSize}, + dataSize: 100, + dataAndTreeInSameFile: true, + expectedLevelOffset: []int64{usermem.PageSize}, }, { - dataSize: 4096 * int64(usermem.PageSize), - expectedLevelOffset: []int64{0, 32 * usermem.PageSize, 33 * usermem.PageSize}, + dataSize: 1000000, + dataAndTreeInSameFile: false, + expectedLevelOffset: []int64{0, 2 * usermem.PageSize, 3 * usermem.PageSize}, }, - } - - for _, tc := range testCases { - t.Run(fmt.Sprintf("%d", tc.dataSize), func(t *testing.T) { - p := InitLayout(tc.dataSize) - if p.blockSize != int64(usermem.PageSize) { - t.Errorf("got blockSize %d, want %d", p.blockSize, usermem.PageSize) - } - if p.digestSize != sha256DigestSize { - t.Errorf("got digestSize %d, want %d", p.digestSize, sha256DigestSize) - } - if p.numLevels() != len(tc.expectedLevelOffset) { - t.Errorf("got levels %d, want %d", p.numLevels(), len(tc.expectedLevelOffset)) - } - for i := 0; i < p.numLevels() && i < len(tc.expectedLevelOffset); i++ { - if p.levelOffset[i] != tc.expectedLevelOffset[i] { - t.Errorf("got levelStart[%d] %d, want %d", i, p.levelOffset[i], tc.expectedLevelOffset[i]) - } - } - }) - } -} - -func TestGenerate(t *testing.T) { - // The input data has size dataSize. It starts with the data in startWith, - // and all other bytes are zeroes. - testCases := []struct { - data []byte - expectedRoot []byte - }{ { - data: bytes.Repeat([]byte{0}, usermem.PageSize), - expectedRoot: []byte{173, 127, 172, 178, 88, 111, 198, 233, 102, 192, 4, 215, 209, 209, 107, 2, 79, 88, 5, 255, 124, 180, 124, 122, 133, 218, 189, 139, 72, 137, 44, 167}, - }, - { - data: bytes.Repeat([]byte{0}, 128*usermem.PageSize+1), - expectedRoot: []byte{62, 93, 40, 92, 161, 241, 30, 223, 202, 99, 39, 2, 132, 113, 240, 139, 117, 99, 79, 243, 54, 18, 100, 184, 141, 121, 238, 46, 149, 202, 203, 132}, + dataSize: 1000000, + dataAndTreeInSameFile: true, + expectedLevelOffset: []int64{245 * usermem.PageSize, 247 * usermem.PageSize, 248 * usermem.PageSize}, }, { - data: []byte{'a'}, - expectedRoot: []byte{52, 75, 204, 142, 172, 129, 37, 14, 145, 137, 103, 203, 11, 162, 209, 205, 30, 169, 213, 72, 20, 28, 243, 24, 242, 2, 92, 43, 169, 59, 110, 210}, + dataSize: 4096 * int64(usermem.PageSize), + dataAndTreeInSameFile: false, + expectedLevelOffset: []int64{0, 32 * usermem.PageSize, 33 * usermem.PageSize}, }, { - data: bytes.Repeat([]byte{'a'}, usermem.PageSize), - expectedRoot: []byte{201, 62, 238, 45, 13, 176, 47, 16, 172, 199, 70, 13, 149, 118, 225, 34, 220, 248, 205, 83, 196, 191, 141, 252, 174, 27, 62, 116, 235, 207, 255, 90}, + dataSize: 4096 * int64(usermem.PageSize), + dataAndTreeInSameFile: true, + expectedLevelOffset: []int64{4096 * usermem.PageSize, 4128 * usermem.PageSize, 4129 * usermem.PageSize}, }, } for _, tc := range testCases { - t.Run(fmt.Sprintf("%d:%v", len(tc.data), tc.data[0]), func(t *testing.T) { - var tree bytes.Buffer - - root, err := Generate(bytes.NewBuffer(tc.data), int64(len(tc.data)), &tree, &tree) - if err != nil { - t.Fatalf("Generate failed: %v", err) + t.Run(fmt.Sprintf("%d", tc.dataSize), func(t *testing.T) { + l := InitLayout(tc.dataSize, tc.dataAndTreeInSameFile) + if l.blockSize != int64(usermem.PageSize) { + t.Errorf("Got blockSize %d, want %d", l.blockSize, usermem.PageSize) } - - if !bytes.Equal(root, tc.expectedRoot) { - t.Errorf("Unexpected root") + if l.digestSize != sha256DigestSize { + t.Errorf("Got digestSize %d, want %d", l.digestSize, sha256DigestSize) + } + if l.numLevels() != len(tc.expectedLevelOffset) { + t.Errorf("Got levels %d, want %d", l.numLevels(), len(tc.expectedLevelOffset)) + } + for i := 0; i < l.numLevels() && i < len(tc.expectedLevelOffset); i++ { + if l.levelOffset[i] != tc.expectedLevelOffset[i] { + t.Errorf("Got levelStart[%d] %d, want %d", i, l.levelOffset[i], tc.expectedLevelOffset[i]) + } } }) } @@ -151,6 +129,57 @@ func (brw *bytesReadWriter) Seek(offset int64, whence int) (int64, error) { return off, nil } +func TestGenerate(t *testing.T) { + // The input data has size dataSize. It starts with the data in startWith, + // and all other bytes are zeroes. + testCases := []struct { + data []byte + expectedRoot []byte + }{ + { + data: bytes.Repeat([]byte{0}, usermem.PageSize), + expectedRoot: []byte{173, 127, 172, 178, 88, 111, 198, 233, 102, 192, 4, 215, 209, 209, 107, 2, 79, 88, 5, 255, 124, 180, 124, 122, 133, 218, 189, 139, 72, 137, 44, 167}, + }, + { + data: bytes.Repeat([]byte{0}, 128*usermem.PageSize+1), + expectedRoot: []byte{62, 93, 40, 92, 161, 241, 30, 223, 202, 99, 39, 2, 132, 113, 240, 139, 117, 99, 79, 243, 54, 18, 100, 184, 141, 121, 238, 46, 149, 202, 203, 132}, + }, + { + data: []byte{'a'}, + expectedRoot: []byte{52, 75, 204, 142, 172, 129, 37, 14, 145, 137, 103, 203, 11, 162, 209, 205, 30, 169, 213, 72, 20, 28, 243, 24, 242, 2, 92, 43, 169, 59, 110, 210}, + }, + { + data: bytes.Repeat([]byte{'a'}, usermem.PageSize), + expectedRoot: []byte{201, 62, 238, 45, 13, 176, 47, 16, 172, 199, 70, 13, 149, 118, 225, 34, 220, 248, 205, 83, 196, 191, 141, 252, 174, 27, 62, 116, 235, 207, 255, 90}, + }, + } + + for _, tc := range testCases { + t.Run(fmt.Sprintf("%d:%v", len(tc.data), tc.data[0]), func(t *testing.T) { + for _, dataAndTreeInSameFile := range []bool{false, true} { + var tree bytesReadWriter + var root []byte + var err error + if dataAndTreeInSameFile { + tree.Write(tc.data) + root, err = Generate(&tree, int64(len(tc.data)), &tree, &tree, dataAndTreeInSameFile) + } else { + root, err = Generate(&bytesReadWriter{ + bytes: tc.data, + }, int64(len(tc.data)), &tree, &tree, dataAndTreeInSameFile) + } + if err != nil { + t.Fatalf("Got err: %v, want nil", err) + } + + if !bytes.Equal(root, tc.expectedRoot) { + t.Errorf("Got root: %v, want %v", root, tc.expectedRoot) + } + } + }) + } +} + func TestVerify(t *testing.T) { // The input data has size dataSize. The portion to be verified ranges from // verifyStart with verifySize. A bit is flipped in outOfRangeByteIndex to @@ -284,26 +313,44 @@ func TestVerify(t *testing.T) { data := make([]byte, tc.dataSize) // Generate random bytes in data. rand.Read(data) - var tree bytesReadWriter - - root, err := Generate(bytes.NewBuffer(data), int64(tc.dataSize), &tree, &tree) - if err != nil { - t.Fatalf("Generate failed: %v", err) - } - // Flip a bit in data and checks Verify results. - var buf bytes.Buffer - data[tc.modifyByte] ^= 1 - if tc.shouldSucceed { - if err := Verify(&buf, bytes.NewReader(data), &tree, tc.dataSize, tc.verifyStart, tc.verifySize, root); err != nil && err != io.EOF { - t.Errorf("Verification failed when expected to succeed: %v", err) + for _, dataAndTreeInSameFile := range []bool{false, true} { + var tree bytesReadWriter + var root []byte + var err error + if dataAndTreeInSameFile { + tree.Write(data) + root, err = Generate(&tree, int64(len(data)), &tree, &tree, dataAndTreeInSameFile) + } else { + root, err = Generate(&bytesReadWriter{ + bytes: data, + }, int64(tc.dataSize), &tree, &tree, false /* dataAndTreeInSameFile */) } - if int64(buf.Len()) != tc.verifySize || !bytes.Equal(data[tc.verifyStart:tc.verifyStart+tc.verifySize], buf.Bytes()) { - t.Errorf("Incorrect output from Verify") + if err != nil { + t.Fatalf("Generate failed: %v", err) } - } else { - if err := Verify(&buf, bytes.NewReader(data), &tree, tc.dataSize, tc.verifyStart, tc.verifySize, root); err == nil { - t.Errorf("Verification succeeded when expected to fail") + + // Flip a bit in data and checks Verify results. + var buf bytes.Buffer + data[tc.modifyByte] ^= 1 + if tc.shouldSucceed { + n, err := Verify(&buf, bytes.NewReader(data), &tree, tc.dataSize, tc.verifyStart, tc.verifySize, root, dataAndTreeInSameFile) + if err != nil && err != io.EOF { + t.Errorf("Verification failed when expected to succeed: %v", err) + } + if n != tc.verifySize { + t.Errorf("Got Verify output size %d, want %d", n, tc.verifySize) + } + if int64(buf.Len()) != tc.verifySize { + t.Errorf("Got Verify output buf size %d, want %d,", buf.Len(), tc.verifySize) + } + if !bytes.Equal(data[tc.verifyStart:tc.verifyStart+tc.verifySize], buf.Bytes()) { + t.Errorf("Incorrect output buf from Verify") + } + } else { + if _, err := Verify(&buf, bytes.NewReader(data), &tree, tc.dataSize, tc.verifyStart, tc.verifySize, root, dataAndTreeInSameFile); err == nil { + t.Errorf("Verification succeeded when expected to fail") + } } } }) @@ -318,36 +365,54 @@ func TestVerifyRandom(t *testing.T) { data := make([]byte, dataSize) // Generate random bytes in data. rand.Read(data) - var tree bytesReadWriter - root, err := Generate(bytes.NewBuffer(data), int64(dataSize), &tree, &tree) - if err != nil { - t.Fatalf("Generate failed: %v", err) - } + for _, dataAndTreeInSameFile := range []bool{false, true} { + var tree bytesReadWriter + var root []byte + var err error + if dataAndTreeInSameFile { + tree.Write(data) + root, err = Generate(&tree, int64(len(data)), &tree, &tree, dataAndTreeInSameFile) + } else { + root, err = Generate(&bytesReadWriter{ + bytes: data, + }, int64(dataSize), &tree, &tree, dataAndTreeInSameFile) + } + if err != nil { + t.Fatalf("Generate failed: %v", err) + } - // Pick a random portion of data. - start := rand.Int63n(dataSize - 1) - size := rand.Int63n(dataSize) + 1 + // Pick a random portion of data. + start := rand.Int63n(dataSize - 1) + size := rand.Int63n(dataSize) + 1 - var buf bytes.Buffer - // Checks that the random portion of data from the original data is - // verified successfully. - if err := Verify(&buf, bytes.NewReader(data), &tree, dataSize, start, size, root); err != nil && err != io.EOF { - t.Errorf("Verification failed for correct data: %v", err) - } - if size > dataSize-start { - size = dataSize - start - } - if int64(buf.Len()) != size || !bytes.Equal(data[start:start+size], buf.Bytes()) { - t.Errorf("Incorrect output from Verify") - } + var buf bytes.Buffer + // Checks that the random portion of data from the original data is + // verified successfully. + n, err := Verify(&buf, bytes.NewReader(data), &tree, dataSize, start, size, root, dataAndTreeInSameFile) + if err != nil && err != io.EOF { + t.Errorf("Verification failed for correct data: %v", err) + } + if size > dataSize-start { + size = dataSize - start + } + if n != size { + t.Errorf("Got Verify output size %d, want %d", n, size) + } + if int64(buf.Len()) != size { + t.Errorf("Got Verify output buf size %d, want %d", buf.Len(), size) + } + if !bytes.Equal(data[start:start+size], buf.Bytes()) { + t.Errorf("Incorrect output buf from Verify") + } - buf.Reset() - // Flip a random bit in randPortion, and check that verification fails. - randBytePos := rand.Int63n(size) - data[start+randBytePos] ^= 1 + buf.Reset() + // Flip a random bit in randPortion, and check that verification fails. + randBytePos := rand.Int63n(size) + data[start+randBytePos] ^= 1 - if err := Verify(&buf, bytes.NewReader(data), &tree, dataSize, start, size, root); err == nil { - t.Errorf("Verification succeeded for modified data") + if _, err := Verify(&buf, bytes.NewReader(data), &tree, dataSize, start, size, root, dataAndTreeInSameFile); err == nil { + t.Errorf("Verification succeeded for modified data") + } } } diff --git a/pkg/metric/metric.go b/pkg/metric/metric.go index 64aa365ce..d012c5734 100644 --- a/pkg/metric/metric.go +++ b/pkg/metric/metric.go @@ -106,8 +106,8 @@ type customUint64Metric struct { // after Initialized. // // Preconditions: -// * name must be globally unique. -// * Initialize/Disable have not been called. +// * name must be globally unique. +// * Initialize/Disable have not been called. func RegisterCustomUint64Metric(name string, cumulative, sync bool, units pb.MetricMetadata_Units, description string, value func() uint64) error { if initialized { return ErrInitializationDone @@ -221,7 +221,7 @@ var ( // EmitMetricUpdate is thread-safe. // // Preconditions: -// * Initialize has been called. +// * Initialize has been called. func EmitMetricUpdate() { emitMu.Lock() defer emitMu.Unlock() diff --git a/pkg/p9/client_file.go b/pkg/p9/client_file.go index 2ee07b664..28fe081d6 100644 --- a/pkg/p9/client_file.go +++ b/pkg/p9/client_file.go @@ -54,6 +54,8 @@ func (c *Client) newFile(fid FID) *clientFile { // // This proxies all of the interfaces found in file.go. type clientFile struct { + DisallowServerCalls + // client is the originating client. client *Client @@ -283,6 +285,39 @@ func (c *clientFile) Close() error { return nil } +// SetAttrClose implements File.SetAttrClose. +func (c *clientFile) SetAttrClose(valid SetAttrMask, attr SetAttr) error { + if !versionSupportsTsetattrclunk(c.client.version) { + setAttrErr := c.SetAttr(valid, attr) + + // Try to close file even in case of failure above. Since the state of the + // file is unknown to the caller, it will not attempt to close the file + // again. + if err := c.Close(); err != nil { + return err + } + + return setAttrErr + } + + // Avoid double close. + if !atomic.CompareAndSwapUint32(&c.closed, 0, 1) { + return syscall.EBADF + } + + // Send the message. + if err := c.client.sendRecv(&Tsetattrclunk{FID: c.fid, Valid: valid, SetAttr: attr}, &Rsetattrclunk{}); err != nil { + // If an error occurred, we toss away the FID. This isn't ideal, + // but I'm not sure what else makes sense in this context. + log.Warningf("Tsetattrclunk failed, losing FID %v: %v", c.fid, err) + return err + } + + // Return the FID to the pool. + c.client.fidPool.Put(uint64(c.fid)) + return nil +} + // Open implements File.Open. func (c *clientFile) Open(flags OpenFlags) (*fd.FD, QID, uint32, error) { if atomic.LoadUint32(&c.closed) != 0 { @@ -681,6 +716,3 @@ func (c *clientFile) Flush() error { return c.client.sendRecv(&Tflushf{FID: c.fid}, &Rflushf{}) } - -// Renamed implements File.Renamed. -func (c *clientFile) Renamed(newDir File, newName string) {} diff --git a/pkg/p9/file.go b/pkg/p9/file.go index cab35896f..c2e3a3f98 100644 --- a/pkg/p9/file.go +++ b/pkg/p9/file.go @@ -135,6 +135,14 @@ type File interface { // On the server, Close has no concurrency guarantee. Close() error + // SetAttrClose is the equivalent of calling SetAttr() followed by Close(). + // This can be used to set file times before closing the file in a single + // operation. + // + // On the server, SetAttr has a write concurrency guarantee. + // On the server, Close has no concurrency guarantee. + SetAttrClose(valid SetAttrMask, attr SetAttr) error + // Open must be called prior to using Read, Write or Readdir. Once Open // is called, some operations, such as Walk, will no longer work. // @@ -286,3 +294,19 @@ type DefaultWalkGetAttr struct{} func (DefaultWalkGetAttr) WalkGetAttr([]string) ([]QID, File, AttrMask, Attr, error) { return nil, nil, AttrMask{}, Attr{}, syscall.ENOSYS } + +// DisallowClientCalls panics if a client-only function is called. +type DisallowClientCalls struct{} + +// SetAttrClose implements File.SetAttrClose. +func (DisallowClientCalls) SetAttrClose(SetAttrMask, SetAttr) error { + panic("SetAttrClose should not be called on the server") +} + +// DisallowServerCalls panics if a server-only function is called. +type DisallowServerCalls struct{} + +// Renamed implements File.Renamed. +func (*clientFile) Renamed(File, string) { + panic("Renamed should not be called on the client") +} diff --git a/pkg/p9/handlers.go b/pkg/p9/handlers.go index 1db5797dd..abd237f46 100644 --- a/pkg/p9/handlers.go +++ b/pkg/p9/handlers.go @@ -123,6 +123,37 @@ func (t *Tclunk) handle(cs *connState) message { return &Rclunk{} } +func (t *Tsetattrclunk) handle(cs *connState) message { + ref, ok := cs.LookupFID(t.FID) + if !ok { + return newErr(syscall.EBADF) + } + defer ref.DecRef() + + setAttrErr := ref.safelyWrite(func() error { + // We don't allow setattr on files that have been deleted. + // This might be technically incorrect, as it's possible that + // there were multiple links and you can still change the + // corresponding inode information. + if ref.isDeleted() { + return syscall.EINVAL + } + + // Set the attributes. + return ref.file.SetAttr(t.Valid, t.SetAttr) + }) + + // Try to delete FID even in case of failure above. Since the state of the + // file is unknown to the caller, it will not attempt to close the file again. + if !cs.DeleteFID(t.FID) { + return newErr(syscall.EBADF) + } + if setAttrErr != nil { + return newErr(setAttrErr) + } + return &Rsetattrclunk{} +} + // handle implements handler.handle. func (t *Tremove) handle(cs *connState) message { ref, ok := cs.LookupFID(t.FID) diff --git a/pkg/p9/messages.go b/pkg/p9/messages.go index 2cb59f934..cf13cbb69 100644 --- a/pkg/p9/messages.go +++ b/pkg/p9/messages.go @@ -317,6 +317,64 @@ func (r *Rclunk) String() string { return "Rclunk{}" } +// Tsetattrclunk is a setattr+close request. +type Tsetattrclunk struct { + // FID is the FID to change. + FID FID + + // Valid is the set of bits which will be used. + Valid SetAttrMask + + // SetAttr is the set request. + SetAttr SetAttr +} + +// decode implements encoder.decode. +func (t *Tsetattrclunk) decode(b *buffer) { + t.FID = b.ReadFID() + t.Valid.decode(b) + t.SetAttr.decode(b) +} + +// encode implements encoder.encode. +func (t *Tsetattrclunk) encode(b *buffer) { + b.WriteFID(t.FID) + t.Valid.encode(b) + t.SetAttr.encode(b) +} + +// Type implements message.Type. +func (*Tsetattrclunk) Type() MsgType { + return MsgTsetattrclunk +} + +// String implements fmt.Stringer. +func (t *Tsetattrclunk) String() string { + return fmt.Sprintf("Tsetattrclunk{FID: %d, Valid: %v, SetAttr: %s}", t.FID, t.Valid, t.SetAttr) +} + +// Rsetattrclunk is a setattr+close response. +type Rsetattrclunk struct { +} + +// decode implements encoder.decode. +func (*Rsetattrclunk) decode(*buffer) { +} + +// encode implements encoder.encode. +func (*Rsetattrclunk) encode(*buffer) { +} + +// Type implements message.Type. +func (*Rsetattrclunk) Type() MsgType { + return MsgRsetattrclunk +} + +// String implements fmt.Stringer. +func (r *Rsetattrclunk) String() string { + return "Rsetattrclunk{}" +} + // Tremove is a remove request. // // This will eventually be replaced by Tunlinkat. @@ -2657,6 +2715,8 @@ func init() { msgRegistry.register(MsgRlconnect, func() message { return &Rlconnect{} }) msgRegistry.register(MsgTallocate, func() message { return &Tallocate{} }) msgRegistry.register(MsgRallocate, func() message { return &Rallocate{} }) + msgRegistry.register(MsgTsetattrclunk, func() message { return &Tsetattrclunk{} }) + msgRegistry.register(MsgRsetattrclunk, func() message { return &Rsetattrclunk{} }) msgRegistry.register(MsgTchannel, func() message { return &Tchannel{} }) msgRegistry.register(MsgRchannel, func() message { return &Rchannel{} }) } diff --git a/pkg/p9/messages_test.go b/pkg/p9/messages_test.go index 7facc9f5e..bfeb6c236 100644 --- a/pkg/p9/messages_test.go +++ b/pkg/p9/messages_test.go @@ -376,6 +376,30 @@ func TestEncodeDecode(t *testing.T) { &Rumknod{ Rmknod{QID: QID{Type: 1}}, }, + &Tsetattrclunk{ + FID: 1, + Valid: SetAttrMask{ + Permissions: true, + UID: true, + GID: true, + Size: true, + ATime: true, + MTime: true, + CTime: true, + ATimeNotSystemTime: true, + MTimeNotSystemTime: true, + }, + SetAttr: SetAttr{ + Permissions: 1, + UID: 2, + GID: 3, + Size: 4, + ATimeSeconds: 5, + ATimeNanoSeconds: 6, + MTimeSeconds: 7, + MTimeNanoSeconds: 8, + }, + }, } for _, enc := range objs { diff --git a/pkg/p9/p9.go b/pkg/p9/p9.go index 122c457d2..2235f8968 100644 --- a/pkg/p9/p9.go +++ b/pkg/p9/p9.go @@ -315,86 +315,88 @@ type MsgType uint8 // MsgType declarations. const ( - MsgTlerror MsgType = 6 - MsgRlerror = 7 - MsgTstatfs = 8 - MsgRstatfs = 9 - MsgTlopen = 12 - MsgRlopen = 13 - MsgTlcreate = 14 - MsgRlcreate = 15 - MsgTsymlink = 16 - MsgRsymlink = 17 - MsgTmknod = 18 - MsgRmknod = 19 - MsgTrename = 20 - MsgRrename = 21 - MsgTreadlink = 22 - MsgRreadlink = 23 - MsgTgetattr = 24 - MsgRgetattr = 25 - MsgTsetattr = 26 - MsgRsetattr = 27 - MsgTlistxattr = 28 - MsgRlistxattr = 29 - MsgTxattrwalk = 30 - MsgRxattrwalk = 31 - MsgTxattrcreate = 32 - MsgRxattrcreate = 33 - MsgTgetxattr = 34 - MsgRgetxattr = 35 - MsgTsetxattr = 36 - MsgRsetxattr = 37 - MsgTremovexattr = 38 - MsgRremovexattr = 39 - MsgTreaddir = 40 - MsgRreaddir = 41 - MsgTfsync = 50 - MsgRfsync = 51 - MsgTlink = 70 - MsgRlink = 71 - MsgTmkdir = 72 - MsgRmkdir = 73 - MsgTrenameat = 74 - MsgRrenameat = 75 - MsgTunlinkat = 76 - MsgRunlinkat = 77 - MsgTversion = 100 - MsgRversion = 101 - MsgTauth = 102 - MsgRauth = 103 - MsgTattach = 104 - MsgRattach = 105 - MsgTflush = 108 - MsgRflush = 109 - MsgTwalk = 110 - MsgRwalk = 111 - MsgTread = 116 - MsgRread = 117 - MsgTwrite = 118 - MsgRwrite = 119 - MsgTclunk = 120 - MsgRclunk = 121 - MsgTremove = 122 - MsgRremove = 123 - MsgTflushf = 124 - MsgRflushf = 125 - MsgTwalkgetattr = 126 - MsgRwalkgetattr = 127 - MsgTucreate = 128 - MsgRucreate = 129 - MsgTumkdir = 130 - MsgRumkdir = 131 - MsgTumknod = 132 - MsgRumknod = 133 - MsgTusymlink = 134 - MsgRusymlink = 135 - MsgTlconnect = 136 - MsgRlconnect = 137 - MsgTallocate = 138 - MsgRallocate = 139 - MsgTchannel = 250 - MsgRchannel = 251 + MsgTlerror MsgType = 6 + MsgRlerror MsgType = 7 + MsgTstatfs MsgType = 8 + MsgRstatfs MsgType = 9 + MsgTlopen MsgType = 12 + MsgRlopen MsgType = 13 + MsgTlcreate MsgType = 14 + MsgRlcreate MsgType = 15 + MsgTsymlink MsgType = 16 + MsgRsymlink MsgType = 17 + MsgTmknod MsgType = 18 + MsgRmknod MsgType = 19 + MsgTrename MsgType = 20 + MsgRrename MsgType = 21 + MsgTreadlink MsgType = 22 + MsgRreadlink MsgType = 23 + MsgTgetattr MsgType = 24 + MsgRgetattr MsgType = 25 + MsgTsetattr MsgType = 26 + MsgRsetattr MsgType = 27 + MsgTlistxattr MsgType = 28 + MsgRlistxattr MsgType = 29 + MsgTxattrwalk MsgType = 30 + MsgRxattrwalk MsgType = 31 + MsgTxattrcreate MsgType = 32 + MsgRxattrcreate MsgType = 33 + MsgTgetxattr MsgType = 34 + MsgRgetxattr MsgType = 35 + MsgTsetxattr MsgType = 36 + MsgRsetxattr MsgType = 37 + MsgTremovexattr MsgType = 38 + MsgRremovexattr MsgType = 39 + MsgTreaddir MsgType = 40 + MsgRreaddir MsgType = 41 + MsgTfsync MsgType = 50 + MsgRfsync MsgType = 51 + MsgTlink MsgType = 70 + MsgRlink MsgType = 71 + MsgTmkdir MsgType = 72 + MsgRmkdir MsgType = 73 + MsgTrenameat MsgType = 74 + MsgRrenameat MsgType = 75 + MsgTunlinkat MsgType = 76 + MsgRunlinkat MsgType = 77 + MsgTversion MsgType = 100 + MsgRversion MsgType = 101 + MsgTauth MsgType = 102 + MsgRauth MsgType = 103 + MsgTattach MsgType = 104 + MsgRattach MsgType = 105 + MsgTflush MsgType = 108 + MsgRflush MsgType = 109 + MsgTwalk MsgType = 110 + MsgRwalk MsgType = 111 + MsgTread MsgType = 116 + MsgRread MsgType = 117 + MsgTwrite MsgType = 118 + MsgRwrite MsgType = 119 + MsgTclunk MsgType = 120 + MsgRclunk MsgType = 121 + MsgTremove MsgType = 122 + MsgRremove MsgType = 123 + MsgTflushf MsgType = 124 + MsgRflushf MsgType = 125 + MsgTwalkgetattr MsgType = 126 + MsgRwalkgetattr MsgType = 127 + MsgTucreate MsgType = 128 + MsgRucreate MsgType = 129 + MsgTumkdir MsgType = 130 + MsgRumkdir MsgType = 131 + MsgTumknod MsgType = 132 + MsgRumknod MsgType = 133 + MsgTusymlink MsgType = 134 + MsgRusymlink MsgType = 135 + MsgTlconnect MsgType = 136 + MsgRlconnect MsgType = 137 + MsgTallocate MsgType = 138 + MsgRallocate MsgType = 139 + MsgTsetattrclunk MsgType = 140 + MsgRsetattrclunk MsgType = 141 + MsgTchannel MsgType = 250 + MsgRchannel MsgType = 251 ) // QIDType represents the file type for QIDs. diff --git a/pkg/p9/p9test/client_test.go b/pkg/p9/p9test/client_test.go index 6e7bb3db2..6e605b14c 100644 --- a/pkg/p9/p9test/client_test.go +++ b/pkg/p9/p9test/client_test.go @@ -1225,22 +1225,31 @@ func TestOpen(t *testing.T) { func TestClose(t *testing.T) { type closeTest struct { name string - closeFn func(backend *Mock, f p9.File) + closeFn func(backend *Mock, f p9.File) error } cases := []closeTest{ { name: "close", - closeFn: func(_ *Mock, f p9.File) { - f.Close() + closeFn: func(_ *Mock, f p9.File) error { + return f.Close() }, }, { name: "remove", - closeFn: func(backend *Mock, f p9.File) { + closeFn: func(backend *Mock, f p9.File) error { // Allow the rename call in the parent, automatically translated. backend.parent.EXPECT().UnlinkAt(gomock.Any(), gomock.Any()).Times(1) - f.(deprecatedRemover).Remove() + return f.(deprecatedRemover).Remove() + }, + }, + { + name: "setAttrClose", + closeFn: func(backend *Mock, f p9.File) error { + valid := p9.SetAttrMask{ATime: true} + attr := p9.SetAttr{ATimeSeconds: 1, ATimeNanoSeconds: 2} + backend.EXPECT().SetAttr(valid, attr).Times(1) + return f.SetAttrClose(valid, attr) }, }, } @@ -1258,7 +1267,9 @@ func TestClose(t *testing.T) { _, backend, f := walkHelper(h, name, root) // Close via the prescribed method. - tc.closeFn(backend, f) + if err := tc.closeFn(backend, f); err != nil { + t.Fatalf("closeFn failed: %v", err) + } // Everything should fail with EBADF. if _, _, err := f.Walk(nil); err != syscall.EBADF { diff --git a/pkg/p9/server.go b/pkg/p9/server.go index 60cf94fa1..3736f12a3 100644 --- a/pkg/p9/server.go +++ b/pkg/p9/server.go @@ -60,12 +60,6 @@ type connState struct { // server is the backing server. server *Server - // sendMu is the send lock. - sendMu sync.Mutex - - // conn is the connection. - conn *unet.Socket - // fids is the set of active FIDs. // // This is used to find FIDs for files. @@ -87,16 +81,30 @@ type connState struct { // version 0 implies 9P2000.L. version uint32 + // pendingWg counts requests that are still being handled. + pendingWg sync.WaitGroup + // -- below relates to the legacy handler -- - // recvOkay indicates that a receive may start. - recvOkay chan bool + // recvMu serializes receiving from conn. + recvMu sync.Mutex + + // recvIdle is the number of goroutines in handleRequests() attempting to + // lock recvMu so that they can receive from conn. recvIdle is accessed + // using atomic memory operations. + recvIdle int32 - // recvDone is signalled when a message is received. - recvDone chan error + // If recvShutdown is true, at least one goroutine has observed a + // connection error while receiving from conn, and all goroutines in + // handleRequests() should exit immediately. recvShutdown is protected by + // recvMu. + recvShutdown bool - // sendDone is signalled when a send is finished. - sendDone chan error + // sendMu serializes sending to conn. + sendMu sync.Mutex + + // conn is the connection used by the legacy transport. + conn *unet.Socket // -- below relates to the flipcall handler -- @@ -479,7 +487,9 @@ func (cs *connState) lookupChannel(id uint32) *channel { // handle handles a single message. func (cs *connState) handle(m message) (r message) { + cs.pendingWg.Add(1) defer func() { + cs.pendingWg.Done() if r == nil { // Don't allow a panic to propagate. err := recover() @@ -503,11 +513,21 @@ func (cs *connState) handle(m message) (r message) { return } -// handleRequest handles a single request. -// -// The recvDone channel is signaled when recv is done (with a error if -// necessary). The sendDone channel is signaled with the result of the send. -func (cs *connState) handleRequest() { +// handleRequest handles a single request. It returns true if the caller should +// continue handling requests and false if it should terminate. +func (cs *connState) handleRequest() bool { + // Obtain the right to receive a message from cs.conn. + atomic.AddInt32(&cs.recvIdle, 1) + cs.recvMu.Lock() + atomic.AddInt32(&cs.recvIdle, -1) + + if cs.recvShutdown { + // Another goroutine already detected a connection problem; exit + // immediately. + cs.recvMu.Unlock() + return false + } + messageSize := atomic.LoadUint32(&cs.messageSize) if messageSize == 0 { // Default or not yet negotiated. @@ -518,12 +538,17 @@ func (cs *connState) handleRequest() { tag, m, err := recv(cs.conn, messageSize, msgRegistry.get) if errSocket, ok := err.(ErrSocket); ok { // Connection problem; stop serving. - cs.recvDone <- errSocket.error - return + log.Debugf("p9.recv: %v", errSocket.error) + cs.recvShutdown = true + cs.recvMu.Unlock() + return false } - // Signal receive is done. - cs.recvDone <- nil + // Ensure that another goroutine is available to receive from cs.conn. + if atomic.LoadInt32(&cs.recvIdle) == 0 { + go cs.handleRequests() // S/R-SAFE: Irrelevant. + } + cs.recvMu.Unlock() // Deal with other errors. if err != nil && err != io.EOF { @@ -532,16 +557,17 @@ func (cs *connState) handleRequest() { cs.sendMu.Lock() err := send(cs.conn, tag, newErr(err)) cs.sendMu.Unlock() - cs.sendDone <- err - return + if err != nil { + log.Debugf("p9.send: %v", err) + } + return true } // Try to start the tag. if !cs.StartTag(tag) { // Nothing we can do at this point; client is bogus. log.Debugf("no valid tag [%05d]", tag) - cs.sendDone <- ErrNoValidMessage - return + return true } // Handle the message. @@ -555,23 +581,29 @@ func (cs *connState) handleRequest() { cs.sendMu.Lock() err = send(cs.conn, tag, r) cs.sendMu.Unlock() - cs.sendDone <- err + if err != nil { + log.Debugf("p9.send: %v", err) + } // Return the message to the cache. msgRegistry.put(m) + + return true } func (cs *connState) handleRequests() { - for range cs.recvOkay { - cs.handleRequest() + for { + if !cs.handleRequest() { + return + } } } func (cs *connState) stop() { - // Close all channels. - close(cs.recvOkay) - close(cs.recvDone) - close(cs.sendDone) + // Wait for completion of all inflight requests. This is mostly so that if + // a request is stuck, the sandbox supervisor has the opportunity to kill + // us with SIGABRT to get a stack dump of the offending handler. + cs.pendingWg.Wait() // Free the channels. cs.channelMu.Lock() @@ -590,6 +622,9 @@ func (cs *connState) stop() { cs.channelAlloc.Destroy() } + // Ensure the connection is closed. + cs.conn.Close() + // Close all remaining fids. for fid, fidRef := range cs.fids { delete(cs.fids, fid) @@ -599,74 +634,23 @@ func (cs *connState) stop() { // handlers running via the wait for Pending => 0 below. fidRef.DecRef() } - - // Ensure the connection is closed. - cs.conn.Close() -} - -// service services requests concurrently. -func (cs *connState) service() error { - // Pending is the number of handlers that have finished receiving but - // not finished processing requests. These must be waiting on properly - // below. See the next comment for an explanation of the loop. - pending := 0 - - // Start the first request handler. - go cs.handleRequests() // S/R-SAFE: Irrelevant. - cs.recvOkay <- true - - // We loop and make sure there's always one goroutine waiting for a new - // request. We process all the data for a single request in one - // goroutine however, to ensure the best turnaround time possible. - for { - select { - case err := <-cs.recvDone: - if err != nil { - // Wait for pending handlers. - for i := 0; i < pending; i++ { - <-cs.sendDone - } - return nil - } - - // This handler is now pending. - pending++ - - // Kick the next receiver, or start a new handler - // if no receiver is currently waiting. - select { - case cs.recvOkay <- true: - default: - go cs.handleRequests() // S/R-SAFE: Irrelevant. - cs.recvOkay <- true - } - - case <-cs.sendDone: - // This handler is finished. - pending-- - - // Error sending a response? Nothing can be done. - // - // We don't terminate on a send error though, since - // we still have a pending receive. The error would - // have been logged above, we just ignore it here. - } - } } // Handle handles a single connection. func (s *Server) Handle(conn *unet.Socket) error { cs := &connState{ - server: s, - conn: conn, - fids: make(map[FID]*fidRef), - tags: make(map[Tag]chan struct{}), - recvOkay: make(chan bool), - recvDone: make(chan error, 10), - sendDone: make(chan error, 10), + server: s, + fids: make(map[FID]*fidRef), + tags: make(map[Tag]chan struct{}), + conn: conn, } defer cs.stop() - return cs.service() + + // Serve requests from conn in the current goroutine; handleRequests() will + // create more goroutines as needed. + cs.handleRequests() + + return nil } // Serve handles requests from the bound socket. diff --git a/pkg/p9/version.go b/pkg/p9/version.go index 09cde9f5a..8d7168ef5 100644 --- a/pkg/p9/version.go +++ b/pkg/p9/version.go @@ -26,7 +26,7 @@ const ( // // Clients are expected to start requesting this version number and // to continuously decrement it until a Tversion request succeeds. - highestSupportedVersion uint32 = 11 + highestSupportedVersion uint32 = 12 // lowestSupportedVersion is the lowest supported version X in a // version string of the format 9P2000.L.Google.X. @@ -173,3 +173,9 @@ func versionSupportsGetSetXattr(v uint32) bool { func versionSupportsListRemoveXattr(v uint32) bool { return v >= 11 } + +// versionSupportsTsetattrclunk returns true if version v supports +// the Tsetattrclunk message. +func versionSupportsTsetattrclunk(v uint32) bool { + return v >= 12 +} diff --git a/pkg/procid/procid_amd64.s b/pkg/procid/procid_amd64.s index 7c622e5d7..a45920040 100644 --- a/pkg/procid/procid_amd64.s +++ b/pkg/procid/procid_amd64.s @@ -14,7 +14,7 @@ // +build amd64 // +build go1.8 -// +build !go1.16 +// +build !go1.17 #include "textflag.h" diff --git a/pkg/procid/procid_arm64.s b/pkg/procid/procid_arm64.s index 48ebb5fd1..9d3b0666d 100644 --- a/pkg/procid/procid_arm64.s +++ b/pkg/procid/procid_arm64.s @@ -14,7 +14,7 @@ // +build arm64 // +build go1.8 -// +build !go1.16 +// +build !go1.17 #include "textflag.h" diff --git a/pkg/refs/refcounter.go b/pkg/refs/refcounter.go index d9d5e6bcb..57d8542b9 100644 --- a/pkg/refs/refcounter.go +++ b/pkg/refs/refcounter.go @@ -234,6 +234,39 @@ const ( LeaksLogTraces ) +// Set implements flag.Value. +func (l *LeakMode) Set(v string) error { + switch v { + case "disabled": + *l = NoLeakChecking + case "log-names": + *l = LeaksLogWarning + case "log-traces": + *l = LeaksLogTraces + default: + return fmt.Errorf("invalid ref leak mode %q", v) + } + return nil +} + +// Get implements flag.Value. +func (l *LeakMode) Get() interface{} { + return *l +} + +// String implements flag.Value. +func (l *LeakMode) String() string { + switch *l { + case NoLeakChecking: + return "disabled" + case LeaksLogWarning: + return "log-names" + case LeaksLogTraces: + return "log-traces" + } + panic(fmt.Sprintf("invalid ref leak mode %q", *l)) +} + // leakMode stores the current mode for the reference leak checker. // // Values must be one of the LeakMode values. diff --git a/pkg/refs_vfs2/BUILD b/pkg/refs_vfs2/BUILD index 7b3e10683..577b827a5 100644 --- a/pkg/refs_vfs2/BUILD +++ b/pkg/refs_vfs2/BUILD @@ -11,7 +11,7 @@ go_template( types = [ "T", ], - visibility = ["//pkg/sentry:internal"], + visibility = ["//:sandbox"], deps = [ "//pkg/log", "//pkg/refs", diff --git a/pkg/refs_vfs2/refs_template.go b/pkg/refs_vfs2/refs_template.go index 99c43c065..d9b552896 100644 --- a/pkg/refs_vfs2/refs_template.go +++ b/pkg/refs_vfs2/refs_template.go @@ -12,11 +12,15 @@ // See the License for the specific language governing permissions and // limitations under the License. -// Package refs_template defines a template that can be used by reference counted -// objects. +// Package refs_template defines a template that can be used by reference +// counted objects. The "owner" template parameter is used in log messages to +// indicate the type of reference-counted object that exhibited a reference +// leak. As a result, structs that are embedded in other structs should not use +// this template, since it will make tracking down leaks more difficult. package refs_template import ( + "fmt" "runtime" "sync/atomic" @@ -38,6 +42,11 @@ var ownerType *T // Note that the number of references is actually refCount + 1 so that a default // zero-value Refs object contains one reference. // +// TODO(gvisor.dev/issue/1486): Store stack traces when leak check is enabled in +// a map with 16-bit hashes, and store the hash in the top 16 bits of refCount. +// This will allow us to add stack trace information to the leak messages +// without growing the size of Refs. +// // +stateify savable type Refs struct { // refCount is composed of two fields: @@ -82,7 +91,7 @@ func (r *Refs) ReadRefs() int64 { //go:nosplit func (r *Refs) IncRef() { if v := atomic.AddInt64(&r.refCount, 1); v <= 0 { - panic("Incrementing non-positive ref count") + panic(fmt.Sprintf("Incrementing non-positive ref count %p owned by %T", r, ownerType)) } } @@ -122,7 +131,7 @@ func (r *Refs) TryIncRef() bool { func (r *Refs) DecRef(destroy func()) { switch v := atomic.AddInt64(&r.refCount, -1); { case v < -1: - panic("Decrementing non-positive ref count") + panic(fmt.Sprintf("Decrementing non-positive ref count %p, owned by %T", r, ownerType)) case v == -1: // Call the destructor. diff --git a/pkg/safemem/BUILD b/pkg/safemem/BUILD index ce30382ab..68ed074f8 100644 --- a/pkg/safemem/BUILD +++ b/pkg/safemem/BUILD @@ -11,9 +11,7 @@ go_library( "seq_unsafe.go", ], visibility = ["//:sandbox"], - deps = [ - "//pkg/safecopy", - ], + deps = ["//pkg/safecopy"], ) go_test( diff --git a/pkg/safemem/seq_unsafe.go b/pkg/safemem/seq_unsafe.go index f5f0574f8..fc4049eeb 100644 --- a/pkg/safemem/seq_unsafe.go +++ b/pkg/safemem/seq_unsafe.go @@ -91,9 +91,10 @@ func BlockSeqFromSlice(slice []Block) BlockSeq { return blockSeqFromSliceLimited(slice, limit) } -// Preconditions: The combined length of all Blocks in slice <= limit. If -// len(slice) != 0, the first Block in slice has non-zero length, and limit > -// 0. +// Preconditions: +// * The combined length of all Blocks in slice <= limit. +// * If len(slice) != 0, the first Block in slice has non-zero length and +// limit > 0. func blockSeqFromSliceLimited(slice []Block, limit uint64) BlockSeq { switch len(slice) { case 0: diff --git a/pkg/seccomp/BUILD b/pkg/seccomp/BUILD index 29aeaab8c..bdef7762c 100644 --- a/pkg/seccomp/BUILD +++ b/pkg/seccomp/BUILD @@ -10,6 +10,7 @@ go_binary( "seccomp_test_victim_amd64.go", "seccomp_test_victim_arm64.go", ], + nogo = False, deps = [":seccomp"], ) diff --git a/pkg/seccomp/seccomp.go b/pkg/seccomp/seccomp.go index 55fd6967e..752e2dc32 100644 --- a/pkg/seccomp/seccomp.go +++ b/pkg/seccomp/seccomp.go @@ -12,7 +12,8 @@ // See the License for the specific language governing permissions and // limitations under the License. -// Package seccomp provides basic seccomp filters for x86_64 (little endian). +// Package seccomp provides generation of basic seccomp filters. Currently, +// only little endian systems are supported. package seccomp import ( @@ -64,9 +65,9 @@ func Install(rules SyscallRules) error { Rules: rules, Action: linux.SECCOMP_RET_ALLOW, }, - }, defaultAction) + }, defaultAction, defaultAction) if log.IsLogging(log.Debug) { - programStr, errDecode := bpf.DecodeProgram(instrs) + programStr, errDecode := bpf.DecodeInstructions(instrs) if errDecode != nil { programStr = fmt.Sprintf("Error: %v\n%s", errDecode, programStr) } @@ -117,7 +118,7 @@ var SyscallName = func(sysno uintptr) string { // BuildProgram builds a BPF program from the given map of actions to matching // SyscallRules. The single generated program covers all provided RuleSets. -func BuildProgram(rules []RuleSet, defaultAction linux.BPFAction) ([]linux.BPFInstruction, error) { +func BuildProgram(rules []RuleSet, defaultAction, badArchAction linux.BPFAction) ([]linux.BPFInstruction, error) { program := bpf.NewProgramBuilder() // Be paranoid and check that syscall is done in the expected architecture. @@ -128,7 +129,7 @@ func BuildProgram(rules []RuleSet, defaultAction linux.BPFAction) ([]linux.BPFIn // defaultLabel is at the bottom of the program. The size of program // may exceeds 255 lines, which is the limit of a condition jump. program.AddJump(bpf.Jmp|bpf.Jeq|bpf.K, LINUX_AUDIT_ARCH, skipOneInst, 0) - program.AddDirectJumpLabel(defaultLabel) + program.AddStmt(bpf.Ret|bpf.K, uint32(badArchAction)) if err := buildIndex(rules, program); err != nil { return nil, err } @@ -144,6 +145,11 @@ func BuildProgram(rules []RuleSet, defaultAction linux.BPFAction) ([]linux.BPFIn // buildIndex builds a BST to quickly search through all syscalls. func buildIndex(rules []RuleSet, program *bpf.ProgramBuilder) error { + // Do nothing if rules is empty. + if len(rules) == 0 { + return nil + } + // Build a list of all application system calls, across all given rule // sets. We have a simple BST, but may dispatch individual matchers // with different actions. The matchers are evaluated linearly. @@ -216,42 +222,163 @@ func addSyscallArgsCheck(p *bpf.ProgramBuilder, rules []Rule, action linux.BPFAc labelled := false for i, arg := range rule { if arg != nil { + // Break out early if using MatchAny since no further + // instructions are required. + if _, ok := arg.(MatchAny); ok { + continue + } + + // Determine the data offset for low and high bits of input. + dataOffsetLow := seccompDataOffsetArgLow(i) + dataOffsetHigh := seccompDataOffsetArgHigh(i) + if i == RuleIP { + dataOffsetLow = seccompDataOffsetIPLow + dataOffsetHigh = seccompDataOffsetIPHigh + } + + // Add the conditional operation. Input values to the BPF + // program are 64bit values. However, comparisons in BPF can + // only be done on 32bit values. This means that we need to do + // multiple BPF comparisons in order to do one logical 64bit + // comparison. switch a := arg.(type) { - case AllowAny: - case AllowValue: - dataOffsetLow := seccompDataOffsetArgLow(i) - dataOffsetHigh := seccompDataOffsetArgHigh(i) - if i == RuleIP { - dataOffsetLow = seccompDataOffsetIPLow - dataOffsetHigh = seccompDataOffsetIPHigh - } + case EqualTo: + // EqualTo checks that both the higher and lower 32bits are equal. high, low := uint32(a>>32), uint32(a) - // assert arg_low == low + + // Assert that the lower 32bits are equal. + // arg_low == low ? continue : violation p.AddStmt(bpf.Ld|bpf.Abs|bpf.W, dataOffsetLow) p.AddJumpFalseLabel(bpf.Jmp|bpf.Jeq|bpf.K, low, 0, ruleViolationLabel(ruleSetIdx, sysno, ruleidx)) - // assert arg_high == high + + // Assert that the lower 32bits are also equal. + // arg_high == high ? continue/success : violation p.AddStmt(bpf.Ld|bpf.Abs|bpf.W, dataOffsetHigh) p.AddJumpFalseLabel(bpf.Jmp|bpf.Jeq|bpf.K, high, 0, ruleViolationLabel(ruleSetIdx, sysno, ruleidx)) labelled = true + case NotEqual: + // NotEqual checks that either the higher or lower 32bits + // are *not* equal. + high, low := uint32(a>>32), uint32(a) + labelGood := fmt.Sprintf("ne%v", i) + + // Check if the higher 32bits are (not) equal. + // arg_low == low ? continue : success + p.AddStmt(bpf.Ld|bpf.Abs|bpf.W, dataOffsetLow) + p.AddJumpFalseLabel(bpf.Jmp|bpf.Jeq|bpf.K, low, 0, ruleLabel(ruleSetIdx, sysno, ruleidx, labelGood)) + + // Assert that the lower 32bits are not equal (assuming + // higher bits are equal). + // arg_high == high ? violation : continue/success + p.AddStmt(bpf.Ld|bpf.Abs|bpf.W, dataOffsetHigh) + p.AddJumpTrueLabel(bpf.Jmp|bpf.Jeq|bpf.K, high, ruleViolationLabel(ruleSetIdx, sysno, ruleidx), 0) + p.AddLabel(ruleLabel(ruleSetIdx, sysno, ruleidx, labelGood)) + labelled = true case GreaterThan: - dataOffsetLow := seccompDataOffsetArgLow(i) - dataOffsetHigh := seccompDataOffsetArgHigh(i) - if i == RuleIP { - dataOffsetLow = seccompDataOffsetIPLow - dataOffsetHigh = seccompDataOffsetIPHigh - } - labelGood := fmt.Sprintf("gt%v", i) + // GreaterThan checks that the higher 32bits is greater + // *or* that the higher 32bits are equal and the lower + // 32bits are greater. high, low := uint32(a>>32), uint32(a) - // assert arg_high < high + labelGood := fmt.Sprintf("gt%v", i) + + // Assert the higher 32bits are greater than or equal. + // arg_high >= high ? continue : violation (arg_high < high) p.AddStmt(bpf.Ld|bpf.Abs|bpf.W, dataOffsetHigh) p.AddJumpFalseLabel(bpf.Jmp|bpf.Jge|bpf.K, high, 0, ruleViolationLabel(ruleSetIdx, sysno, ruleidx)) - // arg_high > high + + // Assert that the lower 32bits are greater. + // arg_high == high ? continue : success (arg_high > high) p.AddJumpFalseLabel(bpf.Jmp|bpf.Jeq|bpf.K, high, 0, ruleLabel(ruleSetIdx, sysno, ruleidx, labelGood)) - // arg_low < low + // arg_low > low ? continue/success : violation (arg_high == high and arg_low <= low) p.AddStmt(bpf.Ld|bpf.Abs|bpf.W, dataOffsetLow) p.AddJumpFalseLabel(bpf.Jmp|bpf.Jgt|bpf.K, low, 0, ruleViolationLabel(ruleSetIdx, sysno, ruleidx)) p.AddLabel(ruleLabel(ruleSetIdx, sysno, ruleidx, labelGood)) labelled = true + case GreaterThanOrEqual: + // GreaterThanOrEqual checks that the higher 32bits is + // greater *or* that the higher 32bits are equal and the + // lower 32bits are greater than or equal. + high, low := uint32(a>>32), uint32(a) + labelGood := fmt.Sprintf("ge%v", i) + + // Assert the higher 32bits are greater than or equal. + // arg_high >= high ? continue : violation (arg_high < high) + p.AddStmt(bpf.Ld|bpf.Abs|bpf.W, dataOffsetHigh) + p.AddJumpFalseLabel(bpf.Jmp|bpf.Jge|bpf.K, high, 0, ruleViolationLabel(ruleSetIdx, sysno, ruleidx)) + // arg_high == high ? continue : success (arg_high > high) + p.AddJumpFalseLabel(bpf.Jmp|bpf.Jeq|bpf.K, high, 0, ruleLabel(ruleSetIdx, sysno, ruleidx, labelGood)) + + // Assert that the lower 32bits are greater (assuming the + // higher bits are equal). + // arg_low >= low ? continue/success : violation (arg_high == high and arg_low < low) + p.AddStmt(bpf.Ld|bpf.Abs|bpf.W, dataOffsetLow) + p.AddJumpFalseLabel(bpf.Jmp|bpf.Jge|bpf.K, low, 0, ruleViolationLabel(ruleSetIdx, sysno, ruleidx)) + p.AddLabel(ruleLabel(ruleSetIdx, sysno, ruleidx, labelGood)) + labelled = true + case LessThan: + // LessThan checks that the higher 32bits is less *or* that + // the higher 32bits are equal and the lower 32bits are + // less. + high, low := uint32(a>>32), uint32(a) + labelGood := fmt.Sprintf("lt%v", i) + + // Assert the higher 32bits are less than or equal. + // arg_high > high ? violation : continue + p.AddStmt(bpf.Ld|bpf.Abs|bpf.W, dataOffsetHigh) + p.AddJumpTrueLabel(bpf.Jmp|bpf.Jgt|bpf.K, high, ruleViolationLabel(ruleSetIdx, sysno, ruleidx), 0) + // arg_high == high ? continue : success (arg_high < high) + p.AddJumpFalseLabel(bpf.Jmp|bpf.Jeq|bpf.K, high, 0, ruleLabel(ruleSetIdx, sysno, ruleidx, labelGood)) + + // Assert that the lower 32bits are less (assuming the + // higher bits are equal). + // arg_low >= low ? violation : continue + p.AddStmt(bpf.Ld|bpf.Abs|bpf.W, dataOffsetLow) + p.AddJumpTrueLabel(bpf.Jmp|bpf.Jge|bpf.K, low, ruleViolationLabel(ruleSetIdx, sysno, ruleidx), 0) + p.AddLabel(ruleLabel(ruleSetIdx, sysno, ruleidx, labelGood)) + labelled = true + case LessThanOrEqual: + // LessThan checks that the higher 32bits is less *or* that + // the higher 32bits are equal and the lower 32bits are + // less than or equal. + high, low := uint32(a>>32), uint32(a) + labelGood := fmt.Sprintf("le%v", i) + + // Assert the higher 32bits are less than or equal. + // assert arg_high > high ? violation : continue + p.AddStmt(bpf.Ld|bpf.Abs|bpf.W, dataOffsetHigh) + p.AddJumpTrueLabel(bpf.Jmp|bpf.Jgt|bpf.K, high, ruleViolationLabel(ruleSetIdx, sysno, ruleidx), 0) + // arg_high == high ? continue : success + p.AddJumpFalseLabel(bpf.Jmp|bpf.Jeq|bpf.K, high, 0, ruleLabel(ruleSetIdx, sysno, ruleidx, labelGood)) + + // Assert the lower bits are less than or equal (assuming + // the higher bits are equal). + // arg_low > low ? violation : success + p.AddStmt(bpf.Ld|bpf.Abs|bpf.W, dataOffsetLow) + p.AddJumpTrueLabel(bpf.Jmp|bpf.Jgt|bpf.K, low, ruleViolationLabel(ruleSetIdx, sysno, ruleidx), 0) + p.AddLabel(ruleLabel(ruleSetIdx, sysno, ruleidx, labelGood)) + labelled = true + case maskedEqual: + // MaskedEqual checks that the bitwise AND of the value and + // mask are equal for both the higher and lower 32bits. + high, low := uint32(a.value>>32), uint32(a.value) + maskHigh, maskLow := uint32(a.mask>>32), uint32(a.mask) + + // Assert that the lower 32bits are equal when masked. + // A <- arg_low. + p.AddStmt(bpf.Ld|bpf.Abs|bpf.W, dataOffsetLow) + // A <- arg_low & maskLow + p.AddStmt(bpf.Alu|bpf.And|bpf.K, maskLow) + // Assert that arg_low & maskLow == low. + p.AddJumpFalseLabel(bpf.Jmp|bpf.Jeq|bpf.K, low, 0, ruleViolationLabel(ruleSetIdx, sysno, ruleidx)) + + // Assert that the higher 32bits are equal when masked. + // A <- arg_high + p.AddStmt(bpf.Ld|bpf.Abs|bpf.W, dataOffsetHigh) + // A <- arg_high & maskHigh + p.AddStmt(bpf.Alu|bpf.And|bpf.K, maskHigh) + // Assert that arg_high & maskHigh == high. + p.AddJumpFalseLabel(bpf.Jmp|bpf.Jeq|bpf.K, high, 0, ruleViolationLabel(ruleSetIdx, sysno, ruleidx)) + labelled = true default: return fmt.Errorf("unknown syscall rule type: %v", reflect.TypeOf(a)) } diff --git a/pkg/seccomp/seccomp_rules.go b/pkg/seccomp/seccomp_rules.go index a52dc1b4e..daf165bbf 100644 --- a/pkg/seccomp/seccomp_rules.go +++ b/pkg/seccomp/seccomp_rules.go @@ -39,28 +39,79 @@ func seccompDataOffsetArgHigh(i int) uint32 { return seccompDataOffsetArgLow(i) + 4 } -// AllowAny is marker to indicate any value will be accepted. -type AllowAny struct{} +// MatchAny is marker to indicate any value will be accepted. +type MatchAny struct{} -func (a AllowAny) String() (s string) { +func (a MatchAny) String() (s string) { return "*" } -// AllowValue specifies a value that needs to be strictly matched. -type AllowValue uintptr +// EqualTo specifies a value that needs to be strictly matched. +type EqualTo uintptr + +func (a EqualTo) String() (s string) { + return fmt.Sprintf("== %#x", uintptr(a)) +} + +// NotEqual specifies a value that is strictly not equal. +type NotEqual uintptr + +func (a NotEqual) String() (s string) { + return fmt.Sprintf("!= %#x", uintptr(a)) +} // GreaterThan specifies a value that needs to be strictly smaller. type GreaterThan uintptr -func (a AllowValue) String() (s string) { - return fmt.Sprintf("%#x ", uintptr(a)) +func (a GreaterThan) String() (s string) { + return fmt.Sprintf("> %#x", uintptr(a)) +} + +// GreaterThanOrEqual specifies a value that needs to be smaller or equal. +type GreaterThanOrEqual uintptr + +func (a GreaterThanOrEqual) String() (s string) { + return fmt.Sprintf(">= %#x", uintptr(a)) +} + +// LessThan specifies a value that needs to be strictly greater. +type LessThan uintptr + +func (a LessThan) String() (s string) { + return fmt.Sprintf("< %#x", uintptr(a)) +} + +// LessThanOrEqual specifies a value that needs to be greater or equal. +type LessThanOrEqual uintptr + +func (a LessThanOrEqual) String() (s string) { + return fmt.Sprintf("<= %#x", uintptr(a)) +} + +type maskedEqual struct { + mask uintptr + value uintptr +} + +func (a maskedEqual) String() (s string) { + return fmt.Sprintf("& %#x == %#x", a.mask, a.value) +} + +// MaskedEqual specifies a value that matches the input after the input is +// masked (bitwise &) against the given mask. Can be used to verify that input +// only includes certain approved flags. +func MaskedEqual(mask, value uintptr) interface{} { + return maskedEqual{ + mask: mask, + value: value, + } } // Rule stores the allowed syscall arguments. // // For example: // rule := Rule { -// AllowValue(linux.ARCH_GET_FS | linux.ARCH_SET_FS), // arg0 +// EqualTo(linux.ARCH_GET_FS | linux.ARCH_SET_FS), // arg0 // } type Rule [7]interface{} // 6 arguments + RIP @@ -89,12 +140,12 @@ func (r Rule) String() (s string) { // rules := SyscallRules{ // syscall.SYS_FUTEX: []Rule{ // { -// AllowAny{}, -// AllowValue(linux.FUTEX_WAIT | linux.FUTEX_PRIVATE_FLAG), +// MatchAny{}, +// EqualTo(linux.FUTEX_WAIT | linux.FUTEX_PRIVATE_FLAG), // }, // OR // { -// AllowAny{}, -// AllowValue(linux.FUTEX_WAKE | linux.FUTEX_PRIVATE_FLAG), +// MatchAny{}, +// EqualTo(linux.FUTEX_WAKE | linux.FUTEX_PRIVATE_FLAG), // }, // }, // syscall.SYS_GETPID: []Rule{}, diff --git a/pkg/seccomp/seccomp_test.go b/pkg/seccomp/seccomp_test.go index 5238df8bd..23f30678d 100644 --- a/pkg/seccomp/seccomp_test.go +++ b/pkg/seccomp/seccomp_test.go @@ -76,11 +76,14 @@ func TestBasic(t *testing.T) { } for _, test := range []struct { + name string ruleSets []RuleSet defaultAction linux.BPFAction + badArchAction linux.BPFAction specs []spec }{ { + name: "Single syscall", ruleSets: []RuleSet{ { Rules: SyscallRules{1: {}}, @@ -88,26 +91,28 @@ func TestBasic(t *testing.T) { }, }, defaultAction: linux.SECCOMP_RET_TRAP, + badArchAction: linux.SECCOMP_RET_KILL_THREAD, specs: []spec{ { - desc: "Single syscall allowed", + desc: "syscall allowed", data: seccompData{nr: 1, arch: LINUX_AUDIT_ARCH}, want: linux.SECCOMP_RET_ALLOW, }, { - desc: "Single syscall disallowed", + desc: "syscall disallowed", data: seccompData{nr: 2, arch: LINUX_AUDIT_ARCH}, want: linux.SECCOMP_RET_TRAP, }, }, }, { + name: "Multiple rulesets", ruleSets: []RuleSet{ { Rules: SyscallRules{ 1: []Rule{ { - AllowValue(0x1), + EqualTo(0x1), }, }, }, @@ -122,30 +127,32 @@ func TestBasic(t *testing.T) { }, }, defaultAction: linux.SECCOMP_RET_KILL_THREAD, + badArchAction: linux.SECCOMP_RET_KILL_THREAD, specs: []spec{ { - desc: "Multiple rulesets allowed (1a)", + desc: "allowed (1a)", data: seccompData{nr: 1, arch: LINUX_AUDIT_ARCH, args: [6]uint64{0x1}}, want: linux.SECCOMP_RET_ALLOW, }, { - desc: "Multiple rulesets allowed (1b)", + desc: "allowed (1b)", data: seccompData{nr: 1, arch: LINUX_AUDIT_ARCH}, want: linux.SECCOMP_RET_TRAP, }, { - desc: "Multiple rulesets allowed (2)", + desc: "syscall 1 matched 2nd rule", data: seccompData{nr: 1, arch: LINUX_AUDIT_ARCH}, want: linux.SECCOMP_RET_TRAP, }, { - desc: "Multiple rulesets allowed (2)", + desc: "no match", data: seccompData{nr: 0, arch: LINUX_AUDIT_ARCH}, want: linux.SECCOMP_RET_KILL_THREAD, }, }, }, { + name: "Multiple syscalls", ruleSets: []RuleSet{ { Rules: SyscallRules{ @@ -157,50 +164,52 @@ func TestBasic(t *testing.T) { }, }, defaultAction: linux.SECCOMP_RET_TRAP, + badArchAction: linux.SECCOMP_RET_KILL_THREAD, specs: []spec{ { - desc: "Multiple syscalls allowed (1)", + desc: "allowed (1)", data: seccompData{nr: 1, arch: LINUX_AUDIT_ARCH}, want: linux.SECCOMP_RET_ALLOW, }, { - desc: "Multiple syscalls allowed (3)", + desc: "allowed (3)", data: seccompData{nr: 3, arch: LINUX_AUDIT_ARCH}, want: linux.SECCOMP_RET_ALLOW, }, { - desc: "Multiple syscalls allowed (5)", + desc: "allowed (5)", data: seccompData{nr: 5, arch: LINUX_AUDIT_ARCH}, want: linux.SECCOMP_RET_ALLOW, }, { - desc: "Multiple syscalls disallowed (0)", + desc: "disallowed (0)", data: seccompData{nr: 0, arch: LINUX_AUDIT_ARCH}, want: linux.SECCOMP_RET_TRAP, }, { - desc: "Multiple syscalls disallowed (2)", + desc: "disallowed (2)", data: seccompData{nr: 2, arch: LINUX_AUDIT_ARCH}, want: linux.SECCOMP_RET_TRAP, }, { - desc: "Multiple syscalls disallowed (4)", + desc: "disallowed (4)", data: seccompData{nr: 4, arch: LINUX_AUDIT_ARCH}, want: linux.SECCOMP_RET_TRAP, }, { - desc: "Multiple syscalls disallowed (6)", + desc: "disallowed (6)", data: seccompData{nr: 6, arch: LINUX_AUDIT_ARCH}, want: linux.SECCOMP_RET_TRAP, }, { - desc: "Multiple syscalls disallowed (100)", + desc: "disallowed (100)", data: seccompData{nr: 100, arch: LINUX_AUDIT_ARCH}, want: linux.SECCOMP_RET_TRAP, }, }, }, { + name: "Wrong architecture", ruleSets: []RuleSet{ { Rules: SyscallRules{ @@ -210,15 +219,17 @@ func TestBasic(t *testing.T) { }, }, defaultAction: linux.SECCOMP_RET_TRAP, + badArchAction: linux.SECCOMP_RET_KILL_THREAD, specs: []spec{ { - desc: "Wrong architecture", + desc: "arch (123)", data: seccompData{nr: 1, arch: 123}, - want: linux.SECCOMP_RET_TRAP, + want: linux.SECCOMP_RET_KILL_THREAD, }, }, }, { + name: "Syscall disallowed", ruleSets: []RuleSet{ { Rules: SyscallRules{ @@ -228,22 +239,24 @@ func TestBasic(t *testing.T) { }, }, defaultAction: linux.SECCOMP_RET_TRAP, + badArchAction: linux.SECCOMP_RET_KILL_THREAD, specs: []spec{ { - desc: "Syscall disallowed, action trap", + desc: "action trap", data: seccompData{nr: 2, arch: LINUX_AUDIT_ARCH}, want: linux.SECCOMP_RET_TRAP, }, }, }, { + name: "Syscall arguments", ruleSets: []RuleSet{ { Rules: SyscallRules{ 1: []Rule{ { - AllowAny{}, - AllowValue(0xf), + MatchAny{}, + EqualTo(0xf), }, }, }, @@ -251,29 +264,31 @@ func TestBasic(t *testing.T) { }, }, defaultAction: linux.SECCOMP_RET_TRAP, + badArchAction: linux.SECCOMP_RET_KILL_THREAD, specs: []spec{ { - desc: "Syscall argument allowed", + desc: "allowed", data: seccompData{nr: 1, arch: LINUX_AUDIT_ARCH, args: [6]uint64{0xf, 0xf}}, want: linux.SECCOMP_RET_ALLOW, }, { - desc: "Syscall argument disallowed", + desc: "disallowed", data: seccompData{nr: 1, arch: LINUX_AUDIT_ARCH, args: [6]uint64{0xf, 0xe}}, want: linux.SECCOMP_RET_TRAP, }, }, }, { + name: "Multiple arguments", ruleSets: []RuleSet{ { Rules: SyscallRules{ 1: []Rule{ { - AllowValue(0xf), + EqualTo(0xf), }, { - AllowValue(0xe), + EqualTo(0xe), }, }, }, @@ -281,28 +296,30 @@ func TestBasic(t *testing.T) { }, }, defaultAction: linux.SECCOMP_RET_TRAP, + badArchAction: linux.SECCOMP_RET_KILL_THREAD, specs: []spec{ { - desc: "Syscall argument allowed, two rules", + desc: "match first rule", data: seccompData{nr: 1, arch: LINUX_AUDIT_ARCH, args: [6]uint64{0xf}}, want: linux.SECCOMP_RET_ALLOW, }, { - desc: "Syscall argument allowed, two rules", + desc: "match 2nd rule", data: seccompData{nr: 1, arch: LINUX_AUDIT_ARCH, args: [6]uint64{0xe}}, want: linux.SECCOMP_RET_ALLOW, }, }, }, { + name: "EqualTo", ruleSets: []RuleSet{ { Rules: SyscallRules{ 1: []Rule{ { - AllowValue(0), - AllowValue(math.MaxUint64 - 1), - AllowValue(math.MaxUint32), + EqualTo(0), + EqualTo(math.MaxUint64 - 1), + EqualTo(math.MaxUint32), }, }, }, @@ -310,9 +327,10 @@ func TestBasic(t *testing.T) { }, }, defaultAction: linux.SECCOMP_RET_TRAP, + badArchAction: linux.SECCOMP_RET_KILL_THREAD, specs: []spec{ { - desc: "64bit syscall argument allowed", + desc: "argument allowed (all match)", data: seccompData{ nr: 1, arch: LINUX_AUDIT_ARCH, @@ -321,7 +339,7 @@ func TestBasic(t *testing.T) { want: linux.SECCOMP_RET_ALLOW, }, { - desc: "64bit syscall argument disallowed", + desc: "argument disallowed (one mismatch)", data: seccompData{ nr: 1, arch: LINUX_AUDIT_ARCH, @@ -330,7 +348,7 @@ func TestBasic(t *testing.T) { want: linux.SECCOMP_RET_TRAP, }, { - desc: "64bit syscall argument disallowed", + desc: "argument disallowed (multiple mismatch)", data: seccompData{ nr: 1, arch: LINUX_AUDIT_ARCH, @@ -341,6 +359,103 @@ func TestBasic(t *testing.T) { }, }, { + name: "NotEqual", + ruleSets: []RuleSet{ + { + Rules: SyscallRules{ + 1: []Rule{ + { + NotEqual(0x7aabbccdd), + NotEqual(math.MaxUint64 - 1), + NotEqual(math.MaxUint32), + }, + }, + }, + Action: linux.SECCOMP_RET_ALLOW, + }, + }, + defaultAction: linux.SECCOMP_RET_TRAP, + badArchAction: linux.SECCOMP_RET_KILL_THREAD, + specs: []spec{ + { + desc: "arg allowed", + data: seccompData{ + nr: 1, + arch: LINUX_AUDIT_ARCH, + args: [6]uint64{0, math.MaxUint64, math.MaxUint32 - 1}, + }, + want: linux.SECCOMP_RET_ALLOW, + }, + { + desc: "arg disallowed (one equal)", + data: seccompData{ + nr: 1, + arch: LINUX_AUDIT_ARCH, + args: [6]uint64{0x7aabbccdd, math.MaxUint64, math.MaxUint32 - 1}, + }, + want: linux.SECCOMP_RET_TRAP, + }, + { + desc: "arg disallowed (all equal)", + data: seccompData{ + nr: 1, + arch: LINUX_AUDIT_ARCH, + args: [6]uint64{0x7aabbccdd, math.MaxUint64 - 1, math.MaxUint32}, + }, + want: linux.SECCOMP_RET_TRAP, + }, + }, + }, + { + name: "GreaterThan", + ruleSets: []RuleSet{ + { + Rules: SyscallRules{ + 1: []Rule{ + { + // 4294967298 + // Both upper 32 bits and lower 32 bits are non-zero. + // 00000000000000000000000000000010 + // 00000000000000000000000000000010 + GreaterThan(0x00000002_00000002), + }, + }, + }, + Action: linux.SECCOMP_RET_ALLOW, + }, + }, + defaultAction: linux.SECCOMP_RET_TRAP, + badArchAction: linux.SECCOMP_RET_KILL_THREAD, + specs: []spec{ + { + desc: "high 32bits greater", + data: seccompData{nr: 1, arch: LINUX_AUDIT_ARCH, args: [6]uint64{0x00000003_00000002}}, + want: linux.SECCOMP_RET_ALLOW, + }, + { + desc: "high 32bits equal, low 32bits greater", + data: seccompData{nr: 1, arch: LINUX_AUDIT_ARCH, args: [6]uint64{0x00000002_00000003}}, + want: linux.SECCOMP_RET_ALLOW, + }, + { + desc: "high 32bits equal, low 32bits equal", + data: seccompData{nr: 1, arch: LINUX_AUDIT_ARCH, args: [6]uint64{0x00000002_00000002}}, + want: linux.SECCOMP_RET_TRAP, + }, + { + desc: "high 32bits equal, low 32bits less", + data: seccompData{nr: 1, arch: LINUX_AUDIT_ARCH, args: [6]uint64{0x00000002_00000001}}, + want: linux.SECCOMP_RET_TRAP, + }, + { + desc: "high 32bits less", + data: seccompData{nr: 1, arch: LINUX_AUDIT_ARCH, args: [6]uint64{0x00000001_00000003}}, + want: linux.SECCOMP_RET_TRAP, + }, + }, + }, + { + name: "GreaterThan (multi)", ruleSets: []RuleSet{ { Rules: SyscallRules{ @@ -355,46 +470,145 @@ func TestBasic(t *testing.T) { }, }, defaultAction: linux.SECCOMP_RET_TRAP, + badArchAction: linux.SECCOMP_RET_KILL_THREAD, specs: []spec{ { - desc: "GreaterThan: Syscall argument allowed", + desc: "arg allowed", data: seccompData{nr: 1, arch: LINUX_AUDIT_ARCH, args: [6]uint64{0x10, 0xffffffff}}, want: linux.SECCOMP_RET_ALLOW, }, { - desc: "GreaterThan: Syscall argument disallowed (equal)", + desc: "arg disallowed (first arg equal)", data: seccompData{nr: 1, arch: LINUX_AUDIT_ARCH, args: [6]uint64{0xf, 0xffffffff}}, want: linux.SECCOMP_RET_TRAP, }, { - desc: "Syscall argument disallowed (smaller)", + desc: "arg disallowed (first arg smaller)", data: seccompData{nr: 1, arch: LINUX_AUDIT_ARCH, args: [6]uint64{0x0, 0xffffffff}}, want: linux.SECCOMP_RET_TRAP, }, { - desc: "GreaterThan2: Syscall argument allowed", - data: seccompData{nr: 1, arch: LINUX_AUDIT_ARCH, args: [6]uint64{0x10, 0xfbcd000d}}, + desc: "arg disallowed (second arg equal)", + data: seccompData{nr: 1, arch: LINUX_AUDIT_ARCH, args: [6]uint64{0x10, 0xabcd000d}}, + want: linux.SECCOMP_RET_TRAP, + }, + { + desc: "arg disallowed (second arg smaller)", + data: seccompData{nr: 1, arch: LINUX_AUDIT_ARCH, args: [6]uint64{0x10, 0xa000ffff}}, + want: linux.SECCOMP_RET_TRAP, + }, + }, + }, + { + name: "GreaterThanOrEqual", + ruleSets: []RuleSet{ + { + Rules: SyscallRules{ + 1: []Rule{ + { + // 4294967298 + // Both upper 32 bits and lower 32 bits are non-zero. + // 00000000000000000000000000000010 + // 00000000000000000000000000000010 + GreaterThanOrEqual(0x00000002_00000002), + }, + }, + }, + Action: linux.SECCOMP_RET_ALLOW, + }, + }, + defaultAction: linux.SECCOMP_RET_TRAP, + badArchAction: linux.SECCOMP_RET_KILL_THREAD, + specs: []spec{ + { + desc: "high 32bits greater", + data: seccompData{nr: 1, arch: LINUX_AUDIT_ARCH, args: [6]uint64{0x00000003_00000002}}, want: linux.SECCOMP_RET_ALLOW, }, { - desc: "GreaterThan2: Syscall argument disallowed (equal)", - data: seccompData{nr: 1, arch: LINUX_AUDIT_ARCH, args: [6]uint64{0x10, 0xabcd000d}}, + desc: "high 32bits equal, low 32bits greater", + data: seccompData{nr: 1, arch: LINUX_AUDIT_ARCH, args: [6]uint64{0x00000002_00000003}}, + want: linux.SECCOMP_RET_ALLOW, + }, + { + desc: "high 32bits equal, low 32bits equal", + data: seccompData{nr: 1, arch: LINUX_AUDIT_ARCH, args: [6]uint64{0x00000002_00000002}}, + want: linux.SECCOMP_RET_ALLOW, + }, + { + desc: "high 32bits equal, low 32bits less", + data: seccompData{nr: 1, arch: LINUX_AUDIT_ARCH, args: [6]uint64{0x00000002_00000001}}, + want: linux.SECCOMP_RET_TRAP, + }, + { + desc: "high 32bits less", + data: seccompData{nr: 1, arch: LINUX_AUDIT_ARCH, args: [6]uint64{0x00000001_00000002}}, + want: linux.SECCOMP_RET_TRAP, + }, + }, + }, + { + name: "GreaterThanOrEqual (multi)", + ruleSets: []RuleSet{ + { + Rules: SyscallRules{ + 1: []Rule{ + { + GreaterThanOrEqual(0xf), + GreaterThanOrEqual(0xabcd000d), + }, + }, + }, + Action: linux.SECCOMP_RET_ALLOW, + }, + }, + defaultAction: linux.SECCOMP_RET_TRAP, + badArchAction: linux.SECCOMP_RET_KILL_THREAD, + specs: []spec{ + { + desc: "arg allowed (both greater)", + data: seccompData{nr: 1, arch: LINUX_AUDIT_ARCH, args: [6]uint64{0x10, 0xffffffff}}, + want: linux.SECCOMP_RET_ALLOW, + }, + { + desc: "arg allowed (first arg equal)", + data: seccompData{nr: 1, arch: LINUX_AUDIT_ARCH, args: [6]uint64{0xf, 0xffffffff}}, + want: linux.SECCOMP_RET_ALLOW, + }, + { + desc: "arg disallowed (first arg smaller)", + data: seccompData{nr: 1, arch: LINUX_AUDIT_ARCH, args: [6]uint64{0x0, 0xffffffff}}, want: linux.SECCOMP_RET_TRAP, }, { - desc: "GreaterThan2: Syscall argument disallowed (smaller)", + desc: "arg allowed (second arg equal)", + data: seccompData{nr: 1, arch: LINUX_AUDIT_ARCH, args: [6]uint64{0x10, 0xabcd000d}}, + want: linux.SECCOMP_RET_ALLOW, + }, + { + desc: "arg disallowed (second arg smaller)", data: seccompData{nr: 1, arch: LINUX_AUDIT_ARCH, args: [6]uint64{0x10, 0xa000ffff}}, want: linux.SECCOMP_RET_TRAP, }, + { + desc: "arg disallowed (both arg smaller)", + data: seccompData{nr: 1, arch: LINUX_AUDIT_ARCH, args: [6]uint64{0x0, 0xa000ffff}}, + want: linux.SECCOMP_RET_TRAP, + }, }, }, { + name: "LessThan", ruleSets: []RuleSet{ { Rules: SyscallRules{ 1: []Rule{ { - RuleIP: AllowValue(0x7aabbccdd), + // 4294967298 + // Both upper 32 bits and lower 32 bits are non-zero. + // 00000000000000000000000000000010 + // 00000000000000000000000000000010 + LessThan(0x00000002_00000002), }, }, }, @@ -402,40 +616,307 @@ func TestBasic(t *testing.T) { }, }, defaultAction: linux.SECCOMP_RET_TRAP, + badArchAction: linux.SECCOMP_RET_KILL_THREAD, specs: []spec{ { - desc: "IP: Syscall instruction pointer allowed", + desc: "high 32bits greater", + data: seccompData{nr: 1, arch: LINUX_AUDIT_ARCH, args: [6]uint64{0x00000003_00000002}}, + want: linux.SECCOMP_RET_TRAP, + }, + { + desc: "high 32bits equal, low 32bits greater", + data: seccompData{nr: 1, arch: LINUX_AUDIT_ARCH, args: [6]uint64{0x00000002_00000003}}, + want: linux.SECCOMP_RET_TRAP, + }, + { + desc: "high 32bits equal, low 32bits equal", + data: seccompData{nr: 1, arch: LINUX_AUDIT_ARCH, args: [6]uint64{0x00000002_00000002}}, + want: linux.SECCOMP_RET_TRAP, + }, + { + desc: "high 32bits equal, low 32bits less", + data: seccompData{nr: 1, arch: LINUX_AUDIT_ARCH, args: [6]uint64{0x00000002_00000001}}, + want: linux.SECCOMP_RET_ALLOW, + }, + { + desc: "high 32bits less", + data: seccompData{nr: 1, arch: LINUX_AUDIT_ARCH, args: [6]uint64{0x00000001_00000002}}, + want: linux.SECCOMP_RET_ALLOW, + }, + }, + }, + { + name: "LessThan (multi)", + ruleSets: []RuleSet{ + { + Rules: SyscallRules{ + 1: []Rule{ + { + LessThan(0x1), + LessThan(0xabcd000d), + }, + }, + }, + Action: linux.SECCOMP_RET_ALLOW, + }, + }, + defaultAction: linux.SECCOMP_RET_TRAP, + badArchAction: linux.SECCOMP_RET_KILL_THREAD, + specs: []spec{ + { + desc: "arg allowed", + data: seccompData{nr: 1, arch: LINUX_AUDIT_ARCH, args: [6]uint64{0x0, 0x0}}, + want: linux.SECCOMP_RET_ALLOW, + }, + { + desc: "arg disallowed (first arg equal)", + data: seccompData{nr: 1, arch: LINUX_AUDIT_ARCH, args: [6]uint64{0x1, 0x0}}, + want: linux.SECCOMP_RET_TRAP, + }, + { + desc: "arg disallowed (first arg greater)", + data: seccompData{nr: 1, arch: LINUX_AUDIT_ARCH, args: [6]uint64{0x2, 0x0}}, + want: linux.SECCOMP_RET_TRAP, + }, + { + desc: "arg disallowed (second arg equal)", + data: seccompData{nr: 1, arch: LINUX_AUDIT_ARCH, args: [6]uint64{0x0, 0xabcd000d}}, + want: linux.SECCOMP_RET_TRAP, + }, + { + desc: "arg disallowed (second arg greater)", + data: seccompData{nr: 1, arch: LINUX_AUDIT_ARCH, args: [6]uint64{0x0, 0xffffffff}}, + want: linux.SECCOMP_RET_TRAP, + }, + { + desc: "arg disallowed (both arg greater)", + data: seccompData{nr: 1, arch: LINUX_AUDIT_ARCH, args: [6]uint64{0x2, 0xffffffff}}, + want: linux.SECCOMP_RET_TRAP, + }, + }, + }, + { + name: "LessThanOrEqual", + ruleSets: []RuleSet{ + { + Rules: SyscallRules{ + 1: []Rule{ + { + // 4294967298 + // Both upper 32 bits and lower 32 bits are non-zero. + // 00000000000000000000000000000010 + // 00000000000000000000000000000010 + LessThanOrEqual(0x00000002_00000002), + }, + }, + }, + Action: linux.SECCOMP_RET_ALLOW, + }, + }, + defaultAction: linux.SECCOMP_RET_TRAP, + badArchAction: linux.SECCOMP_RET_KILL_THREAD, + specs: []spec{ + { + desc: "high 32bits greater", + data: seccompData{nr: 1, arch: LINUX_AUDIT_ARCH, args: [6]uint64{0x00000003_00000002}}, + want: linux.SECCOMP_RET_TRAP, + }, + { + desc: "high 32bits equal, low 32bits greater", + data: seccompData{nr: 1, arch: LINUX_AUDIT_ARCH, args: [6]uint64{0x00000002_00000003}}, + want: linux.SECCOMP_RET_TRAP, + }, + { + desc: "high 32bits equal, low 32bits equal", + data: seccompData{nr: 1, arch: LINUX_AUDIT_ARCH, args: [6]uint64{0x00000002_00000002}}, + want: linux.SECCOMP_RET_ALLOW, + }, + { + desc: "high 32bits equal, low 32bits less", + data: seccompData{nr: 1, arch: LINUX_AUDIT_ARCH, args: [6]uint64{0x00000002_00000001}}, + want: linux.SECCOMP_RET_ALLOW, + }, + { + desc: "high 32bits less", + data: seccompData{nr: 1, arch: LINUX_AUDIT_ARCH, args: [6]uint64{0x00000001_00000002}}, + want: linux.SECCOMP_RET_ALLOW, + }, + }, + }, + + { + name: "LessThanOrEqual (multi)", + ruleSets: []RuleSet{ + { + Rules: SyscallRules{ + 1: []Rule{ + { + LessThanOrEqual(0x1), + LessThanOrEqual(0xabcd000d), + }, + }, + }, + Action: linux.SECCOMP_RET_ALLOW, + }, + }, + defaultAction: linux.SECCOMP_RET_TRAP, + badArchAction: linux.SECCOMP_RET_KILL_THREAD, + specs: []spec{ + { + desc: "arg allowed", + data: seccompData{nr: 1, arch: LINUX_AUDIT_ARCH, args: [6]uint64{0x0, 0x0}}, + want: linux.SECCOMP_RET_ALLOW, + }, + { + desc: "arg allowed (first arg equal)", + data: seccompData{nr: 1, arch: LINUX_AUDIT_ARCH, args: [6]uint64{0x1, 0x0}}, + want: linux.SECCOMP_RET_ALLOW, + }, + { + desc: "arg disallowed (first arg greater)", + data: seccompData{nr: 1, arch: LINUX_AUDIT_ARCH, args: [6]uint64{0x2, 0x0}}, + want: linux.SECCOMP_RET_TRAP, + }, + { + desc: "arg allowed (second arg equal)", + data: seccompData{nr: 1, arch: LINUX_AUDIT_ARCH, args: [6]uint64{0x0, 0xabcd000d}}, + want: linux.SECCOMP_RET_ALLOW, + }, + { + desc: "arg disallowed (second arg greater)", + data: seccompData{nr: 1, arch: LINUX_AUDIT_ARCH, args: [6]uint64{0x0, 0xffffffff}}, + want: linux.SECCOMP_RET_TRAP, + }, + { + desc: "arg disallowed (both arg greater)", + data: seccompData{nr: 1, arch: LINUX_AUDIT_ARCH, args: [6]uint64{0x2, 0xffffffff}}, + want: linux.SECCOMP_RET_TRAP, + }, + }, + }, + { + name: "MaskedEqual", + ruleSets: []RuleSet{ + { + Rules: SyscallRules{ + 1: []Rule{ + { + // x & 00000001 00000011 (0x103) == 00000000 00000001 (0x1) + // Input x must have lowest order bit set and + // must *not* have 8th or second lowest order bit set. + MaskedEqual(0x103, 0x1), + }, + }, + }, + Action: linux.SECCOMP_RET_ALLOW, + }, + }, + defaultAction: linux.SECCOMP_RET_TRAP, + badArchAction: linux.SECCOMP_RET_KILL_THREAD, + specs: []spec{ + { + desc: "arg allowed (low order mandatory bit)", + data: seccompData{ + nr: 1, + arch: LINUX_AUDIT_ARCH, + // 00000000 00000000 00000000 00000001 + args: [6]uint64{0x1}, + }, + want: linux.SECCOMP_RET_ALLOW, + }, + { + desc: "arg allowed (low order optional bit)", + data: seccompData{ + nr: 1, + arch: LINUX_AUDIT_ARCH, + // 00000000 00000000 00000000 00000101 + args: [6]uint64{0x5}, + }, + want: linux.SECCOMP_RET_ALLOW, + }, + { + desc: "arg disallowed (lowest order bit not set)", + data: seccompData{ + nr: 1, + arch: LINUX_AUDIT_ARCH, + // 00000000 00000000 00000000 00000010 + args: [6]uint64{0x2}, + }, + want: linux.SECCOMP_RET_TRAP, + }, + { + desc: "arg disallowed (second lowest order bit set)", + data: seccompData{ + nr: 1, + arch: LINUX_AUDIT_ARCH, + // 00000000 00000000 00000000 00000011 + args: [6]uint64{0x3}, + }, + want: linux.SECCOMP_RET_TRAP, + }, + { + desc: "arg disallowed (8th bit set)", + data: seccompData{ + nr: 1, + arch: LINUX_AUDIT_ARCH, + // 00000000 00000000 00000001 00000000 + args: [6]uint64{0x100}, + }, + want: linux.SECCOMP_RET_TRAP, + }, + }, + }, + { + name: "Instruction Pointer", + ruleSets: []RuleSet{ + { + Rules: SyscallRules{ + 1: []Rule{ + { + RuleIP: EqualTo(0x7aabbccdd), + }, + }, + }, + Action: linux.SECCOMP_RET_ALLOW, + }, + }, + defaultAction: linux.SECCOMP_RET_TRAP, + badArchAction: linux.SECCOMP_RET_KILL_THREAD, + specs: []spec{ + { + desc: "allowed", data: seccompData{nr: 1, arch: LINUX_AUDIT_ARCH, args: [6]uint64{}, instructionPointer: 0x7aabbccdd}, want: linux.SECCOMP_RET_ALLOW, }, { - desc: "IP: Syscall instruction pointer disallowed", + desc: "disallowed", data: seccompData{nr: 1, arch: LINUX_AUDIT_ARCH, args: [6]uint64{}, instructionPointer: 0x711223344}, want: linux.SECCOMP_RET_TRAP, }, }, }, } { - instrs, err := BuildProgram(test.ruleSets, test.defaultAction) - if err != nil { - t.Errorf("%s: buildProgram() got error: %v", test.specs[0].desc, err) - continue - } - p, err := bpf.Compile(instrs) - if err != nil { - t.Errorf("%s: bpf.Compile() got error: %v", test.specs[0].desc, err) - continue - } - for _, spec := range test.specs { - got, err := bpf.Exec(p, spec.data.asInput()) + t.Run(test.name, func(t *testing.T) { + instrs, err := BuildProgram(test.ruleSets, test.defaultAction, test.badArchAction) if err != nil { - t.Errorf("%s: bpf.Exec() got error: %v", spec.desc, err) - continue + t.Fatalf("BuildProgram() got error: %v", err) + } + p, err := bpf.Compile(instrs) + if err != nil { + t.Fatalf("bpf.Compile() got error: %v", err) } - if got != uint32(spec.want) { - t.Errorf("%s: bpd.Exec() = %d, want: %d", spec.desc, got, spec.want) + for _, spec := range test.specs { + got, err := bpf.Exec(p, spec.data.asInput()) + if err != nil { + t.Fatalf("%s: bpf.Exec() got error: %v", spec.desc, err) + } + if got != uint32(spec.want) { + // Include a decoded version of the program in output for debugging purposes. + decoded, _ := bpf.DecodeInstructions(instrs) + t.Fatalf("%s: got: %d, want: %d\nBPF Program\n%s", spec.desc, got, spec.want, decoded) + } } - } + }) } } @@ -457,7 +938,7 @@ func TestRandom(t *testing.T) { Rules: syscallRules, Action: linux.SECCOMP_RET_ALLOW, }, - }, linux.SECCOMP_RET_TRAP) + }, linux.SECCOMP_RET_TRAP, linux.SECCOMP_RET_KILL_THREAD) if err != nil { t.Fatalf("buildProgram() got error: %v", err) } diff --git a/pkg/seccomp/seccomp_test_victim.go b/pkg/seccomp/seccomp_test_victim.go index fe157f539..7f33e0d9e 100644 --- a/pkg/seccomp/seccomp_test_victim.go +++ b/pkg/seccomp/seccomp_test_victim.go @@ -100,7 +100,7 @@ func main() { if !die { syscalls[syscall.SYS_OPENAT] = []seccomp.Rule{ { - seccomp.AllowValue(10), + seccomp.EqualTo(10), }, } } diff --git a/pkg/segment/set.go b/pkg/segment/set.go index 1a17ad9cb..fbb31dbea 100644 --- a/pkg/segment/set.go +++ b/pkg/segment/set.go @@ -407,7 +407,9 @@ func (s *Set) InsertWithoutMerging(gap GapIterator, r Range, val Value) Iterator // and returns an iterator to the inserted segment. All existing iterators // (including gap, but not including the returned iterator) are invalidated. // -// Preconditions: r.Start >= gap.Start(); r.End <= gap.End(). +// Preconditions: +// * r.Start >= gap.Start(). +// * r.End <= gap.End(). func (s *Set) InsertWithoutMergingUnchecked(gap GapIterator, r Range, val Value) Iterator { gap = gap.node.rebalanceBeforeInsert(gap) splitMaxGap := trackGaps != 0 && (gap.node.nrSegments == 0 || gap.Range().Length() == gap.node.maxGap.Get()) @@ -1211,12 +1213,10 @@ func (seg Iterator) End() Key { // does not invalidate any iterators. // // Preconditions: -// -// - r.Length() > 0. -// -// - The new range must not overlap an existing one: If seg.NextSegment().Ok(), -// then r.end <= seg.NextSegment().Start(); if seg.PrevSegment().Ok(), then -// r.start >= seg.PrevSegment().End(). +// * r.Length() > 0. +// * The new range must not overlap an existing one: +// * If seg.NextSegment().Ok(), then r.end <= seg.NextSegment().Start(). +// * If seg.PrevSegment().Ok(), then r.start >= seg.PrevSegment().End(). func (seg Iterator) SetRangeUnchecked(r Range) { seg.node.keys[seg.index] = r } @@ -1241,8 +1241,9 @@ func (seg Iterator) SetRange(r Range) { // SetStartUnchecked mutates the iterated segment's start. This operation does // not invalidate any iterators. // -// Preconditions: The new start must be valid: start < seg.End(); if -// seg.PrevSegment().Ok(), then start >= seg.PrevSegment().End(). +// Preconditions: The new start must be valid: +// * start < seg.End() +// * If seg.PrevSegment().Ok(), then start >= seg.PrevSegment().End(). func (seg Iterator) SetStartUnchecked(start Key) { seg.node.keys[seg.index].Start = start } @@ -1264,8 +1265,9 @@ func (seg Iterator) SetStart(start Key) { // SetEndUnchecked mutates the iterated segment's end. This operation does not // invalidate any iterators. // -// Preconditions: The new end must be valid: end > seg.Start(); if -// seg.NextSegment().Ok(), then end <= seg.NextSegment().Start(). +// Preconditions: The new end must be valid: +// * end > seg.Start(). +// * If seg.NextSegment().Ok(), then end <= seg.NextSegment().Start(). func (seg Iterator) SetEndUnchecked(end Key) { seg.node.keys[seg.index].End = end } @@ -1695,9 +1697,11 @@ func (s *Set) ExportSortedSlices() *SegmentDataSlices { // ImportSortedSlice initializes the given set from the given slice. // -// Preconditions: s must be empty. sds must represent a valid set (the segments -// in sds must have valid lengths that do not overlap). The segments in sds -// must be sorted in ascending key order. +// Preconditions: +// * s must be empty. +// * sds must represent a valid set (the segments in sds must have valid +// lengths that do not overlap). +// * The segments in sds must be sorted in ascending key order. func (s *Set) ImportSortedSlices(sds *SegmentDataSlices) error { if !s.IsEmpty() { return fmt.Errorf("cannot import into non-empty set %v", s) diff --git a/pkg/sentry/arch/BUILD b/pkg/sentry/arch/BUILD index 901e0f320..99e2b3389 100644 --- a/pkg/sentry/arch/BUILD +++ b/pkg/sentry/arch/BUILD @@ -33,11 +33,12 @@ go_library( "//pkg/context", "//pkg/cpuid", "//pkg/log", + "//pkg/marshal", + "//pkg/marshal/primitive", "//pkg/sentry/limits", "//pkg/sync", "//pkg/syserror", "//pkg/usermem", - "//tools/go_marshal/marshal", ], ) diff --git a/pkg/sentry/arch/arch.go b/pkg/sentry/arch/arch.go index a903d031c..d75d665ae 100644 --- a/pkg/sentry/arch/arch.go +++ b/pkg/sentry/arch/arch.go @@ -23,6 +23,7 @@ import ( "gvisor.dev/gvisor/pkg/abi/linux" "gvisor.dev/gvisor/pkg/cpuid" "gvisor.dev/gvisor/pkg/log" + "gvisor.dev/gvisor/pkg/marshal" "gvisor.dev/gvisor/pkg/sentry/limits" "gvisor.dev/gvisor/pkg/usermem" ) @@ -72,12 +73,12 @@ type Context interface { // with return values of varying sizes (for example ARCH_GETFS). This // is a simple utility function to convert to the native size in these // cases, and then we can CopyOut. - Native(val uintptr) interface{} + Native(val uintptr) marshal.Marshallable // Value converts a native type back to a generic value. // Once a value has been converted to native via the above call -- it // can be converted back here. - Value(val interface{}) uintptr + Value(val marshal.Marshallable) uintptr // Width returns the number of bytes for a native value. Width() uint @@ -205,7 +206,7 @@ type Context interface { // equivalent of arch_ptrace(): // PtracePeekUser implements ptrace(PTRACE_PEEKUSR). - PtracePeekUser(addr uintptr) (interface{}, error) + PtracePeekUser(addr uintptr) (marshal.Marshallable, error) // PtracePokeUser implements ptrace(PTRACE_POKEUSR). PtracePokeUser(addr, data uintptr) error diff --git a/pkg/sentry/arch/arch_amd64.go b/pkg/sentry/arch/arch_amd64.go index 1c3e3c14c..c7d3a206d 100644 --- a/pkg/sentry/arch/arch_amd64.go +++ b/pkg/sentry/arch/arch_amd64.go @@ -23,6 +23,8 @@ import ( "syscall" "gvisor.dev/gvisor/pkg/cpuid" + "gvisor.dev/gvisor/pkg/marshal" + "gvisor.dev/gvisor/pkg/marshal/primitive" "gvisor.dev/gvisor/pkg/sentry/limits" "gvisor.dev/gvisor/pkg/usermem" ) @@ -179,14 +181,14 @@ func (c *context64) SetOldRSeqInterruptedIP(value uintptr) { } // Native returns the native type for the given val. -func (c *context64) Native(val uintptr) interface{} { - v := uint64(val) +func (c *context64) Native(val uintptr) marshal.Marshallable { + v := primitive.Uint64(val) return &v } // Value returns the generic val for the given native type. -func (c *context64) Value(val interface{}) uintptr { - return uintptr(*val.(*uint64)) +func (c *context64) Value(val marshal.Marshallable) uintptr { + return uintptr(*val.(*primitive.Uint64)) } // Width returns the byte width of this architecture. @@ -293,7 +295,7 @@ func (c *context64) PIELoadAddress(l MmapLayout) usermem.Addr { const userStructSize = 928 // PtracePeekUser implements Context.PtracePeekUser. -func (c *context64) PtracePeekUser(addr uintptr) (interface{}, error) { +func (c *context64) PtracePeekUser(addr uintptr) (marshal.Marshallable, error) { if addr&7 != 0 || addr >= userStructSize { return nil, syscall.EIO } diff --git a/pkg/sentry/arch/arch_arm64.go b/pkg/sentry/arch/arch_arm64.go index 550741d8c..680d23a9f 100644 --- a/pkg/sentry/arch/arch_arm64.go +++ b/pkg/sentry/arch/arch_arm64.go @@ -22,6 +22,8 @@ import ( "syscall" "gvisor.dev/gvisor/pkg/cpuid" + "gvisor.dev/gvisor/pkg/marshal" + "gvisor.dev/gvisor/pkg/marshal/primitive" "gvisor.dev/gvisor/pkg/sentry/limits" "gvisor.dev/gvisor/pkg/usermem" ) @@ -163,14 +165,14 @@ func (c *context64) SetOldRSeqInterruptedIP(value uintptr) { } // Native returns the native type for the given val. -func (c *context64) Native(val uintptr) interface{} { - v := uint64(val) +func (c *context64) Native(val uintptr) marshal.Marshallable { + v := primitive.Uint64(val) return &v } // Value returns the generic val for the given native type. -func (c *context64) Value(val interface{}) uintptr { - return uintptr(*val.(*uint64)) +func (c *context64) Value(val marshal.Marshallable) uintptr { + return uintptr(*val.(*primitive.Uint64)) } // Width returns the byte width of this architecture. @@ -274,7 +276,7 @@ func (c *context64) PIELoadAddress(l MmapLayout) usermem.Addr { } // PtracePeekUser implements Context.PtracePeekUser. -func (c *context64) PtracePeekUser(addr uintptr) (interface{}, error) { +func (c *context64) PtracePeekUser(addr uintptr) (marshal.Marshallable, error) { // TODO(gvisor.dev/issue/1239): Full ptrace supporting for Arm64. return c.Native(0), nil } diff --git a/pkg/sentry/arch/signal_act.go b/pkg/sentry/arch/signal_act.go index 32173aa20..d3e2324a8 100644 --- a/pkg/sentry/arch/signal_act.go +++ b/pkg/sentry/arch/signal_act.go @@ -14,7 +14,7 @@ package arch -import "gvisor.dev/gvisor/tools/go_marshal/marshal" +import "gvisor.dev/gvisor/pkg/marshal" // Special values for SignalAct.Handler. const ( diff --git a/pkg/sentry/arch/signal_stack.go b/pkg/sentry/arch/signal_stack.go index 0fa738a1d..a1eae98f9 100644 --- a/pkg/sentry/arch/signal_stack.go +++ b/pkg/sentry/arch/signal_stack.go @@ -17,8 +17,8 @@ package arch import ( + "gvisor.dev/gvisor/pkg/marshal" "gvisor.dev/gvisor/pkg/usermem" - "gvisor.dev/gvisor/tools/go_marshal/marshal" ) const ( diff --git a/pkg/sentry/contexttest/contexttest.go b/pkg/sentry/contexttest/contexttest.go index 8e5658c7a..dfd195a23 100644 --- a/pkg/sentry/contexttest/contexttest.go +++ b/pkg/sentry/contexttest/contexttest.go @@ -144,27 +144,7 @@ func (t *TestContext) MemoryFile() *pgalloc.MemoryFile { // RootContext returns a Context that may be used in tests that need root // credentials. Uses ptrace as the platform.Platform. func RootContext(tb testing.TB) context.Context { - return WithCreds(Context(tb), auth.NewRootCredentials(auth.NewRootUserNamespace())) -} - -// WithCreds returns a copy of ctx carrying creds. -func WithCreds(ctx context.Context, creds *auth.Credentials) context.Context { - return &authContext{ctx, creds} -} - -type authContext struct { - context.Context - creds *auth.Credentials -} - -// Value implements context.Context. -func (ac *authContext) Value(key interface{}) interface{} { - switch key { - case auth.CtxCredentials: - return ac.creds - default: - return ac.Context.Value(key) - } + return auth.ContextWithCredentials(Context(tb), auth.NewRootCredentials(auth.NewRootUserNamespace())) } // WithLimitSet returns a copy of ctx carrying l. diff --git a/pkg/sentry/control/BUILD b/pkg/sentry/control/BUILD index 2c5d14be5..deaf5fa23 100644 --- a/pkg/sentry/control/BUILD +++ b/pkg/sentry/control/BUILD @@ -35,7 +35,6 @@ go_library( "//pkg/sync", "//pkg/tcpip/link/sniffer", "//pkg/urpc", - "@org_golang_x_sys//unix:go_default_library", ], ) diff --git a/pkg/sentry/control/proc.go b/pkg/sentry/control/proc.go index dfa936563..668f47802 100644 --- a/pkg/sentry/control/proc.go +++ b/pkg/sentry/control/proc.go @@ -23,8 +23,8 @@ import ( "text/tabwriter" "time" - "golang.org/x/sys/unix" "gvisor.dev/gvisor/pkg/abi/linux" + "gvisor.dev/gvisor/pkg/fd" "gvisor.dev/gvisor/pkg/sentry/fdimport" "gvisor.dev/gvisor/pkg/sentry/fs" "gvisor.dev/gvisor/pkg/sentry/fs/host" @@ -203,27 +203,17 @@ func (proc *Proc) execAsync(args *ExecArgs) (*kernel.ThreadGroup, kernel.ThreadI } initArgs.Filename = resolved - fds := make([]int, len(args.FilePayload.Files)) - for i, file := range args.FilePayload.Files { - if kernel.VFS2Enabled { - // Need to dup to remove ownership from os.File. - dup, err := unix.Dup(int(file.Fd())) - if err != nil { - return nil, 0, nil, nil, fmt.Errorf("duplicating payload files: %w", err) - } - fds[i] = dup - } else { - // VFS1 dups the file on import. - fds[i] = int(file.Fd()) - } + fds, err := fd.NewFromFiles(args.Files) + if err != nil { + return nil, 0, nil, nil, fmt.Errorf("duplicating payload files: %w", err) } + defer func() { + for _, fd := range fds { + _ = fd.Close() + } + }() ttyFile, ttyFileVFS2, err := fdimport.Import(ctx, fdTable, args.StdioIsPty, fds) if err != nil { - if kernel.VFS2Enabled { - for _, fd := range fds { - unix.Close(fd) - } - } return nil, 0, nil, nil, err } diff --git a/pkg/sentry/device/device.go b/pkg/sentry/device/device.go index f45b2bd2b..6ca9dc79f 100644 --- a/pkg/sentry/device/device.go +++ b/pkg/sentry/device/device.go @@ -256,7 +256,7 @@ func (m *MultiDevice) Load(key MultiDeviceKey, value uint64) bool { } if k, exists := m.rcache[value]; exists && k != key { // Should never happen. - panic("MultiDevice's caches are inconsistent") + panic(fmt.Sprintf("MultiDevice's caches are inconsistent, current: %+v, previous: %+v", key, k)) } // Cache value at key. diff --git a/pkg/sentry/devices/memdev/BUILD b/pkg/sentry/devices/memdev/BUILD index abe58f818..4c8604d58 100644 --- a/pkg/sentry/devices/memdev/BUILD +++ b/pkg/sentry/devices/memdev/BUILD @@ -18,9 +18,10 @@ go_library( "//pkg/rand", "//pkg/safemem", "//pkg/sentry/fsimpl/devtmpfs", + "//pkg/sentry/fsimpl/tmpfs", + "//pkg/sentry/kernel", + "//pkg/sentry/kernel/auth", "//pkg/sentry/memmap", - "//pkg/sentry/mm", - "//pkg/sentry/pgalloc", "//pkg/sentry/vfs", "//pkg/syserror", "//pkg/usermem", diff --git a/pkg/sentry/devices/memdev/zero.go b/pkg/sentry/devices/memdev/zero.go index 2e631a252..60cfea888 100644 --- a/pkg/sentry/devices/memdev/zero.go +++ b/pkg/sentry/devices/memdev/zero.go @@ -16,9 +16,10 @@ package memdev import ( "gvisor.dev/gvisor/pkg/context" + "gvisor.dev/gvisor/pkg/sentry/fsimpl/tmpfs" + "gvisor.dev/gvisor/pkg/sentry/kernel" + "gvisor.dev/gvisor/pkg/sentry/kernel/auth" "gvisor.dev/gvisor/pkg/sentry/memmap" - "gvisor.dev/gvisor/pkg/sentry/mm" - "gvisor.dev/gvisor/pkg/sentry/pgalloc" "gvisor.dev/gvisor/pkg/sentry/vfs" "gvisor.dev/gvisor/pkg/usermem" ) @@ -79,11 +80,22 @@ func (fd *zeroFD) Seek(ctx context.Context, offset int64, whence int32) (int64, // ConfigureMMap implements vfs.FileDescriptionImpl.ConfigureMMap. func (fd *zeroFD) ConfigureMMap(ctx context.Context, opts *memmap.MMapOpts) error { - m, err := mm.NewSharedAnonMappable(opts.Length, pgalloc.MemoryFileProviderFromContext(ctx)) + if opts.Private || !opts.MaxPerms.Write { + // This mapping will never permit writing to the "underlying file" (in + // Linux terms, it isn't VM_SHARED), so implement it as an anonymous + // mapping, but back it with fd; this is what Linux does, and is + // actually application-visible because the resulting VMA will show up + // in /proc/[pid]/maps with fd.vfsfd.VirtualDentry()'s path rather than + // "/dev/zero (deleted)". + opts.Offset = 0 + opts.MappingIdentity = &fd.vfsfd + opts.MappingIdentity.IncRef() + return nil + } + tmpfsFD, err := tmpfs.NewZeroFile(ctx, auth.CredentialsFromContext(ctx), kernel.KernelFromContext(ctx).ShmMount(), opts.Length) if err != nil { return err } - opts.MappingIdentity = m - opts.Mappable = m - return nil + defer tmpfsFD.DecRef(ctx) + return tmpfsFD.ConfigureMMap(ctx, opts) } diff --git a/pkg/sentry/devices/tundev/tundev.go b/pkg/sentry/devices/tundev/tundev.go index a40625e19..0b701a289 100644 --- a/pkg/sentry/devices/tundev/tundev.go +++ b/pkg/sentry/devices/tundev/tundev.go @@ -64,12 +64,13 @@ func (fd *tunFD) Ioctl(ctx context.Context, uio usermem.IO, args arch.SyscallArg request := args[1].Uint() data := args[2].Pointer() + t := kernel.TaskFromContext(ctx) + if t == nil { + panic("Ioctl should be called from a task context") + } + switch request { case linux.TUNSETIFF: - t := kernel.TaskFromContext(ctx) - if t == nil { - panic("Ioctl should be called from a task context") - } if !t.HasCapability(linux.CAP_NET_ADMIN) { return 0, syserror.EPERM } @@ -79,9 +80,7 @@ func (fd *tunFD) Ioctl(ctx context.Context, uio usermem.IO, args arch.SyscallArg } var req linux.IFReq - if _, err := usermem.CopyObjectIn(ctx, uio, data, &req, usermem.IOOpts{ - AddressSpaceActive: true, - }); err != nil { + if _, err := req.CopyIn(t, data); err != nil { return 0, err } flags := usermem.ByteOrder.Uint16(req.Data[:]) @@ -97,9 +96,7 @@ func (fd *tunFD) Ioctl(ctx context.Context, uio usermem.IO, args arch.SyscallArg flags := fd.device.Flags() | linux.IFF_NOFILTER usermem.ByteOrder.PutUint16(req.Data[:], flags) - _, err := usermem.CopyObjectOut(ctx, uio, data, &req, usermem.IOOpts{ - AddressSpaceActive: true, - }) + _, err := req.CopyOut(t, data) return 0, err default: diff --git a/pkg/sentry/fdimport/BUILD b/pkg/sentry/fdimport/BUILD index 5e41ceb4e..6b4f8b0ed 100644 --- a/pkg/sentry/fdimport/BUILD +++ b/pkg/sentry/fdimport/BUILD @@ -10,6 +10,7 @@ go_library( visibility = ["//pkg/sentry:internal"], deps = [ "//pkg/context", + "//pkg/fd", "//pkg/sentry/fs", "//pkg/sentry/fs/host", "//pkg/sentry/fsimpl/host", diff --git a/pkg/sentry/fdimport/fdimport.go b/pkg/sentry/fdimport/fdimport.go index 1b7cb94c0..314661475 100644 --- a/pkg/sentry/fdimport/fdimport.go +++ b/pkg/sentry/fdimport/fdimport.go @@ -18,6 +18,7 @@ import ( "fmt" "gvisor.dev/gvisor/pkg/context" + "gvisor.dev/gvisor/pkg/fd" "gvisor.dev/gvisor/pkg/sentry/fs" "gvisor.dev/gvisor/pkg/sentry/fs/host" hostvfs2 "gvisor.dev/gvisor/pkg/sentry/fsimpl/host" @@ -27,8 +28,9 @@ import ( // Import imports a slice of FDs into the given FDTable. If console is true, // sets up TTY for the first 3 FDs in the slice representing stdin, stdout, -// stderr. Upon success, Import takes ownership of all FDs. -func Import(ctx context.Context, fdTable *kernel.FDTable, console bool, fds []int) (*host.TTYFileOperations, *hostvfs2.TTYFileDescription, error) { +// stderr. Used FDs are either closed or released. It's safe for the caller to +// close any remaining files upon return. +func Import(ctx context.Context, fdTable *kernel.FDTable, console bool, fds []*fd.FD) (*host.TTYFileOperations, *hostvfs2.TTYFileDescription, error) { if kernel.VFS2Enabled { ttyFile, err := importVFS2(ctx, fdTable, console, fds) return nil, ttyFile, err @@ -37,7 +39,7 @@ func Import(ctx context.Context, fdTable *kernel.FDTable, console bool, fds []in return ttyFile, nil, err } -func importFS(ctx context.Context, fdTable *kernel.FDTable, console bool, fds []int) (*host.TTYFileOperations, error) { +func importFS(ctx context.Context, fdTable *kernel.FDTable, console bool, fds []*fd.FD) (*host.TTYFileOperations, error) { var ttyFile *fs.File for appFD, hostFD := range fds { var appFile *fs.File @@ -46,11 +48,12 @@ func importFS(ctx context.Context, fdTable *kernel.FDTable, console bool, fds [] // Import the file as a host TTY file. if ttyFile == nil { var err error - appFile, err = host.ImportFile(ctx, hostFD, true /* isTTY */) + appFile, err = host.ImportFile(ctx, hostFD.FD(), true /* isTTY */) if err != nil { return nil, err } defer appFile.DecRef(ctx) + _ = hostFD.Close() // FD is dup'd i ImportFile. // Remember this in the TTY file, as we will // use it for the other stdio FDs. @@ -65,11 +68,12 @@ func importFS(ctx context.Context, fdTable *kernel.FDTable, console bool, fds [] } else { // Import the file as a regular host file. var err error - appFile, err = host.ImportFile(ctx, hostFD, false /* isTTY */) + appFile, err = host.ImportFile(ctx, hostFD.FD(), false /* isTTY */) if err != nil { return nil, err } defer appFile.DecRef(ctx) + _ = hostFD.Close() // FD is dup'd i ImportFile. } // Add the file to the FD map. @@ -84,7 +88,7 @@ func importFS(ctx context.Context, fdTable *kernel.FDTable, console bool, fds [] return ttyFile.FileOperations.(*host.TTYFileOperations), nil } -func importVFS2(ctx context.Context, fdTable *kernel.FDTable, console bool, stdioFDs []int) (*hostvfs2.TTYFileDescription, error) { +func importVFS2(ctx context.Context, fdTable *kernel.FDTable, console bool, stdioFDs []*fd.FD) (*hostvfs2.TTYFileDescription, error) { k := kernel.KernelFromContext(ctx) if k == nil { return nil, fmt.Errorf("cannot find kernel from context") @@ -98,11 +102,12 @@ func importVFS2(ctx context.Context, fdTable *kernel.FDTable, console bool, stdi // Import the file as a host TTY file. if ttyFile == nil { var err error - appFile, err = hostvfs2.ImportFD(ctx, k.HostMount(), hostFD, true /* isTTY */) + appFile, err = hostvfs2.ImportFD(ctx, k.HostMount(), hostFD.FD(), true /* isTTY */) if err != nil { return nil, err } defer appFile.DecRef(ctx) + hostFD.Release() // FD is transfered to host FD. // Remember this in the TTY file, as we will use it for the other stdio // FDs. @@ -115,11 +120,12 @@ func importVFS2(ctx context.Context, fdTable *kernel.FDTable, console bool, stdi } } else { var err error - appFile, err = hostvfs2.ImportFD(ctx, k.HostMount(), hostFD, false /* isTTY */) + appFile, err = hostvfs2.ImportFD(ctx, k.HostMount(), hostFD.FD(), false /* isTTY */) if err != nil { return nil, err } defer appFile.DecRef(ctx) + hostFD.Release() // FD is transfered to host FD. } if err := fdTable.NewFDAtVFS2(ctx, int32(appFD), appFile, kernel.FDFlags{}); err != nil { diff --git a/pkg/sentry/fs/copy_up.go b/pkg/sentry/fs/copy_up.go index 735452b07..ff2fe6712 100644 --- a/pkg/sentry/fs/copy_up.go +++ b/pkg/sentry/fs/copy_up.go @@ -107,8 +107,7 @@ func copyUp(ctx context.Context, d *Dirent) error { // leave the upper filesystem filled with any number of parent directories // but the upper filesystem will never be in an inconsistent state. // -// Preconditions: -// - d.Inode.overlay is non-nil. +// Preconditions: d.Inode.overlay is non-nil. func copyUpLockedForRename(ctx context.Context, d *Dirent) error { for { // Did we race with another copy up or does there @@ -183,12 +182,12 @@ func doCopyUp(ctx context.Context, d *Dirent) error { // Returns a generic error on failure. // // Preconditions: -// - parent.Inode.overlay.upper must be non-nil. -// - next.Inode.overlay.copyMu must be locked writable. -// - next.Inode.overlay.lower must be non-nil. -// - next.Inode.overlay.lower.StableAttr.Type must be RegularFile, Directory, +// * parent.Inode.overlay.upper must be non-nil. +// * next.Inode.overlay.copyMu must be locked writable. +// * next.Inode.overlay.lower must be non-nil. +// * next.Inode.overlay.lower.StableAttr.Type must be RegularFile, Directory, // or Symlink. -// - upper filesystem must support setting file ownership and timestamps. +// * upper filesystem must support setting file ownership and timestamps. func copyUpLocked(ctx context.Context, parent *Dirent, next *Dirent) error { // Extract the attributes of the file we wish to copy. attrs, err := next.Inode.overlay.lower.UnstableAttr(ctx) diff --git a/pkg/sentry/fs/dev/net_tun.go b/pkg/sentry/fs/dev/net_tun.go index ec474e554..5f8c9b5a2 100644 --- a/pkg/sentry/fs/dev/net_tun.go +++ b/pkg/sentry/fs/dev/net_tun.go @@ -89,12 +89,13 @@ func (fops *netTunFileOperations) Ioctl(ctx context.Context, file *fs.File, io u request := args[1].Uint() data := args[2].Pointer() + t := kernel.TaskFromContext(ctx) + if t == nil { + panic("Ioctl should be called from a task context") + } + switch request { case linux.TUNSETIFF: - t := kernel.TaskFromContext(ctx) - if t == nil { - panic("Ioctl should be called from a task context") - } if !t.HasCapability(linux.CAP_NET_ADMIN) { return 0, syserror.EPERM } @@ -104,9 +105,7 @@ func (fops *netTunFileOperations) Ioctl(ctx context.Context, file *fs.File, io u } var req linux.IFReq - if _, err := usermem.CopyObjectIn(ctx, io, data, &req, usermem.IOOpts{ - AddressSpaceActive: true, - }); err != nil { + if _, err := req.CopyIn(t, data); err != nil { return 0, err } flags := usermem.ByteOrder.Uint16(req.Data[:]) @@ -122,9 +121,7 @@ func (fops *netTunFileOperations) Ioctl(ctx context.Context, file *fs.File, io u flags := fops.device.Flags() | linux.IFF_NOFILTER usermem.ByteOrder.PutUint16(req.Data[:], flags) - _, err := usermem.CopyObjectOut(ctx, io, data, &req, usermem.IOOpts{ - AddressSpaceActive: true, - }) + _, err := req.CopyOut(t, data) return 0, err default: diff --git a/pkg/sentry/fs/dirent.go b/pkg/sentry/fs/dirent.go index a2f751068..00c526b03 100644 --- a/pkg/sentry/fs/dirent.go +++ b/pkg/sentry/fs/dirent.go @@ -413,9 +413,9 @@ func (d *Dirent) descendantOf(p *Dirent) bool { // Inode.Lookup, otherwise walk will keep d.mu locked. // // Preconditions: -// - renameMu must be held for reading. -// - d.mu must be held. -// - name must must not contain "/"s. +// * renameMu must be held for reading. +// * d.mu must be held. +// * name must must not contain "/"s. func (d *Dirent) walk(ctx context.Context, root *Dirent, name string, walkMayUnlock bool) (*Dirent, error) { if !IsDir(d.Inode.StableAttr) { return nil, syscall.ENOTDIR @@ -577,9 +577,9 @@ func (d *Dirent) Walk(ctx context.Context, root *Dirent, name string) (*Dirent, // exists returns true if name exists in relation to d. // // Preconditions: -// - renameMu must be held for reading. -// - d.mu must be held. -// - name must must not contain "/"s. +// * renameMu must be held for reading. +// * d.mu must be held. +// * name must must not contain "/"s. func (d *Dirent) exists(ctx context.Context, root *Dirent, name string) bool { child, err := d.walk(ctx, root, name, false /* may unlock */) if err != nil { diff --git a/pkg/sentry/fs/file_operations.go b/pkg/sentry/fs/file_operations.go index 305c0f840..6ec721022 100644 --- a/pkg/sentry/fs/file_operations.go +++ b/pkg/sentry/fs/file_operations.go @@ -159,8 +159,9 @@ type FileOperations interface { // io provides access to the virtual memory space to which pointers in args // refer. // - // Preconditions: The AddressSpace (if any) that io refers to is activated. - // Must only be called from a task goroutine. + // Preconditions: + // * The AddressSpace (if any) that io refers to is activated. + // * Must only be called from a task goroutine. Ioctl(ctx context.Context, file *File, io usermem.IO, args arch.SyscallArguments) (uintptr, error) } diff --git a/pkg/sentry/fs/fsutil/file_range_set.go b/pkg/sentry/fs/fsutil/file_range_set.go index bbafebf03..9197aeb88 100644 --- a/pkg/sentry/fs/fsutil/file_range_set.go +++ b/pkg/sentry/fs/fsutil/file_range_set.go @@ -70,7 +70,9 @@ func (seg FileRangeIterator) FileRange() memmap.FileRange { // FileRangeOf returns the FileRange mapped by mr. // -// Preconditions: seg.Range().IsSupersetOf(mr). mr.Length() != 0. +// Preconditions: +// * seg.Range().IsSupersetOf(mr). +// * mr.Length() != 0. func (seg FileRangeIterator) FileRangeOf(mr memmap.MappableRange) memmap.FileRange { frstart := seg.Value() + (mr.Start - seg.Start()) return memmap.FileRange{frstart, frstart + mr.Length()} @@ -88,8 +90,10 @@ func (seg FileRangeIterator) FileRangeOf(mr memmap.MappableRange) memmap.FileRan // outside of optional. It returns a non-nil error if any error occurs, even // if the error only affects offsets in optional, but not in required. // -// Preconditions: required.Length() > 0. optional.IsSupersetOf(required). -// required and optional must be page-aligned. +// Preconditions: +// * required.Length() > 0. +// * optional.IsSupersetOf(required). +// * required and optional must be page-aligned. func (frs *FileRangeSet) Fill(ctx context.Context, required, optional memmap.MappableRange, mf *pgalloc.MemoryFile, kind usage.MemoryKind, readAt func(ctx context.Context, dsts safemem.BlockSeq, offset uint64) (uint64, error)) error { gap := frs.LowerBoundGap(required.Start) for gap.Ok() && gap.Start() < required.End { diff --git a/pkg/sentry/fs/fsutil/host_file_mapper.go b/pkg/sentry/fs/fsutil/host_file_mapper.go index ef0113b52..1390a9a7f 100644 --- a/pkg/sentry/fs/fsutil/host_file_mapper.go +++ b/pkg/sentry/fs/fsutil/host_file_mapper.go @@ -80,7 +80,9 @@ func NewHostFileMapper() *HostFileMapper { // IncRefOn increments the reference count on all offsets in mr. // -// Preconditions: mr.Length() != 0. mr.Start and mr.End must be page-aligned. +// Preconditions: +// * mr.Length() != 0. +// * mr.Start and mr.End must be page-aligned. func (f *HostFileMapper) IncRefOn(mr memmap.MappableRange) { f.refsMu.Lock() defer f.refsMu.Unlock() @@ -97,7 +99,9 @@ func (f *HostFileMapper) IncRefOn(mr memmap.MappableRange) { // DecRefOn decrements the reference count on all offsets in mr. // -// Preconditions: mr.Length() != 0. mr.Start and mr.End must be page-aligned. +// Preconditions: +// * mr.Length() != 0. +// * mr.Start and mr.End must be page-aligned. func (f *HostFileMapper) DecRefOn(mr memmap.MappableRange) { f.refsMu.Lock() defer f.refsMu.Unlock() @@ -204,7 +208,9 @@ func (f *HostFileMapper) UnmapAll() { } } -// Preconditions: f.mapsMu must be locked. f.mappings[chunkStart] == m. +// Preconditions: +// * f.mapsMu must be locked. +// * f.mappings[chunkStart] == m. func (f *HostFileMapper) unmapAndRemoveLocked(chunkStart uint64, m mapping) { if _, _, errno := syscall.Syscall(syscall.SYS_MUNMAP, m.addr, chunkSize, 0); errno != 0 { // This leaks address space and is unexpected, but is otherwise diff --git a/pkg/sentry/fs/fsutil/inode_cached.go b/pkg/sentry/fs/fsutil/inode_cached.go index fe8b0b6ac..9eb6f522e 100644 --- a/pkg/sentry/fs/fsutil/inode_cached.go +++ b/pkg/sentry/fs/fsutil/inode_cached.go @@ -684,7 +684,9 @@ func (rw *inodeReadWriter) ReadToBlocks(dsts safemem.BlockSeq) (uint64, error) { // maybeGrowFile grows the file's size if data has been written past the old // size. // -// Preconditions: rw.c.attrMu and rw.c.dataMu bust be locked. +// Preconditions: +// * rw.c.attrMu must be locked. +// * rw.c.dataMu must be locked. func (rw *inodeReadWriter) maybeGrowFile() { // If the write ends beyond the file's previous size, it causes the // file to grow. diff --git a/pkg/sentry/fs/g3doc/fuse.md b/pkg/sentry/fs/g3doc/fuse.md index 2ca84dd74..05e043583 100644 --- a/pkg/sentry/fs/g3doc/fuse.md +++ b/pkg/sentry/fs/g3doc/fuse.md @@ -79,7 +79,7 @@ ops can be implemented in parallel. - Implement `/dev/fuse` - a character device used to establish an FD for communication between the sentry and the server daemon. -- Implement basic FUSE ops like `FUSE_INIT`, `FUSE_DESTROY`. +- Implement basic FUSE ops like `FUSE_INIT`. #### Read-only mount with basic file operations @@ -95,6 +95,103 @@ ops can be implemented in parallel. - Implement the remaining FUSE ops and decide if we can omit rarely used operations like ioctl. +### Design Details + +#### Lifecycle for a FUSE Request + +- User invokes a syscall +- Sentry prepares corresponding request + - If FUSE device is available + - Write the request in binary + - If FUSE device is full + - Kernel task blocked until available +- Sentry notifies the readers of fuse device that it's ready for read +- FUSE daemon reads the request and processes it +- Sentry waits until a reply is written to the FUSE device + - but returns directly for async requests +- FUSE daemon writes to the fuse device +- Sentry processes the reply + - For sync requests, unblock blocked kernel task + - For async requests, execute pre-specified callback if any +- Sentry returns the syscall to the user + +#### Channels and Queues for Requests in Different Stages + +`connection.initializedChan` + +- a channel that the requests issued before connection initialization blocks + on. + +`fd.queue` + +- a queue of requests that haven’t been read by the FUSE daemon yet. + +`fd.completions` + +- a map of the requests that have been prepared but not yet received a + response, including the ones on the `fd.queue`. + +`fd.waitQueue` + +- a queue of waiters that is waiting for the fuse device fd to be available, + such as the FUSE daemon. + +`fd.fullQueueCh` + +- a channel that the kernel task will be blocked on when the fd is not + available. + +#### Basic I/O Implementation + +Currently we have implemented basic functionalities of read and write for our +FUSE. We describe the design and ways to improve it here: + +##### Basic FUSE Read + +The vfs2 expects implementations of `vfs.FileDescriptionImpl.Read()` and +`vfs.FileDescriptionImpl.PRead()`. When a syscall is made, it will eventually +reach our implementation of those interface functions located at +`pkg/sentry/fsimpl/fuse/regular_file.go` for regular files. + +After validation checks of the input, sentry sends `FUSE_READ` requests to the +FUSE daemon. The FUSE daemon returns data after the `fuse_out_header` as the +responses. For the first version, we create a copy in kernel memory of those +data. They are represented as a byte slice in the marshalled struct. This +happens as a common process for all the FUSE responses at this moment at +`pkg/sentry/fsimpl/fuse/dev.go:writeLocked()`. We then directly copy from this +intermediate buffer to the input buffer provided by the read syscall. + +There is an extra requirement for FUSE: When mounting the FUSE fs, the mounter +or the FUSE daemon can specify a `max_read` or a `max_pages` parameter. They are +the upperbound of the bytes to read in each `FUSE_READ` request. We implemented +the code to handle the fragmented reads. + +To improve the performance: ideally we should have buffer cache to copy those +data from the responses of FUSE daemon into, as is also the design of several +other existing file system implementations for sentry, instead of a single-use +temporary buffer. Directly mapping the memory of one process to another could +also boost the performance, but to keep them isolated, we did not choose to do +so. + +##### Basic FUSE Write + +The vfs2 invokes implementations of `vfs.FileDescriptionImpl.Write()` and +`vfs.FileDescriptionImpl.PWrite()` on the regular file descriptor of FUSE when a +user makes write(2) and pwrite(2) syscall. + +For valid writes, sentry sends the bytes to write after a `FUSE_WRITE` header +(can be regarded as a request with 2 payloads) to the FUSE daemon. For the first +version, we allocate a buffer inside kernel memory to store the bytes from the +user, and copy directly from that buffer to the memory of FUSE daemon. This +happens at `pkg/sentry/fsimpl/fuse/dev.go:readLocked()` + +The parameters `max_write` and `max_pages` restrict the number of bytes in one +`FUSE_WRITE`. There are code handling fragmented writes in current +implementation. + +To have better performance: the extra copy created to store the bytes to write +can be replaced by the buffer cache as well. + # Appendix ## FUSE Protocol diff --git a/pkg/sentry/fs/host/BUILD b/pkg/sentry/fs/host/BUILD index d41d23a43..1368014c4 100644 --- a/pkg/sentry/fs/host/BUILD +++ b/pkg/sentry/fs/host/BUILD @@ -32,6 +32,7 @@ go_library( "//pkg/fdnotifier", "//pkg/iovec", "//pkg/log", + "//pkg/marshal/primitive", "//pkg/refs", "//pkg/safemem", "//pkg/secio", diff --git a/pkg/sentry/fs/host/socket_unsafe.go b/pkg/sentry/fs/host/socket_unsafe.go index 5d4f312cf..c8231e0aa 100644 --- a/pkg/sentry/fs/host/socket_unsafe.go +++ b/pkg/sentry/fs/host/socket_unsafe.go @@ -65,10 +65,10 @@ func fdReadVec(fd int, bufs [][]byte, control []byte, peek bool, maxlen int64) ( controlTrunc = msg.Flags&syscall.MSG_CTRUNC == syscall.MSG_CTRUNC if n > length { - return length, n, msg.Controllen, controlTrunc, err + return length, n, msg.Controllen, controlTrunc, nil } - return n, n, msg.Controllen, controlTrunc, err + return n, n, msg.Controllen, controlTrunc, nil } // fdWriteVec sends from bufs to fd. diff --git a/pkg/sentry/fs/host/tty.go b/pkg/sentry/fs/host/tty.go index b5229098c..1183727ab 100644 --- a/pkg/sentry/fs/host/tty.go +++ b/pkg/sentry/fs/host/tty.go @@ -17,6 +17,7 @@ package host import ( "gvisor.dev/gvisor/pkg/abi/linux" "gvisor.dev/gvisor/pkg/context" + "gvisor.dev/gvisor/pkg/marshal/primitive" "gvisor.dev/gvisor/pkg/sentry/arch" "gvisor.dev/gvisor/pkg/sentry/fs" "gvisor.dev/gvisor/pkg/sentry/kernel" @@ -53,7 +54,7 @@ type TTYFileOperations struct { func newTTYFile(ctx context.Context, dirent *fs.Dirent, flags fs.FileFlags, iops *inodeOperations) *fs.File { return fs.NewFile(ctx, dirent, flags, &TTYFileOperations{ fileOperations: fileOperations{iops: iops}, - termios: linux.DefaultSlaveTermios, + termios: linux.DefaultReplicaTermios, }) } @@ -123,6 +124,11 @@ func (t *TTYFileOperations) Release(ctx context.Context) { // Ioctl implements fs.FileOperations.Ioctl. func (t *TTYFileOperations) Ioctl(ctx context.Context, _ *fs.File, io usermem.IO, args arch.SyscallArguments) (uintptr, error) { + task := kernel.TaskFromContext(ctx) + if task == nil { + return 0, syserror.ENOTTY + } + // Ignore arg[0]. This is the real FD: fd := t.fileOperations.iops.fileState.FD() ioctl := args[1].Uint64() @@ -132,9 +138,7 @@ func (t *TTYFileOperations) Ioctl(ctx context.Context, _ *fs.File, io usermem.IO if err != nil { return 0, err } - _, err = usermem.CopyObjectOut(ctx, io, args[2].Pointer(), termios, usermem.IOOpts{ - AddressSpaceActive: true, - }) + _, err = termios.CopyOut(task, args[2].Pointer()) return 0, err case linux.TCSETS, linux.TCSETSW, linux.TCSETSF: @@ -146,9 +150,7 @@ func (t *TTYFileOperations) Ioctl(ctx context.Context, _ *fs.File, io usermem.IO } var termios linux.Termios - if _, err := usermem.CopyObjectIn(ctx, io, args[2].Pointer(), &termios, usermem.IOOpts{ - AddressSpaceActive: true, - }); err != nil { + if _, err := termios.CopyIn(task, args[2].Pointer()); err != nil { return 0, err } err := ioctlSetTermios(fd, ioctl, &termios) @@ -173,10 +175,8 @@ func (t *TTYFileOperations) Ioctl(ctx context.Context, _ *fs.File, io usermem.IO // Map the ProcessGroup into a ProcessGroupID in the task's PID // namespace. - pgID := pidns.IDOfProcessGroup(t.fgProcessGroup) - _, err := usermem.CopyObjectOut(ctx, io, args[2].Pointer(), &pgID, usermem.IOOpts{ - AddressSpaceActive: true, - }) + pgID := primitive.Int32(pidns.IDOfProcessGroup(t.fgProcessGroup)) + _, err := pgID.CopyOut(task, args[2].Pointer()) return 0, err case linux.TIOCSPGRP: @@ -184,11 +184,6 @@ func (t *TTYFileOperations) Ioctl(ctx context.Context, _ *fs.File, io usermem.IO // Equivalent to tcsetpgrp(fd, *argp). // Set the foreground process group ID of this terminal. - task := kernel.TaskFromContext(ctx) - if task == nil { - return 0, syserror.ENOTTY - } - t.mu.Lock() defer t.mu.Unlock() @@ -208,12 +203,11 @@ func (t *TTYFileOperations) Ioctl(ctx context.Context, _ *fs.File, io usermem.IO return 0, syserror.ENOTTY } - var pgID kernel.ProcessGroupID - if _, err := usermem.CopyObjectIn(ctx, io, args[2].Pointer(), &pgID, usermem.IOOpts{ - AddressSpaceActive: true, - }); err != nil { + var pgIDP primitive.Int32 + if _, err := pgIDP.CopyIn(task, args[2].Pointer()); err != nil { return 0, err } + pgID := kernel.ProcessGroupID(pgIDP) // pgID must be non-negative. if pgID < 0 { @@ -242,9 +236,7 @@ func (t *TTYFileOperations) Ioctl(ctx context.Context, _ *fs.File, io usermem.IO if err != nil { return 0, err } - _, err = usermem.CopyObjectOut(ctx, io, args[2].Pointer(), winsize, usermem.IOOpts{ - AddressSpaceActive: true, - }) + _, err = winsize.CopyOut(task, args[2].Pointer()) return 0, err case linux.TIOCSWINSZ: @@ -255,9 +247,7 @@ func (t *TTYFileOperations) Ioctl(ctx context.Context, _ *fs.File, io usermem.IO // background ones) can set the winsize. var winsize linux.Winsize - if _, err := usermem.CopyObjectIn(ctx, io, args[2].Pointer(), &winsize, usermem.IOOpts{ - AddressSpaceActive: true, - }); err != nil { + if _, err := winsize.CopyIn(task, args[2].Pointer()); err != nil { return 0, err } err := ioctlSetWinsize(fd, &winsize) @@ -358,7 +348,7 @@ func (t *TTYFileOperations) checkChange(ctx context.Context, sig linux.Signal) e // // Linux ignores the result of kill_pgrp(). _ = pg.SendSignal(kernel.SignalInfoPriv(sig)) - return kernel.ERESTARTSYS + return syserror.ERESTARTSYS } // LINT.ThenChange(../../fsimpl/host/tty.go) diff --git a/pkg/sentry/fs/inode.go b/pkg/sentry/fs/inode.go index b79cd9877..004910453 100644 --- a/pkg/sentry/fs/inode.go +++ b/pkg/sentry/fs/inode.go @@ -270,7 +270,7 @@ func (i *Inode) GetXattr(ctx context.Context, name string, size uint64) (string, // SetXattr calls i.InodeOperations.SetXattr with i as the Inode. func (i *Inode) SetXattr(ctx context.Context, d *Dirent, name, value string, flags uint32) error { if i.overlay != nil { - return overlaySetxattr(ctx, i.overlay, d, name, value, flags) + return overlaySetXattr(ctx, i.overlay, d, name, value, flags) } return i.InodeOperations.SetXattr(ctx, i, name, value, flags) } diff --git a/pkg/sentry/fs/inode_overlay.go b/pkg/sentry/fs/inode_overlay.go index dc2e353d9..b16ab08ba 100644 --- a/pkg/sentry/fs/inode_overlay.go +++ b/pkg/sentry/fs/inode_overlay.go @@ -16,7 +16,6 @@ package fs import ( "fmt" - "strings" "gvisor.dev/gvisor/pkg/abi/linux" "gvisor.dev/gvisor/pkg/context" @@ -539,7 +538,7 @@ func overlayGetXattr(ctx context.Context, o *overlayEntry, name string, size uin // Don't forward the value of the extended attribute if it would // unexpectedly change the behavior of a wrapping overlay layer. - if strings.HasPrefix(XattrOverlayPrefix, name) { + if isXattrOverlay(name) { return "", syserror.ENODATA } @@ -553,9 +552,9 @@ func overlayGetXattr(ctx context.Context, o *overlayEntry, name string, size uin return s, err } -func overlaySetxattr(ctx context.Context, o *overlayEntry, d *Dirent, name, value string, flags uint32) error { +func overlaySetXattr(ctx context.Context, o *overlayEntry, d *Dirent, name, value string, flags uint32) error { // Don't allow changes to overlay xattrs through a setxattr syscall. - if strings.HasPrefix(XattrOverlayPrefix, name) { + if isXattrOverlay(name) { return syserror.EPERM } @@ -578,7 +577,7 @@ func overlayListXattr(ctx context.Context, o *overlayEntry, size uint64) (map[st for name := range names { // Same as overlayGetXattr, we shouldn't forward along // overlay attributes. - if strings.HasPrefix(XattrOverlayPrefix, name) { + if isXattrOverlay(name) { delete(names, name) } } @@ -587,7 +586,7 @@ func overlayListXattr(ctx context.Context, o *overlayEntry, size uint64) (map[st func overlayRemoveXattr(ctx context.Context, o *overlayEntry, d *Dirent, name string) error { // Don't allow changes to overlay xattrs through a removexattr syscall. - if strings.HasPrefix(XattrOverlayPrefix, name) { + if isXattrOverlay(name) { return syserror.EPERM } diff --git a/pkg/sentry/fs/overlay.go b/pkg/sentry/fs/overlay.go index 35013a21b..01a1235b8 100644 --- a/pkg/sentry/fs/overlay.go +++ b/pkg/sentry/fs/overlay.go @@ -86,13 +86,12 @@ func isXattrOverlay(name string) bool { // NewOverlayRoot produces the root of an overlay. // // Preconditions: -// -// - upper and lower must be non-nil. -// - upper must not be an overlay. -// - lower should not expose character devices, pipes, or sockets, because +// * upper and lower must be non-nil. +// * upper must not be an overlay. +// * lower should not expose character devices, pipes, or sockets, because // copying up these types of files is not supported. -// - lower must not require that file objects be revalidated. -// - lower must not have dynamic file/directory content. +// * lower must not require that file objects be revalidated. +// * lower must not have dynamic file/directory content. func NewOverlayRoot(ctx context.Context, upper *Inode, lower *Inode, flags MountSourceFlags) (*Inode, error) { if !IsDir(upper.StableAttr) { return nil, fmt.Errorf("upper Inode is a %v, not a directory", upper.StableAttr.Type) @@ -117,12 +116,11 @@ func NewOverlayRoot(ctx context.Context, upper *Inode, lower *Inode, flags Mount // NewOverlayRootFile produces the root of an overlay that points to a file. // // Preconditions: -// -// - lower must be non-nil. -// - lower should not expose character devices, pipes, or sockets, because +// * lower must be non-nil. +// * lower should not expose character devices, pipes, or sockets, because // copying up these types of files is not supported. Neither it can be a dir. -// - lower must not require that file objects be revalidated. -// - lower must not have dynamic file/directory content. +// * lower must not require that file objects be revalidated. +// * lower must not have dynamic file/directory content. func NewOverlayRootFile(ctx context.Context, upperMS *MountSource, lower *Inode, flags MountSourceFlags) (*Inode, error) { if !IsRegular(lower.StableAttr) { return nil, fmt.Errorf("lower Inode is not a regular file") diff --git a/pkg/sentry/fs/proc/sys_net.go b/pkg/sentry/fs/proc/sys_net.go index f2f49a7f6..e555672ad 100644 --- a/pkg/sentry/fs/proc/sys_net.go +++ b/pkg/sentry/fs/proc/sys_net.go @@ -55,7 +55,7 @@ type tcpMemInode struct { // size stores the tcp buffer size during save, and sets the buffer // size in netstack in restore. We must save/restore this here, since - // netstack itself is stateless. + // a netstack instance is created on restore. size inet.TCPBufferSize // mu protects against concurrent reads/writes to files based on this @@ -259,6 +259,9 @@ func (f *tcpSackFile) Write(ctx context.Context, _ *fs.File, src usermem.IOSeque if src.NumBytes() == 0 { return 0, nil } + + // Only consider size of one memory page for input for performance reasons. + // We are only reading if it's zero or not anyway. src = src.TakeFirst(usermem.PageSize - 1) var v int32 @@ -390,21 +393,14 @@ func (p *proc) newSysNetCore(ctx context.Context, msrc *fs.MountSource, s inet.S // // +stateify savable type ipForwarding struct { - stack inet.Stack `state:".(ipForwardingState)"` fsutil.SimpleFileInode -} -// ipForwardingState is used to stores a state of netstack -// for packet forwarding because netstack itself is stateless. -// -// +stateify savable -type ipForwardingState struct { stack inet.Stack `state:"wait"` - // enabled stores packet forwarding settings during save, and sets it back - // in netstack in restore. We must save/restore this here, since - // netstack itself is stateless. - enabled bool + // enabled stores the IPv4 forwarding state on save. + // We must save/restore this here, since a netstack instance + // is created on restore. + enabled *bool } func newIPForwardingInode(ctx context.Context, msrc *fs.MountSource, s inet.Stack) *fs.Inode { @@ -441,6 +437,8 @@ type ipForwardingFile struct { fsutil.FileUseInodeUnstableAttr `state:"nosave"` waiter.AlwaysReady `state:"nosave"` + ipf *ipForwarding + stack inet.Stack `state:"wait"` } @@ -450,6 +448,7 @@ func (ipf *ipForwarding) GetFile(ctx context.Context, dirent *fs.Dirent, flags f flags.Pwrite = true return fs.NewFile(ctx, dirent, flags, &ipForwardingFile{ stack: ipf.stack, + ipf: ipf, }), nil } @@ -459,14 +458,18 @@ func (f *ipForwardingFile) Read(ctx context.Context, _ *fs.File, dst usermem.IOS return 0, io.EOF } + if f.ipf.enabled == nil { + enabled := f.stack.Forwarding(ipv4.ProtocolNumber) + f.ipf.enabled = &enabled + } + val := "0\n" - if f.stack.Forwarding(ipv4.ProtocolNumber) { + if *f.ipf.enabled { // Technically, this is not quite compatible with Linux. Linux // stores these as an integer, so if you write "2" into // ip_forward, you should get 2 back. val = "1\n" } - n, err := dst.CopyOut(ctx, []byte(val)) return int64(n), err } @@ -479,7 +482,8 @@ func (f *ipForwardingFile) Write(ctx context.Context, _ *fs.File, src usermem.IO return 0, nil } - // Only consider size of one memory page for input. + // Only consider size of one memory page for input for performance reasons. + // We are only reading if it's zero or not anyway. src = src.TakeFirst(usermem.PageSize - 1) var v int32 @@ -487,9 +491,11 @@ func (f *ipForwardingFile) Write(ctx context.Context, _ *fs.File, src usermem.IO if err != nil { return n, err } - - enabled := v != 0 - return n, f.stack.SetForwarding(ipv4.ProtocolNumber, enabled) + if f.ipf.enabled == nil { + f.ipf.enabled = new(bool) + } + *f.ipf.enabled = v != 0 + return n, f.stack.SetForwarding(ipv4.ProtocolNumber, *f.ipf.enabled) } func (p *proc) newSysNetIPv4Dir(ctx context.Context, msrc *fs.MountSource, s inet.Stack) *fs.Inode { diff --git a/pkg/sentry/fs/proc/sys_net_state.go b/pkg/sentry/fs/proc/sys_net_state.go index 3fadb870e..4cb4741af 100644 --- a/pkg/sentry/fs/proc/sys_net_state.go +++ b/pkg/sentry/fs/proc/sys_net_state.go @@ -45,18 +45,11 @@ func (s *tcpSack) afterLoad() { } } -// saveStack is invoked by stateify. -func (ipf *ipForwarding) saveStack() ipForwardingState { - return ipForwardingState{ - ipf.stack, - ipf.stack.Forwarding(ipv4.ProtocolNumber), - } -} - -// loadStack is invoked by stateify. -func (ipf *ipForwarding) loadStack(s ipForwardingState) { - ipf.stack = s.stack - if err := ipf.stack.SetForwarding(ipv4.ProtocolNumber, s.enabled); err != nil { - panic(fmt.Sprintf("failed to set previous IPv4 forwarding configuration [%v]: %v", s.enabled, err)) +// afterLoad is invoked by stateify. +func (ipf *ipForwarding) afterLoad() { + if ipf.enabled != nil { + if err := ipf.stack.SetForwarding(ipv4.ProtocolNumber, *ipf.enabled); err != nil { + panic(fmt.Sprintf("failed to set IPv4 forwarding [%v]: %v", *ipf.enabled, err)) + } } } diff --git a/pkg/sentry/fs/proc/sys_net_test.go b/pkg/sentry/fs/proc/sys_net_test.go index 72c9857d0..6ef5738e7 100644 --- a/pkg/sentry/fs/proc/sys_net_test.go +++ b/pkg/sentry/fs/proc/sys_net_test.go @@ -176,18 +176,21 @@ func TestIPForwarding(t *testing.T) { for _, c := range cases { t.Run(c.comment, func(t *testing.T) { s.IPForwarding = c.initial - - file := &ipForwardingFile{stack: s} + ipf := &ipForwarding{stack: s} + file := &ipForwardingFile{ + stack: s, + ipf: ipf, + } // Write the values. src := usermem.BytesIOSequence([]byte(c.str)) if n, err := file.Write(ctx, nil, src, 0); n != int64(len(c.str)) || err != nil { - t.Errorf("file.Write(ctx, nil, %v, 0) = (%d, %v); wanted (%d, nil)", c.str, n, err, len(c.str)) + t.Errorf("file.Write(ctx, nil, %q, 0) = (%d, %v); want (%d, nil)", c.str, n, err, len(c.str)) } // Read the values from the stack and check them. - if s.IPForwarding != c.final { - t.Errorf("s.IPForwarding = %v; wanted %v", s.IPForwarding, c.final) + if got, want := s.IPForwarding, c.final; got != want { + t.Errorf("s.IPForwarding incorrect; got: %v, want: %v", got, want) } }) diff --git a/pkg/sentry/fs/proc/task.go b/pkg/sentry/fs/proc/task.go index 9cf7f2a62..103bfc600 100644 --- a/pkg/sentry/fs/proc/task.go +++ b/pkg/sentry/fs/proc/task.go @@ -604,7 +604,7 @@ func (s *statusData) ReadSeqFileData(ctx context.Context, h seqfile.SeqHandle) ( var vss, rss, data uint64 s.t.WithMuLocked(func(t *kernel.Task) { if fdTable := t.FDTable(); fdTable != nil { - fds = fdTable.Size() + fds = fdTable.CurrentMaxFDs() } if mm := t.MemoryManager(); mm != nil { vss = mm.VirtualMemorySize() diff --git a/pkg/sentry/fs/tmpfs/tmpfs.go b/pkg/sentry/fs/tmpfs/tmpfs.go index b095312fe..998b697ca 100644 --- a/pkg/sentry/fs/tmpfs/tmpfs.go +++ b/pkg/sentry/fs/tmpfs/tmpfs.go @@ -16,6 +16,8 @@ package tmpfs import ( + "math" + "gvisor.dev/gvisor/pkg/abi/linux" "gvisor.dev/gvisor/pkg/context" "gvisor.dev/gvisor/pkg/sentry/fs" @@ -32,9 +34,15 @@ import ( var fsInfo = fs.Info{ Type: linux.TMPFS_MAGIC, + // tmpfs currently does not support configurable size limits. In Linux, + // such a tmpfs mount will return f_blocks == f_bfree == f_bavail == 0 from + // statfs(2). However, many applications treat this as having a size limit + // of 0. To work around this, claim to have a very large but non-zero size, + // chosen to ensure that BlockSize * Blocks does not overflow int64 (which + // applications may also handle incorrectly). // TODO(b/29637826): allow configuring a tmpfs size and enforce it. - TotalBlocks: 0, - FreeBlocks: 0, + TotalBlocks: math.MaxInt64 / usermem.PageSize, + FreeBlocks: math.MaxInt64 / usermem.PageSize, } // rename implements fs.InodeOperations.Rename for tmpfs nodes. diff --git a/pkg/sentry/fs/tty/BUILD b/pkg/sentry/fs/tty/BUILD index 5cb0e0417..e6d0eb359 100644 --- a/pkg/sentry/fs/tty/BUILD +++ b/pkg/sentry/fs/tty/BUILD @@ -10,13 +10,14 @@ go_library( "line_discipline.go", "master.go", "queue.go", - "slave.go", + "replica.go", "terminal.go", ], visibility = ["//pkg/sentry:internal"], deps = [ "//pkg/abi/linux", "//pkg/context", + "//pkg/marshal/primitive", "//pkg/refs", "//pkg/safemem", "//pkg/sentry/arch", diff --git a/pkg/sentry/fs/tty/dir.go b/pkg/sentry/fs/tty/dir.go index 463f6189e..c2da80bc2 100644 --- a/pkg/sentry/fs/tty/dir.go +++ b/pkg/sentry/fs/tty/dir.go @@ -37,14 +37,14 @@ import ( // This indirectly manages all terminals within the mount. // // New Terminals are created by masterInodeOperations.GetFile, which registers -// the slave Inode in the this directory for discovery via Lookup/Readdir. The -// slave inode is unregistered when the master file is Released, as the slave +// the replica Inode in the this directory for discovery via Lookup/Readdir. The +// replica inode is unregistered when the master file is Released, as the replica // is no longer discoverable at that point. // // References on the underlying Terminal are held by masterFileOperations and -// slaveInodeOperations. +// replicaInodeOperations. // -// masterInodeOperations and slaveInodeOperations hold a pointer to +// masterInodeOperations and replicaInodeOperations hold a pointer to // dirInodeOperations, which is reference counted by the refcount their // corresponding Dirents hold on their parent (this directory). // @@ -76,16 +76,16 @@ type dirInodeOperations struct { // master is the master PTY inode. master *fs.Inode - // slaves contains the slave inodes reachable from the directory. + // replicas contains the replica inodes reachable from the directory. // - // A new slave is added by allocateTerminal and is removed by + // A new replica is added by allocateTerminal and is removed by // masterFileOperations.Release. // - // A reference is held on every slave in the map. - slaves map[uint32]*fs.Inode + // A reference is held on every replica in the map. + replicas map[uint32]*fs.Inode // dentryMap is a SortedDentryMap used to implement Readdir containing - // the master and all entries in slaves. + // the master and all entries in replicas. dentryMap *fs.SortedDentryMap // next is the next pty index to use. @@ -101,7 +101,7 @@ func newDir(ctx context.Context, m *fs.MountSource) *fs.Inode { d := &dirInodeOperations{ InodeSimpleAttributes: fsutil.NewInodeSimpleAttributes(ctx, fs.RootOwner, fs.FilePermsFromMode(0555), linux.DEVPTS_SUPER_MAGIC), msrc: m, - slaves: make(map[uint32]*fs.Inode), + replicas: make(map[uint32]*fs.Inode), dentryMap: fs.NewSortedDentryMap(nil), } // Linux devpts uses a default mode of 0000 for ptmx which can be @@ -133,7 +133,7 @@ func (d *dirInodeOperations) Release(ctx context.Context) { defer d.mu.Unlock() d.master.DecRef(ctx) - if len(d.slaves) != 0 { + if len(d.replicas) != 0 { panic(fmt.Sprintf("devpts directory still contains active terminals: %+v", d)) } } @@ -149,14 +149,14 @@ func (d *dirInodeOperations) Lookup(ctx context.Context, dir *fs.Inode, name str return fs.NewDirent(ctx, d.master, name), nil } - // Slave number? + // Replica number? n, err := strconv.ParseUint(name, 10, 32) if err != nil { // Not found. return nil, syserror.ENOENT } - s, ok := d.slaves[uint32(n)] + s, ok := d.replicas[uint32(n)] if !ok { return nil, syserror.ENOENT } @@ -236,7 +236,7 @@ func (d *dirInodeOperations) allocateTerminal(ctx context.Context) (*Terminal, e return nil, syserror.ENOMEM } - if _, ok := d.slaves[n]; ok { + if _, ok := d.replicas[n]; ok { panic(fmt.Sprintf("pty index collision; index %d already exists", n)) } @@ -244,19 +244,19 @@ func (d *dirInodeOperations) allocateTerminal(ctx context.Context) (*Terminal, e d.next++ // The reference returned by newTerminal is returned to the caller. - // Take another for the slave inode. + // Take another for the replica inode. t.IncRef() // Create a pts node. The owner is based on the context that opens // ptmx. creds := auth.CredentialsFromContext(ctx) uid, gid := creds.EffectiveKUID, creds.EffectiveKGID - slave := newSlaveInode(ctx, d, t, fs.FileOwner{uid, gid}, fs.FilePermsFromMode(0666)) + replica := newReplicaInode(ctx, d, t, fs.FileOwner{uid, gid}, fs.FilePermsFromMode(0666)) - d.slaves[n] = slave + d.replicas[n] = replica d.dentryMap.Add(strconv.FormatUint(uint64(n), 10), fs.DentAttr{ - Type: slave.StableAttr.Type, - InodeID: slave.StableAttr.InodeID, + Type: replica.StableAttr.Type, + InodeID: replica.StableAttr.InodeID, }) return t, nil @@ -267,18 +267,18 @@ func (d *dirInodeOperations) masterClose(ctx context.Context, t *Terminal) { d.mu.Lock() defer d.mu.Unlock() - // The slave end disappears from the directory when the master end is - // closed, even if the slave end is open elsewhere. + // The replica end disappears from the directory when the master end is + // closed, even if the replica end is open elsewhere. // // N.B. since we're using a backdoor method to remove a directory entry // we won't properly fire inotify events like Linux would. - s, ok := d.slaves[t.n] + s, ok := d.replicas[t.n] if !ok { panic(fmt.Sprintf("Terminal %+v doesn't exist in %+v?", t, d)) } s.DecRef(ctx) - delete(d.slaves, t.n) + delete(d.replicas, t.n) d.dentryMap.Remove(strconv.FormatUint(uint64(t.n), 10)) } diff --git a/pkg/sentry/fs/tty/fs.go b/pkg/sentry/fs/tty/fs.go index 2d4d44bf3..13f4901db 100644 --- a/pkg/sentry/fs/tty/fs.go +++ b/pkg/sentry/fs/tty/fs.go @@ -79,8 +79,8 @@ type superOperations struct{} // // It always returns true, forcing a Lookup for all entries. // -// Slave entries are dropped from dir when their master is closed, so an -// existing slave Dirent in the tree is not sufficient to guarantee that it +// Replica entries are dropped from dir when their master is closed, so an +// existing replica Dirent in the tree is not sufficient to guarantee that it // still exists on the filesystem. func (superOperations) Revalidate(context.Context, string, *fs.Inode, *fs.Inode) bool { return true diff --git a/pkg/sentry/fs/tty/line_discipline.go b/pkg/sentry/fs/tty/line_discipline.go index 2e9dd2d55..b34f4a0eb 100644 --- a/pkg/sentry/fs/tty/line_discipline.go +++ b/pkg/sentry/fs/tty/line_discipline.go @@ -21,6 +21,7 @@ import ( "gvisor.dev/gvisor/pkg/abi/linux" "gvisor.dev/gvisor/pkg/context" "gvisor.dev/gvisor/pkg/sentry/arch" + "gvisor.dev/gvisor/pkg/sentry/kernel" "gvisor.dev/gvisor/pkg/sync" "gvisor.dev/gvisor/pkg/syserror" "gvisor.dev/gvisor/pkg/usermem" @@ -43,7 +44,7 @@ const ( ) // lineDiscipline dictates how input and output are handled between the -// pseudoterminal (pty) master and slave. It can be configured to alter I/O, +// pseudoterminal (pty) master and replica. It can be configured to alter I/O, // modify control characters (e.g. Ctrl-C for SIGINT), etc. The following man // pages are good resources for how to affect the line discipline: // @@ -54,8 +55,8 @@ const ( // // lineDiscipline has a simple structure but supports a multitude of options // (see the above man pages). It consists of two queues of bytes: one from the -// terminal master to slave (the input queue) and one from slave to master (the -// output queue). When bytes are written to one end of the pty, the line +// terminal master to replica (the input queue) and one from replica to master +// (the output queue). When bytes are written to one end of the pty, the line // discipline reads the bytes, modifies them or takes special action if // required, and enqueues them to be read by the other end of the pty: // @@ -64,7 +65,7 @@ const ( // | (inputQueueWrite) +-------------+ (inputQueueRead) | // | | // | v -// masterFD slaveFD +// masterFD replicaFD // ^ | // | | // | output to terminal +--------------+ output from process | @@ -103,8 +104,8 @@ type lineDiscipline struct { // masterWaiter is used to wait on the master end of the TTY. masterWaiter waiter.Queue `state:"zerovalue"` - // slaveWaiter is used to wait on the slave end of the TTY. - slaveWaiter waiter.Queue `state:"zerovalue"` + // replicaWaiter is used to wait on the replica end of the TTY. + replicaWaiter waiter.Queue `state:"zerovalue"` } func newLineDiscipline(termios linux.KernelTermios) *lineDiscipline { @@ -115,27 +116,23 @@ func newLineDiscipline(termios linux.KernelTermios) *lineDiscipline { } // getTermios gets the linux.Termios for the tty. -func (l *lineDiscipline) getTermios(ctx context.Context, io usermem.IO, args arch.SyscallArguments) (uintptr, error) { +func (l *lineDiscipline) getTermios(task *kernel.Task, args arch.SyscallArguments) (uintptr, error) { l.termiosMu.RLock() defer l.termiosMu.RUnlock() // We must copy a Termios struct, not KernelTermios. t := l.termios.ToTermios() - _, err := usermem.CopyObjectOut(ctx, io, args[2].Pointer(), t, usermem.IOOpts{ - AddressSpaceActive: true, - }) + _, err := t.CopyOut(task, args[2].Pointer()) return 0, err } // setTermios sets a linux.Termios for the tty. -func (l *lineDiscipline) setTermios(ctx context.Context, io usermem.IO, args arch.SyscallArguments) (uintptr, error) { +func (l *lineDiscipline) setTermios(task *kernel.Task, args arch.SyscallArguments) (uintptr, error) { l.termiosMu.Lock() defer l.termiosMu.Unlock() oldCanonEnabled := l.termios.LEnabled(linux.ICANON) // We must copy a Termios struct, not KernelTermios. var t linux.Termios - _, err := usermem.CopyObjectIn(ctx, io, args[2].Pointer(), &t, usermem.IOOpts{ - AddressSpaceActive: true, - }) + _, err := t.CopyIn(task, args[2].Pointer()) l.termios.FromTermios(t) // If canonical mode is turned off, move bytes from inQueue's wait @@ -146,27 +143,23 @@ func (l *lineDiscipline) setTermios(ctx context.Context, io usermem.IO, args arc l.inQueue.pushWaitBufLocked(l) l.inQueue.readable = true l.inQueue.mu.Unlock() - l.slaveWaiter.Notify(waiter.EventIn) + l.replicaWaiter.Notify(waiter.EventIn) } return 0, err } -func (l *lineDiscipline) windowSize(ctx context.Context, io usermem.IO, args arch.SyscallArguments) error { +func (l *lineDiscipline) windowSize(t *kernel.Task, args arch.SyscallArguments) error { l.sizeMu.Lock() defer l.sizeMu.Unlock() - _, err := usermem.CopyObjectOut(ctx, io, args[2].Pointer(), l.size, usermem.IOOpts{ - AddressSpaceActive: true, - }) + _, err := l.size.CopyOut(t, args[2].Pointer()) return err } -func (l *lineDiscipline) setWindowSize(ctx context.Context, io usermem.IO, args arch.SyscallArguments) error { +func (l *lineDiscipline) setWindowSize(t *kernel.Task, args arch.SyscallArguments) error { l.sizeMu.Lock() defer l.sizeMu.Unlock() - _, err := usermem.CopyObjectIn(ctx, io, args[2].Pointer(), &l.size, usermem.IOOpts{ - AddressSpaceActive: true, - }) + _, err := l.size.CopyIn(t, args[2].Pointer()) return err } @@ -176,14 +169,14 @@ func (l *lineDiscipline) masterReadiness() waiter.EventMask { return l.inQueue.writeReadiness(&linux.MasterTermios) | l.outQueue.readReadiness(&linux.MasterTermios) } -func (l *lineDiscipline) slaveReadiness() waiter.EventMask { +func (l *lineDiscipline) replicaReadiness() waiter.EventMask { l.termiosMu.RLock() defer l.termiosMu.RUnlock() return l.outQueue.writeReadiness(&l.termios) | l.inQueue.readReadiness(&l.termios) } -func (l *lineDiscipline) inputQueueReadSize(ctx context.Context, io usermem.IO, args arch.SyscallArguments) error { - return l.inQueue.readableSize(ctx, io, args) +func (l *lineDiscipline) inputQueueReadSize(t *kernel.Task, args arch.SyscallArguments) error { + return l.inQueue.readableSize(t, args) } func (l *lineDiscipline) inputQueueRead(ctx context.Context, dst usermem.IOSequence) (int64, error) { @@ -196,7 +189,7 @@ func (l *lineDiscipline) inputQueueRead(ctx context.Context, dst usermem.IOSeque if n > 0 { l.masterWaiter.Notify(waiter.EventOut) if pushed { - l.slaveWaiter.Notify(waiter.EventIn) + l.replicaWaiter.Notify(waiter.EventIn) } return n, nil } @@ -211,14 +204,14 @@ func (l *lineDiscipline) inputQueueWrite(ctx context.Context, src usermem.IOSequ return 0, err } if n > 0 { - l.slaveWaiter.Notify(waiter.EventIn) + l.replicaWaiter.Notify(waiter.EventIn) return n, nil } return 0, syserror.ErrWouldBlock } -func (l *lineDiscipline) outputQueueReadSize(ctx context.Context, io usermem.IO, args arch.SyscallArguments) error { - return l.outQueue.readableSize(ctx, io, args) +func (l *lineDiscipline) outputQueueReadSize(t *kernel.Task, args arch.SyscallArguments) error { + return l.outQueue.readableSize(t, args) } func (l *lineDiscipline) outputQueueRead(ctx context.Context, dst usermem.IOSequence) (int64, error) { @@ -229,7 +222,7 @@ func (l *lineDiscipline) outputQueueRead(ctx context.Context, dst usermem.IOSequ return 0, err } if n > 0 { - l.slaveWaiter.Notify(waiter.EventOut) + l.replicaWaiter.Notify(waiter.EventOut) if pushed { l.masterWaiter.Notify(waiter.EventIn) } diff --git a/pkg/sentry/fs/tty/master.go b/pkg/sentry/fs/tty/master.go index e00746017..b91184b1b 100644 --- a/pkg/sentry/fs/tty/master.go +++ b/pkg/sentry/fs/tty/master.go @@ -17,9 +17,11 @@ package tty import ( "gvisor.dev/gvisor/pkg/abi/linux" "gvisor.dev/gvisor/pkg/context" + "gvisor.dev/gvisor/pkg/marshal/primitive" "gvisor.dev/gvisor/pkg/sentry/arch" "gvisor.dev/gvisor/pkg/sentry/fs" "gvisor.dev/gvisor/pkg/sentry/fs/fsutil" + "gvisor.dev/gvisor/pkg/sentry/kernel" "gvisor.dev/gvisor/pkg/sentry/unimpl" "gvisor.dev/gvisor/pkg/syserror" "gvisor.dev/gvisor/pkg/usermem" @@ -152,46 +154,51 @@ func (mf *masterFileOperations) Write(ctx context.Context, _ *fs.File, src userm // Ioctl implements fs.FileOperations.Ioctl. func (mf *masterFileOperations) Ioctl(ctx context.Context, _ *fs.File, io usermem.IO, args arch.SyscallArguments) (uintptr, error) { + t := kernel.TaskFromContext(ctx) + if t == nil { + // ioctl(2) may only be called from a task goroutine. + return 0, syserror.ENOTTY + } + switch cmd := args[1].Uint(); cmd { case linux.FIONREAD: // linux.FIONREAD == linux.TIOCINQ // Get the number of bytes in the output queue read buffer. - return 0, mf.t.ld.outputQueueReadSize(ctx, io, args) + return 0, mf.t.ld.outputQueueReadSize(t, args) case linux.TCGETS: // N.B. TCGETS on the master actually returns the configuration - // of the slave end. - return mf.t.ld.getTermios(ctx, io, args) + // of the replica end. + return mf.t.ld.getTermios(t, args) case linux.TCSETS: // N.B. TCSETS on the master actually affects the configuration - // of the slave end. - return mf.t.ld.setTermios(ctx, io, args) + // of the replica end. + return mf.t.ld.setTermios(t, args) case linux.TCSETSW: // TODO(b/29356795): This should drain the output queue first. - return mf.t.ld.setTermios(ctx, io, args) + return mf.t.ld.setTermios(t, args) case linux.TIOCGPTN: - _, err := usermem.CopyObjectOut(ctx, io, args[2].Pointer(), uint32(mf.t.n), usermem.IOOpts{ - AddressSpaceActive: true, - }) + nP := primitive.Uint32(mf.t.n) + _, err := nP.CopyOut(t, args[2].Pointer()) return 0, err case linux.TIOCSPTLCK: // TODO(b/29356795): Implement pty locking. For now just pretend we do. return 0, nil case linux.TIOCGWINSZ: - return 0, mf.t.ld.windowSize(ctx, io, args) + return 0, mf.t.ld.windowSize(t, args) case linux.TIOCSWINSZ: - return 0, mf.t.ld.setWindowSize(ctx, io, args) + return 0, mf.t.ld.setWindowSize(t, args) case linux.TIOCSCTTY: // Make the given terminal the controlling terminal of the // calling process. - return 0, mf.t.setControllingTTY(ctx, io, args, true /* isMaster */) + return 0, mf.t.setControllingTTY(ctx, args, true /* isMaster */) case linux.TIOCNOTTY: // Release this process's controlling terminal. - return 0, mf.t.releaseControllingTTY(ctx, io, args, true /* isMaster */) + return 0, mf.t.releaseControllingTTY(ctx, args, true /* isMaster */) case linux.TIOCGPGRP: // Get the foreground process group. - return mf.t.foregroundProcessGroup(ctx, io, args, true /* isMaster */) + return mf.t.foregroundProcessGroup(ctx, args, true /* isMaster */) case linux.TIOCSPGRP: // Set the foreground process group. - return mf.t.setForegroundProcessGroup(ctx, io, args, true /* isMaster */) + return mf.t.setForegroundProcessGroup(ctx, args, true /* isMaster */) default: maybeEmitUnimplementedEvent(ctx, cmd) return 0, syserror.ENOTTY diff --git a/pkg/sentry/fs/tty/queue.go b/pkg/sentry/fs/tty/queue.go index ceabb9b1e..79975d812 100644 --- a/pkg/sentry/fs/tty/queue.go +++ b/pkg/sentry/fs/tty/queue.go @@ -17,8 +17,10 @@ package tty import ( "gvisor.dev/gvisor/pkg/abi/linux" "gvisor.dev/gvisor/pkg/context" + "gvisor.dev/gvisor/pkg/marshal/primitive" "gvisor.dev/gvisor/pkg/safemem" "gvisor.dev/gvisor/pkg/sentry/arch" + "gvisor.dev/gvisor/pkg/sentry/kernel" "gvisor.dev/gvisor/pkg/sync" "gvisor.dev/gvisor/pkg/syserror" "gvisor.dev/gvisor/pkg/usermem" @@ -32,7 +34,7 @@ import ( const waitBufMaxBytes = 131072 // queue represents one of the input or output queues between a pty master and -// slave. Bytes written to a queue are added to the read buffer until it is +// replica. Bytes written to a queue are added to the read buffer until it is // full, at which point they are written to the wait buffer. Bytes are // processed (i.e. undergo termios transformations) as they are added to the // read buffer. The read buffer is readable when its length is nonzero and @@ -85,17 +87,15 @@ func (q *queue) writeReadiness(t *linux.KernelTermios) waiter.EventMask { } // readableSize writes the number of readable bytes to userspace. -func (q *queue) readableSize(ctx context.Context, io usermem.IO, args arch.SyscallArguments) error { +func (q *queue) readableSize(t *kernel.Task, args arch.SyscallArguments) error { q.mu.Lock() defer q.mu.Unlock() - var size int32 + size := primitive.Int32(0) if q.readable { - size = int32(len(q.readBuf)) + size = primitive.Int32(len(q.readBuf)) } - _, err := usermem.CopyObjectOut(ctx, io, args[2].Pointer(), size, usermem.IOOpts{ - AddressSpaceActive: true, - }) + _, err := size.CopyOut(t, args[2].Pointer()) return err } @@ -104,8 +104,7 @@ func (q *queue) readableSize(ctx context.Context, io usermem.IO, args arch.Sysca // as whether the read caused more readable data to become available (whether // data was pushed from the wait buffer to the read buffer). // -// Preconditions: -// * l.termiosMu must be held for reading. +// Preconditions: l.termiosMu must be held for reading. func (q *queue) read(ctx context.Context, dst usermem.IOSequence, l *lineDiscipline) (int64, bool, error) { q.mu.Lock() defer q.mu.Unlock() @@ -145,8 +144,7 @@ func (q *queue) read(ctx context.Context, dst usermem.IOSequence, l *lineDiscipl // write writes to q from userspace. // -// Preconditions: -// * l.termiosMu must be held for reading. +// Preconditions: l.termiosMu must be held for reading. func (q *queue) write(ctx context.Context, src usermem.IOSequence, l *lineDiscipline) (int64, error) { q.mu.Lock() defer q.mu.Unlock() @@ -188,8 +186,7 @@ func (q *queue) write(ctx context.Context, src usermem.IOSequence, l *lineDiscip // writeBytes writes to q from b. // -// Preconditions: -// * l.termiosMu must be held for reading. +// Preconditions: l.termiosMu must be held for reading. func (q *queue) writeBytes(b []byte, l *lineDiscipline) { q.mu.Lock() defer q.mu.Unlock() diff --git a/pkg/sentry/fs/tty/slave.go b/pkg/sentry/fs/tty/replica.go index 7c7292687..385d230fb 100644 --- a/pkg/sentry/fs/tty/slave.go +++ b/pkg/sentry/fs/tty/replica.go @@ -17,9 +17,11 @@ package tty import ( "gvisor.dev/gvisor/pkg/abi/linux" "gvisor.dev/gvisor/pkg/context" + "gvisor.dev/gvisor/pkg/marshal/primitive" "gvisor.dev/gvisor/pkg/sentry/arch" "gvisor.dev/gvisor/pkg/sentry/fs" "gvisor.dev/gvisor/pkg/sentry/fs/fsutil" + "gvisor.dev/gvisor/pkg/sentry/kernel" "gvisor.dev/gvisor/pkg/syserror" "gvisor.dev/gvisor/pkg/usermem" "gvisor.dev/gvisor/pkg/waiter" @@ -27,11 +29,11 @@ import ( // LINT.IfChange -// slaveInodeOperations are the fs.InodeOperations for the slave end of the +// replicaInodeOperations are the fs.InodeOperations for the replica end of the // Terminal (pts file). // // +stateify savable -type slaveInodeOperations struct { +type replicaInodeOperations struct { fsutil.SimpleFileInode // d is the containing dir. @@ -41,13 +43,13 @@ type slaveInodeOperations struct { t *Terminal } -var _ fs.InodeOperations = (*slaveInodeOperations)(nil) +var _ fs.InodeOperations = (*replicaInodeOperations)(nil) -// newSlaveInode creates an fs.Inode for the slave end of a terminal. +// newReplicaInode creates an fs.Inode for the replica end of a terminal. // -// newSlaveInode takes ownership of t. -func newSlaveInode(ctx context.Context, d *dirInodeOperations, t *Terminal, owner fs.FileOwner, p fs.FilePermissions) *fs.Inode { - iops := &slaveInodeOperations{ +// newReplicaInode takes ownership of t. +func newReplicaInode(ctx context.Context, d *dirInodeOperations, t *Terminal, owner fs.FileOwner, p fs.FilePermissions) *fs.Inode { + iops := &replicaInodeOperations{ SimpleFileInode: *fsutil.NewSimpleFileInode(ctx, owner, p, linux.DEVPTS_SUPER_MAGIC), d: d, t: t, @@ -64,18 +66,18 @@ func newSlaveInode(ctx context.Context, d *dirInodeOperations, t *Terminal, owne Type: fs.CharacterDevice, // See fs/devpts/inode.c:devpts_fill_super. BlockSize: 1024, - DeviceFileMajor: linux.UNIX98_PTY_SLAVE_MAJOR, + DeviceFileMajor: linux.UNIX98_PTY_REPLICA_MAJOR, DeviceFileMinor: t.n, }) } // Release implements fs.InodeOperations.Release. -func (si *slaveInodeOperations) Release(ctx context.Context) { +func (si *replicaInodeOperations) Release(ctx context.Context) { si.t.DecRef(ctx) } // Truncate implements fs.InodeOperations.Truncate. -func (*slaveInodeOperations) Truncate(context.Context, *fs.Inode, int64) error { +func (*replicaInodeOperations) Truncate(context.Context, *fs.Inode, int64) error { return nil } @@ -83,14 +85,15 @@ func (*slaveInodeOperations) Truncate(context.Context, *fs.Inode, int64) error { // // This may race with destruction of the terminal. If the terminal is gone, it // returns ENOENT. -func (si *slaveInodeOperations) GetFile(ctx context.Context, d *fs.Dirent, flags fs.FileFlags) (*fs.File, error) { - return fs.NewFile(ctx, d, flags, &slaveFileOperations{si: si}), nil +func (si *replicaInodeOperations) GetFile(ctx context.Context, d *fs.Dirent, flags fs.FileFlags) (*fs.File, error) { + return fs.NewFile(ctx, d, flags, &replicaFileOperations{si: si}), nil } -// slaveFileOperations are the fs.FileOperations for the slave end of a terminal. +// replicaFileOperations are the fs.FileOperations for the replica end of a +// terminal. // // +stateify savable -type slaveFileOperations struct { +type replicaFileOperations struct { fsutil.FilePipeSeek `state:"nosave"` fsutil.FileNotDirReaddir `state:"nosave"` fsutil.FileNoFsync `state:"nosave"` @@ -100,79 +103,84 @@ type slaveFileOperations struct { fsutil.FileUseInodeUnstableAttr `state:"nosave"` // si is the inode operations. - si *slaveInodeOperations + si *replicaInodeOperations } -var _ fs.FileOperations = (*slaveFileOperations)(nil) +var _ fs.FileOperations = (*replicaFileOperations)(nil) // Release implements fs.FileOperations.Release. -func (sf *slaveFileOperations) Release(context.Context) { +func (sf *replicaFileOperations) Release(context.Context) { } // EventRegister implements waiter.Waitable.EventRegister. -func (sf *slaveFileOperations) EventRegister(e *waiter.Entry, mask waiter.EventMask) { - sf.si.t.ld.slaveWaiter.EventRegister(e, mask) +func (sf *replicaFileOperations) EventRegister(e *waiter.Entry, mask waiter.EventMask) { + sf.si.t.ld.replicaWaiter.EventRegister(e, mask) } // EventUnregister implements waiter.Waitable.EventUnregister. -func (sf *slaveFileOperations) EventUnregister(e *waiter.Entry) { - sf.si.t.ld.slaveWaiter.EventUnregister(e) +func (sf *replicaFileOperations) EventUnregister(e *waiter.Entry) { + sf.si.t.ld.replicaWaiter.EventUnregister(e) } // Readiness implements waiter.Waitable.Readiness. -func (sf *slaveFileOperations) Readiness(mask waiter.EventMask) waiter.EventMask { - return sf.si.t.ld.slaveReadiness() +func (sf *replicaFileOperations) Readiness(mask waiter.EventMask) waiter.EventMask { + return sf.si.t.ld.replicaReadiness() } // Read implements fs.FileOperations.Read. -func (sf *slaveFileOperations) Read(ctx context.Context, _ *fs.File, dst usermem.IOSequence, _ int64) (int64, error) { +func (sf *replicaFileOperations) Read(ctx context.Context, _ *fs.File, dst usermem.IOSequence, _ int64) (int64, error) { return sf.si.t.ld.inputQueueRead(ctx, dst) } // Write implements fs.FileOperations.Write. -func (sf *slaveFileOperations) Write(ctx context.Context, _ *fs.File, src usermem.IOSequence, _ int64) (int64, error) { +func (sf *replicaFileOperations) Write(ctx context.Context, _ *fs.File, src usermem.IOSequence, _ int64) (int64, error) { return sf.si.t.ld.outputQueueWrite(ctx, src) } // Ioctl implements fs.FileOperations.Ioctl. -func (sf *slaveFileOperations) Ioctl(ctx context.Context, _ *fs.File, io usermem.IO, args arch.SyscallArguments) (uintptr, error) { +func (sf *replicaFileOperations) Ioctl(ctx context.Context, _ *fs.File, io usermem.IO, args arch.SyscallArguments) (uintptr, error) { + t := kernel.TaskFromContext(ctx) + if t == nil { + // ioctl(2) may only be called from a task goroutine. + return 0, syserror.ENOTTY + } + switch cmd := args[1].Uint(); cmd { case linux.FIONREAD: // linux.FIONREAD == linux.TIOCINQ // Get the number of bytes in the input queue read buffer. - return 0, sf.si.t.ld.inputQueueReadSize(ctx, io, args) + return 0, sf.si.t.ld.inputQueueReadSize(t, args) case linux.TCGETS: - return sf.si.t.ld.getTermios(ctx, io, args) + return sf.si.t.ld.getTermios(t, args) case linux.TCSETS: - return sf.si.t.ld.setTermios(ctx, io, args) + return sf.si.t.ld.setTermios(t, args) case linux.TCSETSW: // TODO(b/29356795): This should drain the output queue first. - return sf.si.t.ld.setTermios(ctx, io, args) + return sf.si.t.ld.setTermios(t, args) case linux.TIOCGPTN: - _, err := usermem.CopyObjectOut(ctx, io, args[2].Pointer(), uint32(sf.si.t.n), usermem.IOOpts{ - AddressSpaceActive: true, - }) + nP := primitive.Uint32(sf.si.t.n) + _, err := nP.CopyOut(t, args[2].Pointer()) return 0, err case linux.TIOCGWINSZ: - return 0, sf.si.t.ld.windowSize(ctx, io, args) + return 0, sf.si.t.ld.windowSize(t, args) case linux.TIOCSWINSZ: - return 0, sf.si.t.ld.setWindowSize(ctx, io, args) + return 0, sf.si.t.ld.setWindowSize(t, args) case linux.TIOCSCTTY: // Make the given terminal the controlling terminal of the // calling process. - return 0, sf.si.t.setControllingTTY(ctx, io, args, false /* isMaster */) + return 0, sf.si.t.setControllingTTY(ctx, args, false /* isMaster */) case linux.TIOCNOTTY: // Release this process's controlling terminal. - return 0, sf.si.t.releaseControllingTTY(ctx, io, args, false /* isMaster */) + return 0, sf.si.t.releaseControllingTTY(ctx, args, false /* isMaster */) case linux.TIOCGPGRP: // Get the foreground process group. - return sf.si.t.foregroundProcessGroup(ctx, io, args, false /* isMaster */) + return sf.si.t.foregroundProcessGroup(ctx, args, false /* isMaster */) case linux.TIOCSPGRP: // Set the foreground process group. - return sf.si.t.setForegroundProcessGroup(ctx, io, args, false /* isMaster */) + return sf.si.t.setForegroundProcessGroup(ctx, args, false /* isMaster */) default: maybeEmitUnimplementedEvent(ctx, cmd) return 0, syserror.ENOTTY } } -// LINT.ThenChange(../../fsimpl/devpts/slave.go) +// LINT.ThenChange(../../fsimpl/devpts/replica.go) diff --git a/pkg/sentry/fs/tty/terminal.go b/pkg/sentry/fs/tty/terminal.go index ddcccf4da..4f431d74d 100644 --- a/pkg/sentry/fs/tty/terminal.go +++ b/pkg/sentry/fs/tty/terminal.go @@ -17,10 +17,10 @@ package tty import ( "gvisor.dev/gvisor/pkg/abi/linux" "gvisor.dev/gvisor/pkg/context" + "gvisor.dev/gvisor/pkg/marshal/primitive" "gvisor.dev/gvisor/pkg/refs" "gvisor.dev/gvisor/pkg/sentry/arch" "gvisor.dev/gvisor/pkg/sentry/kernel" - "gvisor.dev/gvisor/pkg/usermem" ) // LINT.IfChange @@ -44,19 +44,19 @@ type Terminal struct { // this terminal. This field is immutable. masterKTTY *kernel.TTY - // slaveKTTY contains the controlling process of the slave end of this + // replicaKTTY contains the controlling process of the replica end of this // terminal. This field is immutable. - slaveKTTY *kernel.TTY + replicaKTTY *kernel.TTY } func newTerminal(ctx context.Context, d *dirInodeOperations, n uint32) *Terminal { - termios := linux.DefaultSlaveTermios + termios := linux.DefaultReplicaTermios t := Terminal{ - d: d, - n: n, - ld: newLineDiscipline(termios), - masterKTTY: &kernel.TTY{Index: n}, - slaveKTTY: &kernel.TTY{Index: n}, + d: d, + n: n, + ld: newLineDiscipline(termios), + masterKTTY: &kernel.TTY{Index: n}, + replicaKTTY: &kernel.TTY{Index: n}, } t.EnableLeakCheck("tty.Terminal") return &t @@ -64,7 +64,7 @@ func newTerminal(ctx context.Context, d *dirInodeOperations, n uint32) *Terminal // setControllingTTY makes tm the controlling terminal of the calling thread // group. -func (tm *Terminal) setControllingTTY(ctx context.Context, io usermem.IO, args arch.SyscallArguments, isMaster bool) error { +func (tm *Terminal) setControllingTTY(ctx context.Context, args arch.SyscallArguments, isMaster bool) error { task := kernel.TaskFromContext(ctx) if task == nil { panic("setControllingTTY must be called from a task context") @@ -75,7 +75,7 @@ func (tm *Terminal) setControllingTTY(ctx context.Context, io usermem.IO, args a // releaseControllingTTY removes tm as the controlling terminal of the calling // thread group. -func (tm *Terminal) releaseControllingTTY(ctx context.Context, io usermem.IO, args arch.SyscallArguments, isMaster bool) error { +func (tm *Terminal) releaseControllingTTY(ctx context.Context, args arch.SyscallArguments, isMaster bool) error { task := kernel.TaskFromContext(ctx) if task == nil { panic("releaseControllingTTY must be called from a task context") @@ -85,7 +85,7 @@ func (tm *Terminal) releaseControllingTTY(ctx context.Context, io usermem.IO, ar } // foregroundProcessGroup gets the process group ID of tm's foreground process. -func (tm *Terminal) foregroundProcessGroup(ctx context.Context, io usermem.IO, args arch.SyscallArguments, isMaster bool) (uintptr, error) { +func (tm *Terminal) foregroundProcessGroup(ctx context.Context, args arch.SyscallArguments, isMaster bool) (uintptr, error) { task := kernel.TaskFromContext(ctx) if task == nil { panic("foregroundProcessGroup must be called from a task context") @@ -97,24 +97,21 @@ func (tm *Terminal) foregroundProcessGroup(ctx context.Context, io usermem.IO, a } // Write it out to *arg. - _, err = usermem.CopyObjectOut(ctx, io, args[2].Pointer(), int32(ret), usermem.IOOpts{ - AddressSpaceActive: true, - }) + retP := primitive.Int32(ret) + _, err = retP.CopyOut(task, args[2].Pointer()) return 0, err } // foregroundProcessGroup sets tm's foreground process. -func (tm *Terminal) setForegroundProcessGroup(ctx context.Context, io usermem.IO, args arch.SyscallArguments, isMaster bool) (uintptr, error) { +func (tm *Terminal) setForegroundProcessGroup(ctx context.Context, args arch.SyscallArguments, isMaster bool) (uintptr, error) { task := kernel.TaskFromContext(ctx) if task == nil { panic("setForegroundProcessGroup must be called from a task context") } // Read in the process group ID. - var pgid int32 - if _, err := usermem.CopyObjectIn(ctx, io, args[2].Pointer(), &pgid, usermem.IOOpts{ - AddressSpaceActive: true, - }); err != nil { + var pgid primitive.Int32 + if _, err := pgid.CopyIn(task, args[2].Pointer()); err != nil { return 0, err } @@ -126,7 +123,7 @@ func (tm *Terminal) tty(isMaster bool) *kernel.TTY { if isMaster { return tm.masterKTTY } - return tm.slaveKTTY + return tm.replicaKTTY } // LINT.ThenChange(../../fsimpl/devpts/terminal.go) diff --git a/pkg/sentry/fs/tty/tty_test.go b/pkg/sentry/fs/tty/tty_test.go index 2cbc05678..49edee83d 100644 --- a/pkg/sentry/fs/tty/tty_test.go +++ b/pkg/sentry/fs/tty/tty_test.go @@ -22,8 +22,8 @@ import ( "gvisor.dev/gvisor/pkg/usermem" ) -func TestSimpleMasterToSlave(t *testing.T) { - ld := newLineDiscipline(linux.DefaultSlaveTermios) +func TestSimpleMasterToReplica(t *testing.T) { + ld := newLineDiscipline(linux.DefaultReplicaTermios) ctx := contexttest.Context(t) inBytes := []byte("hello, tty\n") src := usermem.BytesIOSequence(inBytes) diff --git a/pkg/sentry/fsimpl/devpts/BUILD b/pkg/sentry/fsimpl/devpts/BUILD index 93512c9b6..48e13613a 100644 --- a/pkg/sentry/fsimpl/devpts/BUILD +++ b/pkg/sentry/fsimpl/devpts/BUILD @@ -1,7 +1,19 @@ load("//tools:defs.bzl", "go_library", "go_test") +load("//tools/go_generics:defs.bzl", "go_template_instance") licenses(["notice"]) +go_template_instance( + name = "root_inode_refs", + out = "root_inode_refs.go", + package = "devpts", + prefix = "rootInode", + template = "//pkg/refs_vfs2:refs_template", + types = { + "T": "rootInode", + }, +) + go_library( name = "devpts", srcs = [ @@ -9,13 +21,18 @@ go_library( "line_discipline.go", "master.go", "queue.go", - "slave.go", + "replica.go", + "root_inode_refs.go", "terminal.go", ], visibility = ["//pkg/sentry:internal"], deps = [ "//pkg/abi/linux", "//pkg/context", + "//pkg/log", + "//pkg/marshal", + "//pkg/marshal/primitive", + "//pkg/refs", "//pkg/safemem", "//pkg/sentry/arch", "//pkg/sentry/fs/lock", diff --git a/pkg/sentry/fsimpl/devpts/devpts.go b/pkg/sentry/fsimpl/devpts/devpts.go index 7169e91af..f0f2e0be7 100644 --- a/pkg/sentry/fsimpl/devpts/devpts.go +++ b/pkg/sentry/fsimpl/devpts/devpts.go @@ -79,10 +79,11 @@ func (fstype FilesystemType) newFilesystem(vfsObj *vfs.VirtualFilesystem, creds // Construct the root directory. This is always inode id 1. root := &rootInode{ - slaves: make(map[uint32]*slaveInode), + replicas: make(map[uint32]*replicaInode), } root.InodeAttrs.Init(creds, linux.UNNAMED_MAJOR, devMinor, 1, linux.ModeDirectory|0555) root.OrderedChildren.Init(kernfs.OrderedChildrenOptions{}) + root.EnableLeakCheck() root.dentry.Init(root) // Construct the pts master inode and dentry. Linux always uses inode @@ -110,11 +111,13 @@ func (fs *filesystem) Release(ctx context.Context) { // rootInode is the root directory inode for the devpts mounts. type rootInode struct { + implStatFS kernfs.AlwaysValid kernfs.InodeAttrs kernfs.InodeDirectoryNoNewChildren kernfs.InodeNotSymlink kernfs.OrderedChildren + rootInodeRefs locks vfs.FileLocks @@ -130,8 +133,8 @@ type rootInode struct { // mu protects the fields below. mu sync.Mutex - // slaves maps pty ids to slave inodes. - slaves map[uint32]*slaveInode + // replicas maps pty ids to replica inodes. + replicas map[uint32]*replicaInode // nextIdx is the next pty index to use. Must be accessed atomically. // @@ -151,22 +154,22 @@ func (i *rootInode) allocateTerminal(creds *auth.Credentials) (*Terminal, error) idx := i.nextIdx i.nextIdx++ - // Sanity check that slave with idx does not exist. - if _, ok := i.slaves[idx]; ok { + // Sanity check that replica with idx does not exist. + if _, ok := i.replicas[idx]; ok { panic(fmt.Sprintf("pty index collision; index %d already exists", idx)) } - // Create the new terminal and slave. + // Create the new terminal and replica. t := newTerminal(idx) - slave := &slaveInode{ + replica := &replicaInode{ root: i, t: t, } // Linux always uses pty index + 3 as the inode id. See // fs/devpts/inode.c:devpts_pty_new(). - slave.InodeAttrs.Init(creds, i.InodeAttrs.DevMajor(), i.InodeAttrs.DevMinor(), uint64(idx+3), linux.ModeCharacterDevice|0600) - slave.dentry.Init(slave) - i.slaves[idx] = slave + replica.InodeAttrs.Init(creds, i.InodeAttrs.DevMajor(), i.InodeAttrs.DevMinor(), uint64(idx+3), linux.ModeCharacterDevice|0600) + replica.dentry.Init(replica) + i.replicas[idx] = replica return t, nil } @@ -176,16 +179,18 @@ func (i *rootInode) masterClose(t *Terminal) { i.mu.Lock() defer i.mu.Unlock() - // Sanity check that slave with idx exists. - if _, ok := i.slaves[t.n]; !ok { + // Sanity check that replica with idx exists. + if _, ok := i.replicas[t.n]; !ok { panic(fmt.Sprintf("pty with index %d does not exist", t.n)) } - delete(i.slaves, t.n) + delete(i.replicas, t.n) } // Open implements kernfs.Inode.Open. func (i *rootInode) Open(ctx context.Context, rp *vfs.ResolvingPath, vfsd *vfs.Dentry, opts vfs.OpenOptions) (*vfs.FileDescription, error) { - fd, err := kernfs.NewGenericDirectoryFD(rp.Mount(), vfsd, &i.OrderedChildren, &i.locks, &opts) + fd, err := kernfs.NewGenericDirectoryFD(rp.Mount(), vfsd, &i.OrderedChildren, &i.locks, &opts, kernfs.GenericDirectoryFDOptions{ + SeekEnd: kernfs.SeekEndStaticEntries, + }) if err != nil { return nil, err } @@ -200,7 +205,7 @@ func (i *rootInode) Lookup(ctx context.Context, name string) (*vfs.Dentry, error } i.mu.Lock() defer i.mu.Unlock() - if si, ok := i.slaves[uint32(idx)]; ok { + if si, ok := i.replicas[uint32(idx)]; ok { si.dentry.IncRef() return si.dentry.VFSDentry(), nil @@ -212,8 +217,8 @@ func (i *rootInode) Lookup(ctx context.Context, name string) (*vfs.Dentry, error func (i *rootInode) IterDirents(ctx context.Context, cb vfs.IterDirentsCallback, offset, relOffset int64) (int64, error) { i.mu.Lock() defer i.mu.Unlock() - ids := make([]int, 0, len(i.slaves)) - for id := range i.slaves { + ids := make([]int, 0, len(i.replicas)) + for id := range i.replicas { ids = append(ids, int(id)) } sort.Ints(ids) @@ -221,7 +226,7 @@ func (i *rootInode) IterDirents(ctx context.Context, cb vfs.IterDirentsCallback, dirent := vfs.Dirent{ Name: strconv.FormatUint(uint64(id), 10), Type: linux.DT_CHR, - Ino: i.slaves[uint32(id)].InodeAttrs.Ino(), + Ino: i.replicas[uint32(id)].InodeAttrs.Ino(), NextOff: offset + 1, } if err := cb.Handle(dirent); err != nil { @@ -231,3 +236,15 @@ func (i *rootInode) IterDirents(ctx context.Context, cb vfs.IterDirentsCallback, } return offset, nil } + +// DecRef implements kernfs.Inode.DecRef. +func (i *rootInode) DecRef(context.Context) { + i.rootInodeRefs.DecRef(i.Destroy) +} + +type implStatFS struct{} + +// StatFS implements kernfs.Inode.StatFS. +func (*implStatFS) StatFS(context.Context, *vfs.Filesystem) (linux.Statfs, error) { + return vfs.GenericStatFS(linux.DEVPTS_SUPER_MAGIC), nil +} diff --git a/pkg/sentry/fsimpl/devpts/devpts_test.go b/pkg/sentry/fsimpl/devpts/devpts_test.go index b7c149047..448390cfe 100644 --- a/pkg/sentry/fsimpl/devpts/devpts_test.go +++ b/pkg/sentry/fsimpl/devpts/devpts_test.go @@ -22,8 +22,8 @@ import ( "gvisor.dev/gvisor/pkg/usermem" ) -func TestSimpleMasterToSlave(t *testing.T) { - ld := newLineDiscipline(linux.DefaultSlaveTermios) +func TestSimpleMasterToReplica(t *testing.T) { + ld := newLineDiscipline(linux.DefaultReplicaTermios) ctx := contexttest.Context(t) inBytes := []byte("hello, tty\n") src := usermem.BytesIOSequence(inBytes) diff --git a/pkg/sentry/fsimpl/devpts/line_discipline.go b/pkg/sentry/fsimpl/devpts/line_discipline.go index f7bc325d1..e6b0e81cf 100644 --- a/pkg/sentry/fsimpl/devpts/line_discipline.go +++ b/pkg/sentry/fsimpl/devpts/line_discipline.go @@ -21,6 +21,7 @@ import ( "gvisor.dev/gvisor/pkg/abi/linux" "gvisor.dev/gvisor/pkg/context" "gvisor.dev/gvisor/pkg/sentry/arch" + "gvisor.dev/gvisor/pkg/sentry/kernel" "gvisor.dev/gvisor/pkg/sync" "gvisor.dev/gvisor/pkg/syserror" "gvisor.dev/gvisor/pkg/usermem" @@ -41,7 +42,7 @@ const ( ) // lineDiscipline dictates how input and output are handled between the -// pseudoterminal (pty) master and slave. It can be configured to alter I/O, +// pseudoterminal (pty) master and replica. It can be configured to alter I/O, // modify control characters (e.g. Ctrl-C for SIGINT), etc. The following man // pages are good resources for how to affect the line discipline: // @@ -52,8 +53,8 @@ const ( // // lineDiscipline has a simple structure but supports a multitude of options // (see the above man pages). It consists of two queues of bytes: one from the -// terminal master to slave (the input queue) and one from slave to master (the -// output queue). When bytes are written to one end of the pty, the line +// terminal master to replica (the input queue) and one from replica to master +// (the output queue). When bytes are written to one end of the pty, the line // discipline reads the bytes, modifies them or takes special action if // required, and enqueues them to be read by the other end of the pty: // @@ -62,7 +63,7 @@ const ( // | (inputQueueWrite) +-------------+ (inputQueueRead) | // | | // | v -// masterFD slaveFD +// masterFD replicaFD // ^ | // | | // | output to terminal +--------------+ output from process | @@ -101,8 +102,8 @@ type lineDiscipline struct { // masterWaiter is used to wait on the master end of the TTY. masterWaiter waiter.Queue `state:"zerovalue"` - // slaveWaiter is used to wait on the slave end of the TTY. - slaveWaiter waiter.Queue `state:"zerovalue"` + // replicaWaiter is used to wait on the replica end of the TTY. + replicaWaiter waiter.Queue `state:"zerovalue"` } func newLineDiscipline(termios linux.KernelTermios) *lineDiscipline { @@ -113,27 +114,23 @@ func newLineDiscipline(termios linux.KernelTermios) *lineDiscipline { } // getTermios gets the linux.Termios for the tty. -func (l *lineDiscipline) getTermios(ctx context.Context, io usermem.IO, args arch.SyscallArguments) (uintptr, error) { +func (l *lineDiscipline) getTermios(task *kernel.Task, args arch.SyscallArguments) (uintptr, error) { l.termiosMu.RLock() defer l.termiosMu.RUnlock() // We must copy a Termios struct, not KernelTermios. t := l.termios.ToTermios() - _, err := usermem.CopyObjectOut(ctx, io, args[2].Pointer(), t, usermem.IOOpts{ - AddressSpaceActive: true, - }) + _, err := t.CopyOut(task, args[2].Pointer()) return 0, err } // setTermios sets a linux.Termios for the tty. -func (l *lineDiscipline) setTermios(ctx context.Context, io usermem.IO, args arch.SyscallArguments) (uintptr, error) { +func (l *lineDiscipline) setTermios(task *kernel.Task, args arch.SyscallArguments) (uintptr, error) { l.termiosMu.Lock() defer l.termiosMu.Unlock() oldCanonEnabled := l.termios.LEnabled(linux.ICANON) // We must copy a Termios struct, not KernelTermios. var t linux.Termios - _, err := usermem.CopyObjectIn(ctx, io, args[2].Pointer(), &t, usermem.IOOpts{ - AddressSpaceActive: true, - }) + _, err := t.CopyIn(task, args[2].Pointer()) l.termios.FromTermios(t) // If canonical mode is turned off, move bytes from inQueue's wait @@ -144,27 +141,23 @@ func (l *lineDiscipline) setTermios(ctx context.Context, io usermem.IO, args arc l.inQueue.pushWaitBufLocked(l) l.inQueue.readable = true l.inQueue.mu.Unlock() - l.slaveWaiter.Notify(waiter.EventIn) + l.replicaWaiter.Notify(waiter.EventIn) } return 0, err } -func (l *lineDiscipline) windowSize(ctx context.Context, io usermem.IO, args arch.SyscallArguments) error { +func (l *lineDiscipline) windowSize(t *kernel.Task, args arch.SyscallArguments) error { l.sizeMu.Lock() defer l.sizeMu.Unlock() - _, err := usermem.CopyObjectOut(ctx, io, args[2].Pointer(), l.size, usermem.IOOpts{ - AddressSpaceActive: true, - }) + _, err := l.size.CopyOut(t, args[2].Pointer()) return err } -func (l *lineDiscipline) setWindowSize(ctx context.Context, io usermem.IO, args arch.SyscallArguments) error { +func (l *lineDiscipline) setWindowSize(t *kernel.Task, args arch.SyscallArguments) error { l.sizeMu.Lock() defer l.sizeMu.Unlock() - _, err := usermem.CopyObjectIn(ctx, io, args[2].Pointer(), &l.size, usermem.IOOpts{ - AddressSpaceActive: true, - }) + _, err := l.size.CopyIn(t, args[2].Pointer()) return err } @@ -174,14 +167,14 @@ func (l *lineDiscipline) masterReadiness() waiter.EventMask { return l.inQueue.writeReadiness(&linux.MasterTermios) | l.outQueue.readReadiness(&linux.MasterTermios) } -func (l *lineDiscipline) slaveReadiness() waiter.EventMask { +func (l *lineDiscipline) replicaReadiness() waiter.EventMask { l.termiosMu.RLock() defer l.termiosMu.RUnlock() return l.outQueue.writeReadiness(&l.termios) | l.inQueue.readReadiness(&l.termios) } -func (l *lineDiscipline) inputQueueReadSize(ctx context.Context, io usermem.IO, args arch.SyscallArguments) error { - return l.inQueue.readableSize(ctx, io, args) +func (l *lineDiscipline) inputQueueReadSize(t *kernel.Task, io usermem.IO, args arch.SyscallArguments) error { + return l.inQueue.readableSize(t, io, args) } func (l *lineDiscipline) inputQueueRead(ctx context.Context, dst usermem.IOSequence) (int64, error) { @@ -194,7 +187,7 @@ func (l *lineDiscipline) inputQueueRead(ctx context.Context, dst usermem.IOSeque if n > 0 { l.masterWaiter.Notify(waiter.EventOut) if pushed { - l.slaveWaiter.Notify(waiter.EventIn) + l.replicaWaiter.Notify(waiter.EventIn) } return n, nil } @@ -209,14 +202,14 @@ func (l *lineDiscipline) inputQueueWrite(ctx context.Context, src usermem.IOSequ return 0, err } if n > 0 { - l.slaveWaiter.Notify(waiter.EventIn) + l.replicaWaiter.Notify(waiter.EventIn) return n, nil } return 0, syserror.ErrWouldBlock } -func (l *lineDiscipline) outputQueueReadSize(ctx context.Context, io usermem.IO, args arch.SyscallArguments) error { - return l.outQueue.readableSize(ctx, io, args) +func (l *lineDiscipline) outputQueueReadSize(t *kernel.Task, io usermem.IO, args arch.SyscallArguments) error { + return l.outQueue.readableSize(t, io, args) } func (l *lineDiscipline) outputQueueRead(ctx context.Context, dst usermem.IOSequence) (int64, error) { @@ -227,7 +220,7 @@ func (l *lineDiscipline) outputQueueRead(ctx context.Context, dst usermem.IOSequ return 0, err } if n > 0 { - l.slaveWaiter.Notify(waiter.EventOut) + l.replicaWaiter.Notify(waiter.EventOut) if pushed { l.masterWaiter.Notify(waiter.EventIn) } diff --git a/pkg/sentry/fsimpl/devpts/master.go b/pkg/sentry/fsimpl/devpts/master.go index 3bb397f71..83d790b38 100644 --- a/pkg/sentry/fsimpl/devpts/master.go +++ b/pkg/sentry/fsimpl/devpts/master.go @@ -17,9 +17,11 @@ package devpts import ( "gvisor.dev/gvisor/pkg/abi/linux" "gvisor.dev/gvisor/pkg/context" + "gvisor.dev/gvisor/pkg/marshal/primitive" "gvisor.dev/gvisor/pkg/sentry/arch" fslock "gvisor.dev/gvisor/pkg/sentry/fs/lock" "gvisor.dev/gvisor/pkg/sentry/fsimpl/kernfs" + "gvisor.dev/gvisor/pkg/sentry/kernel" "gvisor.dev/gvisor/pkg/sentry/kernel/auth" "gvisor.dev/gvisor/pkg/sentry/unimpl" "gvisor.dev/gvisor/pkg/sentry/vfs" @@ -30,6 +32,7 @@ import ( // masterInode is the inode for the master end of the Terminal. type masterInode struct { + implStatFS kernfs.InodeAttrs kernfs.InodeNoopRefCount kernfs.InodeNotDirectory @@ -130,46 +133,51 @@ func (mfd *masterFileDescription) Write(ctx context.Context, src usermem.IOSeque // Ioctl implements vfs.FileDescriptionImpl.Ioctl. func (mfd *masterFileDescription) Ioctl(ctx context.Context, io usermem.IO, args arch.SyscallArguments) (uintptr, error) { + t := kernel.TaskFromContext(ctx) + if t == nil { + // ioctl(2) may only be called from a task goroutine. + return 0, syserror.ENOTTY + } + switch cmd := args[1].Uint(); cmd { case linux.FIONREAD: // linux.FIONREAD == linux.TIOCINQ // Get the number of bytes in the output queue read buffer. - return 0, mfd.t.ld.outputQueueReadSize(ctx, io, args) + return 0, mfd.t.ld.outputQueueReadSize(t, io, args) case linux.TCGETS: // N.B. TCGETS on the master actually returns the configuration - // of the slave end. - return mfd.t.ld.getTermios(ctx, io, args) + // of the replica end. + return mfd.t.ld.getTermios(t, args) case linux.TCSETS: // N.B. TCSETS on the master actually affects the configuration - // of the slave end. - return mfd.t.ld.setTermios(ctx, io, args) + // of the replica end. + return mfd.t.ld.setTermios(t, args) case linux.TCSETSW: // TODO(b/29356795): This should drain the output queue first. - return mfd.t.ld.setTermios(ctx, io, args) + return mfd.t.ld.setTermios(t, args) case linux.TIOCGPTN: - _, err := usermem.CopyObjectOut(ctx, io, args[2].Pointer(), uint32(mfd.t.n), usermem.IOOpts{ - AddressSpaceActive: true, - }) + nP := primitive.Uint32(mfd.t.n) + _, err := nP.CopyOut(t, args[2].Pointer()) return 0, err case linux.TIOCSPTLCK: // TODO(b/29356795): Implement pty locking. For now just pretend we do. return 0, nil case linux.TIOCGWINSZ: - return 0, mfd.t.ld.windowSize(ctx, io, args) + return 0, mfd.t.ld.windowSize(t, args) case linux.TIOCSWINSZ: - return 0, mfd.t.ld.setWindowSize(ctx, io, args) + return 0, mfd.t.ld.setWindowSize(t, args) case linux.TIOCSCTTY: // Make the given terminal the controlling terminal of the // calling process. - return 0, mfd.t.setControllingTTY(ctx, io, args, true /* isMaster */) + return 0, mfd.t.setControllingTTY(ctx, args, true /* isMaster */) case linux.TIOCNOTTY: // Release this process's controlling terminal. - return 0, mfd.t.releaseControllingTTY(ctx, io, args, true /* isMaster */) + return 0, mfd.t.releaseControllingTTY(ctx, args, true /* isMaster */) case linux.TIOCGPGRP: // Get the foreground process group. - return mfd.t.foregroundProcessGroup(ctx, io, args, true /* isMaster */) + return mfd.t.foregroundProcessGroup(ctx, args, true /* isMaster */) case linux.TIOCSPGRP: // Set the foreground process group. - return mfd.t.setForegroundProcessGroup(ctx, io, args, true /* isMaster */) + return mfd.t.setForegroundProcessGroup(ctx, args, true /* isMaster */) default: maybeEmitUnimplementedEvent(ctx, cmd) return 0, syserror.ENOTTY diff --git a/pkg/sentry/fsimpl/devpts/queue.go b/pkg/sentry/fsimpl/devpts/queue.go index dffb4232c..55bff3e60 100644 --- a/pkg/sentry/fsimpl/devpts/queue.go +++ b/pkg/sentry/fsimpl/devpts/queue.go @@ -17,8 +17,10 @@ package devpts import ( "gvisor.dev/gvisor/pkg/abi/linux" "gvisor.dev/gvisor/pkg/context" + "gvisor.dev/gvisor/pkg/marshal/primitive" "gvisor.dev/gvisor/pkg/safemem" "gvisor.dev/gvisor/pkg/sentry/arch" + "gvisor.dev/gvisor/pkg/sentry/kernel" "gvisor.dev/gvisor/pkg/sync" "gvisor.dev/gvisor/pkg/syserror" "gvisor.dev/gvisor/pkg/usermem" @@ -30,7 +32,7 @@ import ( const waitBufMaxBytes = 131072 // queue represents one of the input or output queues between a pty master and -// slave. Bytes written to a queue are added to the read buffer until it is +// replica. Bytes written to a queue are added to the read buffer until it is // full, at which point they are written to the wait buffer. Bytes are // processed (i.e. undergo termios transformations) as they are added to the // read buffer. The read buffer is readable when its length is nonzero and @@ -83,17 +85,15 @@ func (q *queue) writeReadiness(t *linux.KernelTermios) waiter.EventMask { } // readableSize writes the number of readable bytes to userspace. -func (q *queue) readableSize(ctx context.Context, io usermem.IO, args arch.SyscallArguments) error { +func (q *queue) readableSize(t *kernel.Task, io usermem.IO, args arch.SyscallArguments) error { q.mu.Lock() defer q.mu.Unlock() - var size int32 + size := primitive.Int32(0) if q.readable { - size = int32(len(q.readBuf)) + size = primitive.Int32(len(q.readBuf)) } - _, err := usermem.CopyObjectOut(ctx, io, args[2].Pointer(), size, usermem.IOOpts{ - AddressSpaceActive: true, - }) + _, err := size.CopyOut(t, args[2].Pointer()) return err } @@ -102,8 +102,7 @@ func (q *queue) readableSize(ctx context.Context, io usermem.IO, args arch.Sysca // as whether the read caused more readable data to become available (whether // data was pushed from the wait buffer to the read buffer). // -// Preconditions: -// * l.termiosMu must be held for reading. +// Preconditions: l.termiosMu must be held for reading. func (q *queue) read(ctx context.Context, dst usermem.IOSequence, l *lineDiscipline) (int64, bool, error) { q.mu.Lock() defer q.mu.Unlock() @@ -143,8 +142,7 @@ func (q *queue) read(ctx context.Context, dst usermem.IOSequence, l *lineDiscipl // write writes to q from userspace. // -// Preconditions: -// * l.termiosMu must be held for reading. +// Preconditions: l.termiosMu must be held for reading. func (q *queue) write(ctx context.Context, src usermem.IOSequence, l *lineDiscipline) (int64, error) { q.mu.Lock() defer q.mu.Unlock() @@ -186,8 +184,7 @@ func (q *queue) write(ctx context.Context, src usermem.IOSequence, l *lineDiscip // writeBytes writes to q from b. // -// Preconditions: -// * l.termiosMu must be held for reading. +// Preconditions: l.termiosMu must be held for reading. func (q *queue) writeBytes(b []byte, l *lineDiscipline) { q.mu.Lock() defer q.mu.Unlock() diff --git a/pkg/sentry/fsimpl/devpts/slave.go b/pkg/sentry/fsimpl/devpts/replica.go index 32e4e1908..58f6c1d3a 100644 --- a/pkg/sentry/fsimpl/devpts/slave.go +++ b/pkg/sentry/fsimpl/devpts/replica.go @@ -17,9 +17,11 @@ package devpts import ( "gvisor.dev/gvisor/pkg/abi/linux" "gvisor.dev/gvisor/pkg/context" + "gvisor.dev/gvisor/pkg/marshal/primitive" "gvisor.dev/gvisor/pkg/sentry/arch" fslock "gvisor.dev/gvisor/pkg/sentry/fs/lock" "gvisor.dev/gvisor/pkg/sentry/fsimpl/kernfs" + "gvisor.dev/gvisor/pkg/sentry/kernel" "gvisor.dev/gvisor/pkg/sentry/kernel/auth" "gvisor.dev/gvisor/pkg/sentry/vfs" "gvisor.dev/gvisor/pkg/syserror" @@ -27,8 +29,9 @@ import ( "gvisor.dev/gvisor/pkg/waiter" ) -// slaveInode is the inode for the slave end of the Terminal. -type slaveInode struct { +// replicaInode is the inode for the replica end of the Terminal. +type replicaInode struct { + implStatFS kernfs.InodeAttrs kernfs.InodeNoopRefCount kernfs.InodeNotDirectory @@ -46,12 +49,12 @@ type slaveInode struct { t *Terminal } -var _ kernfs.Inode = (*slaveInode)(nil) +var _ kernfs.Inode = (*replicaInode)(nil) // Open implements kernfs.Inode.Open. -func (si *slaveInode) Open(ctx context.Context, rp *vfs.ResolvingPath, vfsd *vfs.Dentry, opts vfs.OpenOptions) (*vfs.FileDescription, error) { +func (si *replicaInode) Open(ctx context.Context, rp *vfs.ResolvingPath, vfsd *vfs.Dentry, opts vfs.OpenOptions) (*vfs.FileDescription, error) { si.IncRef() - fd := &slaveFileDescription{ + fd := &replicaFileDescription{ inode: si, } fd.LockFD.Init(&si.locks) @@ -64,109 +67,114 @@ func (si *slaveInode) Open(ctx context.Context, rp *vfs.ResolvingPath, vfsd *vfs } // Valid implements kernfs.Inode.Valid. -func (si *slaveInode) Valid(context.Context) bool { - // Return valid if the slave still exists. +func (si *replicaInode) Valid(context.Context) bool { + // Return valid if the replica still exists. si.root.mu.Lock() defer si.root.mu.Unlock() - _, ok := si.root.slaves[si.t.n] + _, ok := si.root.replicas[si.t.n] return ok } // Stat implements kernfs.Inode.Stat. -func (si *slaveInode) Stat(ctx context.Context, vfsfs *vfs.Filesystem, opts vfs.StatOptions) (linux.Statx, error) { +func (si *replicaInode) Stat(ctx context.Context, vfsfs *vfs.Filesystem, opts vfs.StatOptions) (linux.Statx, error) { statx, err := si.InodeAttrs.Stat(ctx, vfsfs, opts) if err != nil { return linux.Statx{}, err } statx.Blksize = 1024 - statx.RdevMajor = linux.UNIX98_PTY_SLAVE_MAJOR + statx.RdevMajor = linux.UNIX98_PTY_REPLICA_MAJOR statx.RdevMinor = si.t.n return statx, nil } // SetStat implements kernfs.Inode.SetStat -func (si *slaveInode) SetStat(ctx context.Context, vfsfs *vfs.Filesystem, creds *auth.Credentials, opts vfs.SetStatOptions) error { +func (si *replicaInode) SetStat(ctx context.Context, vfsfs *vfs.Filesystem, creds *auth.Credentials, opts vfs.SetStatOptions) error { if opts.Stat.Mask&linux.STATX_SIZE != 0 { return syserror.EINVAL } return si.InodeAttrs.SetStat(ctx, vfsfs, creds, opts) } -type slaveFileDescription struct { +type replicaFileDescription struct { vfsfd vfs.FileDescription vfs.FileDescriptionDefaultImpl vfs.LockFD - inode *slaveInode + inode *replicaInode } -var _ vfs.FileDescriptionImpl = (*slaveFileDescription)(nil) +var _ vfs.FileDescriptionImpl = (*replicaFileDescription)(nil) // Release implements fs.FileOperations.Release. -func (sfd *slaveFileDescription) Release(ctx context.Context) { +func (sfd *replicaFileDescription) Release(ctx context.Context) { sfd.inode.DecRef(ctx) } // EventRegister implements waiter.Waitable.EventRegister. -func (sfd *slaveFileDescription) EventRegister(e *waiter.Entry, mask waiter.EventMask) { - sfd.inode.t.ld.slaveWaiter.EventRegister(e, mask) +func (sfd *replicaFileDescription) EventRegister(e *waiter.Entry, mask waiter.EventMask) { + sfd.inode.t.ld.replicaWaiter.EventRegister(e, mask) } // EventUnregister implements waiter.Waitable.EventUnregister. -func (sfd *slaveFileDescription) EventUnregister(e *waiter.Entry) { - sfd.inode.t.ld.slaveWaiter.EventUnregister(e) +func (sfd *replicaFileDescription) EventUnregister(e *waiter.Entry) { + sfd.inode.t.ld.replicaWaiter.EventUnregister(e) } // Readiness implements waiter.Waitable.Readiness. -func (sfd *slaveFileDescription) Readiness(mask waiter.EventMask) waiter.EventMask { - return sfd.inode.t.ld.slaveReadiness() +func (sfd *replicaFileDescription) Readiness(mask waiter.EventMask) waiter.EventMask { + return sfd.inode.t.ld.replicaReadiness() } // Read implements vfs.FileDescriptionImpl.Read. -func (sfd *slaveFileDescription) Read(ctx context.Context, dst usermem.IOSequence, _ vfs.ReadOptions) (int64, error) { +func (sfd *replicaFileDescription) Read(ctx context.Context, dst usermem.IOSequence, _ vfs.ReadOptions) (int64, error) { return sfd.inode.t.ld.inputQueueRead(ctx, dst) } // Write implements vfs.FileDescriptionImpl.Write. -func (sfd *slaveFileDescription) Write(ctx context.Context, src usermem.IOSequence, _ vfs.WriteOptions) (int64, error) { +func (sfd *replicaFileDescription) Write(ctx context.Context, src usermem.IOSequence, _ vfs.WriteOptions) (int64, error) { return sfd.inode.t.ld.outputQueueWrite(ctx, src) } // Ioctl implements vfs.FileDescriptionImpl.Ioctl. -func (sfd *slaveFileDescription) Ioctl(ctx context.Context, io usermem.IO, args arch.SyscallArguments) (uintptr, error) { +func (sfd *replicaFileDescription) Ioctl(ctx context.Context, io usermem.IO, args arch.SyscallArguments) (uintptr, error) { + t := kernel.TaskFromContext(ctx) + if t == nil { + // ioctl(2) may only be called from a task goroutine. + return 0, syserror.ENOTTY + } + switch cmd := args[1].Uint(); cmd { case linux.FIONREAD: // linux.FIONREAD == linux.TIOCINQ // Get the number of bytes in the input queue read buffer. - return 0, sfd.inode.t.ld.inputQueueReadSize(ctx, io, args) + return 0, sfd.inode.t.ld.inputQueueReadSize(t, io, args) case linux.TCGETS: - return sfd.inode.t.ld.getTermios(ctx, io, args) + return sfd.inode.t.ld.getTermios(t, args) case linux.TCSETS: - return sfd.inode.t.ld.setTermios(ctx, io, args) + return sfd.inode.t.ld.setTermios(t, args) case linux.TCSETSW: // TODO(b/29356795): This should drain the output queue first. - return sfd.inode.t.ld.setTermios(ctx, io, args) + return sfd.inode.t.ld.setTermios(t, args) case linux.TIOCGPTN: - _, err := usermem.CopyObjectOut(ctx, io, args[2].Pointer(), uint32(sfd.inode.t.n), usermem.IOOpts{ - AddressSpaceActive: true, - }) + nP := primitive.Uint32(sfd.inode.t.n) + _, err := nP.CopyOut(t, args[2].Pointer()) return 0, err case linux.TIOCGWINSZ: - return 0, sfd.inode.t.ld.windowSize(ctx, io, args) + return 0, sfd.inode.t.ld.windowSize(t, args) case linux.TIOCSWINSZ: - return 0, sfd.inode.t.ld.setWindowSize(ctx, io, args) + return 0, sfd.inode.t.ld.setWindowSize(t, args) case linux.TIOCSCTTY: // Make the given terminal the controlling terminal of the // calling process. - return 0, sfd.inode.t.setControllingTTY(ctx, io, args, false /* isMaster */) + return 0, sfd.inode.t.setControllingTTY(ctx, args, false /* isMaster */) case linux.TIOCNOTTY: // Release this process's controlling terminal. - return 0, sfd.inode.t.releaseControllingTTY(ctx, io, args, false /* isMaster */) + return 0, sfd.inode.t.releaseControllingTTY(ctx, args, false /* isMaster */) case linux.TIOCGPGRP: // Get the foreground process group. - return sfd.inode.t.foregroundProcessGroup(ctx, io, args, false /* isMaster */) + return sfd.inode.t.foregroundProcessGroup(ctx, args, false /* isMaster */) case linux.TIOCSPGRP: // Set the foreground process group. - return sfd.inode.t.setForegroundProcessGroup(ctx, io, args, false /* isMaster */) + return sfd.inode.t.setForegroundProcessGroup(ctx, args, false /* isMaster */) default: maybeEmitUnimplementedEvent(ctx, cmd) return 0, syserror.ENOTTY @@ -174,24 +182,24 @@ func (sfd *slaveFileDescription) Ioctl(ctx context.Context, io usermem.IO, args } // SetStat implements vfs.FileDescriptionImpl.SetStat. -func (sfd *slaveFileDescription) SetStat(ctx context.Context, opts vfs.SetStatOptions) error { +func (sfd *replicaFileDescription) SetStat(ctx context.Context, opts vfs.SetStatOptions) error { creds := auth.CredentialsFromContext(ctx) fs := sfd.vfsfd.VirtualDentry().Mount().Filesystem() return sfd.inode.SetStat(ctx, fs, creds, opts) } // Stat implements vfs.FileDescriptionImpl.Stat. -func (sfd *slaveFileDescription) Stat(ctx context.Context, opts vfs.StatOptions) (linux.Statx, error) { +func (sfd *replicaFileDescription) Stat(ctx context.Context, opts vfs.StatOptions) (linux.Statx, error) { fs := sfd.vfsfd.VirtualDentry().Mount().Filesystem() return sfd.inode.Stat(ctx, fs, opts) } // LockPOSIX implements vfs.FileDescriptionImpl.LockPOSIX. -func (sfd *slaveFileDescription) LockPOSIX(ctx context.Context, uid fslock.UniqueID, t fslock.LockType, start, length uint64, whence int16, block fslock.Blocker) error { +func (sfd *replicaFileDescription) LockPOSIX(ctx context.Context, uid fslock.UniqueID, t fslock.LockType, start, length uint64, whence int16, block fslock.Blocker) error { return sfd.Locks().LockPOSIX(ctx, &sfd.vfsfd, uid, t, start, length, whence, block) } // UnlockPOSIX implements vfs.FileDescriptionImpl.UnlockPOSIX. -func (sfd *slaveFileDescription) UnlockPOSIX(ctx context.Context, uid fslock.UniqueID, start, length uint64, whence int16) error { +func (sfd *replicaFileDescription) UnlockPOSIX(ctx context.Context, uid fslock.UniqueID, start, length uint64, whence int16) error { return sfd.Locks().UnlockPOSIX(ctx, &sfd.vfsfd, uid, start, length, whence) } diff --git a/pkg/sentry/fsimpl/devpts/terminal.go b/pkg/sentry/fsimpl/devpts/terminal.go index 7d2781c54..510bd6d89 100644 --- a/pkg/sentry/fsimpl/devpts/terminal.go +++ b/pkg/sentry/fsimpl/devpts/terminal.go @@ -17,9 +17,9 @@ package devpts import ( "gvisor.dev/gvisor/pkg/abi/linux" "gvisor.dev/gvisor/pkg/context" + "gvisor.dev/gvisor/pkg/marshal/primitive" "gvisor.dev/gvisor/pkg/sentry/arch" "gvisor.dev/gvisor/pkg/sentry/kernel" - "gvisor.dev/gvisor/pkg/usermem" ) // Terminal is a pseudoterminal. @@ -36,25 +36,25 @@ type Terminal struct { // this terminal. This field is immutable. masterKTTY *kernel.TTY - // slaveKTTY contains the controlling process of the slave end of this + // replicaKTTY contains the controlling process of the replica end of this // terminal. This field is immutable. - slaveKTTY *kernel.TTY + replicaKTTY *kernel.TTY } func newTerminal(n uint32) *Terminal { - termios := linux.DefaultSlaveTermios + termios := linux.DefaultReplicaTermios t := Terminal{ - n: n, - ld: newLineDiscipline(termios), - masterKTTY: &kernel.TTY{Index: n}, - slaveKTTY: &kernel.TTY{Index: n}, + n: n, + ld: newLineDiscipline(termios), + masterKTTY: &kernel.TTY{Index: n}, + replicaKTTY: &kernel.TTY{Index: n}, } return &t } // setControllingTTY makes tm the controlling terminal of the calling thread // group. -func (tm *Terminal) setControllingTTY(ctx context.Context, io usermem.IO, args arch.SyscallArguments, isMaster bool) error { +func (tm *Terminal) setControllingTTY(ctx context.Context, args arch.SyscallArguments, isMaster bool) error { task := kernel.TaskFromContext(ctx) if task == nil { panic("setControllingTTY must be called from a task context") @@ -65,7 +65,7 @@ func (tm *Terminal) setControllingTTY(ctx context.Context, io usermem.IO, args a // releaseControllingTTY removes tm as the controlling terminal of the calling // thread group. -func (tm *Terminal) releaseControllingTTY(ctx context.Context, io usermem.IO, args arch.SyscallArguments, isMaster bool) error { +func (tm *Terminal) releaseControllingTTY(ctx context.Context, args arch.SyscallArguments, isMaster bool) error { task := kernel.TaskFromContext(ctx) if task == nil { panic("releaseControllingTTY must be called from a task context") @@ -75,7 +75,7 @@ func (tm *Terminal) releaseControllingTTY(ctx context.Context, io usermem.IO, ar } // foregroundProcessGroup gets the process group ID of tm's foreground process. -func (tm *Terminal) foregroundProcessGroup(ctx context.Context, io usermem.IO, args arch.SyscallArguments, isMaster bool) (uintptr, error) { +func (tm *Terminal) foregroundProcessGroup(ctx context.Context, args arch.SyscallArguments, isMaster bool) (uintptr, error) { task := kernel.TaskFromContext(ctx) if task == nil { panic("foregroundProcessGroup must be called from a task context") @@ -87,24 +87,21 @@ func (tm *Terminal) foregroundProcessGroup(ctx context.Context, io usermem.IO, a } // Write it out to *arg. - _, err = usermem.CopyObjectOut(ctx, io, args[2].Pointer(), int32(ret), usermem.IOOpts{ - AddressSpaceActive: true, - }) + retP := primitive.Int32(ret) + _, err = retP.CopyOut(task, args[2].Pointer()) return 0, err } // foregroundProcessGroup sets tm's foreground process. -func (tm *Terminal) setForegroundProcessGroup(ctx context.Context, io usermem.IO, args arch.SyscallArguments, isMaster bool) (uintptr, error) { +func (tm *Terminal) setForegroundProcessGroup(ctx context.Context, args arch.SyscallArguments, isMaster bool) (uintptr, error) { task := kernel.TaskFromContext(ctx) if task == nil { panic("setForegroundProcessGroup must be called from a task context") } // Read in the process group ID. - var pgid int32 - if _, err := usermem.CopyObjectIn(ctx, io, args[2].Pointer(), &pgid, usermem.IOOpts{ - AddressSpaceActive: true, - }); err != nil { + var pgid primitive.Int32 + if _, err := pgid.CopyIn(task, args[2].Pointer()); err != nil { return 0, err } @@ -116,5 +113,5 @@ func (tm *Terminal) tty(isMaster bool) *kernel.TTY { if isMaster { return tm.masterKTTY } - return tm.slaveKTTY + return tm.replicaKTTY } diff --git a/pkg/sentry/fsimpl/devtmpfs/BUILD b/pkg/sentry/fsimpl/devtmpfs/BUILD index aa0c2ad8c..01bbee5ad 100644 --- a/pkg/sentry/fsimpl/devtmpfs/BUILD +++ b/pkg/sentry/fsimpl/devtmpfs/BUILD @@ -24,6 +24,7 @@ go_test( library = ":devtmpfs", deps = [ "//pkg/abi/linux", + "//pkg/context", "//pkg/fspath", "//pkg/sentry/contexttest", "//pkg/sentry/fsimpl/tmpfs", diff --git a/pkg/sentry/fsimpl/devtmpfs/devtmpfs.go b/pkg/sentry/fsimpl/devtmpfs/devtmpfs.go index 2ed5fa8a9..a23094e54 100644 --- a/pkg/sentry/fsimpl/devtmpfs/devtmpfs.go +++ b/pkg/sentry/fsimpl/devtmpfs/devtmpfs.go @@ -18,6 +18,7 @@ package devtmpfs import ( "fmt" + "path" "gvisor.dev/gvisor/pkg/abi/linux" "gvisor.dev/gvisor/pkg/context" @@ -79,7 +80,7 @@ type Accessor struct { // NewAccessor returns an Accessor that supports creation of device special // files in the devtmpfs instance registered with name fsTypeName in vfsObj. func NewAccessor(ctx context.Context, vfsObj *vfs.VirtualFilesystem, creds *auth.Credentials, fsTypeName string) (*Accessor, error) { - mntns, err := vfsObj.NewMountNamespace(ctx, creds, "devtmpfs" /* source */, fsTypeName, &vfs.GetFilesystemOptions{}) + mntns, err := vfsObj.NewMountNamespace(ctx, creds, "devtmpfs" /* source */, fsTypeName, &vfs.MountOptions{}) if err != nil { return nil, err } @@ -150,13 +151,11 @@ func (a *Accessor) CreateDeviceFile(ctx context.Context, pathname string, kind v // Create any parent directories. See // devtmpfs.c:handle_create()=>path_create(). - for it := fspath.Parse(pathname).Begin; it.NextOk(); it = it.Next() { - pop := a.pathOperationAt(it.String()) - if err := a.vfsObj.MkdirAt(actx, a.creds, pop, &vfs.MkdirOptions{ - Mode: 0755, - }); err != nil { - return fmt.Errorf("failed to create directory %q: %v", it.String(), err) - } + parent := path.Dir(pathname) + if err := a.vfsObj.MkdirAllAt(ctx, parent, a.root, a.creds, &vfs.MkdirOptions{ + Mode: 0755, + }); err != nil { + return fmt.Errorf("failed to create device parent directory %q: %v", parent, err) } // NOTE: Linux's devtmpfs refuses to automatically delete files it didn't diff --git a/pkg/sentry/fsimpl/devtmpfs/devtmpfs_test.go b/pkg/sentry/fsimpl/devtmpfs/devtmpfs_test.go index 747867cca..3a38b8bb4 100644 --- a/pkg/sentry/fsimpl/devtmpfs/devtmpfs_test.go +++ b/pkg/sentry/fsimpl/devtmpfs/devtmpfs_test.go @@ -15,9 +15,11 @@ package devtmpfs import ( + "path" "testing" "gvisor.dev/gvisor/pkg/abi/linux" + "gvisor.dev/gvisor/pkg/context" "gvisor.dev/gvisor/pkg/fspath" "gvisor.dev/gvisor/pkg/sentry/contexttest" "gvisor.dev/gvisor/pkg/sentry/fsimpl/tmpfs" @@ -25,10 +27,13 @@ import ( "gvisor.dev/gvisor/pkg/sentry/vfs" ) -func TestDevtmpfs(t *testing.T) { +const devPath = "/dev" + +func setupDevtmpfs(t *testing.T) (context.Context, *auth.Credentials, *vfs.VirtualFilesystem, vfs.VirtualDentry, func()) { + t.Helper() + ctx := contexttest.Context(t) creds := auth.CredentialsFromContext(ctx) - vfsObj := &vfs.VirtualFilesystem{} if err := vfsObj.Init(ctx); err != nil { t.Fatalf("VFS init: %v", err) @@ -43,14 +48,11 @@ func TestDevtmpfs(t *testing.T) { }) // Create a test mount namespace with devtmpfs mounted at "/dev". - const devPath = "/dev" - mntns, err := vfsObj.NewMountNamespace(ctx, creds, "tmpfs" /* source */, "tmpfs" /* fsTypeName */, &vfs.GetFilesystemOptions{}) + mntns, err := vfsObj.NewMountNamespace(ctx, creds, "tmpfs" /* source */, "tmpfs" /* fsTypeName */, &vfs.MountOptions{}) if err != nil { t.Fatalf("failed to create tmpfs root mount: %v", err) } - defer mntns.DecRef(ctx) root := mntns.Root() - defer root.DecRef(ctx) devpop := vfs.PathOperation{ Root: root, Start: root, @@ -61,10 +63,20 @@ func TestDevtmpfs(t *testing.T) { }); err != nil { t.Fatalf("failed to create mount point: %v", err) } - if err := vfsObj.MountAt(ctx, creds, "devtmpfs" /* source */, &devpop, "devtmpfs" /* fsTypeName */, &vfs.MountOptions{}); err != nil { + if _, err := vfsObj.MountAt(ctx, creds, "devtmpfs" /* source */, &devpop, "devtmpfs" /* fsTypeName */, &vfs.MountOptions{}); err != nil { t.Fatalf("failed to mount devtmpfs: %v", err) } + return ctx, creds, vfsObj, root, func() { + root.DecRef(ctx) + mntns.DecRef(ctx) + } +} + +func TestUserspaceInit(t *testing.T) { + ctx, creds, vfsObj, root, cleanup := setupDevtmpfs(t) + defer cleanup() + a, err := NewAccessor(ctx, vfsObj, creds, "devtmpfs") if err != nil { t.Fatalf("failed to create devtmpfs.Accessor: %v", err) @@ -75,48 +87,143 @@ func TestDevtmpfs(t *testing.T) { if err := a.UserspaceInit(ctx); err != nil { t.Fatalf("failed to userspace-initialize devtmpfs: %v", err) } + // Created files should be visible in the test mount namespace. - abspath := devPath + "/fd" - target, err := vfsObj.ReadlinkAt(ctx, creds, &vfs.PathOperation{ - Root: root, - Start: root, - Path: fspath.Parse(abspath), - }) - if want := "/proc/self/fd"; err != nil || target != want { - t.Fatalf("readlink(%q): got (%q, %v), wanted (%q, nil)", abspath, target, err, want) + links := []struct { + source string + target string + }{ + { + source: "fd", + target: "/proc/self/fd", + }, + { + source: "stdin", + target: "/proc/self/fd/0", + }, + { + source: "stdout", + target: "/proc/self/fd/1", + }, + { + source: "stderr", + target: "/proc/self/fd/2", + }, + { + source: "ptmx", + target: "pts/ptmx", + }, } - // Create a dummy device special file using a devtmpfs.Accessor. - const ( - pathInDev = "dummy" - kind = vfs.CharDevice - major = 12 - minor = 34 - perms = 0600 - wantMode = linux.S_IFCHR | perms - ) - if err := a.CreateDeviceFile(ctx, pathInDev, kind, major, minor, perms); err != nil { - t.Fatalf("failed to create device file: %v", err) + for _, link := range links { + abspath := path.Join(devPath, link.source) + if gotTarget, err := vfsObj.ReadlinkAt(ctx, creds, &vfs.PathOperation{ + Root: root, + Start: root, + Path: fspath.Parse(abspath), + }); err != nil || gotTarget != link.target { + t.Errorf("readlink(%q): got (%q, %v), wanted (%q, nil)", abspath, gotTarget, err, link.target) + } } - // The device special file should be visible in the test mount namespace. - abspath = devPath + "/" + pathInDev - stat, err := vfsObj.StatAt(ctx, creds, &vfs.PathOperation{ - Root: root, - Start: root, - Path: fspath.Parse(abspath), - }, &vfs.StatOptions{ - Mask: linux.STATX_TYPE | linux.STATX_MODE, - }) - if err != nil { - t.Fatalf("failed to stat device file at %q: %v", abspath, err) + + dirs := []string{"shm", "pts"} + for _, dir := range dirs { + abspath := path.Join(devPath, dir) + statx, err := vfsObj.StatAt(ctx, creds, &vfs.PathOperation{ + Root: root, + Start: root, + Path: fspath.Parse(abspath), + }, &vfs.StatOptions{ + Mask: linux.STATX_MODE, + }) + if err != nil { + t.Errorf("stat(%q): got error %v ", abspath, err) + continue + } + if want := uint16(0755) | linux.S_IFDIR; statx.Mode != want { + t.Errorf("stat(%q): got mode %x, want %x", abspath, statx.Mode, want) + } } - if stat.Mode != wantMode { - t.Errorf("device file mode: got %v, wanted %v", stat.Mode, wantMode) +} + +func TestCreateDeviceFile(t *testing.T) { + ctx, creds, vfsObj, root, cleanup := setupDevtmpfs(t) + defer cleanup() + + a, err := NewAccessor(ctx, vfsObj, creds, "devtmpfs") + if err != nil { + t.Fatalf("failed to create devtmpfs.Accessor: %v", err) } - if stat.RdevMajor != major { - t.Errorf("major device number: got %v, wanted %v", stat.RdevMajor, major) + defer a.Release(ctx) + + devFiles := []struct { + path string + kind vfs.DeviceKind + major uint32 + minor uint32 + perms uint16 + }{ + { + path: "dummy", + kind: vfs.CharDevice, + major: 12, + minor: 34, + perms: 0600, + }, + { + path: "foo/bar", + kind: vfs.BlockDevice, + major: 13, + minor: 35, + perms: 0660, + }, + { + path: "foo/baz", + kind: vfs.CharDevice, + major: 12, + minor: 40, + perms: 0666, + }, + { + path: "a/b/c/d/e", + kind: vfs.BlockDevice, + major: 12, + minor: 34, + perms: 0600, + }, } - if stat.RdevMinor != minor { - t.Errorf("minor device number: got %v, wanted %v", stat.RdevMinor, minor) + + for _, f := range devFiles { + if err := a.CreateDeviceFile(ctx, f.path, f.kind, f.major, f.minor, f.perms); err != nil { + t.Fatalf("failed to create device file: %v", err) + } + // The device special file should be visible in the test mount namespace. + abspath := path.Join(devPath, f.path) + stat, err := vfsObj.StatAt(ctx, creds, &vfs.PathOperation{ + Root: root, + Start: root, + Path: fspath.Parse(abspath), + }, &vfs.StatOptions{ + Mask: linux.STATX_TYPE | linux.STATX_MODE, + }) + if err != nil { + t.Fatalf("failed to stat device file at %q: %v", abspath, err) + } + if stat.RdevMajor != f.major { + t.Errorf("major device number: got %v, wanted %v", stat.RdevMajor, f.major) + } + if stat.RdevMinor != f.minor { + t.Errorf("minor device number: got %v, wanted %v", stat.RdevMinor, f.minor) + } + wantMode := f.perms + switch f.kind { + case vfs.CharDevice: + wantMode |= linux.S_IFCHR + case vfs.BlockDevice: + wantMode |= linux.S_IFBLK + } + if stat.Mode != wantMode { + t.Errorf("device file mode: got %v, wanted %v", stat.Mode, wantMode) + } } } diff --git a/pkg/sentry/fsimpl/eventfd/eventfd.go b/pkg/sentry/fsimpl/eventfd/eventfd.go index 812171fa3..bb0bf3a07 100644 --- a/pkg/sentry/fsimpl/eventfd/eventfd.go +++ b/pkg/sentry/fsimpl/eventfd/eventfd.go @@ -30,7 +30,7 @@ import ( "gvisor.dev/gvisor/pkg/waiter" ) -// EventFileDescription implements FileDescriptionImpl for file-based event +// EventFileDescription implements vfs.FileDescriptionImpl for file-based event // notification (eventfd). Eventfds are usually internal to the Sentry but in // certain situations they may be converted into a host-backed eventfd. type EventFileDescription struct { @@ -106,7 +106,7 @@ func (efd *EventFileDescription) HostFD() (int, error) { return efd.hostfd, nil } -// Release implements FileDescriptionImpl.Release() +// Release implements vfs.FileDescriptionImpl.Release. func (efd *EventFileDescription) Release(context.Context) { efd.mu.Lock() defer efd.mu.Unlock() @@ -119,7 +119,7 @@ func (efd *EventFileDescription) Release(context.Context) { } } -// Read implements FileDescriptionImpl.Read. +// Read implements vfs.FileDescriptionImpl.Read. func (efd *EventFileDescription) Read(ctx context.Context, dst usermem.IOSequence, _ vfs.ReadOptions) (int64, error) { if dst.NumBytes() < 8 { return 0, syscall.EINVAL @@ -130,7 +130,7 @@ func (efd *EventFileDescription) Read(ctx context.Context, dst usermem.IOSequenc return 8, nil } -// Write implements FileDescriptionImpl.Write. +// Write implements vfs.FileDescriptionImpl.Write. func (efd *EventFileDescription) Write(ctx context.Context, src usermem.IOSequence, _ vfs.WriteOptions) (int64, error) { if src.NumBytes() < 8 { return 0, syscall.EINVAL diff --git a/pkg/sentry/fsimpl/ext/benchmark/benchmark_test.go b/pkg/sentry/fsimpl/ext/benchmark/benchmark_test.go index 8f7d5a9bb..c349b886e 100644 --- a/pkg/sentry/fsimpl/ext/benchmark/benchmark_test.go +++ b/pkg/sentry/fsimpl/ext/benchmark/benchmark_test.go @@ -59,7 +59,11 @@ func setUp(b *testing.B, imagePath string) (context.Context, *vfs.VirtualFilesys vfsObj.MustRegisterFilesystemType("extfs", ext.FilesystemType{}, &vfs.RegisterFilesystemTypeOptions{ AllowUserMount: true, }) - mntns, err := vfsObj.NewMountNamespace(ctx, creds, imagePath, "extfs", &vfs.GetFilesystemOptions{InternalData: int(f.Fd())}) + mntns, err := vfsObj.NewMountNamespace(ctx, creds, imagePath, "extfs", &vfs.MountOptions{ + GetFilesystemOptions: vfs.GetFilesystemOptions{ + InternalData: int(f.Fd()), + }, + }) if err != nil { f.Close() return nil, nil, nil, nil, err @@ -90,7 +94,7 @@ func mount(b *testing.B, imagePath string, vfsfs *vfs.VirtualFilesystem, pop *vf ctx := contexttest.Context(b) creds := auth.CredentialsFromContext(ctx) - if err := vfsfs.MountAt(ctx, creds, imagePath, pop, "extfs", &vfs.MountOptions{ + if _, err := vfsfs.MountAt(ctx, creds, imagePath, pop, "extfs", &vfs.MountOptions{ GetFilesystemOptions: vfs.GetFilesystemOptions{ InternalData: int(f.Fd()), }, diff --git a/pkg/sentry/fsimpl/ext/ext_test.go b/pkg/sentry/fsimpl/ext/ext_test.go index 2dbaee287..0989558cd 100644 --- a/pkg/sentry/fsimpl/ext/ext_test.go +++ b/pkg/sentry/fsimpl/ext/ext_test.go @@ -71,7 +71,11 @@ func setUp(t *testing.T, imagePath string) (context.Context, *vfs.VirtualFilesys vfsObj.MustRegisterFilesystemType("extfs", FilesystemType{}, &vfs.RegisterFilesystemTypeOptions{ AllowUserMount: true, }) - mntns, err := vfsObj.NewMountNamespace(ctx, creds, localImagePath, "extfs", &vfs.GetFilesystemOptions{InternalData: int(f.Fd())}) + mntns, err := vfsObj.NewMountNamespace(ctx, creds, localImagePath, "extfs", &vfs.MountOptions{ + GetFilesystemOptions: vfs.GetFilesystemOptions{ + InternalData: int(f.Fd()), + }, + }) if err != nil { f.Close() return nil, nil, nil, nil, err diff --git a/pkg/sentry/fsimpl/ext/filesystem.go b/pkg/sentry/fsimpl/ext/filesystem.go index c714ddf73..a4a6d8c55 100644 --- a/pkg/sentry/fsimpl/ext/filesystem.go +++ b/pkg/sentry/fsimpl/ext/filesystem.go @@ -81,9 +81,9 @@ var _ vfs.FilesystemImpl = (*filesystem)(nil) // stepLocked is loosely analogous to fs/namei.c:walk_component(). // // Preconditions: -// - filesystem.mu must be locked (for writing if write param is true). -// - !rp.Done(). -// - inode == vfsd.Impl().(*Dentry).inode. +// * filesystem.mu must be locked (for writing if write param is true). +// * !rp.Done(). +// * inode == vfsd.Impl().(*Dentry).inode. func stepLocked(ctx context.Context, rp *vfs.ResolvingPath, vfsd *vfs.Dentry, inode *inode, write bool) (*vfs.Dentry, *inode, error) { if !inode.isDir() { return nil, nil, syserror.ENOTDIR @@ -166,7 +166,7 @@ func stepLocked(ctx context.Context, rp *vfs.ResolvingPath, vfsd *vfs.Dentry, in // walkLocked is loosely analogous to Linux's fs/namei.c:path_lookupat(). // // Preconditions: -// - filesystem.mu must be locked (for writing if write param is true). +// * filesystem.mu must be locked (for writing if write param is true). func walkLocked(ctx context.Context, rp *vfs.ResolvingPath, write bool) (*vfs.Dentry, *inode, error) { vfsd := rp.Start() inode := vfsd.Impl().(*dentry).inode @@ -194,8 +194,8 @@ func walkLocked(ctx context.Context, rp *vfs.ResolvingPath, write bool) (*vfs.De // walkParentLocked is loosely analogous to Linux's fs/namei.c:path_parentat(). // // Preconditions: -// - filesystem.mu must be locked (for writing if write param is true). -// - !rp.Done(). +// * filesystem.mu must be locked (for writing if write param is true). +// * !rp.Done(). func walkParentLocked(ctx context.Context, rp *vfs.ResolvingPath, write bool) (*vfs.Dentry, *inode, error) { vfsd := rp.Start() inode := vfsd.Impl().(*dentry).inode @@ -490,7 +490,7 @@ func (fs *filesystem) UnlinkAt(ctx context.Context, rp *vfs.ResolvingPath) error return syserror.EROFS } -// BoundEndpointAt implements FilesystemImpl.BoundEndpointAt. +// BoundEndpointAt implements vfs.FilesystemImpl.BoundEndpointAt. func (fs *filesystem) BoundEndpointAt(ctx context.Context, rp *vfs.ResolvingPath, opts vfs.BoundEndpointOptions) (transport.BoundEndpoint, error) { _, inode, err := fs.walk(ctx, rp, false) if err != nil { @@ -504,8 +504,8 @@ func (fs *filesystem) BoundEndpointAt(ctx context.Context, rp *vfs.ResolvingPath return nil, syserror.ECONNREFUSED } -// ListxattrAt implements vfs.FilesystemImpl.ListxattrAt. -func (fs *filesystem) ListxattrAt(ctx context.Context, rp *vfs.ResolvingPath, size uint64) ([]string, error) { +// ListXattrAt implements vfs.FilesystemImpl.ListXattrAt. +func (fs *filesystem) ListXattrAt(ctx context.Context, rp *vfs.ResolvingPath, size uint64) ([]string, error) { _, _, err := fs.walk(ctx, rp, false) if err != nil { return nil, err @@ -513,8 +513,8 @@ func (fs *filesystem) ListxattrAt(ctx context.Context, rp *vfs.ResolvingPath, si return nil, syserror.ENOTSUP } -// GetxattrAt implements vfs.FilesystemImpl.GetxattrAt. -func (fs *filesystem) GetxattrAt(ctx context.Context, rp *vfs.ResolvingPath, opts vfs.GetxattrOptions) (string, error) { +// GetXattrAt implements vfs.FilesystemImpl.GetXattrAt. +func (fs *filesystem) GetXattrAt(ctx context.Context, rp *vfs.ResolvingPath, opts vfs.GetXattrOptions) (string, error) { _, _, err := fs.walk(ctx, rp, false) if err != nil { return "", err @@ -522,8 +522,8 @@ func (fs *filesystem) GetxattrAt(ctx context.Context, rp *vfs.ResolvingPath, opt return "", syserror.ENOTSUP } -// SetxattrAt implements vfs.FilesystemImpl.SetxattrAt. -func (fs *filesystem) SetxattrAt(ctx context.Context, rp *vfs.ResolvingPath, opts vfs.SetxattrOptions) error { +// SetXattrAt implements vfs.FilesystemImpl.SetXattrAt. +func (fs *filesystem) SetXattrAt(ctx context.Context, rp *vfs.ResolvingPath, opts vfs.SetXattrOptions) error { _, _, err := fs.walk(ctx, rp, false) if err != nil { return err @@ -531,8 +531,8 @@ func (fs *filesystem) SetxattrAt(ctx context.Context, rp *vfs.ResolvingPath, opt return syserror.ENOTSUP } -// RemovexattrAt implements vfs.FilesystemImpl.RemovexattrAt. -func (fs *filesystem) RemovexattrAt(ctx context.Context, rp *vfs.ResolvingPath, name string) error { +// RemoveXattrAt implements vfs.FilesystemImpl.RemoveXattrAt. +func (fs *filesystem) RemoveXattrAt(ctx context.Context, rp *vfs.ResolvingPath, name string) error { _, _, err := fs.walk(ctx, rp, false) if err != nil { return err diff --git a/pkg/sentry/fsimpl/ext/symlink.go b/pkg/sentry/fsimpl/ext/symlink.go index 2fd0d1fa8..f33592d59 100644 --- a/pkg/sentry/fsimpl/ext/symlink.go +++ b/pkg/sentry/fsimpl/ext/symlink.go @@ -61,7 +61,7 @@ func (in *inode) isSymlink() bool { return ok } -// symlinkFD represents a symlink file description and implements implements +// symlinkFD represents a symlink file description and implements // vfs.FileDescriptionImpl. which may only be used if open options contains // O_PATH. For this reason most of the functions return EBADF. type symlinkFD struct { diff --git a/pkg/sentry/fsimpl/fuse/BUILD b/pkg/sentry/fsimpl/fuse/BUILD index 999111deb..045d7ab08 100644 --- a/pkg/sentry/fsimpl/fuse/BUILD +++ b/pkg/sentry/fsimpl/fuse/BUILD @@ -15,21 +15,41 @@ go_template_instance( }, ) +go_template_instance( + name = "inode_refs", + out = "inode_refs.go", + package = "fuse", + prefix = "inode", + template = "//pkg/refs_vfs2:refs_template", + types = { + "T": "inode", + }, +) + go_library( name = "fuse", srcs = [ "connection.go", + "connection_control.go", "dev.go", + "directory.go", + "file.go", "fusefs.go", - "init.go", + "inode_refs.go", + "read_write.go", "register.go", + "regular_file.go", "request_list.go", + "request_response.go", ], visibility = ["//pkg/sentry:internal"], deps = [ "//pkg/abi/linux", "//pkg/context", "//pkg/log", + "//pkg/marshal", + "//pkg/refs", + "//pkg/safemem", "//pkg/sentry/fsimpl/devtmpfs", "//pkg/sentry/fsimpl/kernfs", "//pkg/sentry/kernel", @@ -39,7 +59,6 @@ go_library( "//pkg/syserror", "//pkg/usermem", "//pkg/waiter", - "//tools/go_marshal/marshal", "@org_golang_x_sys//unix:go_default_library", ], ) @@ -47,10 +66,15 @@ go_library( go_test( name = "fuse_test", size = "small", - srcs = ["dev_test.go"], + srcs = [ + "connection_test.go", + "dev_test.go", + "utils_test.go", + ], library = ":fuse", deps = [ "//pkg/abi/linux", + "//pkg/marshal", "//pkg/sentry/fsimpl/testutil", "//pkg/sentry/kernel", "//pkg/sentry/kernel/auth", @@ -58,6 +82,5 @@ go_test( "//pkg/syserror", "//pkg/usermem", "//pkg/waiter", - "//tools/go_marshal/marshal", ], ) diff --git a/pkg/sentry/fsimpl/fuse/connection.go b/pkg/sentry/fsimpl/fuse/connection.go index 6df2728ab..dbc5e1954 100644 --- a/pkg/sentry/fsimpl/fuse/connection.go +++ b/pkg/sentry/fsimpl/fuse/connection.go @@ -15,31 +15,17 @@ package fuse import ( - "errors" - "fmt" "sync" - "sync/atomic" - "syscall" "gvisor.dev/gvisor/pkg/abi/linux" "gvisor.dev/gvisor/pkg/context" "gvisor.dev/gvisor/pkg/log" "gvisor.dev/gvisor/pkg/sentry/kernel" - "gvisor.dev/gvisor/pkg/sentry/kernel/auth" "gvisor.dev/gvisor/pkg/sentry/vfs" "gvisor.dev/gvisor/pkg/syserror" "gvisor.dev/gvisor/pkg/waiter" - "gvisor.dev/gvisor/tools/go_marshal/marshal" ) -// maxActiveRequestsDefault is the default setting controlling the upper bound -// on the number of active requests at any given time. -const maxActiveRequestsDefault = 10000 - -// Ordinary requests have even IDs, while interrupts IDs are odd. -// Used to increment the unique ID for each FUSE request. -var reqIDStep uint64 = 2 - const ( // fuseDefaultMaxBackground is the default value for MaxBackground. fuseDefaultMaxBackground = 12 @@ -52,43 +38,36 @@ const ( fuseDefaultMaxPagesPerReq = 32 ) -// Request represents a FUSE operation request that hasn't been sent to the -// server yet. -// -// +stateify savable -type Request struct { - requestEntry - - id linux.FUSEOpID - hdr *linux.FUSEHeaderIn - data []byte -} - -// Response represents an actual response from the server, including the -// response payload. -// -// +stateify savable -type Response struct { - opcode linux.FUSEOpcode - hdr linux.FUSEHeaderOut - data []byte -} - // connection is the struct by which the sentry communicates with the FUSE server daemon. +// Lock order: +// - conn.fd.mu +// - conn.mu +// - conn.asyncMu type connection struct { fd *DeviceFD + // mu protects access to struct memebers. + mu sync.Mutex + + // attributeVersion is the version of connection's attributes. + attributeVersion uint64 + + // We target FUSE 7.23. // The following FUSE_INIT flags are currently unsupported by this implementation: - // - FUSE_ATOMIC_O_TRUNC: requires open(..., O_TRUNC) // - FUSE_EXPORT_SUPPORT - // - FUSE_HANDLE_KILLPRIV // - FUSE_POSIX_LOCKS: requires POSIX locks // - FUSE_FLOCK_LOCKS: requires POSIX locks // - FUSE_AUTO_INVAL_DATA: requires page caching eviction - // - FUSE_EXPLICIT_INVAL_DATA: requires page caching eviction // - FUSE_DO_READDIRPLUS/FUSE_READDIRPLUS_AUTO: requires FUSE_READDIRPLUS implementation // - FUSE_ASYNC_DIO - // - FUSE_POSIX_ACL: affects defaultPermissions, posixACL, xattr handler + // - FUSE_PARALLEL_DIROPS (7.25) + // - FUSE_HANDLE_KILLPRIV (7.26) + // - FUSE_POSIX_ACL: affects defaultPermissions, posixACL, xattr handler (7.26) + // - FUSE_ABORT_ERROR (7.27) + // - FUSE_CACHE_SYMLINKS (7.28) + // - FUSE_NO_OPENDIR_SUPPORT (7.29) + // - FUSE_EXPLICIT_INVAL_DATA: requires page caching eviction (7.30) + // - FUSE_MAP_ALIGNMENT (7.31) // initialized after receiving FUSE_INIT reply. // Until it's set, suspend sending FUSE requests. @@ -98,10 +77,6 @@ type connection struct { // initializedChan is used to block requests before initialization. initializedChan chan struct{} - // blocked when there are too many outstading backgrounds requests (NumBackground == MaxBackground). - // TODO(gvisor.dev/issue/3185): update the numBackground accordingly; use a channel to block. - blocked bool - // connected (connection established) when a new FUSE file system is created. // Set to false when: // umount, @@ -109,48 +84,55 @@ type connection struct { // device release. connected bool - // aborted via sysfs. - // TODO(gvisor.dev/issue/3185): abort all queued requests. - aborted bool - // connInitError if FUSE_INIT encountered error (major version mismatch). // Only set in INIT. connInitError bool // connInitSuccess if FUSE_INIT is successful. // Only set in INIT. - // Used for destory. + // Used for destory (not yet implemented). connInitSuccess bool - // TODO(gvisor.dev/issue/3185): All the queue logic are working in progress. - - // NumberBackground is the number of requests in the background. - numBackground uint16 + // aborted via sysfs, and will send ECONNABORTED to read after disconnection (instead of ENODEV). + // Set only if abortErr is true and via fuse control fs (not yet implemented). + // TODO(gvisor.dev/issue/3525): set this to true when user aborts. + aborted bool - // congestionThreshold for NumBackground. - // Negotiated in FUSE_INIT. - congestionThreshold uint16 + // numWating is the number of requests waiting to be + // sent to FUSE device or being processed by FUSE daemon. + numWaiting uint32 - // maxBackground is the maximum number of NumBackground. - // Block connection when it is reached. - // Negotiated in FUSE_INIT. - maxBackground uint16 + // Terminology note: + // + // - `asyncNumMax` is the `MaxBackground` in the FUSE_INIT_IN struct. + // + // - `asyncCongestionThreshold` is the `CongestionThreshold` in the FUSE_INIT_IN struct. + // + // We call the "background" requests in unix term as async requests. + // The "async requests" in unix term is our async requests that expect a reply, + // i.e. `!request.noReply` - // numActiveBackground is the number of requests in background and has being marked as active. - numActiveBackground uint16 + // asyncMu protects the async request fields. + asyncMu sync.Mutex - // numWating is the number of requests waiting for completion. - numWaiting uint32 + // asyncNum is the number of async requests. + // Protected by asyncMu. + asyncNum uint16 - // TODO(gvisor.dev/issue/3185): BgQueue - // some queue for background queued requests. + // asyncCongestionThreshold the number of async requests. + // Negotiated in FUSE_INIT as "CongestionThreshold". + // TODO(gvisor.dev/issue/3529): add congestion control. + // Protected by asyncMu. + asyncCongestionThreshold uint16 - // bgLock protects: - // MaxBackground, CongestionThreshold, NumBackground, - // NumActiveBackground, BgQueue, Blocked. - bgLock sync.Mutex + // asyncNumMax is the maximum number of asyncNum. + // Connection blocks the async requests when it is reached. + // Negotiated in FUSE_INIT as "MaxBackground". + // Protected by asyncMu. + asyncNumMax uint16 // maxRead is the maximum size of a read buffer in in bytes. + // Initialized from a fuse fs parameter. maxRead uint32 // maxWrite is the maximum size of a write buffer in bytes. @@ -165,23 +147,20 @@ type connection struct { // Negotiated and only set in INIT. minor uint32 - // asyncRead if read pages asynchronously. + // atomicOTrunc is true when FUSE does not send a separate SETATTR request + // before open with O_TRUNC flag. // Negotiated and only set in INIT. - asyncRead bool + atomicOTrunc bool - // abortErr is true if kernel need to return an unique read error after abort. + // asyncRead if read pages asynchronously. // Negotiated and only set in INIT. - abortErr bool + asyncRead bool // writebackCache is true for write-back cache policy, // false for write-through policy. // Negotiated and only set in INIT. writebackCache bool - // cacheSymlinks if filesystem needs to cache READLINK responses in page cache. - // Negotiated and only set in INIT. - cacheSymlinks bool - // bigWrites if doing multi-page cached writes. // Negotiated and only set in INIT. bigWrites bool @@ -189,116 +168,70 @@ type connection struct { // dontMask if filestestem does not apply umask to creation modes. // Negotiated in INIT. dontMask bool + + // noOpen if FUSE server doesn't support open operation. + // This flag only influence performance, not correctness of the program. + noOpen bool } // newFUSEConnection creates a FUSE connection to fd. -func newFUSEConnection(_ context.Context, fd *vfs.FileDescription, maxInFlightRequests uint64) (*connection, error) { +func newFUSEConnection(_ context.Context, fd *vfs.FileDescription, opts *filesystemOptions) (*connection, error) { // Mark the device as ready so it can be used. /dev/fuse can only be used if the FD was used to // mount a FUSE filesystem. fuseFD := fd.Impl().(*DeviceFD) - fuseFD.mounted = true // Create the writeBuf for the header to be stored in. hdrLen := uint32((*linux.FUSEHeaderOut)(nil).SizeBytes()) fuseFD.writeBuf = make([]byte, hdrLen) fuseFD.completions = make(map[linux.FUSEOpID]*futureResponse) - fuseFD.fullQueueCh = make(chan struct{}, maxInFlightRequests) + fuseFD.fullQueueCh = make(chan struct{}, opts.maxActiveRequests) fuseFD.writeCursor = 0 return &connection{ - fd: fuseFD, - maxBackground: fuseDefaultMaxBackground, - congestionThreshold: fuseDefaultCongestionThreshold, - maxPages: fuseDefaultMaxPagesPerReq, - initializedChan: make(chan struct{}), - connected: true, - }, nil -} - -// SetInitialized atomically sets the connection as initialized. -func (conn *connection) SetInitialized() { - // Unblock the requests sent before INIT. - close(conn.initializedChan) - - // Close the channel first to avoid the non-atomic situation - // where conn.initialized is true but there are - // tasks being blocked on the channel. - // And it prevents the newer tasks from gaining - // unnecessary higher chance to be issued before the blocked one. - - atomic.StoreInt32(&(conn.initialized), int32(1)) -} - -// IsInitialized atomically check if the connection is initialized. -// pairs with SetInitialized(). -func (conn *connection) Initialized() bool { - return atomic.LoadInt32(&(conn.initialized)) != 0 -} - -// NewRequest creates a new request that can be sent to the FUSE server. -func (conn *connection) NewRequest(creds *auth.Credentials, pid uint32, ino uint64, opcode linux.FUSEOpcode, payload marshal.Marshallable) (*Request, error) { - conn.fd.mu.Lock() - defer conn.fd.mu.Unlock() - conn.fd.nextOpID += linux.FUSEOpID(reqIDStep) - - hdrLen := (*linux.FUSEHeaderIn)(nil).SizeBytes() - hdr := linux.FUSEHeaderIn{ - Len: uint32(hdrLen + payload.SizeBytes()), - Opcode: opcode, - Unique: conn.fd.nextOpID, - NodeID: ino, - UID: uint32(creds.EffectiveKUID), - GID: uint32(creds.EffectiveKGID), - PID: pid, - } - - buf := make([]byte, hdr.Len) - hdr.MarshalUnsafe(buf[:hdrLen]) - payload.MarshalUnsafe(buf[hdrLen:]) - - return &Request{ - id: hdr.Unique, - hdr: &hdr, - data: buf, + fd: fuseFD, + asyncNumMax: fuseDefaultMaxBackground, + asyncCongestionThreshold: fuseDefaultCongestionThreshold, + maxRead: opts.maxRead, + maxPages: fuseDefaultMaxPagesPerReq, + initializedChan: make(chan struct{}), + connected: true, }, nil } -// Call makes a request to the server and blocks the invoking task until a -// server responds with a response. Task should never be nil. -// Requests will not be sent before the connection is initialized. -// For async tasks, use CallAsync(). -func (conn *connection) Call(t *kernel.Task, r *Request) (*Response, error) { - // Block requests sent before connection is initalized. - if !conn.Initialized() { - if err := t.Block(conn.initializedChan); err != nil { - return nil, err - } - } - - return conn.call(t, r) +// CallAsync makes an async (aka background) request. +// It's a simple wrapper around Call(). +func (conn *connection) CallAsync(t *kernel.Task, r *Request) error { + r.async = true + _, err := conn.Call(t, r) + return err } -// CallAsync makes an async (aka background) request. -// Those requests either do not expect a response (e.g. release) or -// the response should be handled by others (e.g. init). -// Return immediately unless the connection is blocked (before initialization). -// Async call example: init, release, forget, aio, interrupt. +// Call makes a request to the server. +// Block before the connection is initialized. // When the Request is FUSE_INIT, it will not be blocked before initialization. -func (conn *connection) CallAsync(t *kernel.Task, r *Request) error { +// Task should never be nil. +// +// For a sync request, it blocks the invoking task until +// a server responds with a response. +// +// For an async request (that do not expect a response immediately), +// it returns directly unless being blocked either before initialization +// or when there are too many async requests ongoing. +// +// Example for async request: +// init, readahead, write, async read/write, fuse_notify_reply, +// non-sync release, interrupt, forget. +// +// The forget request does not have a reply, +// as documented in include/uapi/linux/fuse.h:FUSE_FORGET. +func (conn *connection) Call(t *kernel.Task, r *Request) (*Response, error) { // Block requests sent before connection is initalized. if !conn.Initialized() && r.hdr.Opcode != linux.FUSE_INIT { if err := t.Block(conn.initializedChan); err != nil { - return err + return nil, err } } - // This should be the only place that invokes call() with a nil task. - _, err := conn.call(nil, r) - return err -} - -// call makes a call without blocking checks. -func (conn *connection) call(t *kernel.Task, r *Request) (*Response, error) { if !conn.connected { return nil, syserror.ENOTCONN } @@ -315,31 +248,6 @@ func (conn *connection) call(t *kernel.Task, r *Request) (*Response, error) { return fut.resolve(t) } -// Error returns the error of the FUSE call. -func (r *Response) Error() error { - errno := r.hdr.Error - if errno >= 0 { - return nil - } - - sysErrNo := syscall.Errno(-errno) - return error(sysErrNo) -} - -// UnmarshalPayload unmarshals the response data into m. -func (r *Response) UnmarshalPayload(m marshal.Marshallable) error { - hdrLen := r.hdr.SizeBytes() - haveDataLen := r.hdr.Len - uint32(hdrLen) - wantDataLen := uint32(m.SizeBytes()) - - if haveDataLen < wantDataLen { - return fmt.Errorf("payload too small. Minimum data lenth required: %d, but got data length %d", wantDataLen, haveDataLen) - } - - m.UnmarshalUnsafe(r.data[hdrLen:]) - return nil -} - // callFuture makes a request to the server and returns a future response. // Call resolve() when the response needs to be fulfilled. func (conn *connection) callFuture(t *kernel.Task, r *Request) (*futureResponse, error) { @@ -358,11 +266,6 @@ func (conn *connection) callFuture(t *kernel.Task, r *Request) (*futureResponse, // if there are always too many ongoing requests all the time. The // supported maxActiveRequests setting should be really high to avoid this. for conn.fd.numActiveRequests == conn.fd.fs.opts.maxActiveRequests { - if t == nil { - // Since there is no task that is waiting. We must error out. - return nil, errors.New("FUSE request queue full") - } - log.Infof("Blocking request %v from being queued. Too many active requests: %v", r.id, conn.fd.numActiveRequests) conn.fd.mu.Unlock() @@ -378,9 +281,19 @@ func (conn *connection) callFuture(t *kernel.Task, r *Request) (*futureResponse, // callFutureLocked makes a request to the server and returns a future response. func (conn *connection) callFutureLocked(t *kernel.Task, r *Request) (*futureResponse, error) { + // Check connected again holding conn.mu. + conn.mu.Lock() + if !conn.connected { + conn.mu.Unlock() + // we checked connected before, + // this must be due to aborted connection. + return nil, syserror.ECONNABORTED + } + conn.mu.Unlock() + conn.fd.queue.PushBack(r) - conn.fd.numActiveRequests += 1 - fut := newFutureResponse(r.hdr.Opcode) + conn.fd.numActiveRequests++ + fut := newFutureResponse(r) conn.fd.completions[r.id] = fut // Signal the readers that there is something to read. @@ -388,50 +301,3 @@ func (conn *connection) callFutureLocked(t *kernel.Task, r *Request) (*futureRes return fut, nil } - -// futureResponse represents an in-flight request, that may or may not have -// completed yet. Convert it to a resolved Response by calling Resolve, but note -// that this may block. -// -// +stateify savable -type futureResponse struct { - opcode linux.FUSEOpcode - ch chan struct{} - hdr *linux.FUSEHeaderOut - data []byte -} - -// newFutureResponse creates a future response to a FUSE request. -func newFutureResponse(opcode linux.FUSEOpcode) *futureResponse { - return &futureResponse{ - opcode: opcode, - ch: make(chan struct{}), - } -} - -// resolve blocks the task until the server responds to its corresponding request, -// then returns a resolved response. -func (f *futureResponse) resolve(t *kernel.Task) (*Response, error) { - // If there is no Task associated with this request - then we don't try to resolve - // the response. Instead, the task writing the response (proxy to the server) will - // process the response on our behalf. - if t == nil { - log.Infof("fuse.Response.resolve: Not waiting on a response from server.") - return nil, nil - } - - if err := t.Block(f.ch); err != nil { - return nil, err - } - - return f.getResponse(), nil -} - -// getResponse creates a Response from the data the futureResponse has. -func (f *futureResponse) getResponse() *Response { - return &Response{ - opcode: f.opcode, - hdr: *f.hdr, - data: f.data, - } -} diff --git a/pkg/sentry/fsimpl/fuse/init.go b/pkg/sentry/fsimpl/fuse/connection_control.go index 779c2bd3f..bfde78559 100644 --- a/pkg/sentry/fsimpl/fuse/init.go +++ b/pkg/sentry/fsimpl/fuse/connection_control.go @@ -15,7 +15,11 @@ package fuse import ( + "sync/atomic" + "syscall" + "gvisor.dev/gvisor/pkg/abi/linux" + "gvisor.dev/gvisor/pkg/context" "gvisor.dev/gvisor/pkg/sentry/kernel/auth" ) @@ -29,9 +33,10 @@ const ( // Follow the same behavior as unix fuse implementation. fuseMaxTimeGranNs = 1000000000 - // Minimum value for MaxWrite. + // Minimum value for MaxWrite and MaxRead. // Follow the same behavior as unix fuse implementation. fuseMinMaxWrite = 4096 + fuseMinMaxRead = 4096 // Temporary default value for max readahead, 128kb. fuseDefaultMaxReadahead = 131072 @@ -49,6 +54,26 @@ var ( MaxUserCongestionThreshold uint16 = fuseDefaultCongestionThreshold ) +// SetInitialized atomically sets the connection as initialized. +func (conn *connection) SetInitialized() { + // Unblock the requests sent before INIT. + close(conn.initializedChan) + + // Close the channel first to avoid the non-atomic situation + // where conn.initialized is true but there are + // tasks being blocked on the channel. + // And it prevents the newer tasks from gaining + // unnecessary higher chance to be issued before the blocked one. + + atomic.StoreInt32(&(conn.initialized), int32(1)) +} + +// IsInitialized atomically check if the connection is initialized. +// pairs with SetInitialized(). +func (conn *connection) Initialized() bool { + return atomic.LoadInt32(&(conn.initialized)) != 0 +} + // InitSend sends a FUSE_INIT request. func (conn *connection) InitSend(creds *auth.Credentials, pid uint32) error { in := linux.FUSEInitIn{ @@ -70,29 +95,31 @@ func (conn *connection) InitSend(creds *auth.Credentials, pid uint32) error { } // InitRecv receives a FUSE_INIT reply and process it. +// +// Preconditions: conn.asyncMu must not be held if minor verion is newer than 13. func (conn *connection) InitRecv(res *Response, hasSysAdminCap bool) error { if err := res.Error(); err != nil { return err } - var out linux.FUSEInitOut - if err := res.UnmarshalPayload(&out); err != nil { + initRes := fuseInitRes{initLen: res.DataLen()} + if err := res.UnmarshalPayload(&initRes); err != nil { return err } - return conn.initProcessReply(&out, hasSysAdminCap) + return conn.initProcessReply(&initRes.initOut, hasSysAdminCap) } // Process the FUSE_INIT reply from the FUSE server. +// It tries to acquire the conn.asyncMu lock if minor version is newer than 13. func (conn *connection) initProcessReply(out *linux.FUSEInitOut, hasSysAdminCap bool) error { + // No matter error or not, always set initialzied. + // to unblock the blocked requests. + defer conn.SetInitialized() + // No support for old major fuse versions. if out.Major != linux.FUSE_KERNEL_VERSION { conn.connInitError = true - - // Set the connection as initialized and unblock the blocked requests - // (i.e. return error for them). - conn.SetInitialized() - return nil } @@ -100,29 +127,14 @@ func (conn *connection) initProcessReply(out *linux.FUSEInitOut, hasSysAdminCap conn.connInitSuccess = true conn.minor = out.Minor - // No support for limits before minor version 13. - if out.Minor >= 13 { - conn.bgLock.Lock() - - if out.MaxBackground > 0 { - conn.maxBackground = out.MaxBackground - - if !hasSysAdminCap && - conn.maxBackground > MaxUserBackgroundRequest { - conn.maxBackground = MaxUserBackgroundRequest - } - } - - if out.CongestionThreshold > 0 { - conn.congestionThreshold = out.CongestionThreshold - - if !hasSysAdminCap && - conn.congestionThreshold > MaxUserCongestionThreshold { - conn.congestionThreshold = MaxUserCongestionThreshold - } - } - - conn.bgLock.Unlock() + // No support for negotiating MaxWrite before minor version 5. + if out.Minor >= 5 { + conn.maxWrite = out.MaxWrite + } else { + conn.maxWrite = fuseMinMaxWrite + } + if conn.maxWrite < fuseMinMaxWrite { + conn.maxWrite = fuseMinMaxWrite } // No support for the following flags before minor version 6. @@ -131,8 +143,6 @@ func (conn *connection) initProcessReply(out *linux.FUSEInitOut, hasSysAdminCap conn.bigWrites = out.Flags&linux.FUSE_BIG_WRITES != 0 conn.dontMask = out.Flags&linux.FUSE_DONT_MASK != 0 conn.writebackCache = out.Flags&linux.FUSE_WRITEBACK_CACHE != 0 - conn.cacheSymlinks = out.Flags&linux.FUSE_CACHE_SYMLINKS != 0 - conn.abortErr = out.Flags&linux.FUSE_ABORT_ERROR != 0 // TODO(gvisor.dev/issue/3195): figure out how to use TimeGran (0 < TimeGran <= fuseMaxTimeGranNs). @@ -148,19 +158,90 @@ func (conn *connection) initProcessReply(out *linux.FUSEInitOut, hasSysAdminCap } } - // No support for negotiating MaxWrite before minor version 5. - if out.Minor >= 5 { - conn.maxWrite = out.MaxWrite - } else { - conn.maxWrite = fuseMinMaxWrite + // No support for limits before minor version 13. + if out.Minor >= 13 { + conn.asyncMu.Lock() + + if out.MaxBackground > 0 { + conn.asyncNumMax = out.MaxBackground + + if !hasSysAdminCap && + conn.asyncNumMax > MaxUserBackgroundRequest { + conn.asyncNumMax = MaxUserBackgroundRequest + } + } + + if out.CongestionThreshold > 0 { + conn.asyncCongestionThreshold = out.CongestionThreshold + + if !hasSysAdminCap && + conn.asyncCongestionThreshold > MaxUserCongestionThreshold { + conn.asyncCongestionThreshold = MaxUserCongestionThreshold + } + } + + conn.asyncMu.Unlock() } - if conn.maxWrite < fuseMinMaxWrite { - conn.maxWrite = fuseMinMaxWrite + + return nil +} + +// Abort this FUSE connection. +// It tries to acquire conn.fd.mu, conn.lock, conn.bgLock in order. +// All possible requests waiting or blocking will be aborted. +// +// Preconditions: conn.fd.mu is locked. +func (conn *connection) Abort(ctx context.Context) { + conn.mu.Lock() + conn.asyncMu.Lock() + + if !conn.connected { + conn.asyncMu.Unlock() + conn.mu.Unlock() + conn.fd.mu.Unlock() + return } - // Set connection as initialized and unblock the requests - // issued before init. - conn.SetInitialized() + conn.connected = false - return nil + // Empty the `fd.queue` that holds the requests + // not yet read by the FUSE daemon yet. + // These are a subset of the requests in `fuse.completion` map. + for !conn.fd.queue.Empty() { + req := conn.fd.queue.Front() + conn.fd.queue.Remove(req) + } + + var terminate []linux.FUSEOpID + + // 2. Collect the requests have not been sent to FUSE daemon, + // or have not received a reply. + for unique := range conn.fd.completions { + terminate = append(terminate, unique) + } + + // Release locks to avoid deadlock. + conn.asyncMu.Unlock() + conn.mu.Unlock() + + // 1. The requets blocked before initialization. + // Will reach call() `connected` check and return. + if !conn.Initialized() { + conn.SetInitialized() + } + + // 2. Terminate the requests collected above. + // Set ECONNABORTED error. + // sendError() will remove them from `fd.completion` map. + // Will enter the path of a normally received error. + for _, toTerminate := range terminate { + conn.fd.sendError(ctx, -int32(syscall.ECONNABORTED), toTerminate) + } + + // 3. The requests not yet written to FUSE device. + // Early terminate. + // Will reach callFutureLocked() `connected` check and return. + close(conn.fd.fullQueueCh) + + // TODO(gvisor.dev/issue/3528): Forget all pending forget reqs. } diff --git a/pkg/sentry/fsimpl/fuse/connection_test.go b/pkg/sentry/fsimpl/fuse/connection_test.go new file mode 100644 index 000000000..91d16c1cf --- /dev/null +++ b/pkg/sentry/fsimpl/fuse/connection_test.go @@ -0,0 +1,117 @@ +// Copyright 2020 The gVisor Authors. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package fuse + +import ( + "math/rand" + "syscall" + "testing" + + "gvisor.dev/gvisor/pkg/sentry/kernel" + "gvisor.dev/gvisor/pkg/sentry/kernel/auth" + "gvisor.dev/gvisor/pkg/syserror" +) + +// TestConnectionInitBlock tests if initialization +// correctly blocks and unblocks the connection. +// Since it's unfeasible to test kernelTask.Block() in unit test, +// the code in Call() are not tested here. +func TestConnectionInitBlock(t *testing.T) { + s := setup(t) + defer s.Destroy() + + k := kernel.KernelFromContext(s.Ctx) + + conn, _, err := newTestConnection(s, k, maxActiveRequestsDefault) + if err != nil { + t.Fatalf("newTestConnection: %v", err) + } + + select { + case <-conn.initializedChan: + t.Fatalf("initializedChan should be blocking before SetInitialized") + default: + } + + conn.SetInitialized() + + select { + case <-conn.initializedChan: + default: + t.Fatalf("initializedChan should not be blocking after SetInitialized") + } +} + +func TestConnectionAbort(t *testing.T) { + s := setup(t) + defer s.Destroy() + + k := kernel.KernelFromContext(s.Ctx) + creds := auth.CredentialsFromContext(s.Ctx) + task := kernel.TaskFromContext(s.Ctx) + + const numRequests uint64 = 256 + + conn, _, err := newTestConnection(s, k, numRequests) + if err != nil { + t.Fatalf("newTestConnection: %v", err) + } + + testObj := &testPayload{ + data: rand.Uint32(), + } + + var futNormal []*futureResponse + + for i := 0; i < int(numRequests); i++ { + req, err := conn.NewRequest(creds, uint32(i), uint64(i), 0, testObj) + if err != nil { + t.Fatalf("NewRequest creation failed: %v", err) + } + fut, err := conn.callFutureLocked(task, req) + if err != nil { + t.Fatalf("callFutureLocked failed: %v", err) + } + futNormal = append(futNormal, fut) + } + + conn.Abort(s.Ctx) + + // Abort should unblock the initialization channel. + // Note: no test requests are actually blocked on `conn.initializedChan`. + select { + case <-conn.initializedChan: + default: + t.Fatalf("initializedChan should not be blocking after SetInitialized") + } + + // Abort will return ECONNABORTED error to unblocked requests. + for _, fut := range futNormal { + if fut.getResponse().hdr.Error != -int32(syscall.ECONNABORTED) { + t.Fatalf("Incorrect error code received for aborted connection: %v", fut.getResponse().hdr.Error) + } + } + + // After abort, Call() should return directly with ENOTCONN. + req, err := conn.NewRequest(creds, 0, 0, 0, testObj) + if err != nil { + t.Fatalf("NewRequest creation failed: %v", err) + } + _, err = conn.Call(task, req) + if err != syserror.ENOTCONN { + t.Fatalf("Incorrect error code received for Call() after connection aborted") + } + +} diff --git a/pkg/sentry/fsimpl/fuse/dev.go b/pkg/sentry/fsimpl/fuse/dev.go index e522ff9a0..f690ef5ad 100644 --- a/pkg/sentry/fsimpl/fuse/dev.go +++ b/pkg/sentry/fsimpl/fuse/dev.go @@ -19,7 +19,6 @@ import ( "gvisor.dev/gvisor/pkg/abi/linux" "gvisor.dev/gvisor/pkg/context" - "gvisor.dev/gvisor/pkg/log" "gvisor.dev/gvisor/pkg/sentry/kernel" "gvisor.dev/gvisor/pkg/sentry/kernel/auth" "gvisor.dev/gvisor/pkg/sentry/vfs" @@ -56,9 +55,6 @@ type DeviceFD struct { vfs.DentryMetadataFileDescriptionImpl vfs.NoLockFD - // mounted specifies whether a FUSE filesystem was mounted using the DeviceFD. - mounted bool - // nextOpID is used to create new requests. nextOpID linux.FUSEOpID @@ -99,14 +95,21 @@ type DeviceFD struct { } // Release implements vfs.FileDescriptionImpl.Release. -func (fd *DeviceFD) Release(context.Context) { - fd.fs.conn.connected = false +func (fd *DeviceFD) Release(ctx context.Context) { + if fd.fs != nil { + fd.fs.conn.mu.Lock() + fd.fs.conn.connected = false + fd.fs.conn.mu.Unlock() + + fd.fs.VFSFilesystem().DecRef(ctx) + fd.fs = nil + } } // PRead implements vfs.FileDescriptionImpl.PRead. func (fd *DeviceFD) PRead(ctx context.Context, dst usermem.IOSequence, offset int64, opts vfs.ReadOptions) (int64, error) { // Operations on /dev/fuse don't make sense until a FUSE filesystem is mounted. - if !fd.mounted { + if fd.fs == nil { return 0, syserror.EPERM } @@ -116,10 +119,16 @@ func (fd *DeviceFD) PRead(ctx context.Context, dst usermem.IOSequence, offset in // Read implements vfs.FileDescriptionImpl.Read. func (fd *DeviceFD) Read(ctx context.Context, dst usermem.IOSequence, opts vfs.ReadOptions) (int64, error) { // Operations on /dev/fuse don't make sense until a FUSE filesystem is mounted. - if !fd.mounted { + if fd.fs == nil { return 0, syserror.EPERM } + // Return ENODEV if the filesystem is umounted. + if fd.fs.umounted { + // TODO(gvisor.dev/issue/3525): return ECONNABORTED if aborted via fuse control fs. + return 0, syserror.ENODEV + } + // We require that any Read done on this filesystem have a sane minimum // read buffer. It must have the capacity for the fixed parts of any request // header (Linux uses the request header and the FUSEWriteIn header for this @@ -143,58 +152,82 @@ func (fd *DeviceFD) Read(ctx context.Context, dst usermem.IOSequence, opts vfs.R } // readLocked implements the reading of the fuse device while locked with DeviceFD.mu. +// +// Preconditions: dst is large enough for any reasonable request. func (fd *DeviceFD) readLocked(ctx context.Context, dst usermem.IOSequence, opts vfs.ReadOptions) (int64, error) { - if fd.queue.Empty() { - return 0, syserror.ErrWouldBlock - } + var req *Request - var readCursor uint32 - var bytesRead int64 - for { - req := fd.queue.Front() - if dst.NumBytes() < int64(req.hdr.Len) { - // The request is too large. Cannot process it. All requests must be smaller than the - // negotiated size as specified by Connection.MaxWrite set as part of the FUSE_INIT - // handshake. - errno := -int32(syscall.EIO) - if req.hdr.Opcode == linux.FUSE_SETXATTR { - errno = -int32(syscall.E2BIG) - } + // Find the first valid request. + // For the normal case this loop only execute once. + for !fd.queue.Empty() { + req = fd.queue.Front() - // Return the error to the calling task. - if err := fd.sendError(ctx, errno, req); err != nil { - return 0, err - } + if int64(req.hdr.Len)+int64(len(req.payload)) <= dst.NumBytes() { + break + } - // We're done with this request. - fd.queue.Remove(req) + // The request is too large. Cannot process it. All requests must be smaller than the + // negotiated size as specified by Connection.MaxWrite set as part of the FUSE_INIT + // handshake. + errno := -int32(syscall.EIO) + if req.hdr.Opcode == linux.FUSE_SETXATTR { + errno = -int32(syscall.E2BIG) + } - // Restart the read as this request was invalid. - log.Warningf("fuse.DeviceFD.Read: request found was too large. Restarting read.") - return fd.readLocked(ctx, dst, opts) + // Return the error to the calling task. + if err := fd.sendError(ctx, errno, req.hdr.Unique); err != nil { + return 0, err } - n, err := dst.CopyOut(ctx, req.data[readCursor:]) + // We're done with this request. + fd.queue.Remove(req) + req = nil + } + + if req == nil { + return 0, syserror.ErrWouldBlock + } + + // We already checked the size: dst must be able to fit the whole request. + // Now we write the marshalled header, the payload, + // and the potential additional payload + // to the user memory IOSequence. + + n, err := dst.CopyOut(ctx, req.data) + if err != nil { + return 0, err + } + if n != len(req.data) { + return 0, syserror.EIO + } + + if req.hdr.Opcode == linux.FUSE_WRITE { + written, err := dst.DropFirst(n).CopyOut(ctx, req.payload) if err != nil { return 0, err } - readCursor += uint32(n) - bytesRead += int64(n) - - if readCursor >= req.hdr.Len { - // Fully done with this req, remove it from the queue. - fd.queue.Remove(req) - break + if written != len(req.payload) { + return 0, syserror.EIO } + n += int(written) + } + + // Fully done with this req, remove it from the queue. + fd.queue.Remove(req) + + // Remove noReply ones from map of requests expecting a reply. + if req.noReply { + fd.numActiveRequests -= 1 + delete(fd.completions, req.hdr.Unique) } - return bytesRead, nil + return int64(n), nil } // PWrite implements vfs.FileDescriptionImpl.PWrite. func (fd *DeviceFD) PWrite(ctx context.Context, src usermem.IOSequence, offset int64, opts vfs.WriteOptions) (int64, error) { // Operations on /dev/fuse don't make sense until a FUSE filesystem is mounted. - if !fd.mounted { + if fd.fs == nil { return 0, syserror.EPERM } @@ -211,10 +244,15 @@ func (fd *DeviceFD) Write(ctx context.Context, src usermem.IOSequence, opts vfs. // writeLocked implements writing to the fuse device while locked with DeviceFD.mu. func (fd *DeviceFD) writeLocked(ctx context.Context, src usermem.IOSequence, opts vfs.WriteOptions) (int64, error) { // Operations on /dev/fuse don't make sense until a FUSE filesystem is mounted. - if !fd.mounted { + if fd.fs == nil { return 0, syserror.EPERM } + // Return ENODEV if the filesystem is umounted. + if fd.fs.umounted { + return 0, syserror.ENODEV + } + var cn, n int64 hdrLen := uint32((*linux.FUSEHeaderOut)(nil).SizeBytes()) @@ -276,7 +314,8 @@ func (fd *DeviceFD) writeLocked(ctx context.Context, src usermem.IOSequence, opt fut, ok := fd.completions[hdr.Unique] if !ok { - // Server sent us a response for a request we never sent? + // Server sent us a response for a request we never sent, + // or for which we already received a reply (e.g. aborted), an unlikely event. return 0, syserror.EINVAL } @@ -307,8 +346,23 @@ func (fd *DeviceFD) writeLocked(ctx context.Context, src usermem.IOSequence, opt // Readiness implements vfs.FileDescriptionImpl.Readiness. func (fd *DeviceFD) Readiness(mask waiter.EventMask) waiter.EventMask { + fd.mu.Lock() + defer fd.mu.Unlock() + return fd.readinessLocked(mask) +} + +// readinessLocked implements checking the readiness of the fuse device while +// locked with DeviceFD.mu. +func (fd *DeviceFD) readinessLocked(mask waiter.EventMask) waiter.EventMask { var ready waiter.EventMask - ready |= waiter.EventOut // FD is always writable + + if fd.fs.umounted { + ready |= waiter.EventErr + return ready & mask + } + + // FD is always writable. + ready |= waiter.EventOut if !fd.queue.Empty() { // Have reqs available, FD is readable. ready |= waiter.EventIn @@ -330,7 +384,7 @@ func (fd *DeviceFD) EventUnregister(e *waiter.Entry) { // Seek implements vfs.FileDescriptionImpl.Seek. func (fd *DeviceFD) Seek(ctx context.Context, offset int64, whence int32) (int64, error) { // Operations on /dev/fuse don't make sense until a FUSE filesystem is mounted. - if !fd.mounted { + if fd.fs == nil { return 0, syserror.EPERM } @@ -338,59 +392,59 @@ func (fd *DeviceFD) Seek(ctx context.Context, offset int64, whence int32) (int64 } // sendResponse sends a response to the waiting task (if any). +// +// Preconditions: fd.mu must be held. func (fd *DeviceFD) sendResponse(ctx context.Context, fut *futureResponse) error { - // See if the running task need to perform some action before returning. - // Since we just finished writing the future, we can be sure that - // getResponse generates a populated response. - if err := fd.noReceiverAction(ctx, fut.getResponse()); err != nil { - return err - } + // Signal the task waiting on a response if any. + defer close(fut.ch) // Signal that the queue is no longer full. select { case fd.fullQueueCh <- struct{}{}: default: } - fd.numActiveRequests -= 1 + fd.numActiveRequests-- + + if fut.async { + return fd.asyncCallBack(ctx, fut.getResponse()) + } - // Signal the task waiting on a response. - close(fut.ch) return nil } -// sendError sends an error response to the waiting task (if any). -func (fd *DeviceFD) sendError(ctx context.Context, errno int32, req *Request) error { +// sendError sends an error response to the waiting task (if any) by calling sendResponse(). +// +// Preconditions: fd.mu must be held. +func (fd *DeviceFD) sendError(ctx context.Context, errno int32, unique linux.FUSEOpID) error { // Return the error to the calling task. outHdrLen := uint32((*linux.FUSEHeaderOut)(nil).SizeBytes()) respHdr := linux.FUSEHeaderOut{ Len: outHdrLen, Error: errno, - Unique: req.hdr.Unique, + Unique: unique, } fut, ok := fd.completions[respHdr.Unique] if !ok { - // Server sent us a response for a request we never sent? + // A response for a request we never sent, + // or for which we already received a reply (e.g. aborted). return syserror.EINVAL } delete(fd.completions, respHdr.Unique) fut.hdr = &respHdr - if err := fd.sendResponse(ctx, fut); err != nil { - return err - } - - return nil + return fd.sendResponse(ctx, fut) } -// noReceiverAction has the calling kernel.Task do some action if its known that no -// receiver is going to be waiting on the future channel. This is to be used by: -// FUSE_INIT. -func (fd *DeviceFD) noReceiverAction(ctx context.Context, r *Response) error { - if r.opcode == linux.FUSE_INIT { +// asyncCallBack executes pre-defined callback function for async requests. +// Currently used by: FUSE_INIT. +func (fd *DeviceFD) asyncCallBack(ctx context.Context, r *Response) error { + switch r.opcode { + case linux.FUSE_INIT: creds := auth.CredentialsFromContext(ctx) rootUserNs := kernel.KernelFromContext(ctx).RootUserNamespace() return fd.fs.conn.InitRecv(r, creds.HasCapabilityIn(linux.CAP_SYS_ADMIN, rootUserNs)) + // TODO(gvisor.dev/issue/3247): support async read: correctly process the response. } return nil diff --git a/pkg/sentry/fsimpl/fuse/dev_test.go b/pkg/sentry/fsimpl/fuse/dev_test.go index 1ffe7ccd2..5986133e9 100644 --- a/pkg/sentry/fsimpl/fuse/dev_test.go +++ b/pkg/sentry/fsimpl/fuse/dev_test.go @@ -16,7 +16,6 @@ package fuse import ( "fmt" - "io" "math/rand" "testing" @@ -28,17 +27,12 @@ import ( "gvisor.dev/gvisor/pkg/syserror" "gvisor.dev/gvisor/pkg/usermem" "gvisor.dev/gvisor/pkg/waiter" - "gvisor.dev/gvisor/tools/go_marshal/marshal" ) // echoTestOpcode is the Opcode used during testing. The server used in tests // will simply echo the payload back with the appropriate headers. const echoTestOpcode linux.FUSEOpcode = 1000 -type testPayload struct { - data uint32 -} - // TestFUSECommunication tests that the communication layer between the Sentry and the // FUSE server daemon works as expected. func TestFUSECommunication(t *testing.T) { @@ -327,102 +321,3 @@ func fuseServerRun(t *testing.T, s *testutil.System, k *kernel.Kernel, fd *vfs.F } } } - -func setup(t *testing.T) *testutil.System { - k, err := testutil.Boot() - if err != nil { - t.Fatalf("Error creating kernel: %v", err) - } - - ctx := k.SupervisorContext() - creds := auth.CredentialsFromContext(ctx) - - k.VFS().MustRegisterFilesystemType(Name, &FilesystemType{}, &vfs.RegisterFilesystemTypeOptions{ - AllowUserList: true, - AllowUserMount: true, - }) - - mntns, err := k.VFS().NewMountNamespace(ctx, creds, "", "tmpfs", &vfs.GetFilesystemOptions{}) - if err != nil { - t.Fatalf("NewMountNamespace(): %v", err) - } - - return testutil.NewSystem(ctx, t, k.VFS(), mntns) -} - -// newTestConnection creates a fuse connection that the sentry can communicate with -// and the FD for the server to communicate with. -func newTestConnection(system *testutil.System, k *kernel.Kernel, maxActiveRequests uint64) (*connection, *vfs.FileDescription, error) { - vfsObj := &vfs.VirtualFilesystem{} - fuseDev := &DeviceFD{} - - if err := vfsObj.Init(system.Ctx); err != nil { - return nil, nil, err - } - - vd := vfsObj.NewAnonVirtualDentry("genCountFD") - defer vd.DecRef(system.Ctx) - if err := fuseDev.vfsfd.Init(fuseDev, linux.O_RDWR|linux.O_CREAT, vd.Mount(), vd.Dentry(), &vfs.FileDescriptionOptions{}); err != nil { - return nil, nil, err - } - - fsopts := filesystemOptions{ - maxActiveRequests: maxActiveRequests, - } - fs, err := NewFUSEFilesystem(system.Ctx, 0, &fsopts, &fuseDev.vfsfd) - if err != nil { - return nil, nil, err - } - - return fs.conn, &fuseDev.vfsfd, nil -} - -// SizeBytes implements marshal.Marshallable.SizeBytes. -func (t *testPayload) SizeBytes() int { - return 4 -} - -// MarshalBytes implements marshal.Marshallable.MarshalBytes. -func (t *testPayload) MarshalBytes(dst []byte) { - usermem.ByteOrder.PutUint32(dst[:4], t.data) -} - -// UnmarshalBytes implements marshal.Marshallable.UnmarshalBytes. -func (t *testPayload) UnmarshalBytes(src []byte) { - *t = testPayload{data: usermem.ByteOrder.Uint32(src[:4])} -} - -// Packed implements marshal.Marshallable.Packed. -func (t *testPayload) Packed() bool { - return true -} - -// MarshalUnsafe implements marshal.Marshallable.MarshalUnsafe. -func (t *testPayload) MarshalUnsafe(dst []byte) { - t.MarshalBytes(dst) -} - -// UnmarshalUnsafe implements marshal.Marshallable.UnmarshalUnsafe. -func (t *testPayload) UnmarshalUnsafe(src []byte) { - t.UnmarshalBytes(src) -} - -// CopyOutN implements marshal.Marshallable.CopyOutN. -func (t *testPayload) CopyOutN(task marshal.Task, addr usermem.Addr, limit int) (int, error) { - panic("not implemented") -} - -// CopyOut implements marshal.Marshallable.CopyOut. -func (t *testPayload) CopyOut(task marshal.Task, addr usermem.Addr) (int, error) { - panic("not implemented") -} - -// CopyIn implements marshal.Marshallable.CopyIn. -func (t *testPayload) CopyIn(task marshal.Task, addr usermem.Addr) (int, error) { - panic("not implemented") -} - -// WriteTo implements io.WriterTo.WriteTo. -func (t *testPayload) WriteTo(w io.Writer) (int64, error) { - panic("not implemented") -} diff --git a/pkg/sentry/fsimpl/fuse/directory.go b/pkg/sentry/fsimpl/fuse/directory.go new file mode 100644 index 000000000..8f220a04b --- /dev/null +++ b/pkg/sentry/fsimpl/fuse/directory.go @@ -0,0 +1,105 @@ +// Copyright 2020 The gVisor Authors. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package fuse + +import ( + "sync/atomic" + + "gvisor.dev/gvisor/pkg/abi/linux" + "gvisor.dev/gvisor/pkg/context" + "gvisor.dev/gvisor/pkg/sentry/kernel" + "gvisor.dev/gvisor/pkg/sentry/kernel/auth" + "gvisor.dev/gvisor/pkg/sentry/vfs" + "gvisor.dev/gvisor/pkg/syserror" + "gvisor.dev/gvisor/pkg/usermem" +) + +type directoryFD struct { + fileDescription +} + +// Allocate implements directoryFD.Allocate. +func (*directoryFD) Allocate(ctx context.Context, mode, offset, length uint64) error { + return syserror.EISDIR +} + +// PRead implements vfs.FileDescriptionImpl.PRead. +func (*directoryFD) PRead(ctx context.Context, dst usermem.IOSequence, offset int64, opts vfs.ReadOptions) (int64, error) { + return 0, syserror.EISDIR +} + +// Read implements vfs.FileDescriptionImpl.Read. +func (*directoryFD) Read(ctx context.Context, dst usermem.IOSequence, opts vfs.ReadOptions) (int64, error) { + return 0, syserror.EISDIR +} + +// PWrite implements vfs.FileDescriptionImpl.PWrite. +func (*directoryFD) PWrite(ctx context.Context, src usermem.IOSequence, offset int64, opts vfs.WriteOptions) (int64, error) { + return 0, syserror.EISDIR +} + +// Write implements vfs.FileDescriptionImpl.Write. +func (*directoryFD) Write(ctx context.Context, src usermem.IOSequence, opts vfs.WriteOptions) (int64, error) { + return 0, syserror.EISDIR +} + +// IterDirents implements vfs.FileDescriptionImpl.IterDirents. +func (dir *directoryFD) IterDirents(ctx context.Context, callback vfs.IterDirentsCallback) error { + fusefs := dir.inode().fs + task, creds := kernel.TaskFromContext(ctx), auth.CredentialsFromContext(ctx) + + in := linux.FUSEReadIn{ + Fh: dir.Fh, + Offset: uint64(atomic.LoadInt64(&dir.off)), + Size: linux.FUSE_PAGE_SIZE, + Flags: dir.statusFlags(), + } + + // TODO(gVisor.dev/issue/3404): Support FUSE_READDIRPLUS. + req, err := fusefs.conn.NewRequest(creds, uint32(task.ThreadID()), dir.inode().nodeID, linux.FUSE_READDIR, &in) + if err != nil { + return err + } + + res, err := fusefs.conn.Call(task, req) + if err != nil { + return err + } + if err := res.Error(); err != nil { + return err + } + + var out linux.FUSEDirents + if err := res.UnmarshalPayload(&out); err != nil { + return err + } + + for _, fuseDirent := range out.Dirents { + nextOff := int64(fuseDirent.Meta.Off) + dirent := vfs.Dirent{ + Name: fuseDirent.Name, + Type: uint8(fuseDirent.Meta.Type), + Ino: fuseDirent.Meta.Ino, + NextOff: nextOff, + } + + if err := callback.Handle(dirent); err != nil { + return err + } + atomic.StoreInt64(&dir.off, nextOff) + } + + return nil +} diff --git a/pkg/sentry/fsimpl/fuse/file.go b/pkg/sentry/fsimpl/fuse/file.go new file mode 100644 index 000000000..83f2816b7 --- /dev/null +++ b/pkg/sentry/fsimpl/fuse/file.go @@ -0,0 +1,133 @@ +// Copyright 2020 The gVisor Authors. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package fuse + +import ( + "gvisor.dev/gvisor/pkg/abi/linux" + "gvisor.dev/gvisor/pkg/context" + "gvisor.dev/gvisor/pkg/sentry/fsimpl/kernfs" + "gvisor.dev/gvisor/pkg/sentry/kernel" + "gvisor.dev/gvisor/pkg/sentry/kernel/auth" + "gvisor.dev/gvisor/pkg/sentry/vfs" + "gvisor.dev/gvisor/pkg/usermem" +) + +// fileDescription implements vfs.FileDescriptionImpl for fuse. +type fileDescription struct { + vfsfd vfs.FileDescription + vfs.FileDescriptionDefaultImpl + vfs.DentryMetadataFileDescriptionImpl + vfs.NoLockFD + + // the file handle used in userspace. + Fh uint64 + + // Nonseekable is indicate cannot perform seek on a file. + Nonseekable bool + + // DirectIO suggest fuse to use direct io operation. + DirectIO bool + + // OpenFlag is the flag returned by open. + OpenFlag uint32 + + // off is the file offset. + off int64 +} + +func (fd *fileDescription) dentry() *kernfs.Dentry { + return fd.vfsfd.Dentry().Impl().(*kernfs.Dentry) +} + +func (fd *fileDescription) inode() *inode { + return fd.dentry().Inode().(*inode) +} + +func (fd *fileDescription) filesystem() *vfs.Filesystem { + return fd.vfsfd.VirtualDentry().Mount().Filesystem() +} + +func (fd *fileDescription) statusFlags() uint32 { + return fd.vfsfd.StatusFlags() +} + +// Release implements vfs.FileDescriptionImpl.Release. +func (fd *fileDescription) Release(ctx context.Context) { + // no need to release if FUSE server doesn't implement Open. + conn := fd.inode().fs.conn + if conn.noOpen { + return + } + + in := linux.FUSEReleaseIn{ + Fh: fd.Fh, + Flags: fd.statusFlags(), + } + // TODO(gvisor.dev/issue/3245): add logic when we support file lock owner. + var opcode linux.FUSEOpcode + if fd.inode().Mode().IsDir() { + opcode = linux.FUSE_RELEASEDIR + } else { + opcode = linux.FUSE_RELEASE + } + kernelTask := kernel.TaskFromContext(ctx) + // ignoring errors and FUSE server reply is analogous to Linux's behavior. + req, err := conn.NewRequest(auth.CredentialsFromContext(ctx), uint32(kernelTask.ThreadID()), fd.inode().nodeID, opcode, &in) + if err != nil { + // No way to invoke Call() with an errored request. + return + } + // The reply will be ignored since no callback is defined in asyncCallBack(). + conn.CallAsync(kernelTask, req) +} + +// PRead implements vfs.FileDescriptionImpl.PRead. +func (fd *fileDescription) PRead(ctx context.Context, dst usermem.IOSequence, offset int64, opts vfs.ReadOptions) (int64, error) { + return 0, nil +} + +// Read implements vfs.FileDescriptionImpl.Read. +func (fd *fileDescription) Read(ctx context.Context, dst usermem.IOSequence, opts vfs.ReadOptions) (int64, error) { + return 0, nil +} + +// PWrite implements vfs.FileDescriptionImpl.PWrite. +func (fd *fileDescription) PWrite(ctx context.Context, src usermem.IOSequence, offset int64, opts vfs.WriteOptions) (int64, error) { + return 0, nil +} + +// Write implements vfs.FileDescriptionImpl.Write. +func (fd *fileDescription) Write(ctx context.Context, src usermem.IOSequence, opts vfs.WriteOptions) (int64, error) { + return 0, nil +} + +// Seek implements vfs.FileDescriptionImpl.Seek. +func (fd *fileDescription) Seek(ctx context.Context, offset int64, whence int32) (int64, error) { + return 0, nil +} + +// Stat implements vfs.FileDescriptionImpl.Stat. +func (fd *fileDescription) Stat(ctx context.Context, opts vfs.StatOptions) (linux.Statx, error) { + fs := fd.filesystem() + inode := fd.inode() + return inode.Stat(ctx, fs, opts) +} + +// SetStat implements vfs.FileDescriptionImpl.SetStat. +func (fd *fileDescription) SetStat(ctx context.Context, opts vfs.SetStatOptions) error { + fs := fd.filesystem() + creds := auth.CredentialsFromContext(ctx) + return fd.inode().setAttr(ctx, fs, creds, opts, true, fd.Fh) +} diff --git a/pkg/sentry/fsimpl/fuse/fusefs.go b/pkg/sentry/fsimpl/fuse/fusefs.go index 83c24ec25..b3573f80d 100644 --- a/pkg/sentry/fsimpl/fuse/fusefs.go +++ b/pkg/sentry/fsimpl/fuse/fusefs.go @@ -16,21 +16,30 @@ package fuse import ( + "math" "strconv" + "sync" + "sync/atomic" "gvisor.dev/gvisor/pkg/abi/linux" "gvisor.dev/gvisor/pkg/context" "gvisor.dev/gvisor/pkg/log" + "gvisor.dev/gvisor/pkg/marshal" "gvisor.dev/gvisor/pkg/sentry/fsimpl/kernfs" "gvisor.dev/gvisor/pkg/sentry/kernel" "gvisor.dev/gvisor/pkg/sentry/kernel/auth" "gvisor.dev/gvisor/pkg/sentry/vfs" "gvisor.dev/gvisor/pkg/syserror" + "gvisor.dev/gvisor/pkg/waiter" ) // Name is the default filesystem name. const Name = "fuse" +// maxActiveRequestsDefault is the default setting controlling the upper bound +// on the number of active requests at any given time. +const maxActiveRequestsDefault = 10000 + // FilesystemType implements vfs.FilesystemType. type FilesystemType struct{} @@ -56,6 +65,11 @@ type filesystemOptions struct { // exist at any time. Any further requests will block when trying to // Call the server. maxActiveRequests uint64 + + // maxRead is the max number of bytes to read, + // specified as "max_read" in fs parameters. + // If not specified by user, use math.MaxUint32 as default value. + maxRead uint32 } // filesystem implements vfs.FilesystemImpl. @@ -69,6 +83,9 @@ type filesystem struct { // opts is the options the fusefs is initialized with. opts *filesystemOptions + + // umounted is true if filesystem.Release() has been called. + umounted bool } // Name implements vfs.FilesystemType.Name. @@ -142,14 +159,29 @@ func (fsType FilesystemType) GetFilesystem(ctx context.Context, vfsObj *vfs.Virt // Set the maxInFlightRequests option. fsopts.maxActiveRequests = maxActiveRequestsDefault + if maxReadStr, ok := mopts["max_read"]; ok { + delete(mopts, "max_read") + maxRead, err := strconv.ParseUint(maxReadStr, 10, 32) + if err != nil { + log.Warningf("%s.GetFilesystem: invalid max_read: max_read=%s", fsType.Name(), maxReadStr) + return nil, nil, syserror.EINVAL + } + if maxRead < fuseMinMaxRead { + maxRead = fuseMinMaxRead + } + fsopts.maxRead = uint32(maxRead) + } else { + fsopts.maxRead = math.MaxUint32 + } + // Check for unparsed options. if len(mopts) != 0 { - log.Warningf("%s.GetFilesystem: unknown options: %v", fsType.Name(), mopts) + log.Warningf("%s.GetFilesystem: unsupported or unknown options: %v", fsType.Name(), mopts) return nil, nil, syserror.EINVAL } // Create a new FUSE filesystem. - fs, err := NewFUSEFilesystem(ctx, devMinor, &fsopts, fuseFd) + fs, err := newFUSEFilesystem(ctx, devMinor, &fsopts, fuseFd) if err != nil { log.Warningf("%s.NewFUSEFilesystem: failed with error: %v", fsType.Name(), err) return nil, nil, err @@ -165,26 +197,28 @@ func (fsType FilesystemType) GetFilesystem(ctx context.Context, vfsObj *vfs.Virt } // root is the fusefs root directory. - root := fs.newInode(creds, fsopts.rootMode) + root := fs.newRootInode(creds, fsopts.rootMode) return fs.VFSFilesystem(), root.VFSDentry(), nil } -// NewFUSEFilesystem creates a new FUSE filesystem. -func NewFUSEFilesystem(ctx context.Context, devMinor uint32, opts *filesystemOptions, device *vfs.FileDescription) (*filesystem, error) { - fs := &filesystem{ - devMinor: devMinor, - opts: opts, - } - - conn, err := newFUSEConnection(ctx, device, opts.maxActiveRequests) +// newFUSEFilesystem creates a new FUSE filesystem. +func newFUSEFilesystem(ctx context.Context, devMinor uint32, opts *filesystemOptions, device *vfs.FileDescription) (*filesystem, error) { + conn, err := newFUSEConnection(ctx, device, opts) if err != nil { log.Warningf("fuse.NewFUSEFilesystem: NewFUSEConnection failed with error: %v", err) return nil, syserror.EINVAL } - fs.conn = conn fuseFD := device.Impl().(*DeviceFD) + + fs := &filesystem{ + devMinor: devMinor, + opts: opts, + conn: conn, + } + + fs.VFSFilesystem().IncRef() fuseFD.fs = fs return fs, nil @@ -192,39 +226,375 @@ func NewFUSEFilesystem(ctx context.Context, devMinor uint32, opts *filesystemOpt // Release implements vfs.FilesystemImpl.Release. func (fs *filesystem) Release(ctx context.Context) { + fs.conn.fd.mu.Lock() + + fs.umounted = true + fs.conn.Abort(ctx) + // Notify all the waiters on this fd. + fs.conn.fd.waitQueue.Notify(waiter.EventIn) + + fs.conn.fd.mu.Unlock() + fs.Filesystem.VFSFilesystem().VirtualFilesystem().PutAnonBlockDevMinor(fs.devMinor) fs.Filesystem.Release(ctx) } // inode implements kernfs.Inode. type inode struct { + inodeRefs kernfs.InodeAttrs + kernfs.InodeDirectoryNoNewChildren kernfs.InodeNoDynamicLookup kernfs.InodeNotSymlink - kernfs.InodeDirectoryNoNewChildren kernfs.OrderedChildren + dentry kernfs.Dentry + + // the owning filesystem. fs is immutable. + fs *filesystem + + // metaDataMu protects the metadata of this inode. + metadataMu sync.Mutex + + nodeID uint64 + locks vfs.FileLocks - dentry kernfs.Dentry + // size of the file. + size uint64 + + // attributeVersion is the version of inode's attributes. + attributeVersion uint64 + + // attributeTime is the remaining vaild time of attributes. + attributeTime uint64 + + // version of the inode. + version uint64 + + // link is result of following a symbolic link. + link string } -func (fs *filesystem) newInode(creds *auth.Credentials, mode linux.FileMode) *kernfs.Dentry { - i := &inode{} - i.InodeAttrs.Init(creds, linux.UNNAMED_MAJOR, fs.devMinor, fs.NextIno(), linux.ModeDirectory|0755) +func (fs *filesystem) newRootInode(creds *auth.Credentials, mode linux.FileMode) *kernfs.Dentry { + i := &inode{fs: fs} + i.InodeAttrs.Init(creds, linux.UNNAMED_MAJOR, fs.devMinor, 1, linux.ModeDirectory|0755) i.OrderedChildren.Init(kernfs.OrderedChildrenOptions{}) i.dentry.Init(i) + i.nodeID = 1 + + return &i.dentry +} + +func (fs *filesystem) newInode(nodeID uint64, attr linux.FUSEAttr) *kernfs.Dentry { + i := &inode{fs: fs, nodeID: nodeID} + creds := auth.Credentials{EffectiveKGID: auth.KGID(attr.UID), EffectiveKUID: auth.KUID(attr.UID)} + i.InodeAttrs.Init(&creds, linux.UNNAMED_MAJOR, fs.devMinor, fs.NextIno(), linux.FileMode(attr.Mode)) + atomic.StoreUint64(&i.size, attr.Size) + i.OrderedChildren.Init(kernfs.OrderedChildrenOptions{}) + i.EnableLeakCheck() + i.dentry.Init(i) return &i.dentry } // Open implements kernfs.Inode.Open. func (i *inode) Open(ctx context.Context, rp *vfs.ResolvingPath, vfsd *vfs.Dentry, opts vfs.OpenOptions) (*vfs.FileDescription, error) { - fd, err := kernfs.NewGenericDirectoryFD(rp.Mount(), vfsd, &i.OrderedChildren, &i.locks, &opts) + isDir := i.InodeAttrs.Mode().IsDir() + // return error if specified to open directory but inode is not a directory. + if !isDir && opts.Mode.IsDir() { + return nil, syserror.ENOTDIR + } + if opts.Flags&linux.O_LARGEFILE == 0 && atomic.LoadUint64(&i.size) > linux.MAX_NON_LFS { + return nil, syserror.EOVERFLOW + } + + var fd *fileDescription + var fdImpl vfs.FileDescriptionImpl + if isDir { + directoryFD := &directoryFD{} + fd = &(directoryFD.fileDescription) + fdImpl = directoryFD + } else { + regularFD := ®ularFileFD{} + fd = &(regularFD.fileDescription) + fdImpl = regularFD + } + // FOPEN_KEEP_CACHE is the defualt flag for noOpen. + fd.OpenFlag = linux.FOPEN_KEEP_CACHE + + // Only send open request when FUSE server support open or is opening a directory. + if !i.fs.conn.noOpen || isDir { + kernelTask := kernel.TaskFromContext(ctx) + if kernelTask == nil { + log.Warningf("fusefs.Inode.Open: couldn't get kernel task from context") + return nil, syserror.EINVAL + } + + // Build the request. + var opcode linux.FUSEOpcode + if isDir { + opcode = linux.FUSE_OPENDIR + } else { + opcode = linux.FUSE_OPEN + } + + in := linux.FUSEOpenIn{Flags: opts.Flags & ^uint32(linux.O_CREAT|linux.O_EXCL|linux.O_NOCTTY)} + if !i.fs.conn.atomicOTrunc { + in.Flags &= ^uint32(linux.O_TRUNC) + } + + req, err := i.fs.conn.NewRequest(auth.CredentialsFromContext(ctx), uint32(kernelTask.ThreadID()), i.nodeID, opcode, &in) + if err != nil { + return nil, err + } + + // Send the request and receive the reply. + res, err := i.fs.conn.Call(kernelTask, req) + if err != nil { + return nil, err + } + if err := res.Error(); err == syserror.ENOSYS && !isDir { + i.fs.conn.noOpen = true + } else if err != nil { + return nil, err + } else { + out := linux.FUSEOpenOut{} + if err := res.UnmarshalPayload(&out); err != nil { + return nil, err + } + + // Process the reply. + fd.OpenFlag = out.OpenFlag + if isDir { + fd.OpenFlag &= ^uint32(linux.FOPEN_DIRECT_IO) + } + + fd.Fh = out.Fh + } + } + + // TODO(gvisor.dev/issue/3234): invalidate mmap after implemented it for FUSE Inode + fd.DirectIO = fd.OpenFlag&linux.FOPEN_DIRECT_IO != 0 + fdOptions := &vfs.FileDescriptionOptions{} + if fd.OpenFlag&linux.FOPEN_NONSEEKABLE != 0 { + fdOptions.DenyPRead = true + fdOptions.DenyPWrite = true + fd.Nonseekable = true + } + + // If we don't send SETATTR before open (which is indicated by atomicOTrunc) + // and O_TRUNC is set, update the inode's version number and clean existing data + // by setting the file size to 0. + if i.fs.conn.atomicOTrunc && opts.Flags&linux.O_TRUNC != 0 { + i.fs.conn.mu.Lock() + i.fs.conn.attributeVersion++ + i.attributeVersion = i.fs.conn.attributeVersion + atomic.StoreUint64(&i.size, 0) + i.fs.conn.mu.Unlock() + i.attributeTime = 0 + } + + if err := fd.vfsfd.Init(fdImpl, opts.Flags, rp.Mount(), vfsd, fdOptions); err != nil { + return nil, err + } + return &fd.vfsfd, nil +} + +// Lookup implements kernfs.Inode.Lookup. +func (i *inode) Lookup(ctx context.Context, name string) (*vfs.Dentry, error) { + in := linux.FUSELookupIn{Name: name} + return i.newEntry(ctx, name, 0, linux.FUSE_LOOKUP, &in) +} + +// IterDirents implements kernfs.Inode.IterDirents. +func (*inode) IterDirents(ctx context.Context, callback vfs.IterDirentsCallback, offset, relOffset int64) (int64, error) { + return offset, nil +} + +// Valid implements kernfs.Inode.Valid. +func (*inode) Valid(ctx context.Context) bool { + return true +} + +// NewFile implements kernfs.Inode.NewFile. +func (i *inode) NewFile(ctx context.Context, name string, opts vfs.OpenOptions) (*vfs.Dentry, error) { + kernelTask := kernel.TaskFromContext(ctx) + if kernelTask == nil { + log.Warningf("fusefs.Inode.NewFile: couldn't get kernel task from context", i.nodeID) + return nil, syserror.EINVAL + } + in := linux.FUSECreateIn{ + CreateMeta: linux.FUSECreateMeta{ + Flags: opts.Flags, + Mode: uint32(opts.Mode) | linux.S_IFREG, + Umask: uint32(kernelTask.FSContext().Umask()), + }, + Name: name, + } + return i.newEntry(ctx, name, linux.S_IFREG, linux.FUSE_CREATE, &in) +} + +// NewNode implements kernfs.Inode.NewNode. +func (i *inode) NewNode(ctx context.Context, name string, opts vfs.MknodOptions) (*vfs.Dentry, error) { + in := linux.FUSEMknodIn{ + MknodMeta: linux.FUSEMknodMeta{ + Mode: uint32(opts.Mode), + Rdev: linux.MakeDeviceID(uint16(opts.DevMajor), opts.DevMinor), + Umask: uint32(kernel.TaskFromContext(ctx).FSContext().Umask()), + }, + Name: name, + } + return i.newEntry(ctx, name, opts.Mode.FileType(), linux.FUSE_MKNOD, &in) +} + +// NewSymlink implements kernfs.Inode.NewSymlink. +func (i *inode) NewSymlink(ctx context.Context, name, target string) (*vfs.Dentry, error) { + in := linux.FUSESymLinkIn{ + Name: name, + Target: target, + } + return i.newEntry(ctx, name, linux.S_IFLNK, linux.FUSE_SYMLINK, &in) +} + +// Unlink implements kernfs.Inode.Unlink. +func (i *inode) Unlink(ctx context.Context, name string, child *vfs.Dentry) error { + kernelTask := kernel.TaskFromContext(ctx) + if kernelTask == nil { + log.Warningf("fusefs.Inode.newEntry: couldn't get kernel task from context", i.nodeID) + return syserror.EINVAL + } + in := linux.FUSEUnlinkIn{Name: name} + req, err := i.fs.conn.NewRequest(auth.CredentialsFromContext(ctx), uint32(kernelTask.ThreadID()), i.nodeID, linux.FUSE_UNLINK, &in) + if err != nil { + return err + } + res, err := i.fs.conn.Call(kernelTask, req) + if err != nil { + return err + } + // only return error, discard res. + if err := res.Error(); err != nil { + return err + } + return i.dentry.RemoveChildLocked(name, child) +} + +// NewDir implements kernfs.Inode.NewDir. +func (i *inode) NewDir(ctx context.Context, name string, opts vfs.MkdirOptions) (*vfs.Dentry, error) { + in := linux.FUSEMkdirIn{ + MkdirMeta: linux.FUSEMkdirMeta{ + Mode: uint32(opts.Mode), + Umask: uint32(kernel.TaskFromContext(ctx).FSContext().Umask()), + }, + Name: name, + } + return i.newEntry(ctx, name, linux.S_IFDIR, linux.FUSE_MKDIR, &in) +} + +// RmDir implements kernfs.Inode.RmDir. +func (i *inode) RmDir(ctx context.Context, name string, child *vfs.Dentry) error { + fusefs := i.fs + task, creds := kernel.TaskFromContext(ctx), auth.CredentialsFromContext(ctx) + + in := linux.FUSERmDirIn{Name: name} + req, err := fusefs.conn.NewRequest(creds, uint32(task.ThreadID()), i.nodeID, linux.FUSE_RMDIR, &in) if err != nil { + return err + } + + res, err := i.fs.conn.Call(task, req) + if err != nil { + return err + } + if err := res.Error(); err != nil { + return err + } + + // TODO(Before merging): When creating new nodes, should we add nodes to the ordered children? + // If so we'll probably need to call this. We will also need to add them with the writable flag when + // appropriate. + // return i.OrderedChildren.RmDir(ctx, name, child) + + return nil +} + +// newEntry calls FUSE server for entry creation and allocates corresponding entry according to response. +// Shared by FUSE_MKNOD, FUSE_MKDIR, FUSE_SYMLINK, FUSE_LINK and FUSE_LOOKUP. +func (i *inode) newEntry(ctx context.Context, name string, fileType linux.FileMode, opcode linux.FUSEOpcode, payload marshal.Marshallable) (*vfs.Dentry, error) { + kernelTask := kernel.TaskFromContext(ctx) + if kernelTask == nil { + log.Warningf("fusefs.Inode.newEntry: couldn't get kernel task from context", i.nodeID) + return nil, syserror.EINVAL + } + req, err := i.fs.conn.NewRequest(auth.CredentialsFromContext(ctx), uint32(kernelTask.ThreadID()), i.nodeID, opcode, payload) + if err != nil { + return nil, err + } + res, err := i.fs.conn.Call(kernelTask, req) + if err != nil { + return nil, err + } + if err := res.Error(); err != nil { + return nil, err + } + out := linux.FUSEEntryOut{} + if err := res.UnmarshalPayload(&out); err != nil { return nil, err } - return fd.VFSFileDescription(), nil + if opcode != linux.FUSE_LOOKUP && ((out.Attr.Mode&linux.S_IFMT)^uint32(fileType) != 0 || out.NodeID == 0 || out.NodeID == linux.FUSE_ROOT_ID) { + return nil, syserror.EIO + } + child := i.fs.newInode(out.NodeID, out.Attr) + if opcode == linux.FUSE_LOOKUP { + i.dentry.InsertChildLocked(name, child) + } else { + i.dentry.InsertChild(name, child) + } + return child.VFSDentry(), nil +} + +// Getlink implements kernfs.Inode.Getlink. +func (i *inode) Getlink(ctx context.Context, mnt *vfs.Mount) (vfs.VirtualDentry, string, error) { + path, err := i.Readlink(ctx, mnt) + return vfs.VirtualDentry{}, path, err +} + +// Readlink implements kernfs.Inode.Readlink. +func (i *inode) Readlink(ctx context.Context, mnt *vfs.Mount) (string, error) { + if i.Mode().FileType()&linux.S_IFLNK == 0 { + return "", syserror.EINVAL + } + if len(i.link) == 0 { + kernelTask := kernel.TaskFromContext(ctx) + if kernelTask == nil { + log.Warningf("fusefs.Inode.Readlink: couldn't get kernel task from context") + return "", syserror.EINVAL + } + req, err := i.fs.conn.NewRequest(auth.CredentialsFromContext(ctx), uint32(kernelTask.ThreadID()), i.nodeID, linux.FUSE_READLINK, &linux.FUSEEmptyIn{}) + if err != nil { + return "", err + } + res, err := i.fs.conn.Call(kernelTask, req) + if err != nil { + return "", err + } + i.link = string(res.data[res.hdr.SizeBytes():]) + if !mnt.Options().ReadOnly { + i.attributeTime = 0 + } + } + return i.link, nil +} + +// getFUSEAttr returns a linux.FUSEAttr of this inode stored in local cache. +// TODO(gvisor.dev/issue/3679): Add support for other fields. +func (i *inode) getFUSEAttr() linux.FUSEAttr { + return linux.FUSEAttr{ + Ino: i.Ino(), + Size: atomic.LoadUint64(&i.size), + Mode: uint32(i.Mode()), + } } // statFromFUSEAttr makes attributes from linux.FUSEAttr to linux.Statx. The @@ -280,45 +650,179 @@ func statFromFUSEAttr(attr linux.FUSEAttr, mask, devMinor uint32) linux.Statx { return stat } -// Stat implements kernfs.Inode.Stat. -func (i *inode) Stat(ctx context.Context, fs *vfs.Filesystem, opts vfs.StatOptions) (linux.Statx, error) { - fusefs := fs.Impl().(*filesystem) - conn := fusefs.conn - task, creds := kernel.TaskFromContext(ctx), auth.CredentialsFromContext(ctx) +// getAttr gets the attribute of this inode by issuing a FUSE_GETATTR request +// or read from local cache. It updates the corresponding attributes if +// necessary. +func (i *inode) getAttr(ctx context.Context, fs *vfs.Filesystem, opts vfs.StatOptions, flags uint32, fh uint64) (linux.FUSEAttr, error) { + attributeVersion := atomic.LoadUint64(&i.fs.conn.attributeVersion) + + // TODO(gvisor.dev/issue/3679): send the request only if + // - invalid local cache for fields specified in the opts.Mask + // - forced update + // - i.attributeTime expired + // If local cache is still valid, return local cache. + // Currently we always send a request, + // and we always set the metadata with the new result, + // unless attributeVersion has changed. + + task := kernel.TaskFromContext(ctx) if task == nil { log.Warningf("couldn't get kernel task from context") - return linux.Statx{}, syserror.EINVAL + return linux.FUSEAttr{}, syserror.EINVAL } - var in linux.FUSEGetAttrIn - // We don't set any attribute in the request, because in VFS2 fstat(2) will - // finally be translated into vfs.FilesystemImpl.StatAt() (see - // pkg/sentry/syscalls/linux/vfs2/stat.go), resulting in the same flow - // as stat(2). Thus GetAttrFlags and Fh variable will never be used in VFS2. - req, err := conn.NewRequest(creds, uint32(task.ThreadID()), i.Ino(), linux.FUSE_GETATTR, &in) + creds := auth.CredentialsFromContext(ctx) + + in := linux.FUSEGetAttrIn{ + GetAttrFlags: flags, + Fh: fh, + } + req, err := i.fs.conn.NewRequest(creds, uint32(task.ThreadID()), i.nodeID, linux.FUSE_GETATTR, &in) if err != nil { - return linux.Statx{}, err + return linux.FUSEAttr{}, err } - res, err := conn.Call(task, req) + res, err := i.fs.conn.Call(task, req) if err != nil { - return linux.Statx{}, err + return linux.FUSEAttr{}, err } if err := res.Error(); err != nil { - return linux.Statx{}, err + return linux.FUSEAttr{}, err } var out linux.FUSEGetAttrOut if err := res.UnmarshalPayload(&out); err != nil { - return linux.Statx{}, err + return linux.FUSEAttr{}, err } - // Set all metadata into kernfs.InodeAttrs. - if err := i.SetStat(ctx, fs, creds, vfs.SetStatOptions{ - Stat: statFromFUSEAttr(out.Attr, linux.STATX_ALL, fusefs.devMinor), + // Local version is newer, return the local one. + // Skip the update. + if attributeVersion != 0 && atomic.LoadUint64(&i.attributeVersion) > attributeVersion { + return i.getFUSEAttr(), nil + } + + // Set the metadata of kernfs.InodeAttrs. + if err := i.SetInodeStat(ctx, fs, creds, vfs.SetStatOptions{ + Stat: statFromFUSEAttr(out.Attr, linux.STATX_ALL, i.fs.devMinor), }); err != nil { + return linux.FUSEAttr{}, err + } + + // Set the size if no error (after SetStat() check). + atomic.StoreUint64(&i.size, out.Attr.Size) + + return out.Attr, nil +} + +// reviseAttr attempts to update the attributes for internal purposes +// by calling getAttr with a pre-specified mask. +// Used by read, write, lseek. +func (i *inode) reviseAttr(ctx context.Context, flags uint32, fh uint64) error { + // Never need atime for internal purposes. + _, err := i.getAttr(ctx, i.fs.VFSFilesystem(), vfs.StatOptions{ + Mask: linux.STATX_BASIC_STATS &^ linux.STATX_ATIME, + }, flags, fh) + return err +} + +// Stat implements kernfs.Inode.Stat. +func (i *inode) Stat(ctx context.Context, fs *vfs.Filesystem, opts vfs.StatOptions) (linux.Statx, error) { + attr, err := i.getAttr(ctx, fs, opts, 0, 0) + if err != nil { return linux.Statx{}, err } - return statFromFUSEAttr(out.Attr, opts.Mask, fusefs.devMinor), nil + return statFromFUSEAttr(attr, opts.Mask, i.fs.devMinor), nil +} + +// DecRef implements kernfs.Inode.DecRef. +func (i *inode) DecRef(context.Context) { + i.inodeRefs.DecRef(i.Destroy) +} + +// StatFS implements kernfs.Inode.StatFS. +func (i *inode) StatFS(ctx context.Context, fs *vfs.Filesystem) (linux.Statfs, error) { + // TODO(gvisor.dev/issues/3413): Complete the implementation of statfs. + return vfs.GenericStatFS(linux.FUSE_SUPER_MAGIC), nil +} + +// fattrMaskFromStats converts vfs.SetStatOptions.Stat.Mask to linux stats mask +// aligned with the attribute mask defined in include/linux/fs.h. +func fattrMaskFromStats(mask uint32) uint32 { + var fuseAttrMask uint32 + maskMap := map[uint32]uint32{ + linux.STATX_MODE: linux.FATTR_MODE, + linux.STATX_UID: linux.FATTR_UID, + linux.STATX_GID: linux.FATTR_GID, + linux.STATX_SIZE: linux.FATTR_SIZE, + linux.STATX_ATIME: linux.FATTR_ATIME, + linux.STATX_MTIME: linux.FATTR_MTIME, + linux.STATX_CTIME: linux.FATTR_CTIME, + } + for statxMask, fattrMask := range maskMap { + if mask&statxMask != 0 { + fuseAttrMask |= fattrMask + } + } + return fuseAttrMask +} + +// SetStat implements kernfs.Inode.SetStat. +func (i *inode) SetStat(ctx context.Context, fs *vfs.Filesystem, creds *auth.Credentials, opts vfs.SetStatOptions) error { + return i.setAttr(ctx, fs, creds, opts, false, 0) +} + +func (i *inode) setAttr(ctx context.Context, fs *vfs.Filesystem, creds *auth.Credentials, opts vfs.SetStatOptions, useFh bool, fh uint64) error { + conn := i.fs.conn + task := kernel.TaskFromContext(ctx) + if task == nil { + log.Warningf("couldn't get kernel task from context") + return syserror.EINVAL + } + + // We should retain the original file type when assigning new mode. + fileType := uint16(i.Mode()) & linux.S_IFMT + fattrMask := fattrMaskFromStats(opts.Stat.Mask) + if useFh { + fattrMask |= linux.FATTR_FH + } + in := linux.FUSESetAttrIn{ + Valid: fattrMask, + Fh: fh, + Size: opts.Stat.Size, + Atime: uint64(opts.Stat.Atime.Sec), + Mtime: uint64(opts.Stat.Mtime.Sec), + Ctime: uint64(opts.Stat.Ctime.Sec), + AtimeNsec: opts.Stat.Atime.Nsec, + MtimeNsec: opts.Stat.Mtime.Nsec, + CtimeNsec: opts.Stat.Ctime.Nsec, + Mode: uint32(fileType | opts.Stat.Mode), + UID: opts.Stat.UID, + GID: opts.Stat.GID, + } + req, err := conn.NewRequest(creds, uint32(task.ThreadID()), i.nodeID, linux.FUSE_SETATTR, &in) + if err != nil { + return err + } + + res, err := conn.Call(task, req) + if err != nil { + return err + } + if err := res.Error(); err != nil { + return err + } + out := linux.FUSEGetAttrOut{} + if err := res.UnmarshalPayload(&out); err != nil { + return err + } + + // Set the metadata of kernfs.InodeAttrs. + if err := i.SetInodeStat(ctx, fs, creds, vfs.SetStatOptions{ + Stat: statFromFUSEAttr(out.Attr, linux.STATX_ALL, i.fs.devMinor), + }); err != nil { + return err + } + + return nil } diff --git a/pkg/sentry/fsimpl/fuse/read_write.go b/pkg/sentry/fsimpl/fuse/read_write.go new file mode 100644 index 000000000..625d1547f --- /dev/null +++ b/pkg/sentry/fsimpl/fuse/read_write.go @@ -0,0 +1,242 @@ +// Copyright 2020 The gVisor Authors. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package fuse + +import ( + "io" + "sync/atomic" + + "gvisor.dev/gvisor/pkg/abi/linux" + "gvisor.dev/gvisor/pkg/context" + "gvisor.dev/gvisor/pkg/log" + "gvisor.dev/gvisor/pkg/sentry/kernel" + "gvisor.dev/gvisor/pkg/sentry/kernel/auth" + "gvisor.dev/gvisor/pkg/syserror" + "gvisor.dev/gvisor/pkg/usermem" +) + +// ReadInPages sends FUSE_READ requests for the size after round it up to +// a multiple of page size, blocks on it for reply, processes the reply +// and returns the payload (or joined payloads) as a byte slice. +// This is used for the general purpose reading. +// We do not support direct IO (which read the exact number of bytes) +// at this moment. +func (fs *filesystem) ReadInPages(ctx context.Context, fd *regularFileFD, off uint64, size uint32) ([][]byte, uint32, error) { + attributeVersion := atomic.LoadUint64(&fs.conn.attributeVersion) + + t := kernel.TaskFromContext(ctx) + if t == nil { + log.Warningf("fusefs.Read: couldn't get kernel task from context") + return nil, 0, syserror.EINVAL + } + + // Round up to a multiple of page size. + readSize, _ := usermem.PageRoundUp(uint64(size)) + + // One request cannnot exceed either maxRead or maxPages. + maxPages := fs.conn.maxRead >> usermem.PageShift + if maxPages > uint32(fs.conn.maxPages) { + maxPages = uint32(fs.conn.maxPages) + } + + var outs [][]byte + var sizeRead uint32 + + // readSize is a multiple of usermem.PageSize. + // Always request bytes as a multiple of pages. + pagesRead, pagesToRead := uint32(0), uint32(readSize>>usermem.PageShift) + + // Reuse the same struct for unmarshalling to avoid unnecessary memory allocation. + in := linux.FUSEReadIn{ + Fh: fd.Fh, + LockOwner: 0, // TODO(gvisor.dev/issue/3245): file lock + ReadFlags: 0, // TODO(gvisor.dev/issue/3245): |= linux.FUSE_READ_LOCKOWNER + Flags: fd.statusFlags(), + } + + // This loop is intended for fragmented read where the bytes to read is + // larger than either the maxPages or maxRead. + // For the majority of reads with normal size, this loop should only + // execute once. + for pagesRead < pagesToRead { + pagesCanRead := pagesToRead - pagesRead + if pagesCanRead > maxPages { + pagesCanRead = maxPages + } + + in.Offset = off + (uint64(pagesRead) << usermem.PageShift) + in.Size = pagesCanRead << usermem.PageShift + + req, err := fs.conn.NewRequest(auth.CredentialsFromContext(ctx), uint32(t.ThreadID()), fd.inode().nodeID, linux.FUSE_READ, &in) + if err != nil { + return nil, 0, err + } + + // TODO(gvisor.dev/issue/3247): support async read. + + res, err := fs.conn.Call(t, req) + if err != nil { + return nil, 0, err + } + if err := res.Error(); err != nil { + return nil, 0, err + } + + // Not enough bytes in response, + // either we reached EOF, + // or the FUSE server sends back a response + // that cannot even fit the hdr. + if len(res.data) <= res.hdr.SizeBytes() { + // We treat both case as EOF here for now + // since there is no reliable way to detect + // the over-short hdr case. + break + } + + // Directly using the slice to avoid extra copy. + out := res.data[res.hdr.SizeBytes():] + + outs = append(outs, out) + sizeRead += uint32(len(out)) + + pagesRead += pagesCanRead + } + + defer fs.ReadCallback(ctx, fd, off, size, sizeRead, attributeVersion) + + // No bytes returned: offset >= EOF. + if len(outs) == 0 { + return nil, 0, io.EOF + } + + return outs, sizeRead, nil +} + +// ReadCallback updates several information after receiving a read response. +// Due to readahead, sizeRead can be larger than size. +func (fs *filesystem) ReadCallback(ctx context.Context, fd *regularFileFD, off uint64, size uint32, sizeRead uint32, attributeVersion uint64) { + // TODO(gvisor.dev/issue/3247): support async read. + // If this is called by an async read, correctly process it. + // May need to update the signature. + + i := fd.inode() + // TODO(gvisor.dev/issue/1193): Invalidate or update atime. + + // Reached EOF. + if sizeRead < size { + // TODO(gvisor.dev/issue/3630): If we have writeback cache, then we need to fill this hole. + // Might need to update the buf to be returned from the Read(). + + // Update existing size. + newSize := off + uint64(sizeRead) + fs.conn.mu.Lock() + if attributeVersion == i.attributeVersion && newSize < atomic.LoadUint64(&i.size) { + fs.conn.attributeVersion++ + i.attributeVersion = i.fs.conn.attributeVersion + atomic.StoreUint64(&i.size, newSize) + } + fs.conn.mu.Unlock() + } +} + +// Write sends FUSE_WRITE requests and return the bytes +// written according to the response. +// +// Preconditions: len(data) == size. +func (fs *filesystem) Write(ctx context.Context, fd *regularFileFD, off uint64, size uint32, data []byte) (uint32, error) { + t := kernel.TaskFromContext(ctx) + if t == nil { + log.Warningf("fusefs.Read: couldn't get kernel task from context") + return 0, syserror.EINVAL + } + + // One request cannnot exceed either maxWrite or maxPages. + maxWrite := uint32(fs.conn.maxPages) << usermem.PageShift + if maxWrite > fs.conn.maxWrite { + maxWrite = fs.conn.maxWrite + } + + // Reuse the same struct for unmarshalling to avoid unnecessary memory allocation. + in := linux.FUSEWriteIn{ + Fh: fd.Fh, + // TODO(gvisor.dev/issue/3245): file lock + LockOwner: 0, + // TODO(gvisor.dev/issue/3245): |= linux.FUSE_READ_LOCKOWNER + // TODO(gvisor.dev/issue/3237): |= linux.FUSE_WRITE_CACHE (not added yet) + WriteFlags: 0, + Flags: fd.statusFlags(), + } + + var written uint32 + + // This loop is intended for fragmented write where the bytes to write is + // larger than either the maxWrite or maxPages or when bigWrites is false. + // Unless a small value for max_write is explicitly used, this loop + // is expected to execute only once for the majority of the writes. + for written < size { + toWrite := size - written + + // Limit the write size to one page. + // Note that the bigWrites flag is obsolete, + // latest libfuse always sets it on. + if !fs.conn.bigWrites && toWrite > usermem.PageSize { + toWrite = usermem.PageSize + } + + // Limit the write size to maxWrite. + if toWrite > maxWrite { + toWrite = maxWrite + } + + in.Offset = off + uint64(written) + in.Size = toWrite + + req, err := fs.conn.NewRequest(auth.CredentialsFromContext(ctx), uint32(t.ThreadID()), fd.inode().nodeID, linux.FUSE_WRITE, &in) + if err != nil { + return 0, err + } + + req.payload = data[written : written+toWrite] + + // TODO(gvisor.dev/issue/3247): support async write. + + res, err := fs.conn.Call(t, req) + if err != nil { + return 0, err + } + if err := res.Error(); err != nil { + return 0, err + } + + out := linux.FUSEWriteOut{} + if err := res.UnmarshalPayload(&out); err != nil { + return 0, err + } + + // Write more than requested? EIO. + if out.Size > toWrite { + return 0, syserror.EIO + } + + written += out.Size + + // Break if short write. Not necessarily an error. + if out.Size != toWrite { + break + } + } + + return written, nil +} diff --git a/pkg/sentry/fsimpl/fuse/regular_file.go b/pkg/sentry/fsimpl/fuse/regular_file.go new file mode 100644 index 000000000..5bdd096c3 --- /dev/null +++ b/pkg/sentry/fsimpl/fuse/regular_file.go @@ -0,0 +1,230 @@ +// Copyright 2020 The gVisor Authors. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package fuse + +import ( + "io" + "math" + "sync" + "sync/atomic" + + "gvisor.dev/gvisor/pkg/abi/linux" + "gvisor.dev/gvisor/pkg/context" + "gvisor.dev/gvisor/pkg/sentry/vfs" + "gvisor.dev/gvisor/pkg/syserror" + "gvisor.dev/gvisor/pkg/usermem" +) + +type regularFileFD struct { + fileDescription + + // off is the file offset. + off int64 + // offMu protects off. + offMu sync.Mutex +} + +// PRead implements vfs.FileDescriptionImpl.PRead. +func (fd *regularFileFD) PRead(ctx context.Context, dst usermem.IOSequence, offset int64, opts vfs.ReadOptions) (int64, error) { + if offset < 0 { + return 0, syserror.EINVAL + } + + // Check that flags are supported. + // + // TODO(gvisor.dev/issue/2601): Support select preadv2 flags. + if opts.Flags&^linux.RWF_HIPRI != 0 { + return 0, syserror.EOPNOTSUPP + } + + size := dst.NumBytes() + if size == 0 { + // Early return if count is 0. + return 0, nil + } else if size > math.MaxUint32 { + // FUSE only supports uint32 for size. + // Overflow. + return 0, syserror.EINVAL + } + + // TODO(gvisor.dev/issue/3678): Add direct IO support. + + inode := fd.inode() + + // Reading beyond EOF, update file size if outdated. + if uint64(offset+size) > atomic.LoadUint64(&inode.size) { + if err := inode.reviseAttr(ctx, linux.FUSE_GETATTR_FH, fd.Fh); err != nil { + return 0, err + } + // If the offset after update is still too large, return error. + if uint64(offset) >= atomic.LoadUint64(&inode.size) { + return 0, io.EOF + } + } + + // Truncate the read with updated file size. + fileSize := atomic.LoadUint64(&inode.size) + if uint64(offset+size) > fileSize { + size = int64(fileSize) - offset + } + + buffers, n, err := inode.fs.ReadInPages(ctx, fd, uint64(offset), uint32(size)) + if err != nil { + return 0, err + } + + // TODO(gvisor.dev/issue/3237): support indirect IO (e.g. caching), + // store the bytes that were read ahead. + + // Update the number of bytes to copy for short read. + if n < uint32(size) { + size = int64(n) + } + + // Copy the bytes read to the dst. + // This loop is intended for fragmented reads. + // For the majority of reads, this loop only execute once. + var copied int64 + for _, buffer := range buffers { + toCopy := int64(len(buffer)) + if copied+toCopy > size { + toCopy = size - copied + } + cp, err := dst.DropFirst64(copied).CopyOut(ctx, buffer[:toCopy]) + if err != nil { + return 0, err + } + if int64(cp) != toCopy { + return 0, syserror.EIO + } + copied += toCopy + } + + return copied, nil +} + +// Read implements vfs.FileDescriptionImpl.Read. +func (fd *regularFileFD) Read(ctx context.Context, dst usermem.IOSequence, opts vfs.ReadOptions) (int64, error) { + fd.offMu.Lock() + n, err := fd.PRead(ctx, dst, fd.off, opts) + fd.off += n + fd.offMu.Unlock() + return n, err +} + +// PWrite implements vfs.FileDescriptionImpl.PWrite. +func (fd *regularFileFD) PWrite(ctx context.Context, src usermem.IOSequence, offset int64, opts vfs.WriteOptions) (int64, error) { + n, _, err := fd.pwrite(ctx, src, offset, opts) + return n, err +} + +// Write implements vfs.FileDescriptionImpl.Write. +func (fd *regularFileFD) Write(ctx context.Context, src usermem.IOSequence, opts vfs.WriteOptions) (int64, error) { + fd.offMu.Lock() + n, off, err := fd.pwrite(ctx, src, fd.off, opts) + fd.off = off + fd.offMu.Unlock() + return n, err +} + +// pwrite returns the number of bytes written, final offset and error. The +// final offset should be ignored by PWrite. +func (fd *regularFileFD) pwrite(ctx context.Context, src usermem.IOSequence, offset int64, opts vfs.WriteOptions) (written, finalOff int64, err error) { + if offset < 0 { + return 0, offset, syserror.EINVAL + } + + // Check that flags are supported. + // + // TODO(gvisor.dev/issue/2601): Support select preadv2 flags. + if opts.Flags&^linux.RWF_HIPRI != 0 { + return 0, offset, syserror.EOPNOTSUPP + } + + inode := fd.inode() + inode.metadataMu.Lock() + defer inode.metadataMu.Unlock() + + // If the file is opened with O_APPEND, update offset to file size. + // Note: since our Open() implements the interface of kernfs, + // and kernfs currently does not support O_APPEND, this will never + // be true before we switch out from kernfs. + if fd.vfsfd.StatusFlags()&linux.O_APPEND != 0 { + // Locking inode.metadataMu is sufficient for reading size + offset = int64(inode.size) + } + + srclen := src.NumBytes() + + if srclen > math.MaxUint32 { + // FUSE only supports uint32 for size. + // Overflow. + return 0, offset, syserror.EINVAL + } + if end := offset + srclen; end < offset { + // Overflow. + return 0, offset, syserror.EINVAL + } + + srclen, err = vfs.CheckLimit(ctx, offset, srclen) + if err != nil { + return 0, offset, err + } + + if srclen == 0 { + // Return before causing any side effects. + return 0, offset, nil + } + + src = src.TakeFirst64(srclen) + + // TODO(gvisor.dev/issue/3237): Add cache support: + // buffer cache. Ideally we write from src to our buffer cache first. + // The slice passed to fs.Write() should be a slice from buffer cache. + data := make([]byte, srclen) + // Reason for making a copy here: connection.Call() blocks on kerneltask, + // which in turn acquires mm.activeMu lock. Functions like CopyInTo() will + // attemp to acquire the mm.activeMu lock as well -> deadlock. + // We must finish reading from the userspace memory before + // t.Block() deactivates it. + cp, err := src.CopyIn(ctx, data) + if err != nil { + return 0, offset, err + } + if int64(cp) != srclen { + return 0, offset, syserror.EIO + } + + n, err := fd.inode().fs.Write(ctx, fd, uint64(offset), uint32(srclen), data) + if err != nil { + return 0, offset, err + } + + if n == 0 { + // We have checked srclen != 0 previously. + // If err == nil, then it's a short write and we return EIO. + return 0, offset, syserror.EIO + } + + written = int64(n) + finalOff = offset + written + + if finalOff > int64(inode.size) { + atomic.StoreUint64(&inode.size, uint64(finalOff)) + atomic.AddUint64(&inode.fs.conn.attributeVersion, 1) + } + + return +} diff --git a/pkg/sentry/fsimpl/fuse/request_response.go b/pkg/sentry/fsimpl/fuse/request_response.go new file mode 100644 index 000000000..7fa00569b --- /dev/null +++ b/pkg/sentry/fsimpl/fuse/request_response.go @@ -0,0 +1,229 @@ +// Copyright 2020 The gVisor Authors. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package fuse + +import ( + "fmt" + "syscall" + + "gvisor.dev/gvisor/pkg/abi/linux" + "gvisor.dev/gvisor/pkg/marshal" + "gvisor.dev/gvisor/pkg/sentry/kernel" + "gvisor.dev/gvisor/pkg/sentry/kernel/auth" + "gvisor.dev/gvisor/pkg/usermem" +) + +// fuseInitRes is a variable-length wrapper of linux.FUSEInitOut. The FUSE +// server may implement an older version of FUSE protocol, which contains a +// linux.FUSEInitOut with less attributes. +// +// Dynamically-sized objects cannot be marshalled. +type fuseInitRes struct { + marshal.StubMarshallable + + // initOut contains the response from the FUSE server. + initOut linux.FUSEInitOut + + // initLen is the total length of bytes of the response. + initLen uint32 +} + +// UnmarshalBytes deserializes src to the initOut attribute in a fuseInitRes. +func (r *fuseInitRes) UnmarshalBytes(src []byte) { + out := &r.initOut + + // Introduced before FUSE kernel version 7.13. + out.Major = uint32(usermem.ByteOrder.Uint32(src[:4])) + src = src[4:] + out.Minor = uint32(usermem.ByteOrder.Uint32(src[:4])) + src = src[4:] + out.MaxReadahead = uint32(usermem.ByteOrder.Uint32(src[:4])) + src = src[4:] + out.Flags = uint32(usermem.ByteOrder.Uint32(src[:4])) + src = src[4:] + out.MaxBackground = uint16(usermem.ByteOrder.Uint16(src[:2])) + src = src[2:] + out.CongestionThreshold = uint16(usermem.ByteOrder.Uint16(src[:2])) + src = src[2:] + out.MaxWrite = uint32(usermem.ByteOrder.Uint32(src[:4])) + src = src[4:] + + // Introduced in FUSE kernel version 7.23. + if len(src) >= 4 { + out.TimeGran = uint32(usermem.ByteOrder.Uint32(src[:4])) + src = src[4:] + } + // Introduced in FUSE kernel version 7.28. + if len(src) >= 2 { + out.MaxPages = uint16(usermem.ByteOrder.Uint16(src[:2])) + src = src[2:] + } +} + +// SizeBytes is the size of the payload of the FUSE_INIT response. +func (r *fuseInitRes) SizeBytes() int { + return int(r.initLen) +} + +// Ordinary requests have even IDs, while interrupts IDs are odd. +// Used to increment the unique ID for each FUSE request. +var reqIDStep uint64 = 2 + +// Request represents a FUSE operation request that hasn't been sent to the +// server yet. +// +// +stateify savable +type Request struct { + requestEntry + + id linux.FUSEOpID + hdr *linux.FUSEHeaderIn + data []byte + + // payload for this request: extra bytes to write after + // the data slice. Used by FUSE_WRITE. + payload []byte + + // If this request is async. + async bool + // If we don't care its response. + // Manually set by the caller. + noReply bool +} + +// NewRequest creates a new request that can be sent to the FUSE server. +func (conn *connection) NewRequest(creds *auth.Credentials, pid uint32, ino uint64, opcode linux.FUSEOpcode, payload marshal.Marshallable) (*Request, error) { + conn.fd.mu.Lock() + defer conn.fd.mu.Unlock() + conn.fd.nextOpID += linux.FUSEOpID(reqIDStep) + + hdrLen := (*linux.FUSEHeaderIn)(nil).SizeBytes() + hdr := linux.FUSEHeaderIn{ + Len: uint32(hdrLen + payload.SizeBytes()), + Opcode: opcode, + Unique: conn.fd.nextOpID, + NodeID: ino, + UID: uint32(creds.EffectiveKUID), + GID: uint32(creds.EffectiveKGID), + PID: pid, + } + + buf := make([]byte, hdr.Len) + + // TODO(gVisor.dev/issue/3698): Use the unsafe version once go_marshal is safe to use again. + hdr.MarshalBytes(buf[:hdrLen]) + payload.MarshalBytes(buf[hdrLen:]) + + return &Request{ + id: hdr.Unique, + hdr: &hdr, + data: buf, + }, nil +} + +// futureResponse represents an in-flight request, that may or may not have +// completed yet. Convert it to a resolved Response by calling Resolve, but note +// that this may block. +// +// +stateify savable +type futureResponse struct { + opcode linux.FUSEOpcode + ch chan struct{} + hdr *linux.FUSEHeaderOut + data []byte + + // If this request is async. + async bool +} + +// newFutureResponse creates a future response to a FUSE request. +func newFutureResponse(req *Request) *futureResponse { + return &futureResponse{ + opcode: req.hdr.Opcode, + ch: make(chan struct{}), + async: req.async, + } +} + +// resolve blocks the task until the server responds to its corresponding request, +// then returns a resolved response. +func (f *futureResponse) resolve(t *kernel.Task) (*Response, error) { + // Return directly for async requests. + if f.async { + return nil, nil + } + + if err := t.Block(f.ch); err != nil { + return nil, err + } + + return f.getResponse(), nil +} + +// getResponse creates a Response from the data the futureResponse has. +func (f *futureResponse) getResponse() *Response { + return &Response{ + opcode: f.opcode, + hdr: *f.hdr, + data: f.data, + } +} + +// Response represents an actual response from the server, including the +// response payload. +// +// +stateify savable +type Response struct { + opcode linux.FUSEOpcode + hdr linux.FUSEHeaderOut + data []byte +} + +// Error returns the error of the FUSE call. +func (r *Response) Error() error { + errno := r.hdr.Error + if errno >= 0 { + return nil + } + + sysErrNo := syscall.Errno(-errno) + return error(sysErrNo) +} + +// DataLen returns the size of the response without the header. +func (r *Response) DataLen() uint32 { + return r.hdr.Len - uint32(r.hdr.SizeBytes()) +} + +// UnmarshalPayload unmarshals the response data into m. +func (r *Response) UnmarshalPayload(m marshal.Marshallable) error { + hdrLen := r.hdr.SizeBytes() + haveDataLen := r.hdr.Len - uint32(hdrLen) + wantDataLen := uint32(m.SizeBytes()) + + if haveDataLen < wantDataLen { + return fmt.Errorf("payload too small. Minimum data lenth required: %d, but got data length %d", wantDataLen, haveDataLen) + } + + // The response data is empty unless there is some payload. And so, doesn't + // need to be unmarshalled. + if r.data == nil { + return nil + } + + // TODO(gVisor.dev/issue/3698): Use the unsafe version once go_marshal is safe to use again. + m.UnmarshalBytes(r.data[hdrLen:]) + return nil +} diff --git a/pkg/sentry/fsimpl/fuse/utils_test.go b/pkg/sentry/fsimpl/fuse/utils_test.go new file mode 100644 index 000000000..e1d9e3365 --- /dev/null +++ b/pkg/sentry/fsimpl/fuse/utils_test.go @@ -0,0 +1,132 @@ +// Copyright 2020 The gVisor Authors. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package fuse + +import ( + "io" + "testing" + + "gvisor.dev/gvisor/pkg/abi/linux" + "gvisor.dev/gvisor/pkg/marshal" + "gvisor.dev/gvisor/pkg/sentry/fsimpl/testutil" + "gvisor.dev/gvisor/pkg/sentry/kernel" + "gvisor.dev/gvisor/pkg/sentry/kernel/auth" + "gvisor.dev/gvisor/pkg/sentry/vfs" + "gvisor.dev/gvisor/pkg/usermem" +) + +func setup(t *testing.T) *testutil.System { + k, err := testutil.Boot() + if err != nil { + t.Fatalf("Error creating kernel: %v", err) + } + + ctx := k.SupervisorContext() + creds := auth.CredentialsFromContext(ctx) + + k.VFS().MustRegisterFilesystemType(Name, &FilesystemType{}, &vfs.RegisterFilesystemTypeOptions{ + AllowUserList: true, + AllowUserMount: true, + }) + + mntns, err := k.VFS().NewMountNamespace(ctx, creds, "", "tmpfs", &vfs.MountOptions{}) + if err != nil { + t.Fatalf("NewMountNamespace(): %v", err) + } + + return testutil.NewSystem(ctx, t, k.VFS(), mntns) +} + +// newTestConnection creates a fuse connection that the sentry can communicate with +// and the FD for the server to communicate with. +func newTestConnection(system *testutil.System, k *kernel.Kernel, maxActiveRequests uint64) (*connection, *vfs.FileDescription, error) { + vfsObj := &vfs.VirtualFilesystem{} + fuseDev := &DeviceFD{} + + if err := vfsObj.Init(system.Ctx); err != nil { + return nil, nil, err + } + + vd := vfsObj.NewAnonVirtualDentry("genCountFD") + defer vd.DecRef(system.Ctx) + if err := fuseDev.vfsfd.Init(fuseDev, linux.O_RDWR|linux.O_CREAT, vd.Mount(), vd.Dentry(), &vfs.FileDescriptionOptions{}); err != nil { + return nil, nil, err + } + + fsopts := filesystemOptions{ + maxActiveRequests: maxActiveRequests, + } + fs, err := newFUSEFilesystem(system.Ctx, 0, &fsopts, &fuseDev.vfsfd) + if err != nil { + return nil, nil, err + } + + return fs.conn, &fuseDev.vfsfd, nil +} + +type testPayload struct { + marshal.StubMarshallable + data uint32 +} + +// SizeBytes implements marshal.Marshallable.SizeBytes. +func (t *testPayload) SizeBytes() int { + return 4 +} + +// MarshalBytes implements marshal.Marshallable.MarshalBytes. +func (t *testPayload) MarshalBytes(dst []byte) { + usermem.ByteOrder.PutUint32(dst[:4], t.data) +} + +// UnmarshalBytes implements marshal.Marshallable.UnmarshalBytes. +func (t *testPayload) UnmarshalBytes(src []byte) { + *t = testPayload{data: usermem.ByteOrder.Uint32(src[:4])} +} + +// Packed implements marshal.Marshallable.Packed. +func (t *testPayload) Packed() bool { + return true +} + +// MarshalUnsafe implements marshal.Marshallable.MarshalUnsafe. +func (t *testPayload) MarshalUnsafe(dst []byte) { + t.MarshalBytes(dst) +} + +// UnmarshalUnsafe implements marshal.Marshallable.UnmarshalUnsafe. +func (t *testPayload) UnmarshalUnsafe(src []byte) { + t.UnmarshalBytes(src) +} + +// CopyOutN implements marshal.Marshallable.CopyOutN. +func (t *testPayload) CopyOutN(task marshal.CopyContext, addr usermem.Addr, limit int) (int, error) { + panic("not implemented") +} + +// CopyOut implements marshal.Marshallable.CopyOut. +func (t *testPayload) CopyOut(task marshal.CopyContext, addr usermem.Addr) (int, error) { + panic("not implemented") +} + +// CopyIn implements marshal.Marshallable.CopyIn. +func (t *testPayload) CopyIn(task marshal.CopyContext, addr usermem.Addr) (int, error) { + panic("not implemented") +} + +// WriteTo implements io.WriterTo.WriteTo. +func (t *testPayload) WriteTo(w io.Writer) (int64, error) { + panic("not implemented") +} diff --git a/pkg/sentry/fsimpl/gofer/directory.go b/pkg/sentry/fsimpl/gofer/directory.go index 2a8011eb4..91d2ae199 100644 --- a/pkg/sentry/fsimpl/gofer/directory.go +++ b/pkg/sentry/fsimpl/gofer/directory.go @@ -34,8 +34,11 @@ func (d *dentry) isDir() bool { return d.fileType() == linux.S_IFDIR } -// Preconditions: filesystem.renameMu must be locked. d.dirMu must be locked. -// d.isDir(). child must be a newly-created dentry that has never had a parent. +// Preconditions: +// * filesystem.renameMu must be locked. +// * d.dirMu must be locked. +// * d.isDir(). +// * child must be a newly-created dentry that has never had a parent. func (d *dentry) cacheNewChildLocked(child *dentry, name string) { d.IncRef() // reference held by child on its parent child.parent = d @@ -46,7 +49,9 @@ func (d *dentry) cacheNewChildLocked(child *dentry, name string) { d.children[name] = child } -// Preconditions: d.dirMu must be locked. d.isDir(). +// Preconditions: +// * d.dirMu must be locked. +// * d.isDir(). func (d *dentry) cacheNegativeLookupLocked(name string) { // Don't cache negative lookups if InteropModeShared is in effect (since // this makes remote lookup unavoidable), or if d.isSynthetic() (in which @@ -79,10 +84,12 @@ type createSyntheticOpts struct { // createSyntheticChildLocked creates a synthetic file with the given name // in d. // -// Preconditions: d.dirMu must be locked. d.isDir(). d does not already contain -// a child with the given name. +// Preconditions: +// * d.dirMu must be locked. +// * d.isDir(). +// * d does not already contain a child with the given name. func (d *dentry) createSyntheticChildLocked(opts *createSyntheticOpts) { - d2 := &dentry{ + child := &dentry{ refs: 1, // held by d fs: d.fs, ino: d.fs.nextSyntheticIno(), @@ -97,16 +104,16 @@ func (d *dentry) createSyntheticChildLocked(opts *createSyntheticOpts) { case linux.S_IFDIR: // Nothing else needs to be done. case linux.S_IFSOCK: - d2.endpoint = opts.endpoint + child.endpoint = opts.endpoint case linux.S_IFIFO: - d2.pipe = opts.pipe + child.pipe = opts.pipe default: panic(fmt.Sprintf("failed to create synthetic file of unrecognized type: %v", opts.mode.FileType())) } - d2.pf.dentry = d2 - d2.vfsd.Init(d2) + child.pf.dentry = child + child.vfsd.Init(child) - d.cacheNewChildLocked(d2, opts.name) + d.cacheNewChildLocked(child, opts.name) d.syntheticChildren++ } @@ -151,7 +158,9 @@ func (fd *directoryFD) IterDirents(ctx context.Context, cb vfs.IterDirentsCallba return nil } -// Preconditions: d.isDir(). There exists at least one directoryFD representing d. +// Preconditions: +// * d.isDir(). +// * There exists at least one directoryFD representing d. func (d *dentry) getDirents(ctx context.Context) ([]vfs.Dirent, error) { // NOTE(b/135560623): 9P2000.L's readdir does not specify behavior in the // presence of concurrent mutation of an iterated directory, so diff --git a/pkg/sentry/fsimpl/gofer/filesystem.go b/pkg/sentry/fsimpl/gofer/filesystem.go index a3903db33..97b9165cc 100644 --- a/pkg/sentry/fsimpl/gofer/filesystem.go +++ b/pkg/sentry/fsimpl/gofer/filesystem.go @@ -115,9 +115,12 @@ func putDentrySlice(ds *[]*dentry) { // Dentries which may become cached as a result of the traversal are appended // to *ds. // -// Preconditions: fs.renameMu must be locked. d.dirMu must be locked. -// !rp.Done(). If !d.cachedMetadataAuthoritative(), then d's cached metadata -// must be up to date. +// Preconditions: +// * fs.renameMu must be locked. +// * d.dirMu must be locked. +// * !rp.Done(). +// * If !d.cachedMetadataAuthoritative(), then d's cached metadata must be up +// to date. // // Postconditions: The returned dentry's cached metadata is up to date. func (fs *filesystem) stepLocked(ctx context.Context, rp *vfs.ResolvingPath, d *dentry, mayFollowSymlinks bool, ds **[]*dentry) (*dentry, error) { @@ -185,8 +188,11 @@ afterSymlink: // getChildLocked returns a dentry representing the child of parent with the // given name. If no such child exists, getChildLocked returns (nil, nil). // -// Preconditions: fs.renameMu must be locked. parent.dirMu must be locked. -// parent.isDir(). name is not "." or "..". +// Preconditions: +// * fs.renameMu must be locked. +// * parent.dirMu must be locked. +// * parent.isDir(). +// * name is not "." or "..". // // Postconditions: If getChildLocked returns a non-nil dentry, its cached // metadata is up to date. @@ -206,7 +212,8 @@ func (fs *filesystem) getChildLocked(ctx context.Context, vfsObj *vfs.VirtualFil return fs.revalidateChildLocked(ctx, vfsObj, parent, name, child, ds) } -// Preconditions: As for getChildLocked. !parent.isSynthetic(). +// Preconditions: Same as getChildLocked, plus: +// * !parent.isSynthetic(). func (fs *filesystem) revalidateChildLocked(ctx context.Context, vfsObj *vfs.VirtualFilesystem, parent *dentry, name string, child *dentry, ds **[]*dentry) (*dentry, error) { if child != nil { // Need to lock child.metadataMu because we might be updating child @@ -279,9 +286,11 @@ func (fs *filesystem) revalidateChildLocked(ctx context.Context, vfsObj *vfs.Vir // rp.Start().Impl().(*dentry)). It does not check that the returned directory // is searchable by the provider of rp. // -// Preconditions: fs.renameMu must be locked. !rp.Done(). If -// !d.cachedMetadataAuthoritative(), then d's cached metadata must be up to -// date. +// Preconditions: +// * fs.renameMu must be locked. +// * !rp.Done(). +// * If !d.cachedMetadataAuthoritative(), then d's cached metadata must be up +// to date. func (fs *filesystem) walkParentDirLocked(ctx context.Context, rp *vfs.ResolvingPath, d *dentry, ds **[]*dentry) (*dentry, error) { for !rp.Final() { d.dirMu.Lock() @@ -328,9 +337,10 @@ func (fs *filesystem) resolveLocked(ctx context.Context, rp *vfs.ResolvingPath, // createInRemoteDir (if the parent directory is a real remote directory) or // createInSyntheticDir (if the parent directory is synthetic) to do so. // -// Preconditions: !rp.Done(). For the final path component in rp, -// !rp.ShouldFollowSymlink(). -func (fs *filesystem) doCreateAt(ctx context.Context, rp *vfs.ResolvingPath, dir bool, createInRemoteDir func(parent *dentry, name string) error, createInSyntheticDir func(parent *dentry, name string) error) error { +// Preconditions: +// * !rp.Done(). +// * For the final path component in rp, !rp.ShouldFollowSymlink(). +func (fs *filesystem) doCreateAt(ctx context.Context, rp *vfs.ResolvingPath, dir bool, createInRemoteDir func(parent *dentry, name string, ds **[]*dentry) error, createInSyntheticDir func(parent *dentry, name string) error) error { var ds *[]*dentry fs.renameMu.RLock() defer fs.renameMuRUnlockAndCheckCaching(ctx, &ds) @@ -399,7 +409,7 @@ func (fs *filesystem) doCreateAt(ctx context.Context, rp *vfs.ResolvingPath, dir // RPC will fail with EEXIST like we would have. If the RPC succeeds, and a // stale dentry exists, the dentry will fail revalidation next time it's // used. - if err := createInRemoteDir(parent, name); err != nil { + if err := createInRemoteDir(parent, name, &ds); err != nil { return err } ev := linux.IN_CREATE @@ -414,7 +424,7 @@ func (fs *filesystem) doCreateAt(ctx context.Context, rp *vfs.ResolvingPath, dir } // No cached dentry exists; however, there might still be an existing file // at name. As above, we attempt the file creation RPC anyway. - if err := createInRemoteDir(parent, name); err != nil { + if err := createInRemoteDir(parent, name, &ds); err != nil { return err } if child, ok := parent.children[name]; ok && child == nil { @@ -721,7 +731,7 @@ func (fs *filesystem) GetParentDentryAt(ctx context.Context, rp *vfs.ResolvingPa // LinkAt implements vfs.FilesystemImpl.LinkAt. func (fs *filesystem) LinkAt(ctx context.Context, rp *vfs.ResolvingPath, vd vfs.VirtualDentry) error { - return fs.doCreateAt(ctx, rp, false /* dir */, func(parent *dentry, childName string) error { + return fs.doCreateAt(ctx, rp, false /* dir */, func(parent *dentry, childName string, _ **[]*dentry) error { if rp.Mount() != vd.Mount() { return syserror.EXDEV } @@ -754,7 +764,7 @@ func (fs *filesystem) LinkAt(ctx context.Context, rp *vfs.ResolvingPath, vd vfs. // MkdirAt implements vfs.FilesystemImpl.MkdirAt. func (fs *filesystem) MkdirAt(ctx context.Context, rp *vfs.ResolvingPath, opts vfs.MkdirOptions) error { creds := rp.Credentials() - return fs.doCreateAt(ctx, rp, true /* dir */, func(parent *dentry, name string) error { + return fs.doCreateAt(ctx, rp, true /* dir */, func(parent *dentry, name string, _ **[]*dentry) error { if _, err := parent.file.mkdir(ctx, name, (p9.FileMode)(opts.Mode), (p9.UID)(creds.EffectiveKUID), (p9.GID)(creds.EffectiveKGID)); err != nil { if !opts.ForSyntheticMountpoint || err == syserror.EEXIST { return err @@ -789,34 +799,49 @@ func (fs *filesystem) MkdirAt(ctx context.Context, rp *vfs.ResolvingPath, opts v // MknodAt implements vfs.FilesystemImpl.MknodAt. func (fs *filesystem) MknodAt(ctx context.Context, rp *vfs.ResolvingPath, opts vfs.MknodOptions) error { - return fs.doCreateAt(ctx, rp, false /* dir */, func(parent *dentry, name string) error { + return fs.doCreateAt(ctx, rp, false /* dir */, func(parent *dentry, name string, ds **[]*dentry) error { creds := rp.Credentials() _, err := parent.file.mknod(ctx, name, (p9.FileMode)(opts.Mode), opts.DevMajor, opts.DevMinor, (p9.UID)(creds.EffectiveKUID), (p9.GID)(creds.EffectiveKGID)) - // If the gofer does not allow creating a socket or pipe, create a - // synthetic one, i.e. one that is kept entirely in memory. - if err == syserror.EPERM { - switch opts.Mode.FileType() { - case linux.S_IFSOCK: - parent.createSyntheticChildLocked(&createSyntheticOpts{ - name: name, - mode: opts.Mode, - kuid: creds.EffectiveKUID, - kgid: creds.EffectiveKGID, - endpoint: opts.Endpoint, - }) - return nil - case linux.S_IFIFO: - parent.createSyntheticChildLocked(&createSyntheticOpts{ - name: name, - mode: opts.Mode, - kuid: creds.EffectiveKUID, - kgid: creds.EffectiveKGID, - pipe: pipe.NewVFSPipe(true /* isNamed */, pipe.DefaultPipeSize, usermem.PageSize), - }) - return nil - } + if err != syserror.EPERM { + return err } - return err + + // EPERM means that gofer does not allow creating a socket or pipe. Fallback + // to creating a synthetic one, i.e. one that is kept entirely in memory. + + // Check that we're not overriding an existing file with a synthetic one. + _, err = fs.stepLocked(ctx, rp, parent, true, ds) + switch { + case err == nil: + // Step succeeded, another file exists. + return syserror.EEXIST + case err != syserror.ENOENT: + // Unexpected error. + return err + } + + switch opts.Mode.FileType() { + case linux.S_IFSOCK: + parent.createSyntheticChildLocked(&createSyntheticOpts{ + name: name, + mode: opts.Mode, + kuid: creds.EffectiveKUID, + kgid: creds.EffectiveKGID, + endpoint: opts.Endpoint, + }) + return nil + case linux.S_IFIFO: + parent.createSyntheticChildLocked(&createSyntheticOpts{ + name: name, + mode: opts.Mode, + kuid: creds.EffectiveKUID, + kgid: creds.EffectiveKGID, + pipe: pipe.NewVFSPipe(true /* isNamed */, pipe.DefaultPipeSize, usermem.PageSize), + }) + return nil + } + // Retain error from gofer if synthetic file cannot be created internally. + return syserror.EPERM }, nil) } @@ -834,7 +859,14 @@ func (fs *filesystem) OpenAt(ctx context.Context, rp *vfs.ResolvingPath, opts vf var ds *[]*dentry fs.renameMu.RLock() - defer fs.renameMuRUnlockAndCheckCaching(ctx, &ds) + unlocked := false + unlock := func() { + if !unlocked { + fs.renameMuRUnlockAndCheckCaching(ctx, &ds) + unlocked = true + } + } + defer unlock() start := rp.Start().Impl().(*dentry) if !start.cachedMetadataAuthoritative() { @@ -851,7 +883,10 @@ func (fs *filesystem) OpenAt(ctx context.Context, rp *vfs.ResolvingPath, opts vf if mustCreate { return nil, syserror.EEXIST } - return start.openLocked(ctx, rp, &opts) + start.IncRef() + defer start.DecRef(ctx) + unlock() + return start.open(ctx, rp, &opts) } afterTrailingSymlink: @@ -901,11 +936,15 @@ afterTrailingSymlink: if rp.MustBeDir() && !child.isDir() { return nil, syserror.ENOTDIR } - return child.openLocked(ctx, rp, &opts) + child.IncRef() + defer child.DecRef(ctx) + unlock() + return child.open(ctx, rp, &opts) } -// Preconditions: fs.renameMu must be locked. -func (d *dentry) openLocked(ctx context.Context, rp *vfs.ResolvingPath, opts *vfs.OpenOptions) (*vfs.FileDescription, error) { +// Preconditions: The caller must hold no locks (since opening pipes may block +// indefinitely). +func (d *dentry) open(ctx context.Context, rp *vfs.ResolvingPath, opts *vfs.OpenOptions) (*vfs.FileDescription, error) { ats := vfs.AccessTypesForOpenFlags(opts) if err := d.checkPermissions(rp.Credentials(), ats); err != nil { return nil, err @@ -968,7 +1007,7 @@ func (d *dentry) openLocked(ctx context.Context, rp *vfs.ResolvingPath, opts *vf return nil, syserror.ENXIO } if d.fs.iopts.OpenSocketsByConnecting { - return d.connectSocketLocked(ctx, opts) + return d.openSocketByConnecting(ctx, opts) } case linux.S_IFIFO: if d.isSynthetic() { @@ -977,7 +1016,7 @@ func (d *dentry) openLocked(ctx context.Context, rp *vfs.ResolvingPath, opts *vf } if vfd == nil { - if vfd, err = d.openSpecialFileLocked(ctx, mnt, opts); err != nil { + if vfd, err = d.openSpecialFile(ctx, mnt, opts); err != nil { return nil, err } } @@ -987,7 +1026,7 @@ func (d *dentry) openLocked(ctx context.Context, rp *vfs.ResolvingPath, opts *vf // step is required even if !d.cachedMetadataAuthoritative() because // d.mappings has to be updated. // d.metadataMu has already been acquired if trunc == true. - d.updateFileSizeLocked(0) + d.updateSizeLocked(0) if d.cachedMetadataAuthoritative() { d.touchCMtimeLocked() @@ -996,7 +1035,7 @@ func (d *dentry) openLocked(ctx context.Context, rp *vfs.ResolvingPath, opts *vf return vfd, err } -func (d *dentry) connectSocketLocked(ctx context.Context, opts *vfs.OpenOptions) (*vfs.FileDescription, error) { +func (d *dentry) openSocketByConnecting(ctx context.Context, opts *vfs.OpenOptions) (*vfs.FileDescription, error) { if opts.Flags&linux.O_DIRECT != 0 { return nil, syserror.EINVAL } @@ -1016,7 +1055,7 @@ func (d *dentry) connectSocketLocked(ctx context.Context, opts *vfs.OpenOptions) return fd, nil } -func (d *dentry) openSpecialFileLocked(ctx context.Context, mnt *vfs.Mount, opts *vfs.OpenOptions) (*vfs.FileDescription, error) { +func (d *dentry) openSpecialFile(ctx context.Context, mnt *vfs.Mount, opts *vfs.OpenOptions) (*vfs.FileDescription, error) { ats := vfs.AccessTypesForOpenFlags(opts) if opts.Flags&linux.O_DIRECT != 0 { return nil, syserror.EINVAL @@ -1058,8 +1097,10 @@ retry: return &fd.vfsfd, nil } -// Preconditions: d.fs.renameMu must be locked. d.dirMu must be locked. -// !d.isSynthetic(). +// Preconditions: +// * d.fs.renameMu must be locked. +// * d.dirMu must be locked. +// * !d.isSynthetic(). func (d *dentry) createAndOpenChildLocked(ctx context.Context, rp *vfs.ResolvingPath, opts *vfs.OpenOptions, ds **[]*dentry) (*vfs.FileDescription, error) { if err := d.checkPermissions(rp.Credentials(), vfs.MayWrite); err != nil { return nil, err @@ -1270,6 +1311,9 @@ func (fs *filesystem) RenameAt(ctx context.Context, rp *vfs.ResolvingPath, oldPa if !renamed.isDir() { return syserror.EISDIR } + if genericIsAncestorDentry(replaced, renamed) { + return syserror.ENOTEMPTY + } } else { if rp.MustBeDir() || renamed.isDir() { return syserror.ENOTDIR @@ -1320,14 +1364,15 @@ func (fs *filesystem) RenameAt(ctx context.Context, rp *vfs.ResolvingPath, oldPa // with reference counts and queue oldParent for checkCachingLocked if the // parent isn't actually changing. if oldParent != newParent { + oldParent.decRefLocked() ds = appendDentry(ds, oldParent) newParent.IncRef() if renamed.isSynthetic() { oldParent.syntheticChildren-- newParent.syntheticChildren++ } + renamed.parent = newParent } - renamed.parent = newParent renamed.name = newName if newParent.children == nil { newParent.children = make(map[string]*dentry) @@ -1438,7 +1483,7 @@ func (fs *filesystem) StatFSAt(ctx context.Context, rp *vfs.ResolvingPath) (linu // SymlinkAt implements vfs.FilesystemImpl.SymlinkAt. func (fs *filesystem) SymlinkAt(ctx context.Context, rp *vfs.ResolvingPath, target string) error { - return fs.doCreateAt(ctx, rp, false /* dir */, func(parent *dentry, name string) error { + return fs.doCreateAt(ctx, rp, false /* dir */, func(parent *dentry, name string, _ **[]*dentry) error { creds := rp.Credentials() _, err := parent.file.symlink(ctx, target, name, (p9.UID)(creds.EffectiveKUID), (p9.GID)(creds.EffectiveKGID)) return err @@ -1450,7 +1495,7 @@ func (fs *filesystem) UnlinkAt(ctx context.Context, rp *vfs.ResolvingPath) error return fs.unlinkAt(ctx, rp, false /* dir */) } -// BoundEndpointAt implements FilesystemImpl.BoundEndpointAt. +// BoundEndpointAt implements vfs.FilesystemImpl.BoundEndpointAt. func (fs *filesystem) BoundEndpointAt(ctx context.Context, rp *vfs.ResolvingPath, opts vfs.BoundEndpointOptions) (transport.BoundEndpoint, error) { var ds *[]*dentry fs.renameMu.RLock() @@ -1471,13 +1516,15 @@ func (fs *filesystem) BoundEndpointAt(ctx context.Context, rp *vfs.ResolvingPath path: opts.Addr, }, nil } - return d.endpoint, nil + if d.endpoint != nil { + return d.endpoint, nil + } } return nil, syserror.ECONNREFUSED } -// ListxattrAt implements vfs.FilesystemImpl.ListxattrAt. -func (fs *filesystem) ListxattrAt(ctx context.Context, rp *vfs.ResolvingPath, size uint64) ([]string, error) { +// ListXattrAt implements vfs.FilesystemImpl.ListXattrAt. +func (fs *filesystem) ListXattrAt(ctx context.Context, rp *vfs.ResolvingPath, size uint64) ([]string, error) { var ds *[]*dentry fs.renameMu.RLock() defer fs.renameMuRUnlockAndCheckCaching(ctx, &ds) @@ -1485,11 +1532,11 @@ func (fs *filesystem) ListxattrAt(ctx context.Context, rp *vfs.ResolvingPath, si if err != nil { return nil, err } - return d.listxattr(ctx, rp.Credentials(), size) + return d.listXattr(ctx, rp.Credentials(), size) } -// GetxattrAt implements vfs.FilesystemImpl.GetxattrAt. -func (fs *filesystem) GetxattrAt(ctx context.Context, rp *vfs.ResolvingPath, opts vfs.GetxattrOptions) (string, error) { +// GetXattrAt implements vfs.FilesystemImpl.GetXattrAt. +func (fs *filesystem) GetXattrAt(ctx context.Context, rp *vfs.ResolvingPath, opts vfs.GetXattrOptions) (string, error) { var ds *[]*dentry fs.renameMu.RLock() defer fs.renameMuRUnlockAndCheckCaching(ctx, &ds) @@ -1497,11 +1544,11 @@ func (fs *filesystem) GetxattrAt(ctx context.Context, rp *vfs.ResolvingPath, opt if err != nil { return "", err } - return d.getxattr(ctx, rp.Credentials(), &opts) + return d.getXattr(ctx, rp.Credentials(), &opts) } -// SetxattrAt implements vfs.FilesystemImpl.SetxattrAt. -func (fs *filesystem) SetxattrAt(ctx context.Context, rp *vfs.ResolvingPath, opts vfs.SetxattrOptions) error { +// SetXattrAt implements vfs.FilesystemImpl.SetXattrAt. +func (fs *filesystem) SetXattrAt(ctx context.Context, rp *vfs.ResolvingPath, opts vfs.SetXattrOptions) error { var ds *[]*dentry fs.renameMu.RLock() d, err := fs.resolveLocked(ctx, rp, &ds) @@ -1509,7 +1556,7 @@ func (fs *filesystem) SetxattrAt(ctx context.Context, rp *vfs.ResolvingPath, opt fs.renameMuRUnlockAndCheckCaching(ctx, &ds) return err } - if err := d.setxattr(ctx, rp.Credentials(), &opts); err != nil { + if err := d.setXattr(ctx, rp.Credentials(), &opts); err != nil { fs.renameMuRUnlockAndCheckCaching(ctx, &ds) return err } @@ -1519,8 +1566,8 @@ func (fs *filesystem) SetxattrAt(ctx context.Context, rp *vfs.ResolvingPath, opt return nil } -// RemovexattrAt implements vfs.FilesystemImpl.RemovexattrAt. -func (fs *filesystem) RemovexattrAt(ctx context.Context, rp *vfs.ResolvingPath, name string) error { +// RemoveXattrAt implements vfs.FilesystemImpl.RemoveXattrAt. +func (fs *filesystem) RemoveXattrAt(ctx context.Context, rp *vfs.ResolvingPath, name string) error { var ds *[]*dentry fs.renameMu.RLock() d, err := fs.resolveLocked(ctx, rp, &ds) @@ -1528,7 +1575,7 @@ func (fs *filesystem) RemovexattrAt(ctx context.Context, rp *vfs.ResolvingPath, fs.renameMuRUnlockAndCheckCaching(ctx, &ds) return err } - if err := d.removexattr(ctx, rp.Credentials(), name); err != nil { + if err := d.removeXattr(ctx, rp.Credentials(), name); err != nil { fs.renameMuRUnlockAndCheckCaching(ctx, &ds) return err } diff --git a/pkg/sentry/fsimpl/gofer/gofer.go b/pkg/sentry/fsimpl/gofer/gofer.go index 63e589859..aaad9c0d9 100644 --- a/pkg/sentry/fsimpl/gofer/gofer.go +++ b/pkg/sentry/fsimpl/gofer/gofer.go @@ -195,11 +195,7 @@ const ( // and consistent with Linux's semantics (in particular, it is not always // possible for clients to set arbitrary atimes and mtimes depending on the // remote filesystem implementation, and never possible for clients to set - // arbitrary ctimes.) If a dentry containing a client-defined atime or - // mtime is evicted from cache, client timestamps will be sent to the - // remote filesystem on a best-effort basis to attempt to ensure that - // timestamps will be preserved when another dentry representing the same - // file is instantiated. + // arbitrary ctimes.) InteropModeExclusive InteropMode = iota // InteropModeWritethrough is appropriate when there are read-only users of @@ -703,6 +699,13 @@ type dentry struct { locks vfs.FileLocks // Inotify watches for this dentry. + // + // Note that inotify may behave unexpectedly in the presence of hard links, + // because dentries corresponding to the same file have separate inotify + // watches when they should share the same set. This is the case because it is + // impossible for us to know for sure whether two dentries correspond to the + // same underlying file (see the gofer filesystem section fo vfs/inotify.md for + // a more in-depth discussion on this matter). watches vfs.Watches } @@ -830,7 +833,7 @@ func (d *dentry) updateFromP9AttrsLocked(mask p9.AttrMask, attr *p9.Attr) { atomic.StoreUint32(&d.nlink, uint32(attr.NLink)) } if mask.Size { - d.updateFileSizeLocked(attr.Size) + d.updateSizeLocked(attr.Size) } } @@ -984,7 +987,7 @@ func (d *dentry) setStat(ctx context.Context, creds *auth.Credentials, opts *vfs // d.size should be kept up to date, and privatized // copy-on-write mappings of truncated pages need to be // invalidated, even if InteropModeShared is in effect. - d.updateFileSizeLocked(stat.Size) + d.updateSizeLocked(stat.Size) } } if d.fs.opts.interop == InteropModeShared { @@ -1021,8 +1024,31 @@ func (d *dentry) setStat(ctx context.Context, creds *auth.Credentials, opts *vfs return nil } +// doAllocate performs an allocate operation on d. Note that d.metadataMu will +// be held when allocate is called. +func (d *dentry) doAllocate(ctx context.Context, offset, length uint64, allocate func() error) error { + d.metadataMu.Lock() + defer d.metadataMu.Unlock() + + // Allocating a smaller size is a noop. + size := offset + length + if d.cachedMetadataAuthoritative() && size <= d.size { + return nil + } + + err := allocate() + if err != nil { + return err + } + d.updateSizeLocked(size) + if d.cachedMetadataAuthoritative() { + d.touchCMtimeLocked() + } + return nil +} + // Preconditions: d.metadataMu must be locked. -func (d *dentry) updateFileSizeLocked(newSize uint64) { +func (d *dentry) updateSizeLocked(newSize uint64) { d.dataMu.Lock() oldSize := d.size atomic.StoreUint64(&d.size, newSize) @@ -1060,6 +1086,21 @@ func (d *dentry) checkPermissions(creds *auth.Credentials, ats vfs.AccessTypes) return vfs.GenericCheckPermissions(creds, ats, linux.FileMode(atomic.LoadUint32(&d.mode)), auth.KUID(atomic.LoadUint32(&d.uid)), auth.KGID(atomic.LoadUint32(&d.gid))) } +func (d *dentry) checkXattrPermissions(creds *auth.Credentials, name string, ats vfs.AccessTypes) error { + // We only support xattrs prefixed with "user." (see b/148380782). Currently, + // there is no need to expose any other xattrs through a gofer. + if !strings.HasPrefix(name, linux.XATTR_USER_PREFIX) { + return syserror.EOPNOTSUPP + } + mode := linux.FileMode(atomic.LoadUint32(&d.mode)) + kuid := auth.KUID(atomic.LoadUint32(&d.uid)) + kgid := auth.KGID(atomic.LoadUint32(&d.gid)) + if err := vfs.GenericCheckPermissions(creds, ats, mode, kuid, kgid); err != nil { + return err + } + return vfs.CheckXattrPermissions(creds, ats, mode, kuid, name) +} + func (d *dentry) mayDelete(creds *auth.Credentials, child *dentry) error { return vfs.CheckDeleteSticky(creds, linux.FileMode(atomic.LoadUint32(&d.mode)), auth.KUID(atomic.LoadUint32(&child.uid))) } @@ -1293,30 +1334,19 @@ func (d *dentry) destroyLocked(ctx context.Context) { d.handleMu.Unlock() if !d.file.isNil() { - if !d.isDeleted() { - // Write dirty timestamps back to the remote filesystem. - atimeDirty := atomic.LoadUint32(&d.atimeDirty) != 0 - mtimeDirty := atomic.LoadUint32(&d.mtimeDirty) != 0 - if atimeDirty || mtimeDirty { - atime := atomic.LoadInt64(&d.atime) - mtime := atomic.LoadInt64(&d.mtime) - if err := d.file.setAttr(ctx, p9.SetAttrMask{ - ATime: atimeDirty, - ATimeNotSystemTime: atimeDirty, - MTime: mtimeDirty, - MTimeNotSystemTime: mtimeDirty, - }, p9.SetAttr{ - ATimeSeconds: uint64(atime / 1e9), - ATimeNanoSeconds: uint64(atime % 1e9), - MTimeSeconds: uint64(mtime / 1e9), - MTimeNanoSeconds: uint64(mtime % 1e9), - }); err != nil { - log.Warningf("gofer.dentry.destroyLocked: failed to write dirty timestamps back: %v", err) - } - } + // Note that it's possible that d.atimeDirty or d.mtimeDirty are true, + // i.e. client and server timestamps may differ (because e.g. a client + // write was serviced by the page cache, and only written back to the + // remote file later). Ideally, we'd write client timestamps back to + // the remote filesystem so that timestamps for a new dentry + // instantiated for the same file would remain coherent. Unfortunately, + // this turns out to be too expensive in many cases, so for now we + // don't do this. + if err := d.file.close(ctx); err != nil { + log.Warningf("gofer.dentry.destroyLocked: failed to close file: %v", err) } - d.file.close(ctx) d.file = p9file{} + // Remove d from the set of syncable dentries. d.fs.syncMu.Lock() delete(d.fs.syncableDentries, d) @@ -1344,9 +1374,7 @@ func (d *dentry) setDeleted() { atomic.StoreUint32(&d.deleted, 1) } -// We only support xattrs prefixed with "user." (see b/148380782). Currently, -// there is no need to expose any other xattrs through a gofer. -func (d *dentry) listxattr(ctx context.Context, creds *auth.Credentials, size uint64) ([]string, error) { +func (d *dentry) listXattr(ctx context.Context, creds *auth.Credentials, size uint64) ([]string, error) { if d.file.isNil() || !d.userXattrSupported() { return nil, nil } @@ -1356,6 +1384,7 @@ func (d *dentry) listxattr(ctx context.Context, creds *auth.Credentials, size ui } xattrs := make([]string, 0, len(xattrMap)) for x := range xattrMap { + // We only support xattrs in the user.* namespace. if strings.HasPrefix(x, linux.XATTR_USER_PREFIX) { xattrs = append(xattrs, x) } @@ -1363,51 +1392,33 @@ func (d *dentry) listxattr(ctx context.Context, creds *auth.Credentials, size ui return xattrs, nil } -func (d *dentry) getxattr(ctx context.Context, creds *auth.Credentials, opts *vfs.GetxattrOptions) (string, error) { +func (d *dentry) getXattr(ctx context.Context, creds *auth.Credentials, opts *vfs.GetXattrOptions) (string, error) { if d.file.isNil() { return "", syserror.ENODATA } - if err := d.checkPermissions(creds, vfs.MayRead); err != nil { + if err := d.checkXattrPermissions(creds, opts.Name, vfs.MayRead); err != nil { return "", err } - if !strings.HasPrefix(opts.Name, linux.XATTR_USER_PREFIX) { - return "", syserror.EOPNOTSUPP - } - if !d.userXattrSupported() { - return "", syserror.ENODATA - } return d.file.getXattr(ctx, opts.Name, opts.Size) } -func (d *dentry) setxattr(ctx context.Context, creds *auth.Credentials, opts *vfs.SetxattrOptions) error { +func (d *dentry) setXattr(ctx context.Context, creds *auth.Credentials, opts *vfs.SetXattrOptions) error { if d.file.isNil() { return syserror.EPERM } - if err := d.checkPermissions(creds, vfs.MayWrite); err != nil { + if err := d.checkXattrPermissions(creds, opts.Name, vfs.MayWrite); err != nil { return err } - if !strings.HasPrefix(opts.Name, linux.XATTR_USER_PREFIX) { - return syserror.EOPNOTSUPP - } - if !d.userXattrSupported() { - return syserror.EPERM - } return d.file.setXattr(ctx, opts.Name, opts.Value, opts.Flags) } -func (d *dentry) removexattr(ctx context.Context, creds *auth.Credentials, name string) error { +func (d *dentry) removeXattr(ctx context.Context, creds *auth.Credentials, name string) error { if d.file.isNil() { return syserror.EPERM } - if err := d.checkPermissions(creds, vfs.MayWrite); err != nil { + if err := d.checkXattrPermissions(creds, name, vfs.MayWrite); err != nil { return err } - if !strings.HasPrefix(name, linux.XATTR_USER_PREFIX) { - return syserror.EOPNOTSUPP - } - if !d.userXattrSupported() { - return syserror.EPERM - } return d.file.removeXattr(ctx, name) } @@ -1418,7 +1429,9 @@ func (d *dentry) userXattrSupported() bool { return filetype == linux.ModeRegular || filetype == linux.ModeDirectory } -// Preconditions: !d.isSynthetic(). d.isRegularFile() || d.isDir(). +// Preconditions: +// * !d.isSynthetic(). +// * d.isRegularFile() || d.isDir(). func (d *dentry) ensureSharedHandle(ctx context.Context, read, write, trunc bool) error { // O_TRUNC unconditionally requires us to obtain a new handle (opened with // O_TRUNC). @@ -1463,8 +1476,9 @@ func (d *dentry) ensureSharedHandle(ctx context.Context, read, write, trunc bool return err } - if d.hostFD < 0 && openReadable && h.fd >= 0 { - // We have no existing FD; use the new FD for at least reading. + if d.hostFD < 0 && h.fd >= 0 && openReadable && (d.writeFile.isNil() || openWritable) { + // We have no existing FD, and the new FD meets the requirements + // for d.hostFD, so start using it. d.hostFD = h.fd } else if d.hostFD >= 0 && d.writeFile.isNil() && openWritable { // We have an existing read-only FD, but the file has just been @@ -1656,30 +1670,30 @@ func (fd *fileDescription) SetStat(ctx context.Context, opts vfs.SetStatOptions) return nil } -// Listxattr implements vfs.FileDescriptionImpl.Listxattr. -func (fd *fileDescription) Listxattr(ctx context.Context, size uint64) ([]string, error) { - return fd.dentry().listxattr(ctx, auth.CredentialsFromContext(ctx), size) +// ListXattr implements vfs.FileDescriptionImpl.ListXattr. +func (fd *fileDescription) ListXattr(ctx context.Context, size uint64) ([]string, error) { + return fd.dentry().listXattr(ctx, auth.CredentialsFromContext(ctx), size) } -// Getxattr implements vfs.FileDescriptionImpl.Getxattr. -func (fd *fileDescription) Getxattr(ctx context.Context, opts vfs.GetxattrOptions) (string, error) { - return fd.dentry().getxattr(ctx, auth.CredentialsFromContext(ctx), &opts) +// GetXattr implements vfs.FileDescriptionImpl.GetXattr. +func (fd *fileDescription) GetXattr(ctx context.Context, opts vfs.GetXattrOptions) (string, error) { + return fd.dentry().getXattr(ctx, auth.CredentialsFromContext(ctx), &opts) } -// Setxattr implements vfs.FileDescriptionImpl.Setxattr. -func (fd *fileDescription) Setxattr(ctx context.Context, opts vfs.SetxattrOptions) error { +// SetXattr implements vfs.FileDescriptionImpl.SetXattr. +func (fd *fileDescription) SetXattr(ctx context.Context, opts vfs.SetXattrOptions) error { d := fd.dentry() - if err := d.setxattr(ctx, auth.CredentialsFromContext(ctx), &opts); err != nil { + if err := d.setXattr(ctx, auth.CredentialsFromContext(ctx), &opts); err != nil { return err } d.InotifyWithParent(ctx, linux.IN_ATTRIB, 0, vfs.InodeEvent) return nil } -// Removexattr implements vfs.FileDescriptionImpl.Removexattr. -func (fd *fileDescription) Removexattr(ctx context.Context, name string) error { +// RemoveXattr implements vfs.FileDescriptionImpl.RemoveXattr. +func (fd *fileDescription) RemoveXattr(ctx context.Context, name string) error { d := fd.dentry() - if err := d.removexattr(ctx, auth.CredentialsFromContext(ctx), name); err != nil { + if err := d.removeXattr(ctx, auth.CredentialsFromContext(ctx), name); err != nil { return err } d.InotifyWithParent(ctx, linux.IN_ATTRIB, 0, vfs.InodeEvent) diff --git a/pkg/sentry/fsimpl/gofer/p9file.go b/pkg/sentry/fsimpl/gofer/p9file.go index 87f0b877f..21b4a96fe 100644 --- a/pkg/sentry/fsimpl/gofer/p9file.go +++ b/pkg/sentry/fsimpl/gofer/p9file.go @@ -127,6 +127,13 @@ func (f p9file) close(ctx context.Context) error { return err } +func (f p9file) setAttrClose(ctx context.Context, valid p9.SetAttrMask, attr p9.SetAttr) error { + ctx.UninterruptibleSleepStart(false) + err := f.file.SetAttrClose(valid, attr) + ctx.UninterruptibleSleepFinish(false) + return err +} + func (f p9file) open(ctx context.Context, flags p9.OpenFlags) (*fd.FD, p9.QID, uint32, error) { ctx.UninterruptibleSleepStart(false) fdobj, qid, iounit, err := f.file.Open(flags) diff --git a/pkg/sentry/fsimpl/gofer/regular_file.go b/pkg/sentry/fsimpl/gofer/regular_file.go index 7e1cbf065..24f03ee94 100644 --- a/pkg/sentry/fsimpl/gofer/regular_file.go +++ b/pkg/sentry/fsimpl/gofer/regular_file.go @@ -56,10 +56,16 @@ func (fd *regularFileFD) OnClose(ctx context.Context) error { if !fd.vfsfd.IsWritable() { return nil } - // Skip flushing if writes may be buffered by the client, since (as with - // the VFS1 client) we don't flush buffered writes on close anyway. + // Skip flushing if there are client-buffered writes, since (as with the + // VFS1 client) we don't flush buffered writes on close anyway. d := fd.dentry() - if d.fs.opts.interop == InteropModeExclusive { + if d.fs.opts.interop != InteropModeExclusive { + return nil + } + d.dataMu.RLock() + haveDirtyPages := !d.dirty.IsEmpty() + d.dataMu.RUnlock() + if haveDirtyPages { return nil } d.handleMu.RLock() @@ -73,28 +79,11 @@ func (fd *regularFileFD) OnClose(ctx context.Context) error { // Allocate implements vfs.FileDescriptionImpl.Allocate. func (fd *regularFileFD) Allocate(ctx context.Context, mode, offset, length uint64) error { d := fd.dentry() - d.metadataMu.Lock() - defer d.metadataMu.Unlock() - - // Allocating a smaller size is a noop. - size := offset + length - if d.cachedMetadataAuthoritative() && size <= d.size { - return nil - } - - d.handleMu.RLock() - err := d.writeFile.allocate(ctx, p9.ToAllocateMode(mode), offset, length) - d.handleMu.RUnlock() - if err != nil { - return err - } - d.dataMu.Lock() - atomic.StoreUint64(&d.size, size) - d.dataMu.Unlock() - if d.cachedMetadataAuthoritative() { - d.touchCMtimeLocked() - } - return nil + return d.doAllocate(ctx, offset, length, func() error { + d.handleMu.RLock() + defer d.handleMu.RUnlock() + return d.writeFile.allocate(ctx, p9.ToAllocateMode(mode), offset, length) + }) } // PRead implements vfs.FileDescriptionImpl.PRead. @@ -117,6 +106,10 @@ func (fd *regularFileFD) PRead(ctx context.Context, dst usermem.IOSequence, offs return 0, io.EOF } + var ( + n int64 + readErr error + ) if fd.vfsfd.StatusFlags()&linux.O_DIRECT != 0 { // Lock d.metadataMu for the rest of the read to prevent d.size from // changing. @@ -127,20 +120,25 @@ func (fd *regularFileFD) PRead(ctx context.Context, dst usermem.IOSequence, offs if err := d.writeback(ctx, offset, dst.NumBytes()); err != nil { return 0, err } - } - - rw := getDentryReadWriter(ctx, d, offset) - if fd.vfsfd.StatusFlags()&linux.O_DIRECT != 0 { + rw := getDentryReadWriter(ctx, d, offset) // Require the read to go to the remote file. rw.direct = true + n, readErr = dst.CopyOutFrom(ctx, rw) + putDentryReadWriter(rw) + if d.fs.opts.interop != InteropModeShared { + // Compare Linux's mm/filemap.c:do_generic_file_read() => file_accessed(). + d.touchAtimeLocked(fd.vfsfd.Mount()) + } + } else { + rw := getDentryReadWriter(ctx, d, offset) + n, readErr = dst.CopyOutFrom(ctx, rw) + putDentryReadWriter(rw) + if d.fs.opts.interop != InteropModeShared { + // Compare Linux's mm/filemap.c:do_generic_file_read() => file_accessed(). + d.touchAtime(fd.vfsfd.Mount()) + } } - n, err := dst.CopyOutFrom(ctx, rw) - putDentryReadWriter(rw) - if d.fs.opts.interop != InteropModeShared { - // Compare Linux's mm/filemap.c:do_generic_file_read() => file_accessed(). - d.touchAtime(fd.vfsfd.Mount()) - } - return n, err + return n, readErr } // Read implements vfs.FileDescriptionImpl.Read. diff --git a/pkg/sentry/fsimpl/gofer/special_file.go b/pkg/sentry/fsimpl/gofer/special_file.go index a6368fdd0..dc960e5bf 100644 --- a/pkg/sentry/fsimpl/gofer/special_file.go +++ b/pkg/sentry/fsimpl/gofer/special_file.go @@ -22,6 +22,7 @@ import ( "gvisor.dev/gvisor/pkg/abi/linux" "gvisor.dev/gvisor/pkg/context" "gvisor.dev/gvisor/pkg/fdnotifier" + "gvisor.dev/gvisor/pkg/p9" "gvisor.dev/gvisor/pkg/safemem" "gvisor.dev/gvisor/pkg/sentry/vfs" "gvisor.dev/gvisor/pkg/syserror" @@ -39,8 +40,14 @@ type specialFileFD struct { // handle is used for file I/O. handle is immutable. handle handle + // isRegularFile is true if this FD represents a regular file which is only + // possible when filesystemOptions.regularFilesUseSpecialFileFD is in + // effect. isRegularFile is immutable. + isRegularFile bool + // seekable is true if this file description represents a file for which - // file offset is significant, i.e. a regular file. seekable is immutable. + // file offset is significant, i.e. a regular file, character device or + // block device. seekable is immutable. seekable bool // haveQueue is true if this file description represents a file for which @@ -55,12 +62,13 @@ type specialFileFD struct { func newSpecialFileFD(h handle, mnt *vfs.Mount, d *dentry, locks *vfs.FileLocks, flags uint32) (*specialFileFD, error) { ftype := d.fileType() - seekable := ftype == linux.S_IFREG + seekable := ftype == linux.S_IFREG || ftype == linux.S_IFCHR || ftype == linux.S_IFBLK haveQueue := (ftype == linux.S_IFIFO || ftype == linux.S_IFSOCK) && h.fd >= 0 fd := &specialFileFD{ - handle: h, - seekable: seekable, - haveQueue: haveQueue, + handle: h, + isRegularFile: ftype == linux.S_IFREG, + seekable: seekable, + haveQueue: haveQueue, } fd.LockFD.Init(locks) if haveQueue { @@ -128,6 +136,16 @@ func (fd *specialFileFD) EventUnregister(e *waiter.Entry) { fd.fileDescription.EventUnregister(e) } +func (fd *specialFileFD) Allocate(ctx context.Context, mode, offset, length uint64) error { + if fd.isRegularFile { + d := fd.dentry() + return d.doAllocate(ctx, offset, length, func() error { + return fd.handle.file.allocate(ctx, p9.ToAllocateMode(mode), offset, length) + }) + } + return fd.FileDescriptionDefaultImpl.Allocate(ctx, mode, offset, length) +} + // PRead implements vfs.FileDescriptionImpl.PRead. func (fd *specialFileFD) PRead(ctx context.Context, dst usermem.IOSequence, offset int64, opts vfs.ReadOptions) (int64, error) { if fd.seekable && offset < 0 { @@ -200,13 +218,13 @@ func (fd *specialFileFD) pwrite(ctx context.Context, src usermem.IOSequence, off // If the regular file fd was opened with O_APPEND, make sure the file size // is updated. There is a possible race here if size is modified externally // after metadata cache is updated. - if fd.seekable && fd.vfsfd.StatusFlags()&linux.O_APPEND != 0 && !d.cachedMetadataAuthoritative() { + if fd.isRegularFile && fd.vfsfd.StatusFlags()&linux.O_APPEND != 0 && !d.cachedMetadataAuthoritative() { if err := d.updateFromGetattr(ctx); err != nil { return 0, offset, err } } - if fd.seekable { + if fd.isRegularFile { // We need to hold the metadataMu *while* writing to a regular file. d.metadataMu.Lock() defer d.metadataMu.Unlock() @@ -236,18 +254,20 @@ func (fd *specialFileFD) pwrite(ctx context.Context, src usermem.IOSequence, off if err == syserror.EAGAIN { err = syserror.ErrWouldBlock } - finalOff = offset + // Update offset if the offset is valid. + if offset >= 0 { + offset += int64(n) + } // Update file size for regular files. - if fd.seekable { - finalOff += int64(n) + if fd.isRegularFile { // d.metadataMu is already locked at this point. - if uint64(finalOff) > d.size { + if uint64(offset) > d.size { d.dataMu.Lock() defer d.dataMu.Unlock() - atomic.StoreUint64(&d.size, uint64(finalOff)) + atomic.StoreUint64(&d.size, uint64(offset)) } } - return int64(n), finalOff, err + return int64(n), offset, err } // Write implements vfs.FileDescriptionImpl.Write. diff --git a/pkg/sentry/fsimpl/gofer/time.go b/pkg/sentry/fsimpl/gofer/time.go index 2cb8191b9..7e825caae 100644 --- a/pkg/sentry/fsimpl/gofer/time.go +++ b/pkg/sentry/fsimpl/gofer/time.go @@ -38,7 +38,7 @@ func statxTimestampFromDentry(ns int64) linux.StatxTimestamp { // Preconditions: d.cachedMetadataAuthoritative() == true. func (d *dentry) touchAtime(mnt *vfs.Mount) { - if mnt.Flags.NoATime { + if mnt.Flags.NoATime || mnt.ReadOnly() { return } if err := mnt.CheckBeginWrite(); err != nil { @@ -52,8 +52,23 @@ func (d *dentry) touchAtime(mnt *vfs.Mount) { mnt.EndWrite() } -// Preconditions: d.cachedMetadataAuthoritative() == true. The caller has -// successfully called vfs.Mount.CheckBeginWrite(). +// Preconditions: d.metadataMu is locked. d.cachedMetadataAuthoritative() == true. +func (d *dentry) touchAtimeLocked(mnt *vfs.Mount) { + if mnt.Flags.NoATime || mnt.ReadOnly() { + return + } + if err := mnt.CheckBeginWrite(); err != nil { + return + } + now := d.fs.clock.Now().Nanoseconds() + atomic.StoreInt64(&d.atime, now) + atomic.StoreUint32(&d.atimeDirty, 1) + mnt.EndWrite() +} + +// Preconditions: +// * d.cachedMetadataAuthoritative() == true. +// * The caller has successfully called vfs.Mount.CheckBeginWrite(). func (d *dentry) touchCtime() { now := d.fs.clock.Now().Nanoseconds() d.metadataMu.Lock() @@ -61,8 +76,9 @@ func (d *dentry) touchCtime() { d.metadataMu.Unlock() } -// Preconditions: d.cachedMetadataAuthoritative() == true. The caller has -// successfully called vfs.Mount.CheckBeginWrite(). +// Preconditions: +// * d.cachedMetadataAuthoritative() == true. +// * The caller has successfully called vfs.Mount.CheckBeginWrite(). func (d *dentry) touchCMtime() { now := d.fs.clock.Now().Nanoseconds() d.metadataMu.Lock() @@ -72,8 +88,9 @@ func (d *dentry) touchCMtime() { d.metadataMu.Unlock() } -// Preconditions: d.cachedMetadataAuthoritative() == true. The caller has -// locked d.metadataMu. +// Preconditions: +// * d.cachedMetadataAuthoritative() == true. +// * The caller has locked d.metadataMu. func (d *dentry) touchCMtimeLocked() { now := d.fs.clock.Now().Nanoseconds() atomic.StoreInt64(&d.mtime, now) diff --git a/pkg/sentry/fsimpl/host/BUILD b/pkg/sentry/fsimpl/host/BUILD index bd701bbc7..56bcf9bdb 100644 --- a/pkg/sentry/fsimpl/host/BUILD +++ b/pkg/sentry/fsimpl/host/BUILD @@ -1,12 +1,37 @@ load("//tools:defs.bzl", "go_library") +load("//tools/go_generics:defs.bzl", "go_template_instance") licenses(["notice"]) +go_template_instance( + name = "inode_refs", + out = "inode_refs.go", + package = "host", + prefix = "inode", + template = "//pkg/refs_vfs2:refs_template", + types = { + "T": "inode", + }, +) + +go_template_instance( + name = "connected_endpoint_refs", + out = "connected_endpoint_refs.go", + package = "host", + prefix = "ConnectedEndpoint", + template = "//pkg/refs_vfs2:refs_template", + types = { + "T": "ConnectedEndpoint", + }, +) + go_library( name = "host", srcs = [ + "connected_endpoint_refs.go", "control.go", "host.go", + "inode_refs.go", "ioctl_unsafe.go", "mmap.go", "socket.go", @@ -24,6 +49,7 @@ go_library( "//pkg/fspath", "//pkg/iovec", "//pkg/log", + "//pkg/marshal/primitive", "//pkg/refs", "//pkg/safemem", "//pkg/sentry/arch", diff --git a/pkg/sentry/fsimpl/host/host.go b/pkg/sentry/fsimpl/host/host.go index bd6caba06..db8536f26 100644 --- a/pkg/sentry/fsimpl/host/host.go +++ b/pkg/sentry/fsimpl/host/host.go @@ -27,7 +27,6 @@ import ( "gvisor.dev/gvisor/pkg/fdnotifier" "gvisor.dev/gvisor/pkg/fspath" "gvisor.dev/gvisor/pkg/log" - "gvisor.dev/gvisor/pkg/refs" fslock "gvisor.dev/gvisor/pkg/sentry/fs/lock" "gvisor.dev/gvisor/pkg/sentry/fsimpl/kernfs" "gvisor.dev/gvisor/pkg/sentry/hostfd" @@ -41,6 +40,44 @@ import ( "gvisor.dev/gvisor/pkg/waiter" ) +func newInode(fs *filesystem, hostFD int, fileType linux.FileMode, isTTY bool) (*inode, error) { + // Determine if hostFD is seekable. If not, this syscall will return ESPIPE + // (see fs/read_write.c:llseek), e.g. for pipes, sockets, and some character + // devices. + _, err := unix.Seek(hostFD, 0, linux.SEEK_CUR) + seekable := err != syserror.ESPIPE + + i := &inode{ + hostFD: hostFD, + ino: fs.NextIno(), + isTTY: isTTY, + wouldBlock: wouldBlock(uint32(fileType)), + seekable: seekable, + // NOTE(b/38213152): Technically, some obscure char devices can be memory + // mapped, but we only allow regular files. + canMap: fileType == linux.S_IFREG, + } + i.pf.inode = i + i.refs.EnableLeakCheck() + + // Non-seekable files can't be memory mapped, assert this. + if !i.seekable && i.canMap { + panic("files that can return EWOULDBLOCK (sockets, pipes, etc.) cannot be memory mapped") + } + + // If the hostFD would block, we must set it to non-blocking and handle + // blocking behavior in the sentry. + if i.wouldBlock { + if err := syscall.SetNonblock(i.hostFD, true); err != nil { + return nil, err + } + if err := fdnotifier.AddFD(int32(i.hostFD), &i.queue); err != nil { + return nil, err + } + } + return i, nil +} + // NewFDOptions contains options to NewFD. type NewFDOptions struct { // If IsTTY is true, the file descriptor is a TTY. @@ -76,44 +113,11 @@ func NewFD(ctx context.Context, mnt *vfs.Mount, hostFD int, opts *NewFDOptions) flags = uint32(flagsInt) } - fileMode := linux.FileMode(s.Mode) - fileType := fileMode.FileType() - - // Determine if hostFD is seekable. If not, this syscall will return ESPIPE - // (see fs/read_write.c:llseek), e.g. for pipes, sockets, and some character - // devices. - _, err := unix.Seek(hostFD, 0, linux.SEEK_CUR) - seekable := err != syserror.ESPIPE - - i := &inode{ - hostFD: hostFD, - ino: fs.NextIno(), - isTTY: opts.IsTTY, - wouldBlock: wouldBlock(uint32(fileType)), - seekable: seekable, - // NOTE(b/38213152): Technically, some obscure char devices can be memory - // mapped, but we only allow regular files. - canMap: fileType == linux.S_IFREG, - } - i.pf.inode = i - - // Non-seekable files can't be memory mapped, assert this. - if !i.seekable && i.canMap { - panic("files that can return EWOULDBLOCK (sockets, pipes, etc.) cannot be memory mapped") - } - - // If the hostFD would block, we must set it to non-blocking and handle - // blocking behavior in the sentry. - if i.wouldBlock { - if err := syscall.SetNonblock(i.hostFD, true); err != nil { - return nil, err - } - if err := fdnotifier.AddFD(int32(i.hostFD), &i.queue); err != nil { - return nil, err - } - } - d := &kernfs.Dentry{} + i, err := newInode(fs, hostFD, linux.FileMode(s.Mode).FileType(), opts.IsTTY) + if err != nil { + return nil, err + } d.Init(i) // i.open will take a reference on d. @@ -135,12 +139,12 @@ func ImportFD(ctx context.Context, mnt *vfs.Mount, hostFD int, isTTY bool) (*vfs // filesystemType implements vfs.FilesystemType. type filesystemType struct{} -// GetFilesystem implements FilesystemType.GetFilesystem. +// GetFilesystem implements vfs.FilesystemType.GetFilesystem. func (filesystemType) GetFilesystem(context.Context, *vfs.VirtualFilesystem, *auth.Credentials, string, vfs.GetFilesystemOptions) (*vfs.Filesystem, *vfs.Dentry, error) { panic("host.filesystemType.GetFilesystem should never be called") } -// Name implements FilesystemType.Name. +// Name implements vfs.FilesystemType.Name. func (filesystemType) Name() string { return "none" } @@ -182,13 +186,14 @@ func (fs *filesystem) PrependPath(ctx context.Context, vfsroot, vd vfs.VirtualDe // inode implements kernfs.Inode. type inode struct { + kernfs.InodeNoStatFS kernfs.InodeNotDirectory kernfs.InodeNotSymlink locks vfs.FileLocks // When the reference count reaches zero, the host fd is closed. - refs.AtomicRefCount + refs inodeRefs // hostFD contains the host fd that this file was originally created from, // which must be available at time of restore. @@ -238,7 +243,7 @@ type inode struct { pf inodePlatformFile } -// CheckPermissions implements kernfs.Inode. +// CheckPermissions implements kernfs.Inode.CheckPermissions. func (i *inode) CheckPermissions(ctx context.Context, creds *auth.Credentials, ats vfs.AccessTypes) error { var s syscall.Stat_t if err := syscall.Fstat(i.hostFD, &s); err != nil { @@ -247,7 +252,7 @@ func (i *inode) CheckPermissions(ctx context.Context, creds *auth.Credentials, a return vfs.GenericCheckPermissions(creds, ats, linux.FileMode(s.Mode), auth.KUID(s.Uid), auth.KGID(s.Gid)) } -// Mode implements kernfs.Inode. +// Mode implements kernfs.Inode.Mode. func (i *inode) Mode() linux.FileMode { var s syscall.Stat_t if err := syscall.Fstat(i.hostFD, &s); err != nil { @@ -258,7 +263,7 @@ func (i *inode) Mode() linux.FileMode { return linux.FileMode(s.Mode) } -// Stat implements kernfs.Inode. +// Stat implements kernfs.Inode.Stat. func (i *inode) Stat(ctx context.Context, vfsfs *vfs.Filesystem, opts vfs.StatOptions) (linux.Statx, error) { if opts.Mask&linux.STATX__RESERVED != 0 { return linux.Statx{}, syserror.EINVAL @@ -371,7 +376,7 @@ func (i *inode) fstat(fs *filesystem) (linux.Statx, error) { }, nil } -// SetStat implements kernfs.Inode. +// SetStat implements kernfs.Inode.SetStat. func (i *inode) SetStat(ctx context.Context, fs *vfs.Filesystem, creds *auth.Credentials, opts vfs.SetStatOptions) error { s := &opts.Stat @@ -430,22 +435,29 @@ func (i *inode) SetStat(ctx context.Context, fs *vfs.Filesystem, creds *auth.Cre return nil } -// DecRef implements kernfs.Inode. -func (i *inode) DecRef(ctx context.Context) { - i.AtomicRefCount.DecRefWithDestructor(ctx, i.Destroy) +// IncRef implements kernfs.Inode.IncRef. +func (i *inode) IncRef() { + i.refs.IncRef() } -// Destroy implements kernfs.Inode. -func (i *inode) Destroy(context.Context) { - if i.wouldBlock { - fdnotifier.RemoveFD(int32(i.hostFD)) - } - if err := unix.Close(i.hostFD); err != nil { - log.Warningf("failed to close host fd %d: %v", i.hostFD, err) - } +// TryIncRef implements kernfs.Inode.TryIncRef. +func (i *inode) TryIncRef() bool { + return i.refs.TryIncRef() +} + +// DecRef implements kernfs.Inode.DecRef. +func (i *inode) DecRef(ctx context.Context) { + i.refs.DecRef(func() { + if i.wouldBlock { + fdnotifier.RemoveFD(int32(i.hostFD)) + } + if err := unix.Close(i.hostFD); err != nil { + log.Warningf("failed to close host fd %d: %v", i.hostFD, err) + } + }) } -// Open implements kernfs.Inode. +// Open implements kernfs.Inode.Open. func (i *inode) Open(ctx context.Context, rp *vfs.ResolvingPath, vfsd *vfs.Dentry, opts vfs.OpenOptions) (*vfs.FileDescription, error) { // Once created, we cannot re-open a socket fd through /proc/[pid]/fd/. if i.Mode().FileType() == linux.S_IFSOCK { @@ -484,7 +496,7 @@ func (i *inode) open(ctx context.Context, d *vfs.Dentry, mnt *vfs.Mount, flags u if i.isTTY { fd := &TTYFileDescription{ fileDescription: fileDescription{inode: i}, - termios: linux.DefaultSlaveTermios, + termios: linux.DefaultReplicaTermios, } fd.LockFD.Init(&i.locks) vfsfd := &fd.vfsfd @@ -530,33 +542,28 @@ type fileDescription struct { offset int64 } -// SetStat implements vfs.FileDescriptionImpl. +// SetStat implements vfs.FileDescriptionImpl.SetStat. func (f *fileDescription) SetStat(ctx context.Context, opts vfs.SetStatOptions) error { creds := auth.CredentialsFromContext(ctx) return f.inode.SetStat(ctx, f.vfsfd.Mount().Filesystem(), creds, opts) } -// Stat implements vfs.FileDescriptionImpl. +// Stat implements vfs.FileDescriptionImpl.Stat. func (f *fileDescription) Stat(ctx context.Context, opts vfs.StatOptions) (linux.Statx, error) { return f.inode.Stat(ctx, f.vfsfd.Mount().Filesystem(), opts) } -// Release implements vfs.FileDescriptionImpl. +// Release implements vfs.FileDescriptionImpl.Release. func (f *fileDescription) Release(context.Context) { // noop } -// Allocate implements vfs.FileDescriptionImpl. +// Allocate implements vfs.FileDescriptionImpl.Allocate. func (f *fileDescription) Allocate(ctx context.Context, mode, offset, length uint64) error { - if !f.inode.seekable { - return syserror.ESPIPE - } - - // TODO(gvisor.dev/issue/3589): Implement Allocate for non-pipe hostfds. - return syserror.EOPNOTSUPP + return unix.Fallocate(f.inode.hostFD, uint32(mode), int64(offset), int64(length)) } -// PRead implements FileDescriptionImpl. +// PRead implements vfs.FileDescriptionImpl.PRead. func (f *fileDescription) PRead(ctx context.Context, dst usermem.IOSequence, offset int64, opts vfs.ReadOptions) (int64, error) { i := f.inode if !i.seekable { @@ -566,7 +573,7 @@ func (f *fileDescription) PRead(ctx context.Context, dst usermem.IOSequence, off return readFromHostFD(ctx, i.hostFD, dst, offset, opts.Flags) } -// Read implements FileDescriptionImpl. +// Read implements vfs.FileDescriptionImpl.Read. func (f *fileDescription) Read(ctx context.Context, dst usermem.IOSequence, opts vfs.ReadOptions) (int64, error) { i := f.inode if !i.seekable { @@ -603,7 +610,7 @@ func readFromHostFD(ctx context.Context, hostFD int, dst usermem.IOSequence, off return int64(n), err } -// PWrite implements FileDescriptionImpl. +// PWrite implements vfs.FileDescriptionImpl.PWrite. func (f *fileDescription) PWrite(ctx context.Context, src usermem.IOSequence, offset int64, opts vfs.WriteOptions) (int64, error) { if !f.inode.seekable { return 0, syserror.ESPIPE @@ -612,7 +619,7 @@ func (f *fileDescription) PWrite(ctx context.Context, src usermem.IOSequence, of return f.writeToHostFD(ctx, src, offset, opts.Flags) } -// Write implements FileDescriptionImpl. +// Write implements vfs.FileDescriptionImpl.Write. func (f *fileDescription) Write(ctx context.Context, src usermem.IOSequence, opts vfs.WriteOptions) (int64, error) { i := f.inode if !i.seekable { @@ -660,7 +667,7 @@ func (f *fileDescription) writeToHostFD(ctx context.Context, src usermem.IOSeque return int64(n), err } -// Seek implements FileDescriptionImpl. +// Seek implements vfs.FileDescriptionImpl.Seek. // // Note that we do not support seeking on directories, since we do not even // allow directory fds to be imported at all. @@ -725,13 +732,13 @@ func (f *fileDescription) Seek(_ context.Context, offset int64, whence int32) (i return f.offset, nil } -// Sync implements FileDescriptionImpl. +// Sync implements vfs.FileDescriptionImpl.Sync. func (f *fileDescription) Sync(context.Context) error { // TODO(gvisor.dev/issue/1897): Currently, we always sync everything. return unix.Fsync(f.inode.hostFD) } -// ConfigureMMap implements FileDescriptionImpl. +// ConfigureMMap implements vfs.FileDescriptionImpl.ConfigureMMap. func (f *fileDescription) ConfigureMMap(_ context.Context, opts *memmap.MMapOpts) error { if !f.inode.canMap { return syserror.ENODEV diff --git a/pkg/sentry/fsimpl/host/socket.go b/pkg/sentry/fsimpl/host/socket.go index 4979dd0a9..131145b85 100644 --- a/pkg/sentry/fsimpl/host/socket.go +++ b/pkg/sentry/fsimpl/host/socket.go @@ -22,7 +22,6 @@ import ( "gvisor.dev/gvisor/pkg/context" "gvisor.dev/gvisor/pkg/fdnotifier" "gvisor.dev/gvisor/pkg/log" - "gvisor.dev/gvisor/pkg/refs" "gvisor.dev/gvisor/pkg/sentry/socket/control" "gvisor.dev/gvisor/pkg/sentry/socket/unix/transport" "gvisor.dev/gvisor/pkg/sentry/uniqueid" @@ -59,8 +58,7 @@ func newEndpoint(ctx context.Context, hostFD int, queue *waiter.Queue) (transpor // // +stateify savable type ConnectedEndpoint struct { - // ref keeps track of references to a ConnectedEndpoint. - ref refs.AtomicRefCount + ConnectedEndpointRefs // mu protects fd below. mu sync.RWMutex `state:"nosave"` @@ -132,9 +130,9 @@ func NewConnectedEndpoint(ctx context.Context, hostFD int, addr string, saveable return nil, err } - // AtomicRefCounters start off with a single reference. We need two. - e.ref.IncRef() - e.ref.EnableLeakCheck("host.ConnectedEndpoint") + // ConnectedEndpointRefs start off with a single reference. We need two. + e.IncRef() + e.EnableLeakCheck() return &e, nil } @@ -318,7 +316,7 @@ func (c *ConnectedEndpoint) destroyLocked() { // Release implements transport.ConnectedEndpoint.Release and // transport.Receiver.Release. func (c *ConnectedEndpoint) Release(ctx context.Context) { - c.ref.DecRefWithDestructor(ctx, func(context.Context) { + c.DecRef(func() { c.mu.Lock() c.destroyLocked() c.mu.Unlock() @@ -348,7 +346,7 @@ func (e *SCMConnectedEndpoint) Init() error { // Release implements transport.ConnectedEndpoint.Release and // transport.Receiver.Release. func (e *SCMConnectedEndpoint) Release(ctx context.Context) { - e.ref.DecRefWithDestructor(ctx, func(context.Context) { + e.DecRef(func() { e.mu.Lock() if err := syscall.Close(e.fd); err != nil { log.Warningf("Failed to close host fd %d: %v", err) @@ -378,8 +376,8 @@ func NewSCMEndpoint(ctx context.Context, hostFD int, queue *waiter.Queue, addr s return nil, err } - // AtomicRefCounters start off with a single reference. We need two. - e.ref.IncRef() - e.ref.EnableLeakCheck("host.SCMConnectedEndpoint") + // ConnectedEndpointRefs start off with a single reference. We need two. + e.IncRef() + e.EnableLeakCheck() return &e, nil } diff --git a/pkg/sentry/fsimpl/host/socket_unsafe.go b/pkg/sentry/fsimpl/host/socket_unsafe.go index 35ded24bc..c0bf45f08 100644 --- a/pkg/sentry/fsimpl/host/socket_unsafe.go +++ b/pkg/sentry/fsimpl/host/socket_unsafe.go @@ -63,10 +63,10 @@ func fdReadVec(fd int, bufs [][]byte, control []byte, peek bool, maxlen int64) ( controlTrunc = msg.Flags&syscall.MSG_CTRUNC == syscall.MSG_CTRUNC if n > length { - return length, n, msg.Controllen, controlTrunc, err + return length, n, msg.Controllen, controlTrunc, nil } - return n, n, msg.Controllen, controlTrunc, err + return n, n, msg.Controllen, controlTrunc, nil } // fdWriteVec sends from bufs to fd. diff --git a/pkg/sentry/fsimpl/host/tty.go b/pkg/sentry/fsimpl/host/tty.go index d372c60cb..e02b9b8f6 100644 --- a/pkg/sentry/fsimpl/host/tty.go +++ b/pkg/sentry/fsimpl/host/tty.go @@ -17,6 +17,7 @@ package host import ( "gvisor.dev/gvisor/pkg/abi/linux" "gvisor.dev/gvisor/pkg/context" + "gvisor.dev/gvisor/pkg/marshal/primitive" "gvisor.dev/gvisor/pkg/sentry/arch" fslock "gvisor.dev/gvisor/pkg/sentry/fs/lock" "gvisor.dev/gvisor/pkg/sentry/kernel" @@ -75,7 +76,7 @@ func (t *TTYFileDescription) Release(ctx context.Context) { t.fileDescription.Release(ctx) } -// PRead implements vfs.FileDescriptionImpl. +// PRead implements vfs.FileDescriptionImpl.PRead. // // Reading from a TTY is only allowed for foreground process groups. Background // process groups will either get EIO or a SIGTTIN. @@ -93,7 +94,7 @@ func (t *TTYFileDescription) PRead(ctx context.Context, dst usermem.IOSequence, return t.fileDescription.PRead(ctx, dst, offset, opts) } -// Read implements vfs.FileDescriptionImpl. +// Read implements vfs.FileDescriptionImpl.Read. // // Reading from a TTY is only allowed for foreground process groups. Background // process groups will either get EIO or a SIGTTIN. @@ -111,7 +112,7 @@ func (t *TTYFileDescription) Read(ctx context.Context, dst usermem.IOSequence, o return t.fileDescription.Read(ctx, dst, opts) } -// PWrite implements vfs.FileDescriptionImpl. +// PWrite implements vfs.FileDescriptionImpl.PWrite. func (t *TTYFileDescription) PWrite(ctx context.Context, src usermem.IOSequence, offset int64, opts vfs.WriteOptions) (int64, error) { t.mu.Lock() defer t.mu.Unlock() @@ -126,7 +127,7 @@ func (t *TTYFileDescription) PWrite(ctx context.Context, src usermem.IOSequence, return t.fileDescription.PWrite(ctx, src, offset, opts) } -// Write implements vfs.FileDescriptionImpl. +// Write implements vfs.FileDescriptionImpl.Write. func (t *TTYFileDescription) Write(ctx context.Context, src usermem.IOSequence, opts vfs.WriteOptions) (int64, error) { t.mu.Lock() defer t.mu.Unlock() @@ -141,8 +142,13 @@ func (t *TTYFileDescription) Write(ctx context.Context, src usermem.IOSequence, return t.fileDescription.Write(ctx, src, opts) } -// Ioctl implements vfs.FileDescriptionImpl. +// Ioctl implements vfs.FileDescriptionImpl.Ioctl. func (t *TTYFileDescription) Ioctl(ctx context.Context, io usermem.IO, args arch.SyscallArguments) (uintptr, error) { + task := kernel.TaskFromContext(ctx) + if task == nil { + return 0, syserror.ENOTTY + } + // Ignore arg[0]. This is the real FD: fd := t.inode.hostFD ioctl := args[1].Uint64() @@ -152,9 +158,7 @@ func (t *TTYFileDescription) Ioctl(ctx context.Context, io usermem.IO, args arch if err != nil { return 0, err } - _, err = usermem.CopyObjectOut(ctx, io, args[2].Pointer(), termios, usermem.IOOpts{ - AddressSpaceActive: true, - }) + _, err = termios.CopyOut(task, args[2].Pointer()) return 0, err case linux.TCSETS, linux.TCSETSW, linux.TCSETSF: @@ -166,9 +170,7 @@ func (t *TTYFileDescription) Ioctl(ctx context.Context, io usermem.IO, args arch } var termios linux.Termios - if _, err := usermem.CopyObjectIn(ctx, io, args[2].Pointer(), &termios, usermem.IOOpts{ - AddressSpaceActive: true, - }); err != nil { + if _, err := termios.CopyIn(task, args[2].Pointer()); err != nil { return 0, err } err := ioctlSetTermios(fd, ioctl, &termios) @@ -192,10 +194,8 @@ func (t *TTYFileDescription) Ioctl(ctx context.Context, io usermem.IO, args arch defer t.mu.Unlock() // Map the ProcessGroup into a ProcessGroupID in the task's PID namespace. - pgID := pidns.IDOfProcessGroup(t.fgProcessGroup) - _, err := usermem.CopyObjectOut(ctx, io, args[2].Pointer(), &pgID, usermem.IOOpts{ - AddressSpaceActive: true, - }) + pgID := primitive.Int32(pidns.IDOfProcessGroup(t.fgProcessGroup)) + _, err := pgID.CopyOut(task, args[2].Pointer()) return 0, err case linux.TIOCSPGRP: @@ -203,11 +203,6 @@ func (t *TTYFileDescription) Ioctl(ctx context.Context, io usermem.IO, args arch // Equivalent to tcsetpgrp(fd, *argp). // Set the foreground process group ID of this terminal. - task := kernel.TaskFromContext(ctx) - if task == nil { - return 0, syserror.ENOTTY - } - t.mu.Lock() defer t.mu.Unlock() @@ -226,12 +221,11 @@ func (t *TTYFileDescription) Ioctl(ctx context.Context, io usermem.IO, args arch return 0, syserror.ENOTTY } - var pgID kernel.ProcessGroupID - if _, err := usermem.CopyObjectIn(ctx, io, args[2].Pointer(), &pgID, usermem.IOOpts{ - AddressSpaceActive: true, - }); err != nil { + var pgIDP primitive.Int32 + if _, err := pgIDP.CopyIn(task, args[2].Pointer()); err != nil { return 0, err } + pgID := kernel.ProcessGroupID(pgIDP) // pgID must be non-negative. if pgID < 0 { @@ -260,9 +254,7 @@ func (t *TTYFileDescription) Ioctl(ctx context.Context, io usermem.IO, args arch if err != nil { return 0, err } - _, err = usermem.CopyObjectOut(ctx, io, args[2].Pointer(), winsize, usermem.IOOpts{ - AddressSpaceActive: true, - }) + _, err = winsize.CopyOut(task, args[2].Pointer()) return 0, err case linux.TIOCSWINSZ: @@ -273,9 +265,7 @@ func (t *TTYFileDescription) Ioctl(ctx context.Context, io usermem.IO, args arch // set the winsize. var winsize linux.Winsize - if _, err := usermem.CopyObjectIn(ctx, io, args[2].Pointer(), &winsize, usermem.IOOpts{ - AddressSpaceActive: true, - }); err != nil { + if _, err := winsize.CopyIn(task, args[2].Pointer()); err != nil { return 0, err } err := ioctlSetWinsize(fd, &winsize) @@ -376,7 +366,7 @@ func (t *TTYFileDescription) checkChange(ctx context.Context, sig linux.Signal) // // Linux ignores the result of kill_pgrp(). _ = pg.SendSignal(kernel.SignalInfoPriv(sig)) - return kernel.ERESTARTSYS + return syserror.ERESTARTSYS } // LockPOSIX implements vfs.FileDescriptionImpl.LockPOSIX. diff --git a/pkg/sentry/fsimpl/kernfs/BUILD b/pkg/sentry/fsimpl/kernfs/BUILD index 3835557fe..5e91e0536 100644 --- a/pkg/sentry/fsimpl/kernfs/BUILD +++ b/pkg/sentry/fsimpl/kernfs/BUILD @@ -26,9 +26,54 @@ go_template_instance( }, ) +go_template_instance( + name = "dentry_refs", + out = "dentry_refs.go", + package = "kernfs", + prefix = "Dentry", + template = "//pkg/refs_vfs2:refs_template", + types = { + "T": "Dentry", + }, +) + +go_template_instance( + name = "static_directory_refs", + out = "static_directory_refs.go", + package = "kernfs", + prefix = "StaticDirectory", + template = "//pkg/refs_vfs2:refs_template", + types = { + "T": "StaticDirectory", + }, +) + +go_template_instance( + name = "dir_refs", + out = "dir_refs.go", + package = "kernfs_test", + prefix = "dir", + template = "//pkg/refs_vfs2:refs_template", + types = { + "T": "dir", + }, +) + +go_template_instance( + name = "readonly_dir_refs", + out = "readonly_dir_refs.go", + package = "kernfs_test", + prefix = "readonlyDir", + template = "//pkg/refs_vfs2:refs_template", + types = { + "T": "readonlyDir", + }, +) + go_library( name = "kernfs", srcs = [ + "dentry_refs.go", "dynamic_bytes_file.go", "fd_impl_util.go", "filesystem.go", @@ -36,7 +81,9 @@ go_library( "inode_impl_util.go", "kernfs.go", "slot_list.go", + "static_directory_refs.go", "symlink.go", + "synthetic_directory.go", ], visibility = ["//pkg/sentry:internal"], deps = [ @@ -59,11 +106,17 @@ go_library( go_test( name = "kernfs_test", size = "small", - srcs = ["kernfs_test.go"], + srcs = [ + "dir_refs.go", + "kernfs_test.go", + "readonly_dir_refs.go", + ], deps = [ ":kernfs", "//pkg/abi/linux", "//pkg/context", + "//pkg/log", + "//pkg/refs", "//pkg/sentry/contexttest", "//pkg/sentry/fsimpl/testutil", "//pkg/sentry/kernel/auth", diff --git a/pkg/sentry/fsimpl/kernfs/dynamic_bytes_file.go b/pkg/sentry/fsimpl/kernfs/dynamic_bytes_file.go index 12adf727a..1ee089620 100644 --- a/pkg/sentry/fsimpl/kernfs/dynamic_bytes_file.go +++ b/pkg/sentry/fsimpl/kernfs/dynamic_bytes_file.go @@ -35,6 +35,7 @@ import ( // +stateify savable type DynamicBytesFile struct { InodeAttrs + InodeNoStatFS InodeNoopRefCount InodeNotDirectory InodeNotSymlink diff --git a/pkg/sentry/fsimpl/kernfs/fd_impl_util.go b/pkg/sentry/fsimpl/kernfs/fd_impl_util.go index fcee6200a..6518ff5cd 100644 --- a/pkg/sentry/fsimpl/kernfs/fd_impl_util.go +++ b/pkg/sentry/fsimpl/kernfs/fd_impl_util.go @@ -15,7 +15,7 @@ package kernfs import ( - "math" + "fmt" "gvisor.dev/gvisor/pkg/abi/linux" "gvisor.dev/gvisor/pkg/context" @@ -28,9 +28,25 @@ import ( "gvisor.dev/gvisor/pkg/usermem" ) +// SeekEndConfig describes the SEEK_END behaviour for FDs. +type SeekEndConfig int + +// Constants related to SEEK_END behaviour for FDs. +const ( + // Consider the end of the file to be after the final static entry. This is + // the default option. + SeekEndStaticEntries = iota + // Consider the end of the file to be at offset 0. + SeekEndZero +) + +// GenericDirectoryFDOptions contains configuration for a GenericDirectoryFD. +type GenericDirectoryFDOptions struct { + SeekEnd SeekEndConfig +} + // GenericDirectoryFD implements vfs.FileDescriptionImpl for a generic directory -// inode that uses OrderChildren to track child nodes. GenericDirectoryFD is not -// compatible with dynamic directories. +// inode that uses OrderChildren to track child nodes. // // Note that GenericDirectoryFD holds a lock over OrderedChildren while calling // IterDirents callback. The IterDirents callback therefore cannot hash or @@ -45,6 +61,9 @@ type GenericDirectoryFD struct { vfs.DirectoryFileDescriptionDefaultImpl vfs.LockFD + // Immutable. + seekEnd SeekEndConfig + vfsfd vfs.FileDescription children *OrderedChildren @@ -57,9 +76,9 @@ type GenericDirectoryFD struct { // NewGenericDirectoryFD creates a new GenericDirectoryFD and returns its // dentry. -func NewGenericDirectoryFD(m *vfs.Mount, d *vfs.Dentry, children *OrderedChildren, locks *vfs.FileLocks, opts *vfs.OpenOptions) (*GenericDirectoryFD, error) { +func NewGenericDirectoryFD(m *vfs.Mount, d *vfs.Dentry, children *OrderedChildren, locks *vfs.FileLocks, opts *vfs.OpenOptions, fdOpts GenericDirectoryFDOptions) (*GenericDirectoryFD, error) { fd := &GenericDirectoryFD{} - if err := fd.Init(children, locks, opts); err != nil { + if err := fd.Init(children, locks, opts, fdOpts); err != nil { return nil, err } if err := fd.vfsfd.Init(fd, opts.Flags, m, d, &vfs.FileDescriptionOptions{}); err != nil { @@ -71,12 +90,13 @@ func NewGenericDirectoryFD(m *vfs.Mount, d *vfs.Dentry, children *OrderedChildre // Init initializes a GenericDirectoryFD. Use it when overriding // GenericDirectoryFD. Caller must call fd.VFSFileDescription.Init() with the // correct implementation. -func (fd *GenericDirectoryFD) Init(children *OrderedChildren, locks *vfs.FileLocks, opts *vfs.OpenOptions) error { +func (fd *GenericDirectoryFD) Init(children *OrderedChildren, locks *vfs.FileLocks, opts *vfs.OpenOptions, fdOpts GenericDirectoryFDOptions) error { if vfs.AccessTypesForOpenFlags(opts)&vfs.MayWrite != 0 { // Can't open directories for writing. return syserror.EISDIR } fd.LockFD.Init(locks) + fd.seekEnd = fdOpts.SeekEnd fd.children = children return nil } @@ -209,9 +229,17 @@ func (fd *GenericDirectoryFD) Seek(ctx context.Context, offset int64, whence int case linux.SEEK_CUR: offset += fd.off case linux.SEEK_END: - // TODO(gvisor.dev/issue/1193): This can prevent new files from showing up - // if they are added after SEEK_END. - offset = math.MaxInt64 + switch fd.seekEnd { + case SeekEndStaticEntries: + fd.children.mu.RLock() + offset += int64(len(fd.children.set)) + offset += 2 // '.' and '..' aren't tracked in children. + fd.children.mu.RUnlock() + case SeekEndZero: + // No-op: offset += 0. + default: + panic(fmt.Sprintf("Invalid GenericDirectoryFD.seekEnd = %v", fd.seekEnd)) + } default: return 0, syserror.EINVAL } diff --git a/pkg/sentry/fsimpl/kernfs/filesystem.go b/pkg/sentry/fsimpl/kernfs/filesystem.go index d7edb6342..89ed265dc 100644 --- a/pkg/sentry/fsimpl/kernfs/filesystem.go +++ b/pkg/sentry/fsimpl/kernfs/filesystem.go @@ -32,7 +32,9 @@ import ( // // stepExistingLocked is loosely analogous to fs/namei.c:walk_component(). // -// Preconditions: Filesystem.mu must be locked for at least reading. !rp.Done(). +// Preconditions: +// * Filesystem.mu must be locked for at least reading. +// * !rp.Done(). // // Postcondition: Caller must call fs.processDeferredDecRefs*. func (fs *Filesystem) stepExistingLocked(ctx context.Context, rp *vfs.ResolvingPath, vfsd *vfs.Dentry, mayFollowSymlinks bool) (*vfs.Dentry, error) { @@ -107,8 +109,11 @@ afterSymlink: // or vfs.ResolvingPath.ResolveChild(name) returns childVFSD (which may be // nil) to verify that the returned child (or lack thereof) is correct. // -// Preconditions: Filesystem.mu must be locked for at least reading. -// parent.dirMu must be locked. parent.isDir(). name is not "." or "..". +// Preconditions: +// * Filesystem.mu must be locked for at least reading. +// * parent.dirMu must be locked. +// * parent.isDir(). +// * name is not "." or "..". // // Postconditions: Caller must call fs.processDeferredDecRefs*. func (fs *Filesystem) revalidateChildLocked(ctx context.Context, vfsObj *vfs.VirtualFilesystem, parent *Dentry, name string, child *Dentry) (*Dentry, error) { @@ -135,7 +140,7 @@ func (fs *Filesystem) revalidateChildLocked(ctx context.Context, vfsObj *vfs.Vir } // Reference on childVFSD dropped by a corresponding Valid. child = childVFSD.Impl().(*Dentry) - parent.insertChildLocked(name, child) + parent.InsertChildLocked(name, child) } return child, nil } @@ -171,7 +176,9 @@ func (fs *Filesystem) walkExistingLocked(ctx context.Context, rp *vfs.ResolvingP // walkParentDirLocked is loosely analogous to Linux's // fs/namei.c:path_parentat(). // -// Preconditions: Filesystem.mu must be locked for at least reading. !rp.Done(). +// Preconditions: +// * Filesystem.mu must be locked for at least reading. +// * !rp.Done(). // // Postconditions: Caller must call fs.processDeferredDecRefs*. func (fs *Filesystem) walkParentDirLocked(ctx context.Context, rp *vfs.ResolvingPath) (*vfs.Dentry, Inode, error) { @@ -193,8 +200,10 @@ func (fs *Filesystem) walkParentDirLocked(ctx context.Context, rp *vfs.Resolving // checkCreateLocked checks that a file named rp.Component() may be created in // directory parentVFSD, then returns rp.Component(). // -// Preconditions: Filesystem.mu must be locked for at least reading. parentInode -// == parentVFSD.Impl().(*Dentry).Inode. isDir(parentInode) == true. +// Preconditions: +// * Filesystem.mu must be locked for at least reading. +// * parentInode == parentVFSD.Impl().(*Dentry).Inode. +// * isDir(parentInode) == true. func checkCreateLocked(ctx context.Context, rp *vfs.ResolvingPath, parentVFSD *vfs.Dentry, parentInode Inode) (string, error) { if err := parentInode.CheckPermissions(ctx, rp.Credentials(), vfs.MayWrite|vfs.MayExec); err != nil { return "", err @@ -351,7 +360,10 @@ func (fs *Filesystem) MkdirAt(ctx context.Context, rp *vfs.ResolvingPath, opts v defer rp.Mount().EndWrite() childVFSD, err := parentInode.NewDir(ctx, pc, opts) if err != nil { - return err + if !opts.ForSyntheticMountpoint || err == syserror.EEXIST { + return err + } + childVFSD = newSyntheticDirectory(rp.Credentials(), opts.Mode) } parentVFSD.Impl().(*Dentry).InsertChild(pc, childVFSD.Impl().(*Dentry)) return nil @@ -397,15 +409,21 @@ func (fs *Filesystem) OpenAt(ctx context.Context, rp *vfs.ResolvingPath, opts vf // Do not create new file. if opts.Flags&linux.O_CREAT == 0 { fs.mu.RLock() - defer fs.processDeferredDecRefs(ctx) - defer fs.mu.RUnlock() vfsd, inode, err := fs.walkExistingLocked(ctx, rp) if err != nil { + fs.mu.RUnlock() + fs.processDeferredDecRefs(ctx) return nil, err } if err := inode.CheckPermissions(ctx, rp.Credentials(), ats); err != nil { + fs.mu.RUnlock() + fs.processDeferredDecRefs(ctx) return nil, err } + inode.IncRef() + defer inode.DecRef(ctx) + fs.mu.RUnlock() + fs.processDeferredDecRefs(ctx) return inode.Open(ctx, rp, vfsd, opts) } @@ -414,7 +432,14 @@ func (fs *Filesystem) OpenAt(ctx context.Context, rp *vfs.ResolvingPath, opts vf vfsd := rp.Start() inode := vfsd.Impl().(*Dentry).inode fs.mu.Lock() - defer fs.mu.Unlock() + unlocked := false + unlock := func() { + if !unlocked { + fs.mu.Unlock() + unlocked = true + } + } + defer unlock() if rp.Done() { if rp.MustBeDir() { return nil, syserror.EISDIR @@ -425,6 +450,9 @@ func (fs *Filesystem) OpenAt(ctx context.Context, rp *vfs.ResolvingPath, opts vf if err := inode.CheckPermissions(ctx, rp.Credentials(), ats); err != nil { return nil, err } + inode.IncRef() + defer inode.DecRef(ctx) + unlock() return inode.Open(ctx, rp, vfsd, opts) } afterTrailingSymlink: @@ -466,6 +494,9 @@ afterTrailingSymlink: } child := childVFSD.Impl().(*Dentry) parentVFSD.Impl().(*Dentry).InsertChild(pc, child) + child.inode.IncRef() + defer child.inode.DecRef(ctx) + unlock() return child.inode.Open(ctx, rp, childVFSD, opts) } if err != nil { @@ -499,6 +530,9 @@ afterTrailingSymlink: if err := child.inode.CheckPermissions(ctx, rp.Credentials(), ats); err != nil { return nil, err } + child.inode.IncRef() + defer child.inode.DecRef(ctx) + unlock() return child.inode.Open(ctx, rp, &child.vfsd, opts) } @@ -514,7 +548,7 @@ func (fs *Filesystem) ReadlinkAt(ctx context.Context, rp *vfs.ResolvingPath) (st if !d.Impl().(*Dentry).isSymlink() { return "", syserror.EINVAL } - return inode.Readlink(ctx) + return inode.Readlink(ctx, rp.Mount()) } // RenameAt implements vfs.FilesystemImpl.RenameAt. @@ -623,6 +657,7 @@ func (fs *Filesystem) RenameAt(ctx context.Context, rp *vfs.ResolvingPath, oldPa func (fs *Filesystem) RmdirAt(ctx context.Context, rp *vfs.ResolvingPath) error { fs.mu.Lock() defer fs.mu.Unlock() + vfsd, inode, err := fs.walkExistingLocked(ctx, rp) fs.processDeferredDecRefsLocked(ctx) if err != nil { @@ -652,7 +687,8 @@ func (fs *Filesystem) RmdirAt(ctx context.Context, rp *vfs.ResolvingPath) error if err := virtfs.PrepareDeleteDentry(mntns, vfsd); err != nil { return err } - if err := parentDentry.inode.RmDir(ctx, rp.Component(), vfsd); err != nil { + + if err := parentDentry.inode.RmDir(ctx, d.name, vfsd); err != nil { virtfs.AbortDeleteDentry(vfsd) return err } @@ -690,14 +726,13 @@ func (fs *Filesystem) StatAt(ctx context.Context, rp *vfs.ResolvingPath, opts vf // StatFSAt implements vfs.FilesystemImpl.StatFSAt. func (fs *Filesystem) StatFSAt(ctx context.Context, rp *vfs.ResolvingPath) (linux.Statfs, error) { fs.mu.RLock() - _, _, err := fs.walkExistingLocked(ctx, rp) + _, inode, err := fs.walkExistingLocked(ctx, rp) fs.mu.RUnlock() fs.processDeferredDecRefs(ctx) if err != nil { return linux.Statfs{}, err } - // TODO(gvisor.dev/issue/1193): actually implement statfs. - return linux.Statfs{}, syserror.ENOSYS + return inode.StatFS(ctx, fs.VFSFilesystem()) } // SymlinkAt implements vfs.FilesystemImpl.SymlinkAt. @@ -732,6 +767,7 @@ func (fs *Filesystem) SymlinkAt(ctx context.Context, rp *vfs.ResolvingPath, targ func (fs *Filesystem) UnlinkAt(ctx context.Context, rp *vfs.ResolvingPath) error { fs.mu.Lock() defer fs.mu.Unlock() + vfsd, _, err := fs.walkExistingLocked(ctx, rp) fs.processDeferredDecRefsLocked(ctx) if err != nil { @@ -757,7 +793,7 @@ func (fs *Filesystem) UnlinkAt(ctx context.Context, rp *vfs.ResolvingPath) error if err := virtfs.PrepareDeleteDentry(mntns, vfsd); err != nil { return err } - if err := parentDentry.inode.Unlink(ctx, rp.Component(), vfsd); err != nil { + if err := parentDentry.inode.Unlink(ctx, d.name, vfsd); err != nil { virtfs.AbortDeleteDentry(vfsd) return err } @@ -765,7 +801,7 @@ func (fs *Filesystem) UnlinkAt(ctx context.Context, rp *vfs.ResolvingPath) error return nil } -// BoundEndpointAt implements FilesystemImpl.BoundEndpointAt. +// BoundEndpointAt implements vfs.FilesystemImpl.BoundEndpointAt. func (fs *Filesystem) BoundEndpointAt(ctx context.Context, rp *vfs.ResolvingPath, opts vfs.BoundEndpointOptions) (transport.BoundEndpoint, error) { fs.mu.RLock() _, inode, err := fs.walkExistingLocked(ctx, rp) @@ -780,8 +816,8 @@ func (fs *Filesystem) BoundEndpointAt(ctx context.Context, rp *vfs.ResolvingPath return nil, syserror.ECONNREFUSED } -// ListxattrAt implements vfs.FilesystemImpl.ListxattrAt. -func (fs *Filesystem) ListxattrAt(ctx context.Context, rp *vfs.ResolvingPath, size uint64) ([]string, error) { +// ListXattrAt implements vfs.FilesystemImpl.ListXattrAt. +func (fs *Filesystem) ListXattrAt(ctx context.Context, rp *vfs.ResolvingPath, size uint64) ([]string, error) { fs.mu.RLock() _, _, err := fs.walkExistingLocked(ctx, rp) fs.mu.RUnlock() @@ -793,8 +829,8 @@ func (fs *Filesystem) ListxattrAt(ctx context.Context, rp *vfs.ResolvingPath, si return nil, syserror.ENOTSUP } -// GetxattrAt implements vfs.FilesystemImpl.GetxattrAt. -func (fs *Filesystem) GetxattrAt(ctx context.Context, rp *vfs.ResolvingPath, opts vfs.GetxattrOptions) (string, error) { +// GetXattrAt implements vfs.FilesystemImpl.GetXattrAt. +func (fs *Filesystem) GetXattrAt(ctx context.Context, rp *vfs.ResolvingPath, opts vfs.GetXattrOptions) (string, error) { fs.mu.RLock() _, _, err := fs.walkExistingLocked(ctx, rp) fs.mu.RUnlock() @@ -806,8 +842,8 @@ func (fs *Filesystem) GetxattrAt(ctx context.Context, rp *vfs.ResolvingPath, opt return "", syserror.ENOTSUP } -// SetxattrAt implements vfs.FilesystemImpl.SetxattrAt. -func (fs *Filesystem) SetxattrAt(ctx context.Context, rp *vfs.ResolvingPath, opts vfs.SetxattrOptions) error { +// SetXattrAt implements vfs.FilesystemImpl.SetXattrAt. +func (fs *Filesystem) SetXattrAt(ctx context.Context, rp *vfs.ResolvingPath, opts vfs.SetXattrOptions) error { fs.mu.RLock() _, _, err := fs.walkExistingLocked(ctx, rp) fs.mu.RUnlock() @@ -819,8 +855,8 @@ func (fs *Filesystem) SetxattrAt(ctx context.Context, rp *vfs.ResolvingPath, opt return syserror.ENOTSUP } -// RemovexattrAt implements vfs.FilesystemImpl.RemovexattrAt. -func (fs *Filesystem) RemovexattrAt(ctx context.Context, rp *vfs.ResolvingPath, name string) error { +// RemoveXattrAt implements vfs.FilesystemImpl.RemoveXattrAt. +func (fs *Filesystem) RemoveXattrAt(ctx context.Context, rp *vfs.ResolvingPath, name string) error { fs.mu.RLock() _, _, err := fs.walkExistingLocked(ctx, rp) fs.mu.RUnlock() diff --git a/pkg/sentry/fsimpl/kernfs/inode_impl_util.go b/pkg/sentry/fsimpl/kernfs/inode_impl_util.go index c3efcf3ec..6ee353ace 100644 --- a/pkg/sentry/fsimpl/kernfs/inode_impl_util.go +++ b/pkg/sentry/fsimpl/kernfs/inode_impl_util.go @@ -20,7 +20,6 @@ import ( "gvisor.dev/gvisor/pkg/abi/linux" "gvisor.dev/gvisor/pkg/context" - "gvisor.dev/gvisor/pkg/refs" "gvisor.dev/gvisor/pkg/sentry/kernel/auth" "gvisor.dev/gvisor/pkg/sentry/vfs" "gvisor.dev/gvisor/pkg/sync" @@ -48,10 +47,6 @@ func (InodeNoopRefCount) TryIncRef() bool { return true } -// Destroy implements Inode.Destroy. -func (InodeNoopRefCount) Destroy(context.Context) { -} - // InodeDirectoryNoNewChildren partially implements the Inode interface. // InodeDirectoryNoNewChildren represents a directory inode which does not // support creation of new children. @@ -177,7 +172,7 @@ func (InodeNoDynamicLookup) Valid(ctx context.Context) bool { type InodeNotSymlink struct{} // Readlink implements Inode.Readlink. -func (InodeNotSymlink) Readlink(context.Context) (string, error) { +func (InodeNotSymlink) Readlink(context.Context, *vfs.Mount) (string, error) { return "", syserror.EINVAL } @@ -261,12 +256,29 @@ func (a *InodeAttrs) Stat(context.Context, *vfs.Filesystem, vfs.StatOptions) (li // SetStat implements Inode.SetStat. func (a *InodeAttrs) SetStat(ctx context.Context, fs *vfs.Filesystem, creds *auth.Credentials, opts vfs.SetStatOptions) error { + return a.SetInodeStat(ctx, fs, creds, opts) +} + +// SetInodeStat sets the corresponding attributes from opts to InodeAttrs. +// This function can be used by other kernfs-based filesystem implementation to +// sets the unexported attributes into kernfs.InodeAttrs. +func (a *InodeAttrs) SetInodeStat(ctx context.Context, fs *vfs.Filesystem, creds *auth.Credentials, opts vfs.SetStatOptions) error { if opts.Stat.Mask == 0 { return nil } - if opts.Stat.Mask&^(linux.STATX_MODE|linux.STATX_UID|linux.STATX_GID) != 0 { + + // Note that not all fields are modifiable. For example, the file type and + // inode numbers are immutable after node creation. Setting the size is often + // allowed by kernfs files but does not do anything. If some other behavior is + // needed, the embedder should consider extending SetStat. + // + // TODO(gvisor.dev/issue/1193): Implement other stat fields like timestamps. + if opts.Stat.Mask&^(linux.STATX_MODE|linux.STATX_UID|linux.STATX_GID|linux.STATX_SIZE) != 0 { return syserror.EPERM } + if opts.Stat.Mask&linux.STATX_SIZE != 0 && a.Mode().IsDir() { + return syserror.EISDIR + } if err := vfs.CheckSetStat(ctx, creds, &opts, a.Mode(), auth.KUID(atomic.LoadUint32(&a.uid)), auth.KGID(atomic.LoadUint32(&a.gid))); err != nil { return err } @@ -289,13 +301,6 @@ func (a *InodeAttrs) SetStat(ctx context.Context, fs *vfs.Filesystem, creds *aut atomic.StoreUint32(&a.gid, stat.GID) } - // Note that not all fields are modifiable. For example, the file type and - // inode numbers are immutable after node creation. - - // TODO(gvisor.dev/issue/1193): Implement other stat fields like timestamps. - // Also, STATX_SIZE will need some special handling, because read-only static - // files should return EIO for truncate operations. - return nil } @@ -348,8 +353,6 @@ type OrderedChildrenOptions struct { // // Must be initialize with Init before first use. type OrderedChildren struct { - refs.AtomicRefCount - // Can children be modified by user syscalls? It set to false, interface // methods that would modify the children return EPERM. Immutable. writable bool @@ -365,13 +368,10 @@ func (o *OrderedChildren) Init(opts OrderedChildrenOptions) { o.set = make(map[string]*slot) } -// DecRef implements Inode.DecRef. -func (o *OrderedChildren) DecRef(ctx context.Context) { - o.AtomicRefCount.DecRefWithDestructor(ctx, o.Destroy) -} - -// Destroy cleans up resources referenced by this OrderedChildren. -func (o *OrderedChildren) Destroy(context.Context) { +// Destroy clears the children stored in o. It should be called by structs +// embedding OrderedChildren upon destruction, i.e. when their reference count +// reaches zero. +func (o *OrderedChildren) Destroy() { o.mu.Lock() defer o.mu.Unlock() o.order.Reset() @@ -556,21 +556,24 @@ func (InodeSymlink) Open(ctx context.Context, rp *vfs.ResolvingPath, vfsd *vfs.D // // +stateify savable type StaticDirectory struct { - InodeNotSymlink - InodeDirectoryNoNewChildren InodeAttrs + InodeDirectoryNoNewChildren InodeNoDynamicLookup + InodeNoStatFS + InodeNotSymlink OrderedChildren + StaticDirectoryRefs - locks vfs.FileLocks + locks vfs.FileLocks + fdOpts GenericDirectoryFDOptions } var _ Inode = (*StaticDirectory)(nil) // NewStaticDir creates a new static directory and returns its dentry. -func NewStaticDir(creds *auth.Credentials, devMajor, devMinor uint32, ino uint64, perm linux.FileMode, children map[string]*Dentry) *Dentry { +func NewStaticDir(creds *auth.Credentials, devMajor, devMinor uint32, ino uint64, perm linux.FileMode, children map[string]*Dentry, fdOpts GenericDirectoryFDOptions) *Dentry { inode := &StaticDirectory{} - inode.Init(creds, devMajor, devMinor, ino, perm) + inode.Init(creds, devMajor, devMinor, ino, perm, fdOpts) dentry := &Dentry{} dentry.Init(inode) @@ -583,31 +586,46 @@ func NewStaticDir(creds *auth.Credentials, devMajor, devMinor uint32, ino uint64 } // Init initializes StaticDirectory. -func (s *StaticDirectory) Init(creds *auth.Credentials, devMajor, devMinor uint32, ino uint64, perm linux.FileMode) { +func (s *StaticDirectory) Init(creds *auth.Credentials, devMajor, devMinor uint32, ino uint64, perm linux.FileMode, fdOpts GenericDirectoryFDOptions) { if perm&^linux.PermissionsMask != 0 { panic(fmt.Sprintf("Only permission mask must be set: %x", perm&linux.PermissionsMask)) } + s.fdOpts = fdOpts s.InodeAttrs.Init(creds, devMajor, devMinor, ino, linux.ModeDirectory|perm) } -// Open implements kernfs.Inode. +// Open implements kernfs.Inode.Open. func (s *StaticDirectory) Open(ctx context.Context, rp *vfs.ResolvingPath, vfsd *vfs.Dentry, opts vfs.OpenOptions) (*vfs.FileDescription, error) { - fd, err := NewGenericDirectoryFD(rp.Mount(), vfsd, &s.OrderedChildren, &s.locks, &opts) + fd, err := NewGenericDirectoryFD(rp.Mount(), vfsd, &s.OrderedChildren, &s.locks, &opts, s.fdOpts) if err != nil { return nil, err } return fd.VFSFileDescription(), nil } -// SetStat implements Inode.SetStat not allowing inode attributes to be changed. +// SetStat implements kernfs.Inode.SetStat not allowing inode attributes to be changed. func (*StaticDirectory) SetStat(context.Context, *vfs.Filesystem, *auth.Credentials, vfs.SetStatOptions) error { return syserror.EPERM } +// DecRef implements kernfs.Inode.DecRef. +func (s *StaticDirectory) DecRef(context.Context) { + s.StaticDirectoryRefs.DecRef(s.Destroy) +} + // AlwaysValid partially implements kernfs.inodeDynamicLookup. type AlwaysValid struct{} -// Valid implements kernfs.inodeDynamicLookup. +// Valid implements kernfs.inodeDynamicLookup.Valid. func (*AlwaysValid) Valid(context.Context) bool { return true } + +// InodeNoStatFS partially implements the Inode interface, where the client +// filesystem doesn't support statfs(2). +type InodeNoStatFS struct{} + +// StatFS implements Inode.StatFS. +func (*InodeNoStatFS) StatFS(context.Context, *vfs.Filesystem) (linux.Statfs, error) { + return linux.Statfs{}, syserror.ENOSYS +} diff --git a/pkg/sentry/fsimpl/kernfs/kernfs.go b/pkg/sentry/fsimpl/kernfs/kernfs.go index 080118841..163f26ceb 100644 --- a/pkg/sentry/fsimpl/kernfs/kernfs.go +++ b/pkg/sentry/fsimpl/kernfs/kernfs.go @@ -57,10 +57,10 @@ import ( "gvisor.dev/gvisor/pkg/abi/linux" "gvisor.dev/gvisor/pkg/context" - "gvisor.dev/gvisor/pkg/refs" "gvisor.dev/gvisor/pkg/sentry/kernel/auth" "gvisor.dev/gvisor/pkg/sentry/vfs" "gvisor.dev/gvisor/pkg/sync" + "gvisor.dev/gvisor/pkg/syserror" ) // Filesystem mostly implements vfs.FilesystemImpl for a generic in-memory @@ -161,9 +161,9 @@ const ( // // Must be initialized by Init prior to first use. type Dentry struct { - vfsd vfs.Dentry + DentryRefs - refs.AtomicRefCount + vfsd vfs.Dentry // flags caches useful information about the dentry from the inode. See the // dflags* consts above. Must be accessed by atomic ops. @@ -194,6 +194,7 @@ func (d *Dentry) Init(inode Inode) { if ftype == linux.ModeSymlink { d.flags |= dflagsIsSymlink } + d.EnableLeakCheck() } // VFSDentry returns the generic vfs dentry for this kernfs dentry. @@ -213,16 +214,14 @@ func (d *Dentry) isSymlink() bool { // DecRef implements vfs.DentryImpl.DecRef. func (d *Dentry) DecRef(ctx context.Context) { - d.AtomicRefCount.DecRefWithDestructor(ctx, d.destroy) -} - -// Precondition: Dentry must be removed from VFS' dentry cache. -func (d *Dentry) destroy(ctx context.Context) { - d.inode.DecRef(ctx) // IncRef from Init. - d.inode = nil - if d.parent != nil { - d.parent.DecRef(ctx) // IncRef from Dentry.InsertChild. - } + // Before the destructor is called, Dentry must be removed from VFS' dentry cache. + d.DentryRefs.DecRef(func() { + d.inode.DecRef(ctx) // IncRef from Init. + d.inode = nil + if d.parent != nil { + d.parent.DecRef(ctx) // IncRef from Dentry.InsertChild. + } + }) } // InotifyWithParent implements vfs.DentryImpl.InotifyWithParent. @@ -248,15 +247,15 @@ func (d *Dentry) OnZeroWatches(context.Context) {} // Precondition: d must represent a directory inode. func (d *Dentry) InsertChild(name string, child *Dentry) { d.dirMu.Lock() - d.insertChildLocked(name, child) + d.InsertChildLocked(name, child) d.dirMu.Unlock() } -// insertChildLocked is equivalent to InsertChild, with additional +// InsertChildLocked is equivalent to InsertChild, with additional // preconditions. // // Precondition: d.dirMu must be locked. -func (d *Dentry) insertChildLocked(name string, child *Dentry) { +func (d *Dentry) InsertChildLocked(name string, child *Dentry) { if !d.isDir() { panic(fmt.Sprintf("InsertChild called on non-directory Dentry: %+v.", d)) } @@ -269,6 +268,36 @@ func (d *Dentry) insertChildLocked(name string, child *Dentry) { d.children[name] = child } +// RemoveChild removes child from the vfs dentry cache. This does not update the +// directory inode or modify the inode to be unlinked. So calling this on its own +// isn't sufficient to remove a child from a directory. +// +// Precondition: d must represent a directory inode. +func (d *Dentry) RemoveChild(name string, child *vfs.Dentry) error { + d.dirMu.Lock() + defer d.dirMu.Unlock() + return d.RemoveChildLocked(name, child) +} + +// RemoveChildLocked is equivalent to RemoveChild, with additional +// preconditions. +// +// Precondition: d.dirMu must be locked. +func (d *Dentry) RemoveChildLocked(name string, child *vfs.Dentry) error { + if !d.isDir() { + panic(fmt.Sprintf("RemoveChild called on non-directory Dentry: %+v.", d)) + } + c, ok := d.children[name] + if !ok { + return syserror.ENOENT + } + if &c.vfsd != child { + panic(fmt.Sprintf("Dentry hashed into inode doesn't match what vfs thinks! Child: %+v, vfs: %+v", c, child)) + } + delete(d.children, name) + return nil +} + // Inode returns the dentry's inode. func (d *Dentry) Inode() Inode { return d.inode @@ -322,16 +351,17 @@ type Inode interface { // Precondition: rp.Done(). vfsd.Impl() must be the kernfs Dentry containing // the inode on which Open() is being called. Open(ctx context.Context, rp *vfs.ResolvingPath, vfsd *vfs.Dentry, opts vfs.OpenOptions) (*vfs.FileDescription, error) + + // StatFS returns filesystem statistics for the client filesystem. This + // corresponds to vfs.FilesystemImpl.StatFSAt. If the client filesystem + // doesn't support statfs(2), this should return ENOSYS. + StatFS(ctx context.Context, fs *vfs.Filesystem) (linux.Statfs, error) } type inodeRefs interface { IncRef() DecRef(ctx context.Context) TryIncRef() bool - // Destroy is called when the inode reaches zero references. Destroy release - // all resources (references) on objects referenced by the inode, including - // any child dentries. - Destroy(ctx context.Context) } type inodeMetadata interface { @@ -426,7 +456,7 @@ type inodeDynamicLookup interface { Valid(ctx context.Context) bool // IterDirents is used to iterate over dynamically created entries. It invokes - // cb on each entry in the directory represented by the FileDescription. + // cb on each entry in the directory represented by the Inode. // 'offset' is the offset for the entire IterDirents call, which may include // results from the caller (e.g. "." and ".."). 'relOffset' is the offset // inside the entries returned by this IterDirents invocation. In other words, @@ -438,7 +468,7 @@ type inodeDynamicLookup interface { type inodeSymlink interface { // Readlink returns the target of a symbolic link. If an inode is not a // symlink, the implementation should return EINVAL. - Readlink(ctx context.Context) (string, error) + Readlink(ctx context.Context, mnt *vfs.Mount) (string, error) // Getlink returns the target of a symbolic link, as used by path // resolution: diff --git a/pkg/sentry/fsimpl/kernfs/kernfs_test.go b/pkg/sentry/fsimpl/kernfs/kernfs_test.go index c5d5afedf..09806a3f2 100644 --- a/pkg/sentry/fsimpl/kernfs/kernfs_test.go +++ b/pkg/sentry/fsimpl/kernfs/kernfs_test.go @@ -52,7 +52,7 @@ func newTestSystem(t *testing.T, rootFn RootDentryFn) *testutil.System { v.MustRegisterFilesystemType("testfs", &fsType{rootFn: rootFn}, &vfs.RegisterFilesystemTypeOptions{ AllowUserMount: true, }) - mns, err := v.NewMountNamespace(ctx, creds, "", "testfs", &vfs.GetFilesystemOptions{}) + mns, err := v.NewMountNamespace(ctx, creds, "", "testfs", &vfs.MountOptions{}) if err != nil { t.Fatalf("Failed to create testfs root mount: %v", err) } @@ -96,10 +96,12 @@ func (*attrs) SetStat(context.Context, *vfs.Filesystem, *auth.Credentials, vfs.S } type readonlyDir struct { + readonlyDirRefs attrs - kernfs.InodeNotSymlink - kernfs.InodeNoDynamicLookup kernfs.InodeDirectoryNoNewChildren + kernfs.InodeNoDynamicLookup + kernfs.InodeNoStatFS + kernfs.InodeNotSymlink kernfs.OrderedChildren locks vfs.FileLocks @@ -111,6 +113,7 @@ func (fs *filesystem) newReadonlyDir(creds *auth.Credentials, mode linux.FileMod dir := &readonlyDir{} dir.attrs.Init(creds, 0 /* devMajor */, 0 /* devMinor */, fs.NextIno(), linux.ModeDirectory|mode) dir.OrderedChildren.Init(kernfs.OrderedChildrenOptions{}) + dir.EnableLeakCheck() dir.dentry.Init(dir) dir.IncLinks(dir.OrderedChildren.Populate(&dir.dentry, contents)) @@ -119,18 +122,26 @@ func (fs *filesystem) newReadonlyDir(creds *auth.Credentials, mode linux.FileMod } func (d *readonlyDir) Open(ctx context.Context, rp *vfs.ResolvingPath, vfsd *vfs.Dentry, opts vfs.OpenOptions) (*vfs.FileDescription, error) { - fd, err := kernfs.NewGenericDirectoryFD(rp.Mount(), vfsd, &d.OrderedChildren, &d.locks, &opts) + fd, err := kernfs.NewGenericDirectoryFD(rp.Mount(), vfsd, &d.OrderedChildren, &d.locks, &opts, kernfs.GenericDirectoryFDOptions{ + SeekEnd: kernfs.SeekEndStaticEntries, + }) if err != nil { return nil, err } return fd.VFSFileDescription(), nil } +func (d *readonlyDir) DecRef(context.Context) { + d.readonlyDirRefs.DecRef(d.Destroy) +} + type dir struct { + dirRefs attrs - kernfs.InodeNotSymlink kernfs.InodeNoDynamicLookup + kernfs.InodeNotSymlink kernfs.OrderedChildren + kernfs.InodeNoStatFS locks vfs.FileLocks @@ -143,6 +154,7 @@ func (fs *filesystem) newDir(creds *auth.Credentials, mode linux.FileMode, conte dir.fs = fs dir.attrs.Init(creds, 0 /* devMajor */, 0 /* devMinor */, fs.NextIno(), linux.ModeDirectory|mode) dir.OrderedChildren.Init(kernfs.OrderedChildrenOptions{Writable: true}) + dir.EnableLeakCheck() dir.dentry.Init(dir) dir.IncLinks(dir.OrderedChildren.Populate(&dir.dentry, contents)) @@ -151,13 +163,19 @@ func (fs *filesystem) newDir(creds *auth.Credentials, mode linux.FileMode, conte } func (d *dir) Open(ctx context.Context, rp *vfs.ResolvingPath, vfsd *vfs.Dentry, opts vfs.OpenOptions) (*vfs.FileDescription, error) { - fd, err := kernfs.NewGenericDirectoryFD(rp.Mount(), vfsd, &d.OrderedChildren, &d.locks, &opts) + fd, err := kernfs.NewGenericDirectoryFD(rp.Mount(), vfsd, &d.OrderedChildren, &d.locks, &opts, kernfs.GenericDirectoryFDOptions{ + SeekEnd: kernfs.SeekEndStaticEntries, + }) if err != nil { return nil, err } return fd.VFSFileDescription(), nil } +func (d *dir) DecRef(context.Context) { + d.dirRefs.DecRef(d.Destroy) +} + func (d *dir) NewDir(ctx context.Context, name string, opts vfs.MkdirOptions) (*vfs.Dentry, error) { creds := auth.CredentialsFromContext(ctx) dir := d.fs.newDir(creds, opts.Mode, nil) diff --git a/pkg/sentry/fsimpl/kernfs/symlink.go b/pkg/sentry/fsimpl/kernfs/symlink.go index 2ab3f53fd..443121c99 100644 --- a/pkg/sentry/fsimpl/kernfs/symlink.go +++ b/pkg/sentry/fsimpl/kernfs/symlink.go @@ -28,6 +28,7 @@ type StaticSymlink struct { InodeAttrs InodeNoopRefCount InodeSymlink + InodeNoStatFS target string } @@ -50,8 +51,8 @@ func (s *StaticSymlink) Init(creds *auth.Credentials, devMajor uint32, devMinor s.InodeAttrs.Init(creds, devMajor, devMinor, ino, linux.ModeSymlink|0777) } -// Readlink implements Inode. -func (s *StaticSymlink) Readlink(_ context.Context) (string, error) { +// Readlink implements Inode.Readlink. +func (s *StaticSymlink) Readlink(_ context.Context, _ *vfs.Mount) (string, error) { return s.target, nil } diff --git a/pkg/sentry/fsimpl/kernfs/synthetic_directory.go b/pkg/sentry/fsimpl/kernfs/synthetic_directory.go new file mode 100644 index 000000000..01ba72fa8 --- /dev/null +++ b/pkg/sentry/fsimpl/kernfs/synthetic_directory.go @@ -0,0 +1,102 @@ +// Copyright 2019 The gVisor Authors. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package kernfs + +import ( + "fmt" + + "gvisor.dev/gvisor/pkg/abi/linux" + "gvisor.dev/gvisor/pkg/context" + "gvisor.dev/gvisor/pkg/sentry/kernel/auth" + "gvisor.dev/gvisor/pkg/sentry/vfs" + "gvisor.dev/gvisor/pkg/syserror" +) + +// syntheticDirectory implements kernfs.Inode for a directory created by +// MkdirAt(ForSyntheticMountpoint=true). +// +// +stateify savable +type syntheticDirectory struct { + InodeAttrs + InodeNoStatFS + InodeNoopRefCount + InodeNoDynamicLookup + InodeNotSymlink + OrderedChildren + + locks vfs.FileLocks +} + +var _ Inode = (*syntheticDirectory)(nil) + +func newSyntheticDirectory(creds *auth.Credentials, perm linux.FileMode) *vfs.Dentry { + inode := &syntheticDirectory{} + inode.Init(creds, 0 /* devMajor */, 0 /* devMinor */, 0 /* ino */, perm) + d := &Dentry{} + d.Init(inode) + return &d.vfsd +} + +func (dir *syntheticDirectory) Init(creds *auth.Credentials, devMajor, devMinor uint32, ino uint64, perm linux.FileMode) { + if perm&^linux.PermissionsMask != 0 { + panic(fmt.Sprintf("perm contains non-permission bits: %#o", perm)) + } + dir.InodeAttrs.Init(creds, devMajor, devMinor, ino, linux.S_IFDIR|perm) + dir.OrderedChildren.Init(OrderedChildrenOptions{ + Writable: true, + }) +} + +// Open implements Inode.Open. +func (dir *syntheticDirectory) Open(ctx context.Context, rp *vfs.ResolvingPath, vfsd *vfs.Dentry, opts vfs.OpenOptions) (*vfs.FileDescription, error) { + fd, err := NewGenericDirectoryFD(rp.Mount(), vfsd, &dir.OrderedChildren, &dir.locks, &opts, GenericDirectoryFDOptions{}) + if err != nil { + return nil, err + } + return &fd.vfsfd, nil +} + +// NewFile implements Inode.NewFile. +func (dir *syntheticDirectory) NewFile(ctx context.Context, name string, opts vfs.OpenOptions) (*vfs.Dentry, error) { + return nil, syserror.EPERM +} + +// NewDir implements Inode.NewDir. +func (dir *syntheticDirectory) NewDir(ctx context.Context, name string, opts vfs.MkdirOptions) (*vfs.Dentry, error) { + if !opts.ForSyntheticMountpoint { + return nil, syserror.EPERM + } + subdird := newSyntheticDirectory(auth.CredentialsFromContext(ctx), opts.Mode&linux.PermissionsMask) + if err := dir.OrderedChildren.Insert(name, subdird); err != nil { + subdird.DecRef(ctx) + return nil, err + } + return subdird, nil +} + +// NewLink implements Inode.NewLink. +func (dir *syntheticDirectory) NewLink(ctx context.Context, name string, target Inode) (*vfs.Dentry, error) { + return nil, syserror.EPERM +} + +// NewSymlink implements Inode.NewSymlink. +func (dir *syntheticDirectory) NewSymlink(ctx context.Context, name, target string) (*vfs.Dentry, error) { + return nil, syserror.EPERM +} + +// NewNode implements Inode.NewNode. +func (dir *syntheticDirectory) NewNode(ctx context.Context, name string, opts vfs.MknodOptions) (*vfs.Dentry, error) { + return nil, syserror.EPERM +} diff --git a/pkg/sentry/fsimpl/overlay/copy_up.go b/pkg/sentry/fsimpl/overlay/copy_up.go index b3d19ff82..73b126669 100644 --- a/pkg/sentry/fsimpl/overlay/copy_up.go +++ b/pkg/sentry/fsimpl/overlay/copy_up.go @@ -22,6 +22,8 @@ import ( "gvisor.dev/gvisor/pkg/abi/linux" "gvisor.dev/gvisor/pkg/context" "gvisor.dev/gvisor/pkg/fspath" + "gvisor.dev/gvisor/pkg/sentry/kernel/auth" + "gvisor.dev/gvisor/pkg/sentry/memmap" "gvisor.dev/gvisor/pkg/sentry/vfs" "gvisor.dev/gvisor/pkg/syserror" "gvisor.dev/gvisor/pkg/usermem" @@ -40,6 +42,10 @@ func (d *dentry) copyUpLocked(ctx context.Context) error { return nil } + // Attach our credentials to the context, as some VFS operations use + // credentials from context rather an take an explicit creds parameter. + ctx = auth.ContextWithCredentials(ctx, d.fs.creds) + ftype := atomic.LoadUint32(&d.mode) & linux.S_IFMT switch ftype { case linux.S_IFREG, linux.S_IFDIR, linux.S_IFLNK, linux.S_IFBLK, linux.S_IFCHR: @@ -76,6 +82,8 @@ func (d *dentry) copyUpLocked(ctx context.Context) error { Start: d.parent.upperVD, Path: fspath.Parse(d.name), } + // Used during copy-up of memory-mapped regular files. + var mmapOpts *memmap.MMapOpts cleanupUndoCopyUp := func() { var err error if ftype == linux.S_IFDIR { @@ -84,7 +92,11 @@ func (d *dentry) copyUpLocked(ctx context.Context) error { err = vfsObj.UnlinkAt(ctx, d.fs.creds, &newpop) } if err != nil { - ctx.Warningf("Unrecoverable overlayfs inconsistency: failed to delete upper layer file after copy-up error: %v", err) + panic(fmt.Sprintf("unrecoverable overlayfs inconsistency: failed to delete upper layer file after copy-up error: %v", err)) + } + if d.upperVD.Ok() { + d.upperVD.DecRef(ctx) + d.upperVD = vfs.VirtualDentry{} } } switch ftype { @@ -127,6 +139,25 @@ func (d *dentry) copyUpLocked(ctx context.Context) error { break } } + d.mapsMu.Lock() + defer d.mapsMu.Unlock() + if d.wrappedMappable != nil { + // We may have memory mappings of the file on the lower layer. + // Switch to mapping the file on the upper layer instead. + mmapOpts = &memmap.MMapOpts{ + Perms: usermem.ReadWrite, + MaxPerms: usermem.ReadWrite, + } + if err := newFD.ConfigureMMap(ctx, mmapOpts); err != nil { + cleanupUndoCopyUp() + return err + } + if mmapOpts.MappingIdentity != nil { + mmapOpts.MappingIdentity.DecRef(ctx) + } + // Don't actually switch Mappables until the end of copy-up; see + // below for why. + } if err := newFD.SetStat(ctx, vfs.SetStatOptions{ Stat: linux.Statx{ Mask: linux.STATX_UID | linux.STATX_GID, @@ -229,7 +260,10 @@ func (d *dentry) copyUpLocked(ctx context.Context) error { panic(fmt.Sprintf("unexpected file type %o", ftype)) } - // TODO(gvisor.dev/issue/1199): copy up xattrs + if err := d.copyXattrsLocked(ctx); err != nil { + cleanupUndoCopyUp() + return err + } // Update the dentry's device and inode numbers (except for directories, // for which these remain overlay-assigned). @@ -241,14 +275,10 @@ func (d *dentry) copyUpLocked(ctx context.Context) error { Mask: linux.STATX_INO, }) if err != nil { - d.upperVD.DecRef(ctx) - d.upperVD = vfs.VirtualDentry{} cleanupUndoCopyUp() return err } if upperStat.Mask&linux.STATX_INO == 0 { - d.upperVD.DecRef(ctx) - d.upperVD = vfs.VirtualDentry{} cleanupUndoCopyUp() return syserror.EREMOTE } @@ -257,6 +287,135 @@ func (d *dentry) copyUpLocked(ctx context.Context) error { atomic.StoreUint64(&d.ino, upperStat.Ino) } + if mmapOpts != nil && mmapOpts.Mappable != nil { + // Note that if mmapOpts != nil, then d.mapsMu is locked for writing + // (from the S_IFREG path above). + + // Propagate mappings of d to the new Mappable. Remember which mappings + // we added so we can remove them on failure. + upperMappable := mmapOpts.Mappable + allAdded := make(map[memmap.MappableRange]memmap.MappingsOfRange) + for seg := d.lowerMappings.FirstSegment(); seg.Ok(); seg = seg.NextSegment() { + added := make(memmap.MappingsOfRange) + for m := range seg.Value() { + if err := upperMappable.AddMapping(ctx, m.MappingSpace, m.AddrRange, seg.Start(), m.Writable); err != nil { + for m := range added { + upperMappable.RemoveMapping(ctx, m.MappingSpace, m.AddrRange, seg.Start(), m.Writable) + } + for mr, mappings := range allAdded { + for m := range mappings { + upperMappable.RemoveMapping(ctx, m.MappingSpace, m.AddrRange, mr.Start, m.Writable) + } + } + return err + } + added[m] = struct{}{} + } + allAdded[seg.Range()] = added + } + + // Switch to the new Mappable. We do this at the end of copy-up + // because: + // + // - We need to switch Mappables (by changing d.wrappedMappable) before + // invalidating Translations from the old Mappable (to pick up + // Translations from the new one). + // + // - We need to lock d.dataMu while changing d.wrappedMappable, but + // must invalidate Translations with d.dataMu unlocked (due to lock + // ordering). + // + // - Consequently, once we unlock d.dataMu, other threads may + // immediately observe the new (copied-up) Mappable, which we want to + // delay until copy-up is guaranteed to succeed. + d.dataMu.Lock() + lowerMappable := d.wrappedMappable + d.wrappedMappable = upperMappable + d.dataMu.Unlock() + d.lowerMappings.InvalidateAll(memmap.InvalidateOpts{}) + + // Remove mappings from the old Mappable. + for seg := d.lowerMappings.FirstSegment(); seg.Ok(); seg = seg.NextSegment() { + for m := range seg.Value() { + lowerMappable.RemoveMapping(ctx, m.MappingSpace, m.AddrRange, seg.Start(), m.Writable) + } + } + d.lowerMappings.RemoveAll() + } + atomic.StoreUint32(&d.copiedUp, 1) return nil } + +// copyXattrsLocked copies a subset of lower's extended attributes to upper. +// Attributes that configure an overlay in the lower are not copied up. +// +// Preconditions: d.copyMu must be locked for writing. +func (d *dentry) copyXattrsLocked(ctx context.Context) error { + vfsObj := d.fs.vfsfs.VirtualFilesystem() + lowerPop := &vfs.PathOperation{Root: d.lowerVDs[0], Start: d.lowerVDs[0]} + upperPop := &vfs.PathOperation{Root: d.upperVD, Start: d.upperVD} + + lowerXattrs, err := vfsObj.ListXattrAt(ctx, d.fs.creds, lowerPop, 0) + if err != nil { + if err == syserror.EOPNOTSUPP { + // There are no guarantees as to the contents of lowerXattrs. + return nil + } + ctx.Infof("failed to copy up xattrs because ListXattrAt failed: %v", err) + return err + } + + for _, name := range lowerXattrs { + // Do not copy up overlay attributes. + if isOverlayXattr(name) { + continue + } + + value, err := vfsObj.GetXattrAt(ctx, d.fs.creds, lowerPop, &vfs.GetXattrOptions{Name: name, Size: 0}) + if err != nil { + ctx.Infof("failed to copy up xattrs because GetXattrAt failed: %v", err) + return err + } + + if err := vfsObj.SetXattrAt(ctx, d.fs.creds, upperPop, &vfs.SetXattrOptions{Name: name, Value: value}); err != nil { + ctx.Infof("failed to copy up xattrs because SetXattrAt failed: %v", err) + return err + } + } + return nil +} + +// copyUpDescendantsLocked ensures that all descendants of d are copied up. +// +// Preconditions: +// * filesystem.renameMu must be locked. +// * d.dirMu must be locked. +// * d.isDir(). +func (d *dentry) copyUpDescendantsLocked(ctx context.Context, ds **[]*dentry) error { + dirents, err := d.getDirentsLocked(ctx) + if err != nil { + return err + } + for _, dirent := range dirents { + if dirent.Name == "." || dirent.Name == ".." { + continue + } + child, err := d.fs.getChildLocked(ctx, d, dirent.Name, ds) + if err != nil { + return err + } + if err := child.copyUpLocked(ctx); err != nil { + return err + } + if child.isDir() { + child.dirMu.Lock() + err := child.copyUpDescendantsLocked(ctx, ds) + child.dirMu.Unlock() + if err != nil { + return err + } + } + } + return nil +} diff --git a/pkg/sentry/fsimpl/overlay/directory.go b/pkg/sentry/fsimpl/overlay/directory.go index 6a79f7ffe..7ab42e71e 100644 --- a/pkg/sentry/fsimpl/overlay/directory.go +++ b/pkg/sentry/fsimpl/overlay/directory.go @@ -29,7 +29,9 @@ func (d *dentry) isDir() bool { return atomic.LoadUint32(&d.mode)&linux.S_IFMT == linux.S_IFDIR } -// Preconditions: d.dirMu must be locked. d.isDir(). +// Preconditions: +// * d.dirMu must be locked. +// * d.isDir(). func (d *dentry) collectWhiteoutsForRmdirLocked(ctx context.Context) (map[string]bool, error) { vfsObj := d.fs.vfsfs.VirtualFilesystem() var readdirErr error @@ -141,7 +143,14 @@ func (d *dentry) getDirents(ctx context.Context) ([]vfs.Dirent, error) { defer d.fs.renameMu.RUnlock() d.dirMu.Lock() defer d.dirMu.Unlock() + return d.getDirentsLocked(ctx) +} +// Preconditions: +// * filesystem.renameMu must be locked. +// * d.dirMu must be locked. +// * d.isDir(). +func (d *dentry) getDirentsLocked(ctx context.Context) ([]vfs.Dirent, error) { if d.dirents != nil { return d.dirents, nil } diff --git a/pkg/sentry/fsimpl/overlay/filesystem.go b/pkg/sentry/fsimpl/overlay/filesystem.go index 986b36ead..e9ce4bde1 100644 --- a/pkg/sentry/fsimpl/overlay/filesystem.go +++ b/pkg/sentry/fsimpl/overlay/filesystem.go @@ -15,6 +15,8 @@ package overlay import ( + "fmt" + "strings" "sync/atomic" "gvisor.dev/gvisor/pkg/abi/linux" @@ -27,10 +29,15 @@ import ( "gvisor.dev/gvisor/pkg/syserror" ) +// _OVL_XATTR_PREFIX is an extended attribute key prefix to identify overlayfs +// attributes. +// Linux: fs/overlayfs/overlayfs.h:OVL_XATTR_PREFIX +const _OVL_XATTR_PREFIX = linux.XATTR_TRUSTED_PREFIX + "overlay." + // _OVL_XATTR_OPAQUE is an extended attribute key whose value is set to "y" for // opaque directories. // Linux: fs/overlayfs/overlayfs.h:OVL_XATTR_OPAQUE -const _OVL_XATTR_OPAQUE = "trusted.overlay.opaque" +const _OVL_XATTR_OPAQUE = _OVL_XATTR_PREFIX + "opaque" func isWhiteout(stat *linux.Statx) bool { return stat.Mode&linux.S_IFMT == linux.S_IFCHR && stat.RdevMajor == 0 && stat.RdevMinor == 0 @@ -110,8 +117,10 @@ func (fs *filesystem) renameMuUnlockAndCheckDrop(ctx context.Context, ds **[]*de // Dentries which may have a reference count of zero, and which therefore // should be dropped once traversal is complete, are appended to ds. // -// Preconditions: fs.renameMu must be locked. d.dirMu must be locked. -// !rp.Done(). +// Preconditions: +// * fs.renameMu must be locked. +// * d.dirMu must be locked. +// * !rp.Done(). func (fs *filesystem) stepLocked(ctx context.Context, rp *vfs.ResolvingPath, d *dentry, mayFollowSymlinks bool, ds **[]*dentry) (*dentry, error) { if !d.isDir() { return nil, syserror.ENOTDIR @@ -159,7 +168,9 @@ afterSymlink: return child, nil } -// Preconditions: fs.renameMu must be locked. d.dirMu must be locked. +// Preconditions: +// * fs.renameMu must be locked. +// * d.dirMu must be locked. func (fs *filesystem) getChildLocked(ctx context.Context, parent *dentry, name string, ds **[]*dentry) (*dentry, error) { if child, ok := parent.children[name]; ok { return child, nil @@ -177,7 +188,9 @@ func (fs *filesystem) getChildLocked(ctx context.Context, parent *dentry, name s return child, nil } -// Preconditions: fs.renameMu must be locked. parent.dirMu must be locked. +// Preconditions: +// * fs.renameMu must be locked. +// * parent.dirMu must be locked. func (fs *filesystem) lookupLocked(ctx context.Context, parent *dentry, name string) (*dentry, error) { childPath := fspath.Parse(name) child := fs.newDentry() @@ -199,6 +212,7 @@ func (fs *filesystem) lookupLocked(ctx context.Context, parent *dentry, name str lookupErr = err return false } + defer childVD.DecRef(ctx) mask := uint32(linux.STATX_TYPE) if !existsOnAnyLayer { @@ -237,6 +251,7 @@ func (fs *filesystem) lookupLocked(ctx context.Context, parent *dentry, name str } // Update child to include this layer. + childVD.IncRef() if isUpper { child.upperVD = childVD child.copiedUp = 1 @@ -261,10 +276,10 @@ func (fs *filesystem) lookupLocked(ctx context.Context, parent *dentry, name str // Directories are merged with directories from lower layers if they // are not explicitly opaque. - opaqueVal, err := vfsObj.GetxattrAt(ctx, fs.creds, &vfs.PathOperation{ + opaqueVal, err := vfsObj.GetXattrAt(ctx, fs.creds, &vfs.PathOperation{ Root: childVD, Start: childVD, - }, &vfs.GetxattrOptions{ + }, &vfs.GetXattrOptions{ Name: _OVL_XATTR_OPAQUE, Size: 1, }) @@ -300,7 +315,9 @@ func (fs *filesystem) lookupLocked(ctx context.Context, parent *dentry, name str // lookupLayerLocked is similar to lookupLocked, but only returns information // about the file rather than a dentry. // -// Preconditions: fs.renameMu must be locked. parent.dirMu must be locked. +// Preconditions: +// * fs.renameMu must be locked. +// * parent.dirMu must be locked. func (fs *filesystem) lookupLayerLocked(ctx context.Context, parent *dentry, name string) (lookupLayer, error) { childPath := fspath.Parse(name) lookupLayer := lookupLayerNone @@ -385,7 +402,9 @@ func (ll lookupLayer) existsInOverlay() bool { // rp.Start().Impl().(*dentry)). It does not check that the returned directory // is searchable by the provider of rp. // -// Preconditions: fs.renameMu must be locked. !rp.Done(). +// Preconditions: +// * fs.renameMu must be locked. +// * !rp.Done(). func (fs *filesystem) walkParentDirLocked(ctx context.Context, rp *vfs.ResolvingPath, d *dentry, ds **[]*dentry) (*dentry, error) { for !rp.Final() { d.dirMu.Lock() @@ -425,8 +444,9 @@ func (fs *filesystem) resolveLocked(ctx context.Context, rp *vfs.ResolvingPath, // doCreateAt checks that creating a file at rp is permitted, then invokes // create to do so. // -// Preconditions: !rp.Done(). For the final path component in rp, -// !rp.ShouldFollowSymlink(). +// Preconditions: +// * !rp.Done(). +// * For the final path component in rp, !rp.ShouldFollowSymlink(). func (fs *filesystem) doCreateAt(ctx context.Context, rp *vfs.ResolvingPath, dir bool, create func(parent *dentry, name string, haveUpperWhiteout bool) error) error { var ds *[]*dentry fs.renameMu.RLock() @@ -493,7 +513,7 @@ func (fs *filesystem) createWhiteout(ctx context.Context, vfsObj *vfs.VirtualFil func (fs *filesystem) cleanupRecreateWhiteout(ctx context.Context, vfsObj *vfs.VirtualFilesystem, pop *vfs.PathOperation) { if err := fs.createWhiteout(ctx, vfsObj, pop); err != nil { - ctx.Warningf("Unrecoverable overlayfs inconsistency: failed to recreate whiteout after failed file creation: %v", err) + panic(fmt.Sprintf("unrecoverable overlayfs inconsistency: failed to recreate whiteout after failed file creation: %v", err)) } } @@ -605,7 +625,7 @@ func (fs *filesystem) LinkAt(ctx context.Context, rp *vfs.ResolvingPath, vd vfs. }, }); err != nil { if cleanupErr := vfsObj.UnlinkAt(ctx, fs.creds, &newpop); cleanupErr != nil { - ctx.Warningf("Unrecoverable overlayfs inconsistency: failed to delete upper layer file after LinkAt metadata update failure: %v", cleanupErr) + panic(fmt.Sprintf("unrecoverable overlayfs inconsistency: failed to delete upper layer file after LinkAt metadata update failure: %v", cleanupErr)) } else if haveUpperWhiteout { fs.cleanupRecreateWhiteout(ctx, vfsObj, &newpop) } @@ -644,7 +664,7 @@ func (fs *filesystem) MkdirAt(ctx context.Context, rp *vfs.ResolvingPath, opts v }, }); err != nil { if cleanupErr := vfsObj.RmdirAt(ctx, fs.creds, &pop); cleanupErr != nil { - ctx.Warningf("Unrecoverable overlayfs inconsistency: failed to delete upper layer directory after MkdirAt metadata update failure: %v", cleanupErr) + panic(fmt.Sprintf("unrecoverable overlayfs inconsistency: failed to delete upper layer directory after MkdirAt metadata update failure: %v", cleanupErr)) } else if haveUpperWhiteout { fs.cleanupRecreateWhiteout(ctx, vfsObj, &pop) } @@ -654,12 +674,12 @@ func (fs *filesystem) MkdirAt(ctx context.Context, rp *vfs.ResolvingPath, opts v // There may be directories on lower layers (previously hidden by // the whiteout) that the new directory should not be merged with. // Mark it opaque to prevent merging. - if err := vfsObj.SetxattrAt(ctx, fs.creds, &pop, &vfs.SetxattrOptions{ + if err := vfsObj.SetXattrAt(ctx, fs.creds, &pop, &vfs.SetXattrOptions{ Name: _OVL_XATTR_OPAQUE, Value: "y", }); err != nil { if cleanupErr := vfsObj.RmdirAt(ctx, fs.creds, &pop); cleanupErr != nil { - ctx.Warningf("Unrecoverable overlayfs inconsistency: failed to delete upper layer directory after MkdirAt set-opaque failure: %v", cleanupErr) + panic(fmt.Sprintf("unrecoverable overlayfs inconsistency: failed to delete upper layer directory after MkdirAt set-opaque failure: %v", cleanupErr)) } else { fs.cleanupRecreateWhiteout(ctx, vfsObj, &pop) } @@ -703,7 +723,7 @@ func (fs *filesystem) MknodAt(ctx context.Context, rp *vfs.ResolvingPath, opts v }, }); err != nil { if cleanupErr := vfsObj.UnlinkAt(ctx, fs.creds, &pop); cleanupErr != nil { - ctx.Warningf("Unrecoverable overlayfs inconsistency: failed to delete upper layer file after MknodAt metadata update failure: %v", cleanupErr) + panic(fmt.Sprintf("unrecoverable overlayfs inconsistency: failed to delete upper layer file after MknodAt metadata update failure: %v", cleanupErr)) } else if haveUpperWhiteout { fs.cleanupRecreateWhiteout(ctx, vfsObj, &pop) } @@ -717,17 +737,36 @@ func (fs *filesystem) MknodAt(ctx context.Context, rp *vfs.ResolvingPath, opts v func (fs *filesystem) OpenAt(ctx context.Context, rp *vfs.ResolvingPath, opts vfs.OpenOptions) (*vfs.FileDescription, error) { mayCreate := opts.Flags&linux.O_CREAT != 0 mustCreate := opts.Flags&(linux.O_CREAT|linux.O_EXCL) == (linux.O_CREAT | linux.O_EXCL) + mayWrite := vfs.AccessTypesForOpenFlags(&opts).MayWrite() var ds *[]*dentry fs.renameMu.RLock() - defer fs.renameMuRUnlockAndCheckDrop(ctx, &ds) + unlocked := false + unlock := func() { + if !unlocked { + fs.renameMuRUnlockAndCheckDrop(ctx, &ds) + unlocked = true + } + } + defer unlock() start := rp.Start().Impl().(*dentry) if rp.Done() { + if mayCreate && rp.MustBeDir() { + return nil, syserror.EISDIR + } if mustCreate { return nil, syserror.EEXIST } - return start.openLocked(ctx, rp, &opts) + if mayWrite { + if err := start.copyUpLocked(ctx); err != nil { + return nil, err + } + } + start.IncRef() + defer start.DecRef(ctx) + unlock() + return start.openCopiedUp(ctx, rp, &opts) } afterTrailingSymlink: @@ -739,6 +778,10 @@ afterTrailingSymlink: if err := parent.checkPermissions(rp.Credentials(), vfs.MayExec); err != nil { return nil, err } + // Reject attempts to open directories with O_CREAT. + if mayCreate && rp.MustBeDir() { + return nil, syserror.EISDIR + } // Determine whether or not we need to create a file. parent.dirMu.Lock() child, err := fs.stepLocked(ctx, rp, parent, false /* mayFollowSymlinks */, &ds) @@ -747,12 +790,11 @@ afterTrailingSymlink: parent.dirMu.Unlock() return fd, err } + parent.dirMu.Unlock() if err != nil { - parent.dirMu.Unlock() return nil, err } // Open existing child or follow symlink. - parent.dirMu.Unlock() if mustCreate { return nil, syserror.EEXIST } @@ -767,20 +809,27 @@ afterTrailingSymlink: start = parent goto afterTrailingSymlink } - return child.openLocked(ctx, rp, &opts) + if rp.MustBeDir() && !child.isDir() { + return nil, syserror.ENOTDIR + } + if mayWrite { + if err := child.copyUpLocked(ctx); err != nil { + return nil, err + } + } + child.IncRef() + defer child.DecRef(ctx) + unlock() + return child.openCopiedUp(ctx, rp, &opts) } -// Preconditions: fs.renameMu must be locked. -func (d *dentry) openLocked(ctx context.Context, rp *vfs.ResolvingPath, opts *vfs.OpenOptions) (*vfs.FileDescription, error) { +// Preconditions: If vfs.AccessTypesForOpenFlags(opts).MayWrite(), then d has +// been copied up. +func (d *dentry) openCopiedUp(ctx context.Context, rp *vfs.ResolvingPath, opts *vfs.OpenOptions) (*vfs.FileDescription, error) { ats := vfs.AccessTypesForOpenFlags(opts) if err := d.checkPermissions(rp.Credentials(), ats); err != nil { return nil, err } - if ats.MayWrite() { - if err := d.copyUpLocked(ctx); err != nil { - return nil, err - } - } mnt := rp.Mount() // Directory FDs open FDs from each layer when directory entries are read, @@ -792,7 +841,7 @@ func (d *dentry) openLocked(ctx context.Context, rp *vfs.ResolvingPath, opts *vf return nil, syserror.EISDIR } // Can't open directories writably. - if ats&vfs.MayWrite != 0 { + if ats.MayWrite() { return nil, syserror.EISDIR } if opts.Flags&linux.O_DIRECT != 0 { @@ -831,8 +880,9 @@ func (d *dentry) openLocked(ctx context.Context, rp *vfs.ResolvingPath, opts *vf return &fd.vfsfd, nil } -// Preconditions: parent.dirMu must be locked. parent does not already contain -// a child named rp.Component(). +// Preconditions: +// * parent.dirMu must be locked. +// * parent does not already contain a child named rp.Component(). func (fs *filesystem) createAndOpenLocked(ctx context.Context, rp *vfs.ResolvingPath, parent *dentry, opts *vfs.OpenOptions, ds **[]*dentry) (*vfs.FileDescription, error) { creds := rp.Credentials() if err := parent.checkPermissions(creds, vfs.MayWrite); err != nil { @@ -893,7 +943,7 @@ func (fs *filesystem) createAndOpenLocked(ctx context.Context, rp *vfs.Resolving }, }); err != nil { if cleanupErr := vfsObj.UnlinkAt(ctx, fs.creds, &pop); cleanupErr != nil { - ctx.Warningf("Unrecoverable overlayfs inconsistency: failed to delete upper layer file after OpenAt(O_CREAT) metadata update failure: %v", cleanupErr) + panic(fmt.Sprintf("unrecoverable overlayfs inconsistency: failed to delete upper layer file after OpenAt(O_CREAT) metadata update failure: %v", cleanupErr)) } else if haveUpperWhiteout { fs.cleanupRecreateWhiteout(ctx, vfsObj, &pop) } @@ -904,7 +954,7 @@ func (fs *filesystem) createAndOpenLocked(ctx context.Context, rp *vfs.Resolving child, err := fs.getChildLocked(ctx, parent, childName, ds) if err != nil { if cleanupErr := vfsObj.UnlinkAt(ctx, fs.creds, &pop); cleanupErr != nil { - ctx.Warningf("Unrecoverable overlayfs inconsistency: failed to delete upper layer file after OpenAt(O_CREAT) dentry lookup failure: %v", cleanupErr) + panic(fmt.Sprintf("unrecoverable overlayfs inconsistency: failed to delete upper layer file after OpenAt(O_CREAT) dentry lookup failure: %v", cleanupErr)) } else if haveUpperWhiteout { fs.cleanupRecreateWhiteout(ctx, vfsObj, &pop) } @@ -970,9 +1020,223 @@ func (fs *filesystem) RenameAt(ctx context.Context, rp *vfs.ResolvingPath, oldPa } defer mnt.EndWrite() - // FIXME(gvisor.dev/issue/1199): Actually implement rename. - _ = newParent - return syserror.EXDEV + oldParent := oldParentVD.Dentry().Impl().(*dentry) + creds := rp.Credentials() + if err := oldParent.checkPermissions(creds, vfs.MayWrite|vfs.MayExec); err != nil { + return err + } + // We need a dentry representing the renamed file since, if it's a + // directory, we need to check for write permission on it. + oldParent.dirMu.Lock() + defer oldParent.dirMu.Unlock() + renamed, err := fs.getChildLocked(ctx, oldParent, oldName, &ds) + if err != nil { + return err + } + if err := vfs.CheckDeleteSticky(creds, linux.FileMode(atomic.LoadUint32(&oldParent.mode)), auth.KUID(atomic.LoadUint32(&renamed.uid))); err != nil { + return err + } + if renamed.isDir() { + if renamed == newParent || genericIsAncestorDentry(renamed, newParent) { + return syserror.EINVAL + } + if oldParent != newParent { + if err := renamed.checkPermissions(creds, vfs.MayWrite); err != nil { + return err + } + } + } else { + if opts.MustBeDir || rp.MustBeDir() { + return syserror.ENOTDIR + } + } + + if oldParent != newParent { + if err := newParent.checkPermissions(creds, vfs.MayWrite|vfs.MayExec); err != nil { + return err + } + newParent.dirMu.Lock() + defer newParent.dirMu.Unlock() + } + if newParent.vfsd.IsDead() { + return syserror.ENOENT + } + replacedLayer, err := fs.lookupLayerLocked(ctx, newParent, newName) + if err != nil { + return err + } + var ( + replaced *dentry + replacedVFSD *vfs.Dentry + whiteouts map[string]bool + ) + if replacedLayer.existsInOverlay() { + replaced, err = fs.getChildLocked(ctx, newParent, newName, &ds) + if err != nil { + return err + } + replacedVFSD = &replaced.vfsd + if replaced.isDir() { + if !renamed.isDir() { + return syserror.EISDIR + } + if genericIsAncestorDentry(replaced, renamed) { + return syserror.ENOTEMPTY + } + replaced.dirMu.Lock() + defer replaced.dirMu.Unlock() + whiteouts, err = replaced.collectWhiteoutsForRmdirLocked(ctx) + if err != nil { + return err + } + } else { + if rp.MustBeDir() || renamed.isDir() { + return syserror.ENOTDIR + } + } + } + + if oldParent == newParent && oldName == newName { + return nil + } + + // renamed and oldParent need to be copied-up before they're renamed on the + // upper layer. + if err := renamed.copyUpLocked(ctx); err != nil { + return err + } + // If renamed is a directory, all of its descendants need to be copied-up + // before they're renamed on the upper layer. + if renamed.isDir() { + if err := renamed.copyUpDescendantsLocked(ctx, &ds); err != nil { + return err + } + } + // newParent must be copied-up before it can contain renamed on the upper + // layer. + if err := newParent.copyUpLocked(ctx); err != nil { + return err + } + // If replaced exists, it doesn't need to be copied-up, but we do need to + // serialize with copy-up. Holding renameMu for writing should be + // sufficient, but out of an abundance of caution... + if replaced != nil { + replaced.copyMu.RLock() + defer replaced.copyMu.RUnlock() + } + + vfsObj := rp.VirtualFilesystem() + mntns := vfs.MountNamespaceFromContext(ctx) + defer mntns.DecRef(ctx) + if err := vfsObj.PrepareRenameDentry(mntns, &renamed.vfsd, replacedVFSD); err != nil { + return err + } + + newpop := vfs.PathOperation{ + Root: newParent.upperVD, + Start: newParent.upperVD, + Path: fspath.Parse(newName), + } + + needRecreateWhiteouts := false + cleanupRecreateWhiteouts := func() { + if !needRecreateWhiteouts { + return + } + for whiteoutName, whiteoutUpper := range whiteouts { + if !whiteoutUpper { + continue + } + if err := fs.createWhiteout(ctx, vfsObj, &vfs.PathOperation{ + Root: replaced.upperVD, + Start: replaced.upperVD, + Path: fspath.Parse(whiteoutName), + }); err != nil && err != syserror.EEXIST { + panic(fmt.Sprintf("unrecoverable overlayfs inconsistency: failed to recreate deleted whiteout after RenameAt failure: %v", err)) + } + } + } + if renamed.isDir() { + if replacedLayer == lookupLayerUpper { + // Remove whiteouts from the directory being replaced. + needRecreateWhiteouts = true + for whiteoutName, whiteoutUpper := range whiteouts { + if !whiteoutUpper { + continue + } + if err := vfsObj.UnlinkAt(ctx, fs.creds, &vfs.PathOperation{ + Root: replaced.upperVD, + Start: replaced.upperVD, + Path: fspath.Parse(whiteoutName), + }); err != nil { + cleanupRecreateWhiteouts() + vfsObj.AbortRenameDentry(&renamed.vfsd, replacedVFSD) + return err + } + } + } else if replacedLayer == lookupLayerUpperWhiteout { + // We need to explicitly remove the whiteout since otherwise rename + // on the upper layer will fail with ENOTDIR. + if err := vfsObj.UnlinkAt(ctx, fs.creds, &newpop); err != nil { + vfsObj.AbortRenameDentry(&renamed.vfsd, replacedVFSD) + return err + } + } + } + + // Essentially no gVisor filesystem supports RENAME_WHITEOUT, so just do a + // regular rename and create the whiteout at the origin manually. Unlike + // RENAME_WHITEOUT, this isn't atomic with respect to other users of the + // upper filesystem, but this is already the case for virtually all other + // overlay filesystem operations too. + oldpop := vfs.PathOperation{ + Root: oldParent.upperVD, + Start: oldParent.upperVD, + Path: fspath.Parse(oldName), + } + if err := vfsObj.RenameAt(ctx, creds, &oldpop, &newpop, &opts); err != nil { + cleanupRecreateWhiteouts() + vfsObj.AbortRenameDentry(&renamed.vfsd, replacedVFSD) + return err + } + + // Below this point, the renamed dentry is now at newpop, and anything we + // replaced is gone forever. Commit the rename, update the overlay + // filesystem tree, and abandon attempts to recover from errors. + vfsObj.CommitRenameReplaceDentry(ctx, &renamed.vfsd, replacedVFSD) + delete(oldParent.children, oldName) + if replaced != nil { + ds = appendDentry(ds, replaced) + } + if oldParent != newParent { + newParent.dirents = nil + // This can't drop the last reference on oldParent because one is held + // by oldParentVD, so lock recursion is impossible. + oldParent.DecRef(ctx) + ds = appendDentry(ds, oldParent) + newParent.IncRef() + renamed.parent = newParent + } + renamed.name = newName + if newParent.children == nil { + newParent.children = make(map[string]*dentry) + } + newParent.children[newName] = renamed + oldParent.dirents = nil + + if err := fs.createWhiteout(ctx, vfsObj, &oldpop); err != nil { + panic(fmt.Sprintf("unrecoverable overlayfs inconsistency: failed to create whiteout at origin after RenameAt: %v", err)) + } + if renamed.isDir() { + if err := vfsObj.SetXattrAt(ctx, fs.creds, &newpop, &vfs.SetXattrOptions{ + Name: _OVL_XATTR_OPAQUE, + Value: "y", + }); err != nil { + panic(fmt.Sprintf("unrecoverable overlayfs inconsistency: failed to make renamed directory opaque: %v", err)) + } + } + + return nil } // RmdirAt implements vfs.FilesystemImpl.RmdirAt. @@ -1051,7 +1315,7 @@ func (fs *filesystem) RmdirAt(ctx context.Context, rp *vfs.ResolvingPath) error Start: child.upperVD, Path: fspath.Parse(whiteoutName), }); err != nil && err != syserror.EEXIST { - ctx.Warningf("Unrecoverable overlayfs inconsistency: failed to recreate deleted whiteout after RmdirAt failure: %v", err) + panic(fmt.Sprintf("unrecoverable overlayfs inconsistency: failed to recreate deleted whiteout after RmdirAt failure: %v", err)) } } } @@ -1081,9 +1345,7 @@ func (fs *filesystem) RmdirAt(ctx context.Context, rp *vfs.ResolvingPath) error // Don't attempt to recover from this: the original directory is // already gone, so any dentries representing it are invalid, and // creating a new directory won't undo that. - ctx.Warningf("Unrecoverable overlayfs inconsistency: failed to create whiteout during RmdirAt: %v", err) - vfsObj.AbortDeleteDentry(&child.vfsd) - return err + panic(fmt.Sprintf("unrecoverable overlayfs inconsistency: failed to create whiteout during RmdirAt: %v", err)) } vfsObj.CommitDeleteDentry(ctx, &child.vfsd) @@ -1197,7 +1459,7 @@ func (fs *filesystem) SymlinkAt(ctx context.Context, rp *vfs.ResolvingPath, targ }, }); err != nil { if cleanupErr := vfsObj.UnlinkAt(ctx, fs.creds, &pop); cleanupErr != nil { - ctx.Warningf("Unrecoverable overlayfs inconsistency: failed to delete upper layer file after SymlinkAt metadata update failure: %v", cleanupErr) + panic(fmt.Sprintf("unrecoverable overlayfs inconsistency: failed to delete upper layer file after SymlinkAt metadata update failure: %v", cleanupErr)) } else if haveUpperWhiteout { fs.cleanupRecreateWhiteout(ctx, vfsObj, &pop) } @@ -1290,11 +1552,7 @@ func (fs *filesystem) UnlinkAt(ctx context.Context, rp *vfs.ResolvingPath) error } } if err := fs.createWhiteout(ctx, vfsObj, &pop); err != nil { - ctx.Warningf("Unrecoverable overlayfs inconsistency: failed to create whiteout during UnlinkAt: %v", err) - if child != nil { - vfsObj.AbortDeleteDentry(&child.vfsd) - } - return err + panic(fmt.Sprintf("unrecoverable overlayfs inconsistency: failed to create whiteout during UnlinkAt: %v", err)) } if child != nil { @@ -1306,54 +1564,146 @@ func (fs *filesystem) UnlinkAt(ctx context.Context, rp *vfs.ResolvingPath) error return nil } -// ListxattrAt implements vfs.FilesystemImpl.ListxattrAt. -func (fs *filesystem) ListxattrAt(ctx context.Context, rp *vfs.ResolvingPath, size uint64) ([]string, error) { +// isOverlayXattr returns whether the given extended attribute configures the +// overlay. +func isOverlayXattr(name string) bool { + return strings.HasPrefix(name, _OVL_XATTR_PREFIX) +} + +// ListXattrAt implements vfs.FilesystemImpl.ListXattrAt. +func (fs *filesystem) ListXattrAt(ctx context.Context, rp *vfs.ResolvingPath, size uint64) ([]string, error) { var ds *[]*dentry fs.renameMu.RLock() defer fs.renameMuRUnlockAndCheckDrop(ctx, &ds) - _, err := fs.resolveLocked(ctx, rp, &ds) + d, err := fs.resolveLocked(ctx, rp, &ds) if err != nil { return nil, err } - // TODO(gvisor.dev/issue/1199): Linux overlayfs actually allows listxattr, - // but not any other xattr syscalls. For now we just reject all of them. - return nil, syserror.ENOTSUP + + return fs.listXattr(ctx, d, size) } -// GetxattrAt implements vfs.FilesystemImpl.GetxattrAt. -func (fs *filesystem) GetxattrAt(ctx context.Context, rp *vfs.ResolvingPath, opts vfs.GetxattrOptions) (string, error) { +func (fs *filesystem) listXattr(ctx context.Context, d *dentry, size uint64) ([]string, error) { + vfsObj := d.fs.vfsfs.VirtualFilesystem() + top := d.topLayer() + names, err := vfsObj.ListXattrAt(ctx, fs.creds, &vfs.PathOperation{Root: top, Start: top}, size) + if err != nil { + return nil, err + } + + // Filter out all overlay attributes. + n := 0 + for _, name := range names { + if !isOverlayXattr(name) { + names[n] = name + n++ + } + } + return names[:n], err +} + +// GetXattrAt implements vfs.FilesystemImpl.GetXattrAt. +func (fs *filesystem) GetXattrAt(ctx context.Context, rp *vfs.ResolvingPath, opts vfs.GetXattrOptions) (string, error) { var ds *[]*dentry fs.renameMu.RLock() defer fs.renameMuRUnlockAndCheckDrop(ctx, &ds) - _, err := fs.resolveLocked(ctx, rp, &ds) + d, err := fs.resolveLocked(ctx, rp, &ds) if err != nil { return "", err } - return "", syserror.ENOTSUP + + return fs.getXattr(ctx, d, rp.Credentials(), &opts) } -// SetxattrAt implements vfs.FilesystemImpl.SetxattrAt. -func (fs *filesystem) SetxattrAt(ctx context.Context, rp *vfs.ResolvingPath, opts vfs.SetxattrOptions) error { +func (fs *filesystem) getXattr(ctx context.Context, d *dentry, creds *auth.Credentials, opts *vfs.GetXattrOptions) (string, error) { + if err := d.checkXattrPermissions(creds, opts.Name, vfs.MayRead); err != nil { + return "", err + } + + // Return EOPNOTSUPP when fetching an overlay attribute. + // See fs/overlayfs/super.c:ovl_own_xattr_get(). + if isOverlayXattr(opts.Name) { + return "", syserror.EOPNOTSUPP + } + + // Analogous to fs/overlayfs/super.c:ovl_other_xattr_get(). + vfsObj := d.fs.vfsfs.VirtualFilesystem() + top := d.topLayer() + return vfsObj.GetXattrAt(ctx, fs.creds, &vfs.PathOperation{Root: top, Start: top}, opts) +} + +// SetXattrAt implements vfs.FilesystemImpl.SetXattrAt. +func (fs *filesystem) SetXattrAt(ctx context.Context, rp *vfs.ResolvingPath, opts vfs.SetXattrOptions) error { var ds *[]*dentry fs.renameMu.RLock() defer fs.renameMuRUnlockAndCheckDrop(ctx, &ds) - _, err := fs.resolveLocked(ctx, rp, &ds) + d, err := fs.resolveLocked(ctx, rp, &ds) if err != nil { return err } - return syserror.ENOTSUP + + return fs.setXattrLocked(ctx, d, rp.Mount(), rp.Credentials(), &opts) +} + +// Precondition: fs.renameMu must be locked. +func (fs *filesystem) setXattrLocked(ctx context.Context, d *dentry, mnt *vfs.Mount, creds *auth.Credentials, opts *vfs.SetXattrOptions) error { + if err := d.checkXattrPermissions(creds, opts.Name, vfs.MayWrite); err != nil { + return err + } + + // Return EOPNOTSUPP when setting an overlay attribute. + // See fs/overlayfs/super.c:ovl_own_xattr_set(). + if isOverlayXattr(opts.Name) { + return syserror.EOPNOTSUPP + } + + // Analogous to fs/overlayfs/super.c:ovl_other_xattr_set(). + if err := mnt.CheckBeginWrite(); err != nil { + return err + } + defer mnt.EndWrite() + if err := d.copyUpLocked(ctx); err != nil { + return err + } + vfsObj := d.fs.vfsfs.VirtualFilesystem() + return vfsObj.SetXattrAt(ctx, fs.creds, &vfs.PathOperation{Root: d.upperVD, Start: d.upperVD}, opts) } -// RemovexattrAt implements vfs.FilesystemImpl.RemovexattrAt. -func (fs *filesystem) RemovexattrAt(ctx context.Context, rp *vfs.ResolvingPath, name string) error { +// RemoveXattrAt implements vfs.FilesystemImpl.RemoveXattrAt. +func (fs *filesystem) RemoveXattrAt(ctx context.Context, rp *vfs.ResolvingPath, name string) error { var ds *[]*dentry fs.renameMu.RLock() defer fs.renameMuRUnlockAndCheckDrop(ctx, &ds) - _, err := fs.resolveLocked(ctx, rp, &ds) + d, err := fs.resolveLocked(ctx, rp, &ds) if err != nil { return err } - return syserror.ENOTSUP + + return fs.removeXattrLocked(ctx, d, rp.Mount(), rp.Credentials(), name) +} + +// Precondition: fs.renameMu must be locked. +func (fs *filesystem) removeXattrLocked(ctx context.Context, d *dentry, mnt *vfs.Mount, creds *auth.Credentials, name string) error { + if err := d.checkXattrPermissions(creds, name, vfs.MayWrite); err != nil { + return err + } + + // Like SetXattrAt, return EOPNOTSUPP when removing an overlay attribute. + // Linux passes the remove request to xattr_handler->set. + // See fs/xattr.c:vfs_removexattr(). + if isOverlayXattr(name) { + return syserror.EOPNOTSUPP + } + + if err := mnt.CheckBeginWrite(); err != nil { + return err + } + defer mnt.EndWrite() + if err := d.copyUpLocked(ctx); err != nil { + return err + } + vfsObj := d.fs.vfsfs.VirtualFilesystem() + return vfsObj.RemoveXattrAt(ctx, fs.creds, &vfs.PathOperation{Root: d.upperVD, Start: d.upperVD}, name) } // PrependPath implements vfs.FilesystemImpl.PrependPath. diff --git a/pkg/sentry/fsimpl/overlay/non_directory.go b/pkg/sentry/fsimpl/overlay/non_directory.go index d3060a481..6e04705c7 100644 --- a/pkg/sentry/fsimpl/overlay/non_directory.go +++ b/pkg/sentry/fsimpl/overlay/non_directory.go @@ -23,6 +23,7 @@ import ( "gvisor.dev/gvisor/pkg/sentry/memmap" "gvisor.dev/gvisor/pkg/sentry/vfs" "gvisor.dev/gvisor/pkg/sync" + "gvisor.dev/gvisor/pkg/syserror" "gvisor.dev/gvisor/pkg/usermem" ) @@ -121,7 +122,6 @@ func (fd *nonDirectoryFD) OnClose(ctx context.Context) error { fd.cachedFlags = statusFlags } wrappedFD := fd.cachedFD - defer wrappedFD.IncRef() fd.mu.Unlock() return wrappedFD.OnClose(ctx) } @@ -147,6 +147,16 @@ func (fd *nonDirectoryFD) Stat(ctx context.Context, opts vfs.StatOptions) (linux return stat, nil } +// Allocate implements vfs.FileDescriptionImpl.Allocate. +func (fd *nonDirectoryFD) Allocate(ctx context.Context, mode, offset, length uint64) error { + wrappedFD, err := fd.getCurrentFD(ctx) + if err != nil { + return err + } + defer wrappedFD.DecRef(ctx) + return wrappedFD.Allocate(ctx, mode, offset, length) +} + // SetStat implements vfs.FileDescriptionImpl.SetStat. func (fd *nonDirectoryFD) SetStat(ctx context.Context, opts vfs.SetStatOptions) error { d := fd.dentry() @@ -257,10 +267,105 @@ func (fd *nonDirectoryFD) Sync(ctx context.Context) error { // ConfigureMMap implements vfs.FileDescriptionImpl.ConfigureMMap. func (fd *nonDirectoryFD) ConfigureMMap(ctx context.Context, opts *memmap.MMapOpts) error { - wrappedFD, err := fd.getCurrentFD(ctx) + if err := fd.ensureMappable(ctx, opts); err != nil { + return err + } + return vfs.GenericConfigureMMap(&fd.vfsfd, fd.dentry(), opts) +} + +// ensureMappable ensures that fd.dentry().wrappedMappable is not nil. +func (fd *nonDirectoryFD) ensureMappable(ctx context.Context, opts *memmap.MMapOpts) error { + d := fd.dentry() + + // Fast path if we already have a Mappable for the current top layer. + if atomic.LoadUint32(&d.isMappable) != 0 { + return nil + } + + // Only permit mmap of regular files, since other file types may have + // unpredictable behavior when mmapped (e.g. /dev/zero). + if atomic.LoadUint32(&d.mode)&linux.S_IFMT != linux.S_IFREG { + return syserror.ENODEV + } + + // Get a Mappable for the current top layer. + fd.mu.Lock() + defer fd.mu.Unlock() + d.copyMu.RLock() + defer d.copyMu.RUnlock() + if atomic.LoadUint32(&d.isMappable) != 0 { + return nil + } + wrappedFD, err := fd.currentFDLocked(ctx) if err != nil { return err } - defer wrappedFD.DecRef(ctx) - return wrappedFD.ConfigureMMap(ctx, opts) + if err := wrappedFD.ConfigureMMap(ctx, opts); err != nil { + return err + } + if opts.MappingIdentity != nil { + opts.MappingIdentity.DecRef(ctx) + opts.MappingIdentity = nil + } + // Use this Mappable for all mappings of this layer (unless we raced with + // another call to ensureMappable). + d.mapsMu.Lock() + defer d.mapsMu.Unlock() + d.dataMu.Lock() + defer d.dataMu.Unlock() + if d.wrappedMappable == nil { + d.wrappedMappable = opts.Mappable + atomic.StoreUint32(&d.isMappable, 1) + } + return nil +} + +// AddMapping implements memmap.Mappable.AddMapping. +func (d *dentry) AddMapping(ctx context.Context, ms memmap.MappingSpace, ar usermem.AddrRange, offset uint64, writable bool) error { + d.mapsMu.Lock() + defer d.mapsMu.Unlock() + if err := d.wrappedMappable.AddMapping(ctx, ms, ar, offset, writable); err != nil { + return err + } + if !d.isCopiedUp() { + d.lowerMappings.AddMapping(ms, ar, offset, writable) + } + return nil +} + +// RemoveMapping implements memmap.Mappable.RemoveMapping. +func (d *dentry) RemoveMapping(ctx context.Context, ms memmap.MappingSpace, ar usermem.AddrRange, offset uint64, writable bool) { + d.mapsMu.Lock() + defer d.mapsMu.Unlock() + d.wrappedMappable.RemoveMapping(ctx, ms, ar, offset, writable) + if !d.isCopiedUp() { + d.lowerMappings.RemoveMapping(ms, ar, offset, writable) + } +} + +// CopyMapping implements memmap.Mappable.CopyMapping. +func (d *dentry) CopyMapping(ctx context.Context, ms memmap.MappingSpace, srcAR, dstAR usermem.AddrRange, offset uint64, writable bool) error { + d.mapsMu.Lock() + defer d.mapsMu.Unlock() + if err := d.wrappedMappable.CopyMapping(ctx, ms, srcAR, dstAR, offset, writable); err != nil { + return err + } + if !d.isCopiedUp() { + d.lowerMappings.AddMapping(ms, dstAR, offset, writable) + } + return nil +} + +// Translate implements memmap.Mappable.Translate. +func (d *dentry) Translate(ctx context.Context, required, optional memmap.MappableRange, at usermem.AccessType) ([]memmap.Translation, error) { + d.dataMu.RLock() + defer d.dataMu.RUnlock() + return d.wrappedMappable.Translate(ctx, required, optional, at) +} + +// InvalidateUnsavable implements memmap.Mappable.InvalidateUnsavable. +func (d *dentry) InvalidateUnsavable(ctx context.Context) error { + d.mapsMu.Lock() + defer d.mapsMu.Unlock() + return d.wrappedMappable.InvalidateUnsavable(ctx) } diff --git a/pkg/sentry/fsimpl/overlay/overlay.go b/pkg/sentry/fsimpl/overlay/overlay.go index 75cc006bf..d0d26185e 100644 --- a/pkg/sentry/fsimpl/overlay/overlay.go +++ b/pkg/sentry/fsimpl/overlay/overlay.go @@ -22,6 +22,10 @@ // filesystem.renameMu // dentry.dirMu // dentry.copyMu +// *** "memmap.Mappable locks" below this point +// dentry.mapsMu +// *** "memmap.Mappable locks taken by Translate" below this point +// dentry.dataMu // // Locking dentry.dirMu in multiple dentries requires that parent dentries are // locked before child dentries, and that filesystem.renameMu is locked to @@ -37,6 +41,7 @@ import ( "gvisor.dev/gvisor/pkg/fspath" fslock "gvisor.dev/gvisor/pkg/sentry/fs/lock" "gvisor.dev/gvisor/pkg/sentry/kernel/auth" + "gvisor.dev/gvisor/pkg/sentry/memmap" "gvisor.dev/gvisor/pkg/sentry/vfs" "gvisor.dev/gvisor/pkg/sync" "gvisor.dev/gvisor/pkg/syserror" @@ -106,16 +111,16 @@ func (fstype FilesystemType) GetFilesystem(ctx context.Context, vfsObj *vfs.Virt fsoptsRaw := opts.InternalData fsopts, haveFSOpts := fsoptsRaw.(FilesystemOptions) if fsoptsRaw != nil && !haveFSOpts { - ctx.Warningf("overlay.FilesystemType.GetFilesystem: GetFilesystemOptions.InternalData has type %T, wanted overlay.FilesystemOptions or nil", fsoptsRaw) + ctx.Infof("overlay.FilesystemType.GetFilesystem: GetFilesystemOptions.InternalData has type %T, wanted overlay.FilesystemOptions or nil", fsoptsRaw) return nil, nil, syserror.EINVAL } if haveFSOpts { if len(fsopts.LowerRoots) == 0 { - ctx.Warningf("overlay.FilesystemType.GetFilesystem: LowerRoots must be non-empty") + ctx.Infof("overlay.FilesystemType.GetFilesystem: LowerRoots must be non-empty") return nil, nil, syserror.EINVAL } if len(fsopts.LowerRoots) < 2 && !fsopts.UpperRoot.Ok() { - ctx.Warningf("overlay.FilesystemType.GetFilesystem: at least two LowerRoots are required when UpperRoot is unspecified") + ctx.Infof("overlay.FilesystemType.GetFilesystem: at least two LowerRoots are required when UpperRoot is unspecified") return nil, nil, syserror.EINVAL } // We don't enforce a maximum number of lower layers when not @@ -132,7 +137,7 @@ func (fstype FilesystemType) GetFilesystem(ctx context.Context, vfsObj *vfs.Virt delete(mopts, "workdir") upperPath := fspath.Parse(upperPathname) if !upperPath.Absolute { - ctx.Warningf("overlay.FilesystemType.GetFilesystem: upperdir %q must be absolute", upperPathname) + ctx.Infof("overlay.FilesystemType.GetFilesystem: upperdir %q must be absolute", upperPathname) return nil, nil, syserror.EINVAL } upperRoot, err := vfsObj.GetDentryAt(ctx, creds, &vfs.PathOperation{ @@ -144,13 +149,13 @@ func (fstype FilesystemType) GetFilesystem(ctx context.Context, vfsObj *vfs.Virt CheckSearchable: true, }) if err != nil { - ctx.Warningf("overlay.FilesystemType.GetFilesystem: failed to resolve upperdir %q: %v", upperPathname, err) + ctx.Infof("overlay.FilesystemType.GetFilesystem: failed to resolve upperdir %q: %v", upperPathname, err) return nil, nil, err } defer upperRoot.DecRef(ctx) privateUpperRoot, err := clonePrivateMount(vfsObj, upperRoot, false /* forceReadOnly */) if err != nil { - ctx.Warningf("overlay.FilesystemType.GetFilesystem: failed to make private bind mount of upperdir %q: %v", upperPathname, err) + ctx.Infof("overlay.FilesystemType.GetFilesystem: failed to make private bind mount of upperdir %q: %v", upperPathname, err) return nil, nil, err } defer privateUpperRoot.DecRef(ctx) @@ -158,24 +163,24 @@ func (fstype FilesystemType) GetFilesystem(ctx context.Context, vfsObj *vfs.Virt } lowerPathnamesStr, ok := mopts["lowerdir"] if !ok { - ctx.Warningf("overlay.FilesystemType.GetFilesystem: missing required option lowerdir") + ctx.Infof("overlay.FilesystemType.GetFilesystem: missing required option lowerdir") return nil, nil, syserror.EINVAL } delete(mopts, "lowerdir") lowerPathnames := strings.Split(lowerPathnamesStr, ":") const maxLowerLayers = 500 // Linux: fs/overlay/super.c:OVL_MAX_STACK if len(lowerPathnames) < 2 && !fsopts.UpperRoot.Ok() { - ctx.Warningf("overlay.FilesystemType.GetFilesystem: at least two lowerdirs are required when upperdir is unspecified") + ctx.Infof("overlay.FilesystemType.GetFilesystem: at least two lowerdirs are required when upperdir is unspecified") return nil, nil, syserror.EINVAL } if len(lowerPathnames) > maxLowerLayers { - ctx.Warningf("overlay.FilesystemType.GetFilesystem: %d lowerdirs specified, maximum %d", len(lowerPathnames), maxLowerLayers) + ctx.Infof("overlay.FilesystemType.GetFilesystem: %d lowerdirs specified, maximum %d", len(lowerPathnames), maxLowerLayers) return nil, nil, syserror.EINVAL } for _, lowerPathname := range lowerPathnames { lowerPath := fspath.Parse(lowerPathname) if !lowerPath.Absolute { - ctx.Warningf("overlay.FilesystemType.GetFilesystem: lowerdir %q must be absolute", lowerPathname) + ctx.Infof("overlay.FilesystemType.GetFilesystem: lowerdir %q must be absolute", lowerPathname) return nil, nil, syserror.EINVAL } lowerRoot, err := vfsObj.GetDentryAt(ctx, creds, &vfs.PathOperation{ @@ -187,13 +192,13 @@ func (fstype FilesystemType) GetFilesystem(ctx context.Context, vfsObj *vfs.Virt CheckSearchable: true, }) if err != nil { - ctx.Warningf("overlay.FilesystemType.GetFilesystem: failed to resolve lowerdir %q: %v", lowerPathname, err) + ctx.Infof("overlay.FilesystemType.GetFilesystem: failed to resolve lowerdir %q: %v", lowerPathname, err) return nil, nil, err } defer lowerRoot.DecRef(ctx) privateLowerRoot, err := clonePrivateMount(vfsObj, lowerRoot, true /* forceReadOnly */) if err != nil { - ctx.Warningf("overlay.FilesystemType.GetFilesystem: failed to make private bind mount of lowerdir %q: %v", lowerPathname, err) + ctx.Infof("overlay.FilesystemType.GetFilesystem: failed to make private bind mount of lowerdir %q: %v", lowerPathname, err) return nil, nil, err } defer privateLowerRoot.DecRef(ctx) @@ -201,7 +206,7 @@ func (fstype FilesystemType) GetFilesystem(ctx context.Context, vfsObj *vfs.Virt } } if len(mopts) != 0 { - ctx.Warningf("overlay.FilesystemType.GetFilesystem: unused options: %v", mopts) + ctx.Infof("overlay.FilesystemType.GetFilesystem: unused options: %v", mopts) return nil, nil, syserror.EINVAL } @@ -274,7 +279,7 @@ func (fstype FilesystemType) GetFilesystem(ctx context.Context, vfsObj *vfs.Virt return nil, nil, syserror.EREMOTE } if isWhiteout(&rootStat) { - ctx.Warningf("overlay.FilesystemType.GetFilesystem: filesystem root is a whiteout") + ctx.Infof("overlay.FilesystemType.GetFilesystem: filesystem root is a whiteout") root.destroyLocked(ctx) fs.vfsfs.DecRef(ctx) return nil, nil, syserror.EINVAL @@ -315,7 +320,11 @@ func clonePrivateMount(vfsObj *vfs.VirtualFilesystem, vd vfs.VirtualDentry, forc if err != nil { return vfs.VirtualDentry{}, err } - return vfs.MakeVirtualDentry(newmnt, vd.Dentry()), nil + // Take a reference on the dentry which will be owned by the returned + // VirtualDentry. + d := vd.Dentry() + d.IncRef() + return vfs.MakeVirtualDentry(newmnt, d), nil } // Release implements vfs.FilesystemImpl.Release. @@ -415,6 +424,35 @@ type dentry struct { devMinor uint32 ino uint64 + // If this dentry represents a regular file, then: + // + // - mapsMu is used to synchronize between copy-up and memmap.Mappable + // methods on dentry preceding mm.MemoryManager.activeMu in the lock order. + // + // - dataMu is used to synchronize between copy-up and + // dentry.(memmap.Mappable).Translate. + // + // - lowerMappings tracks memory mappings of the file. lowerMappings is + // used to invalidate mappings of the lower layer when the file is copied + // up to ensure that they remain coherent with subsequent writes to the + // file. (Note that, as of this writing, Linux overlayfs does not do this; + // this feature is a gVisor extension.) lowerMappings is protected by + // mapsMu. + // + // - If this dentry is copied-up, then wrappedMappable is the Mappable + // obtained from a call to the current top layer's + // FileDescription.ConfigureMMap(). Once wrappedMappable becomes non-nil + // (from a call to nonDirectoryFD.ensureMappable()), it cannot become nil. + // wrappedMappable is protected by mapsMu and dataMu. + // + // - isMappable is non-zero iff wrappedMappable is non-nil. isMappable is + // accessed using atomic memory operations. + mapsMu sync.Mutex + lowerMappings memmap.MappingSet + dataMu sync.RWMutex + wrappedMappable memmap.Mappable + isMappable uint32 + locks vfs.FileLocks } @@ -482,7 +520,9 @@ func (d *dentry) checkDropLocked(ctx context.Context) { // destroyLocked destroys the dentry. // -// Preconditions: d.fs.renameMu must be locked for writing. d.refs == 0. +// Preconditions: +// * d.fs.renameMu must be locked for writing. +// * d.refs == 0. func (d *dentry) destroyLocked(ctx context.Context) { switch atomic.LoadInt64(&d.refs) { case 0: @@ -564,6 +604,16 @@ func (d *dentry) checkPermissions(creds *auth.Credentials, ats vfs.AccessTypes) return vfs.GenericCheckPermissions(creds, ats, linux.FileMode(atomic.LoadUint32(&d.mode)), auth.KUID(atomic.LoadUint32(&d.uid)), auth.KGID(atomic.LoadUint32(&d.gid))) } +func (d *dentry) checkXattrPermissions(creds *auth.Credentials, name string, ats vfs.AccessTypes) error { + mode := linux.FileMode(atomic.LoadUint32(&d.mode)) + kuid := auth.KUID(atomic.LoadUint32(&d.uid)) + kgid := auth.KGID(atomic.LoadUint32(&d.gid)) + if err := vfs.GenericCheckPermissions(creds, ats, mode, kuid, kgid); err != nil { + return err + } + return vfs.CheckXattrPermissions(creds, ats, mode, kuid, name) +} + // statInternalMask is the set of stat fields that is set by // dentry.statInternalTo(). const statInternalMask = linux.STATX_TYPE | linux.STATX_MODE | linux.STATX_UID | linux.STATX_GID | linux.STATX_INO @@ -616,6 +666,32 @@ func (fd *fileDescription) dentry() *dentry { return fd.vfsfd.Dentry().Impl().(*dentry) } +// ListXattr implements vfs.FileDescriptionImpl.ListXattr. +func (fd *fileDescription) ListXattr(ctx context.Context, size uint64) ([]string, error) { + return fd.filesystem().listXattr(ctx, fd.dentry(), size) +} + +// GetXattr implements vfs.FileDescriptionImpl.GetXattr. +func (fd *fileDescription) GetXattr(ctx context.Context, opts vfs.GetXattrOptions) (string, error) { + return fd.filesystem().getXattr(ctx, fd.dentry(), auth.CredentialsFromContext(ctx), &opts) +} + +// SetXattr implements vfs.FileDescriptionImpl.SetXattr. +func (fd *fileDescription) SetXattr(ctx context.Context, opts vfs.SetXattrOptions) error { + fs := fd.filesystem() + fs.renameMu.RLock() + defer fs.renameMu.RUnlock() + return fs.setXattrLocked(ctx, fd.dentry(), fd.vfsfd.Mount(), auth.CredentialsFromContext(ctx), &opts) +} + +// RemoveXattr implements vfs.FileDescriptionImpl.RemoveXattr. +func (fd *fileDescription) RemoveXattr(ctx context.Context, name string) error { + fs := fd.filesystem() + fs.renameMu.RLock() + defer fs.renameMu.RUnlock() + return fs.removeXattrLocked(ctx, fd.dentry(), fd.vfsfd.Mount(), auth.CredentialsFromContext(ctx), name) +} + // LockPOSIX implements vfs.FileDescriptionImpl.LockPOSIX. func (fd *fileDescription) LockPOSIX(ctx context.Context, uid fslock.UniqueID, t fslock.LockType, start, length uint64, whence int16, block fslock.Blocker) error { return fd.Locks().LockPOSIX(ctx, &fd.vfsfd, uid, t, start, length, whence, block) diff --git a/pkg/sentry/fsimpl/pipefs/pipefs.go b/pkg/sentry/fsimpl/pipefs/pipefs.go index 2ca793db9..7053ad6db 100644 --- a/pkg/sentry/fsimpl/pipefs/pipefs.go +++ b/pkg/sentry/fsimpl/pipefs/pipefs.go @@ -143,14 +143,16 @@ func (i *inode) SetStat(ctx context.Context, vfsfs *vfs.Filesystem, creds *auth. return syserror.EPERM } -// TODO(gvisor.dev/issue/1193): kernfs does not provide a way to implement -// statfs, from which we should indicate PIPEFS_MAGIC. - // Open implements kernfs.Inode.Open. func (i *inode) Open(ctx context.Context, rp *vfs.ResolvingPath, vfsd *vfs.Dentry, opts vfs.OpenOptions) (*vfs.FileDescription, error) { return i.pipe.Open(ctx, rp.Mount(), vfsd, opts.Flags, &i.locks) } +// StatFS implements kernfs.Inode.StatFS. +func (i *inode) StatFS(ctx context.Context, fs *vfs.Filesystem) (linux.Statfs, error) { + return vfs.GenericStatFS(linux.PIPEFS_MAGIC), nil +} + // NewConnectedPipeFDs returns a pair of FileDescriptions representing the read // and write ends of a newly-created pipe, as for pipe(2) and pipe2(2). // diff --git a/pkg/sentry/fsimpl/proc/BUILD b/pkg/sentry/fsimpl/proc/BUILD index f074e6056..2e086e34c 100644 --- a/pkg/sentry/fsimpl/proc/BUILD +++ b/pkg/sentry/fsimpl/proc/BUILD @@ -1,18 +1,79 @@ load("//tools:defs.bzl", "go_library", "go_test") +load("//tools/go_generics:defs.bzl", "go_template_instance") licenses(["notice"]) +go_template_instance( + name = "fd_dir_inode_refs", + out = "fd_dir_inode_refs.go", + package = "proc", + prefix = "fdDirInode", + template = "//pkg/refs_vfs2:refs_template", + types = { + "T": "fdDirInode", + }, +) + +go_template_instance( + name = "fd_info_dir_inode_refs", + out = "fd_info_dir_inode_refs.go", + package = "proc", + prefix = "fdInfoDirInode", + template = "//pkg/refs_vfs2:refs_template", + types = { + "T": "fdInfoDirInode", + }, +) + +go_template_instance( + name = "subtasks_inode_refs", + out = "subtasks_inode_refs.go", + package = "proc", + prefix = "subtasksInode", + template = "//pkg/refs_vfs2:refs_template", + types = { + "T": "subtasksInode", + }, +) + +go_template_instance( + name = "task_inode_refs", + out = "task_inode_refs.go", + package = "proc", + prefix = "taskInode", + template = "//pkg/refs_vfs2:refs_template", + types = { + "T": "taskInode", + }, +) + +go_template_instance( + name = "tasks_inode_refs", + out = "tasks_inode_refs.go", + package = "proc", + prefix = "tasksInode", + template = "//pkg/refs_vfs2:refs_template", + types = { + "T": "tasksInode", + }, +) + go_library( name = "proc", srcs = [ + "fd_dir_inode_refs.go", + "fd_info_dir_inode_refs.go", "filesystem.go", "subtasks.go", + "subtasks_inode_refs.go", "task.go", "task_fds.go", "task_files.go", + "task_inode_refs.go", "task_net.go", "tasks.go", "tasks_files.go", + "tasks_inode_refs.go", "tasks_sys.go", ], visibility = ["//pkg/sentry:internal"], @@ -36,6 +97,7 @@ go_library( "//pkg/sentry/socket/unix/transport", "//pkg/sentry/usage", "//pkg/sentry/vfs", + "//pkg/sync", "//pkg/syserror", "//pkg/tcpip/header", "//pkg/tcpip/network/ipv4", diff --git a/pkg/sentry/fsimpl/proc/filesystem.go b/pkg/sentry/fsimpl/proc/filesystem.go index 2463d51cd..03b5941b9 100644 --- a/pkg/sentry/fsimpl/proc/filesystem.go +++ b/pkg/sentry/fsimpl/proc/filesystem.go @@ -110,8 +110,21 @@ func newStaticFile(data string) *staticFile { return &staticFile{StaticData: vfs.StaticData{Data: data}} } +func newStaticDir(creds *auth.Credentials, devMajor, devMinor uint32, ino uint64, perm linux.FileMode, children map[string]*kernfs.Dentry) *kernfs.Dentry { + return kernfs.NewStaticDir(creds, devMajor, devMinor, ino, perm, children, kernfs.GenericDirectoryFDOptions{ + SeekEnd: kernfs.SeekEndZero, + }) +} + // InternalData contains internal data passed in to the procfs mount via // vfs.GetFilesystemOptions.InternalData. type InternalData struct { Cgroups map[string]string } + +type implStatFS struct{} + +// StatFS implements kernfs.Inode.StatFS. +func (*implStatFS) StatFS(context.Context, *vfs.Filesystem) (linux.Statfs, error) { + return vfs.GenericStatFS(linux.PROC_SUPER_MAGIC), nil +} diff --git a/pkg/sentry/fsimpl/proc/subtasks.go b/pkg/sentry/fsimpl/proc/subtasks.go index 79c2725f3..57f026040 100644 --- a/pkg/sentry/fsimpl/proc/subtasks.go +++ b/pkg/sentry/fsimpl/proc/subtasks.go @@ -31,11 +31,13 @@ import ( // // +stateify savable type subtasksInode struct { - kernfs.InodeNotSymlink - kernfs.InodeDirectoryNoNewChildren + implStatFS + kernfs.AlwaysValid kernfs.InodeAttrs + kernfs.InodeDirectoryNoNewChildren + kernfs.InodeNotSymlink kernfs.OrderedChildren - kernfs.AlwaysValid + subtasksInodeRefs locks vfs.FileLocks @@ -57,6 +59,7 @@ func (fs *filesystem) newSubtasks(task *kernel.Task, pidns *kernel.PIDNamespace, // Note: credentials are overridden by taskOwnedInode. subInode.InodeAttrs.Init(task.Credentials(), linux.UNNAMED_MAJOR, fs.devMinor, fs.NextIno(), linux.ModeDirectory|0555) subInode.OrderedChildren.Init(kernfs.OrderedChildrenOptions{}) + subInode.EnableLeakCheck() inode := &taskOwnedInode{Inode: subInode, owner: task} dentry := &kernfs.Dentry{} @@ -65,7 +68,7 @@ func (fs *filesystem) newSubtasks(task *kernel.Task, pidns *kernel.PIDNamespace, return dentry } -// Lookup implements kernfs.inodeDynamicLookup. +// Lookup implements kernfs.inodeDynamicLookup.Lookup. func (i *subtasksInode) Lookup(ctx context.Context, name string) (*vfs.Dentry, error) { tid, err := strconv.ParseUint(name, 10, 32) if err != nil { @@ -84,7 +87,7 @@ func (i *subtasksInode) Lookup(ctx context.Context, name string) (*vfs.Dentry, e return subTaskDentry.VFSDentry(), nil } -// IterDirents implements kernfs.inodeDynamicLookup. +// IterDirents implements kernfs.inodeDynamicLookup.IterDirents. func (i *subtasksInode) IterDirents(ctx context.Context, cb vfs.IterDirentsCallback, offset, relOffset int64) (int64, error) { tasks := i.task.ThreadGroup().MemberIDs(i.pidns) if len(tasks) == 0 { @@ -152,10 +155,12 @@ func (fd *subtasksFD) SetStat(ctx context.Context, opts vfs.SetStatOptions) erro return fd.GenericDirectoryFD.SetStat(ctx, opts) } -// Open implements kernfs.Inode. +// Open implements kernfs.Inode.Open. func (i *subtasksInode) Open(ctx context.Context, rp *vfs.ResolvingPath, vfsd *vfs.Dentry, opts vfs.OpenOptions) (*vfs.FileDescription, error) { fd := &subtasksFD{task: i.task} - if err := fd.Init(&i.OrderedChildren, &i.locks, &opts); err != nil { + if err := fd.Init(&i.OrderedChildren, &i.locks, &opts, kernfs.GenericDirectoryFDOptions{ + SeekEnd: kernfs.SeekEndZero, + }); err != nil { return nil, err } if err := fd.VFSFileDescription().Init(fd, opts.Flags, rp.Mount(), vfsd, &vfs.FileDescriptionOptions{}); err != nil { @@ -164,7 +169,7 @@ func (i *subtasksInode) Open(ctx context.Context, rp *vfs.ResolvingPath, vfsd *v return fd.VFSFileDescription(), nil } -// Stat implements kernfs.Inode. +// Stat implements kernfs.Inode.Stat. func (i *subtasksInode) Stat(ctx context.Context, vsfs *vfs.Filesystem, opts vfs.StatOptions) (linux.Statx, error) { stat, err := i.InodeAttrs.Stat(ctx, vsfs, opts) if err != nil { @@ -176,7 +181,12 @@ func (i *subtasksInode) Stat(ctx context.Context, vsfs *vfs.Filesystem, opts vfs return stat, nil } -// SetStat implements Inode.SetStat not allowing inode attributes to be changed. +// SetStat implements kernfs.Inode.SetStat not allowing inode attributes to be changed. func (*subtasksInode) SetStat(context.Context, *vfs.Filesystem, *auth.Credentials, vfs.SetStatOptions) error { return syserror.EPERM } + +// DecRef implements kernfs.Inode.DecRef. +func (i *subtasksInode) DecRef(context.Context) { + i.subtasksInodeRefs.DecRef(i.Destroy) +} diff --git a/pkg/sentry/fsimpl/proc/task.go b/pkg/sentry/fsimpl/proc/task.go index a5c7aa470..e24c8a031 100644 --- a/pkg/sentry/fsimpl/proc/task.go +++ b/pkg/sentry/fsimpl/proc/task.go @@ -32,11 +32,13 @@ import ( // // +stateify savable type taskInode struct { - kernfs.InodeNotSymlink + implStatFS + kernfs.InodeAttrs kernfs.InodeDirectoryNoNewChildren kernfs.InodeNoDynamicLookup - kernfs.InodeAttrs + kernfs.InodeNotSymlink kernfs.OrderedChildren + taskInodeRefs locks vfs.FileLocks @@ -84,6 +86,7 @@ func (fs *filesystem) newTaskInode(task *kernel.Task, pidns *kernel.PIDNamespace taskInode := &taskInode{task: task} // Note: credentials are overridden by taskOwnedInode. taskInode.InodeAttrs.Init(task.Credentials(), linux.UNNAMED_MAJOR, fs.devMinor, fs.NextIno(), linux.ModeDirectory|0555) + taskInode.EnableLeakCheck() inode := &taskOwnedInode{Inode: taskInode, owner: task} dentry := &kernfs.Dentry{} @@ -103,20 +106,27 @@ func (i *taskInode) Valid(ctx context.Context) bool { return i.task.ExitState() != kernel.TaskExitDead } -// Open implements kernfs.Inode. +// Open implements kernfs.Inode.Open. func (i *taskInode) Open(ctx context.Context, rp *vfs.ResolvingPath, vfsd *vfs.Dentry, opts vfs.OpenOptions) (*vfs.FileDescription, error) { - fd, err := kernfs.NewGenericDirectoryFD(rp.Mount(), vfsd, &i.OrderedChildren, &i.locks, &opts) + fd, err := kernfs.NewGenericDirectoryFD(rp.Mount(), vfsd, &i.OrderedChildren, &i.locks, &opts, kernfs.GenericDirectoryFDOptions{ + SeekEnd: kernfs.SeekEndZero, + }) if err != nil { return nil, err } return fd.VFSFileDescription(), nil } -// SetStat implements Inode.SetStat not allowing inode attributes to be changed. +// SetStat implements kernfs.Inode.SetStat not allowing inode attributes to be changed. func (*taskInode) SetStat(context.Context, *vfs.Filesystem, *auth.Credentials, vfs.SetStatOptions) error { return syserror.EPERM } +// DecRef implements kernfs.Inode.DecRef. +func (i *taskInode) DecRef(context.Context) { + i.taskInodeRefs.DecRef(i.Destroy) +} + // taskOwnedInode implements kernfs.Inode and overrides inode owner with task // effective user and group. type taskOwnedInode struct { @@ -142,7 +152,10 @@ func (fs *filesystem) newTaskOwnedDir(task *kernel.Task, ino uint64, perm linux. dir := &kernfs.StaticDirectory{} // Note: credentials are overridden by taskOwnedInode. - dir.Init(task.Credentials(), linux.UNNAMED_MAJOR, fs.devMinor, ino, perm) + dir.Init(task.Credentials(), linux.UNNAMED_MAJOR, fs.devMinor, ino, perm, kernfs.GenericDirectoryFDOptions{ + SeekEnd: kernfs.SeekEndZero, + }) + dir.EnableLeakCheck() inode := &taskOwnedInode{Inode: dir, owner: task} d := &kernfs.Dentry{} @@ -155,7 +168,7 @@ func (fs *filesystem) newTaskOwnedDir(task *kernel.Task, ino uint64, perm linux. return d } -// Stat implements kernfs.Inode. +// Stat implements kernfs.Inode.Stat. func (i *taskOwnedInode) Stat(ctx context.Context, fs *vfs.Filesystem, opts vfs.StatOptions) (linux.Statx, error) { stat, err := i.Inode.Stat(ctx, fs, opts) if err != nil { @@ -173,7 +186,7 @@ func (i *taskOwnedInode) Stat(ctx context.Context, fs *vfs.Filesystem, opts vfs. return stat, nil } -// CheckPermissions implements kernfs.Inode. +// CheckPermissions implements kernfs.Inode.CheckPermissions. func (i *taskOwnedInode) CheckPermissions(_ context.Context, creds *auth.Credentials, ats vfs.AccessTypes) error { mode := i.Mode() uid, gid := i.getOwner(mode) diff --git a/pkg/sentry/fsimpl/proc/task_fds.go b/pkg/sentry/fsimpl/proc/task_fds.go index f0d3f7f5e..c492bcfa7 100644 --- a/pkg/sentry/fsimpl/proc/task_fds.go +++ b/pkg/sentry/fsimpl/proc/task_fds.go @@ -22,7 +22,6 @@ import ( "gvisor.dev/gvisor/pkg/abi/linux" "gvisor.dev/gvisor/pkg/context" - "gvisor.dev/gvisor/pkg/refs" "gvisor.dev/gvisor/pkg/sentry/fsimpl/kernfs" "gvisor.dev/gvisor/pkg/sentry/kernel" "gvisor.dev/gvisor/pkg/sentry/kernel/auth" @@ -63,7 +62,7 @@ type fdDir struct { produceSymlink bool } -// IterDirents implements kernfs.inodeDynamicLookup. +// IterDirents implements kernfs.inodeDynamicLookup.IterDirents. func (i *fdDir) IterDirents(ctx context.Context, cb vfs.IterDirentsCallback, offset, relOffset int64) (int64, error) { var fds []int32 i.task.WithMuLocked(func(t *kernel.Task) { @@ -87,26 +86,33 @@ func (i *fdDir) IterDirents(ctx context.Context, cb vfs.IterDirentsCallback, off Name: strconv.FormatUint(uint64(fd), 10), Type: typ, Ino: i.fs.NextIno(), - NextOff: offset + 1, + NextOff: int64(fd) + 3, } if err := cb.Handle(dirent); err != nil { - return offset, err + // Getdents should iterate correctly despite mutation + // of fds, so we return the next fd to serialize plus + // 2 (which accounts for the "." and ".." tracked by + // kernfs) as the offset. + return int64(fd) + 2, err } - offset++ } - return offset, nil + // We serialized them all. Next offset should be higher than last + // serialized fd. + return int64(fds[len(fds)-1]) + 3, nil } // fdDirInode represents the inode for /proc/[pid]/fd directory. // // +stateify savable type fdDirInode struct { - kernfs.InodeNotSymlink - kernfs.InodeDirectoryNoNewChildren + fdDir + fdDirInodeRefs + implStatFS + kernfs.AlwaysValid kernfs.InodeAttrs + kernfs.InodeDirectoryNoNewChildren + kernfs.InodeNotSymlink kernfs.OrderedChildren - kernfs.AlwaysValid - fdDir } var _ kernfs.Inode = (*fdDirInode)(nil) @@ -120,6 +126,7 @@ func (fs *filesystem) newFDDirInode(task *kernel.Task) *kernfs.Dentry { }, } inode.InodeAttrs.Init(task.Credentials(), linux.UNNAMED_MAJOR, fs.devMinor, fs.NextIno(), linux.ModeDirectory|0555) + inode.EnableLeakCheck() dentry := &kernfs.Dentry{} dentry.Init(inode) @@ -128,7 +135,7 @@ func (fs *filesystem) newFDDirInode(task *kernel.Task) *kernfs.Dentry { return dentry } -// Lookup implements kernfs.inodeDynamicLookup. +// Lookup implements kernfs.inodeDynamicLookup.Lookup. func (i *fdDirInode) Lookup(ctx context.Context, name string) (*vfs.Dentry, error) { fdInt, err := strconv.ParseInt(name, 10, 32) if err != nil { @@ -142,16 +149,18 @@ func (i *fdDirInode) Lookup(ctx context.Context, name string) (*vfs.Dentry, erro return taskDentry.VFSDentry(), nil } -// Open implements kernfs.Inode. +// Open implements kernfs.Inode.Open. func (i *fdDirInode) Open(ctx context.Context, rp *vfs.ResolvingPath, vfsd *vfs.Dentry, opts vfs.OpenOptions) (*vfs.FileDescription, error) { - fd, err := kernfs.NewGenericDirectoryFD(rp.Mount(), vfsd, &i.OrderedChildren, &i.locks, &opts) + fd, err := kernfs.NewGenericDirectoryFD(rp.Mount(), vfsd, &i.OrderedChildren, &i.locks, &opts, kernfs.GenericDirectoryFDOptions{ + SeekEnd: kernfs.SeekEndZero, + }) if err != nil { return nil, err } return fd.VFSFileDescription(), nil } -// CheckPermissions implements kernfs.Inode. +// CheckPermissions implements kernfs.Inode.CheckPermissions. // // This is to match Linux, which uses a special permission handler to guarantee // that a process can still access /proc/self/fd after it has executed @@ -173,10 +182,16 @@ func (i *fdDirInode) CheckPermissions(ctx context.Context, creds *auth.Credentia return err } +// DecRef implements kernfs.Inode.DecRef. +func (i *fdDirInode) DecRef(context.Context) { + i.fdDirInodeRefs.DecRef(i.Destroy) +} + // fdSymlink is an symlink for the /proc/[pid]/fd/[fd] file. // // +stateify savable type fdSymlink struct { + implStatFS kernfs.InodeAttrs kernfs.InodeNoopRefCount kernfs.InodeSymlink @@ -199,7 +214,7 @@ func (fs *filesystem) newFDSymlink(task *kernel.Task, fd int32, ino uint64) *ker return d } -func (s *fdSymlink) Readlink(ctx context.Context) (string, error) { +func (s *fdSymlink) Readlink(ctx context.Context, _ *vfs.Mount) (string, error) { file, _ := getTaskFD(s.task, s.fd) if file == nil { return "", syserror.ENOENT @@ -225,12 +240,14 @@ func (s *fdSymlink) Getlink(ctx context.Context, mnt *vfs.Mount) (vfs.VirtualDen // // +stateify savable type fdInfoDirInode struct { - kernfs.InodeNotSymlink - kernfs.InodeDirectoryNoNewChildren + fdDir + fdInfoDirInodeRefs + implStatFS + kernfs.AlwaysValid kernfs.InodeAttrs + kernfs.InodeDirectoryNoNewChildren + kernfs.InodeNotSymlink kernfs.OrderedChildren - kernfs.AlwaysValid - fdDir } var _ kernfs.Inode = (*fdInfoDirInode)(nil) @@ -243,6 +260,7 @@ func (fs *filesystem) newFDInfoDirInode(task *kernel.Task) *kernfs.Dentry { }, } inode.InodeAttrs.Init(task.Credentials(), linux.UNNAMED_MAJOR, fs.devMinor, fs.NextIno(), linux.ModeDirectory|0555) + inode.EnableLeakCheck() dentry := &kernfs.Dentry{} dentry.Init(inode) @@ -251,7 +269,7 @@ func (fs *filesystem) newFDInfoDirInode(task *kernel.Task) *kernfs.Dentry { return dentry } -// Lookup implements kernfs.inodeDynamicLookup. +// Lookup implements kernfs.inodeDynamicLookup.Lookup. func (i *fdInfoDirInode) Lookup(ctx context.Context, name string) (*vfs.Dentry, error) { fdInt, err := strconv.ParseInt(name, 10, 32) if err != nil { @@ -269,21 +287,27 @@ func (i *fdInfoDirInode) Lookup(ctx context.Context, name string) (*vfs.Dentry, return dentry.VFSDentry(), nil } -// Open implements kernfs.Inode. +// Open implements kernfs.Inode.Open. func (i *fdInfoDirInode) Open(ctx context.Context, rp *vfs.ResolvingPath, vfsd *vfs.Dentry, opts vfs.OpenOptions) (*vfs.FileDescription, error) { - fd, err := kernfs.NewGenericDirectoryFD(rp.Mount(), vfsd, &i.OrderedChildren, &i.locks, &opts) + fd, err := kernfs.NewGenericDirectoryFD(rp.Mount(), vfsd, &i.OrderedChildren, &i.locks, &opts, kernfs.GenericDirectoryFDOptions{ + SeekEnd: kernfs.SeekEndZero, + }) if err != nil { return nil, err } return fd.VFSFileDescription(), nil } +// DecRef implements kernfs.Inode.DecRef. +func (i *fdInfoDirInode) DecRef(context.Context) { + i.fdInfoDirInodeRefs.DecRef(i.Destroy) +} + // fdInfoData implements vfs.DynamicBytesSource for /proc/[pid]/fdinfo/[fd]. // // +stateify savable type fdInfoData struct { kernfs.DynamicBytesFile - refs.AtomicRefCount task *kernel.Task fd int32 diff --git a/pkg/sentry/fsimpl/proc/task_files.go b/pkg/sentry/fsimpl/proc/task_files.go index 830b78949..8f7e9b801 100644 --- a/pkg/sentry/fsimpl/proc/task_files.go +++ b/pkg/sentry/fsimpl/proc/task_files.go @@ -543,7 +543,7 @@ func (s *statusData) Generate(ctx context.Context, buf *bytes.Buffer) error { var vss, rss, data uint64 s.task.WithMuLocked(func(t *kernel.Task) { if fdTable := t.FDTable(); fdTable != nil { - fds = fdTable.Size() + fds = fdTable.CurrentMaxFDs() } if mm := t.MemoryManager(); mm != nil { vss = mm.VirtualMemorySize() @@ -648,6 +648,7 @@ func (o *oomScoreAdj) Write(ctx context.Context, src usermem.IOSequence, offset // // +stateify savable type exeSymlink struct { + implStatFS kernfs.InodeAttrs kernfs.InodeNoopRefCount kernfs.InodeSymlink @@ -666,8 +667,8 @@ func (fs *filesystem) newExeSymlink(task *kernel.Task, ino uint64) *kernfs.Dentr return d } -// Readlink implements kernfs.Inode. -func (s *exeSymlink) Readlink(ctx context.Context) (string, error) { +// Readlink implements kernfs.Inode.Readlink. +func (s *exeSymlink) Readlink(ctx context.Context, _ *vfs.Mount) (string, error) { if !kernel.ContextCanTrace(ctx, s.task, false) { return "", syserror.EACCES } @@ -806,15 +807,15 @@ func (fs *filesystem) newNamespaceSymlink(task *kernel.Task, ino uint64, ns stri return d } -// Readlink implements Inode. -func (s *namespaceSymlink) Readlink(ctx context.Context) (string, error) { +// Readlink implements kernfs.Inode.Readlink. +func (s *namespaceSymlink) Readlink(ctx context.Context, mnt *vfs.Mount) (string, error) { if err := checkTaskState(s.task); err != nil { return "", err } - return s.StaticSymlink.Readlink(ctx) + return s.StaticSymlink.Readlink(ctx, mnt) } -// Getlink implements Inode.Getlink. +// Getlink implements kernfs.Inode.Getlink. func (s *namespaceSymlink) Getlink(ctx context.Context, mnt *vfs.Mount) (vfs.VirtualDentry, string, error) { if err := checkTaskState(s.task); err != nil { return vfs.VirtualDentry{}, "", err @@ -832,6 +833,7 @@ func (s *namespaceSymlink) Getlink(ctx context.Context, mnt *vfs.Mount) (vfs.Vir // namespaceInode is a synthetic inode created to represent a namespace in // /proc/[pid]/ns/*. type namespaceInode struct { + implStatFS kernfs.InodeAttrs kernfs.InodeNoopRefCount kernfs.InodeNotDirectory @@ -850,7 +852,7 @@ func (i *namespaceInode) Init(creds *auth.Credentials, devMajor, devMinor uint32 i.InodeAttrs.Init(creds, devMajor, devMinor, ino, linux.ModeRegular|perm) } -// Open implements Inode.Open. +// Open implements kernfs.Inode.Open. func (i *namespaceInode) Open(ctx context.Context, rp *vfs.ResolvingPath, vfsd *vfs.Dentry, opts vfs.OpenOptions) (*vfs.FileDescription, error) { fd := &namespaceFD{inode: i} i.IncRef() @@ -873,20 +875,20 @@ type namespaceFD struct { var _ vfs.FileDescriptionImpl = (*namespaceFD)(nil) -// Stat implements FileDescriptionImpl. +// Stat implements vfs.FileDescriptionImpl.Stat. func (fd *namespaceFD) Stat(ctx context.Context, opts vfs.StatOptions) (linux.Statx, error) { vfs := fd.vfsfd.VirtualDentry().Mount().Filesystem() return fd.inode.Stat(ctx, vfs, opts) } -// SetStat implements FileDescriptionImpl. +// SetStat implements vfs.FileDescriptionImpl.SetStat. func (fd *namespaceFD) SetStat(ctx context.Context, opts vfs.SetStatOptions) error { vfs := fd.vfsfd.VirtualDentry().Mount().Filesystem() creds := auth.CredentialsFromContext(ctx) return fd.inode.SetStat(ctx, vfs, creds, opts) } -// Release implements FileDescriptionImpl. +// Release implements vfs.FileDescriptionImpl.Release. func (fd *namespaceFD) Release(ctx context.Context) { fd.inode.DecRef(ctx) } diff --git a/pkg/sentry/fsimpl/proc/task_net.go b/pkg/sentry/fsimpl/proc/task_net.go index a4c884bf9..1607eac19 100644 --- a/pkg/sentry/fsimpl/proc/task_net.go +++ b/pkg/sentry/fsimpl/proc/task_net.go @@ -262,7 +262,7 @@ func (n *netUnixData) Generate(ctx context.Context, buf *bytes.Buffer) error { // For now, we always redact this pointer. fmt.Fprintf(buf, "%#016p: %08X %08X %08X %04X %02X %8d", (*unix.SocketOperations)(nil), // Num, pointer to kernel socket struct. - s.Refs()-1, // RefCount, don't count our own ref. + s.ReadRefs()-1, // RefCount, don't count our own ref. 0, // Protocol, always 0 for UDS. sockFlags, // Flags. sops.Endpoint().Type(), // Type. @@ -430,7 +430,7 @@ func commonGenerateTCP(ctx context.Context, buf *bytes.Buffer, k *kernel.Kernel, // Field: refcount. Don't count the ref we obtain while deferencing // the weakref to this socket. - fmt.Fprintf(buf, "%d ", s.Refs()-1) + fmt.Fprintf(buf, "%d ", s.ReadRefs()-1) // Field: Socket struct address. Redacted due to the same reason as // the 'Num' field in /proc/net/unix, see netUnix.ReadSeqFileData. @@ -589,7 +589,7 @@ func (d *netUDPData) Generate(ctx context.Context, buf *bytes.Buffer) error { // Field: ref; reference count on the socket inode. Don't count the ref // we obtain while deferencing the weakref to this socket. - fmt.Fprintf(buf, "%d ", s.Refs()-1) + fmt.Fprintf(buf, "%d ", s.ReadRefs()-1) // Field: Socket struct address. Redacted due to the same reason as // the 'Num' field in /proc/net/unix, see netUnix.ReadSeqFileData. @@ -660,7 +660,7 @@ func sprintSlice(s []uint64) string { return r[1 : len(r)-1] // Remove "[]" introduced by fmt of slice. } -// Generate implements vfs.DynamicBytesSource. +// Generate implements vfs.DynamicBytesSource.Generate. func (d *netSnmpData) Generate(ctx context.Context, buf *bytes.Buffer) error { types := []interface{}{ &inet.StatSNMPIP{}, @@ -709,7 +709,7 @@ type netRouteData struct { var _ dynamicInode = (*netRouteData)(nil) -// Generate implements vfs.DynamicBytesSource. +// Generate implements vfs.DynamicBytesSource.Generate. // See Linux's net/ipv4/fib_trie.c:fib_route_seq_show. func (d *netRouteData) Generate(ctx context.Context, buf *bytes.Buffer) error { fmt.Fprintf(buf, "%-127s\n", "Iface\tDestination\tGateway\tFlags\tRefCnt\tUse\tMetric\tMask\tMTU\tWindow\tIRTT") @@ -773,7 +773,7 @@ type netStatData struct { var _ dynamicInode = (*netStatData)(nil) -// Generate implements vfs.DynamicBytesSource. +// Generate implements vfs.DynamicBytesSource.Generate. // See Linux's net/ipv4/fib_trie.c:fib_route_seq_show. func (d *netStatData) Generate(ctx context.Context, buf *bytes.Buffer) error { buf.WriteString("TcpExt: SyncookiesSent SyncookiesRecv SyncookiesFailed " + diff --git a/pkg/sentry/fsimpl/proc/tasks.go b/pkg/sentry/fsimpl/proc/tasks.go index 6d2b90a8b..6d60acc30 100644 --- a/pkg/sentry/fsimpl/proc/tasks.go +++ b/pkg/sentry/fsimpl/proc/tasks.go @@ -37,11 +37,13 @@ const ( // // +stateify savable type tasksInode struct { - kernfs.InodeNotSymlink - kernfs.InodeDirectoryNoNewChildren + implStatFS + kernfs.AlwaysValid kernfs.InodeAttrs + kernfs.InodeDirectoryNoNewChildren + kernfs.InodeNotSymlink kernfs.OrderedChildren - kernfs.AlwaysValid + tasksInodeRefs locks vfs.FileLocks @@ -84,6 +86,7 @@ func (fs *filesystem) newTasksInode(k *kernel.Kernel, pidns *kernel.PIDNamespace cgroupControllers: cgroupControllers, } inode.InodeAttrs.Init(root, linux.UNNAMED_MAJOR, fs.devMinor, fs.NextIno(), linux.ModeDirectory|0555) + inode.EnableLeakCheck() dentry := &kernfs.Dentry{} dentry.Init(inode) @@ -95,7 +98,7 @@ func (fs *filesystem) newTasksInode(k *kernel.Kernel, pidns *kernel.PIDNamespace return inode, dentry } -// Lookup implements kernfs.inodeDynamicLookup. +// Lookup implements kernfs.inodeDynamicLookup.Lookup. func (i *tasksInode) Lookup(ctx context.Context, name string) (*vfs.Dentry, error) { // Try to lookup a corresponding task. tid, err := strconv.ParseUint(name, 10, 64) @@ -119,7 +122,7 @@ func (i *tasksInode) Lookup(ctx context.Context, name string) (*vfs.Dentry, erro return taskDentry.VFSDentry(), nil } -// IterDirents implements kernfs.inodeDynamicLookup. +// IterDirents implements kernfs.inodeDynamicLookup.IterDirents. func (i *tasksInode) IterDirents(ctx context.Context, cb vfs.IterDirentsCallback, offset, _ int64) (int64, error) { // fs/proc/internal.h: #define FIRST_PROCESS_ENTRY 256 const FIRST_PROCESS_ENTRY = 256 @@ -197,9 +200,11 @@ func (i *tasksInode) IterDirents(ctx context.Context, cb vfs.IterDirentsCallback return maxTaskID, nil } -// Open implements kernfs.Inode. +// Open implements kernfs.Inode.Open. func (i *tasksInode) Open(ctx context.Context, rp *vfs.ResolvingPath, vfsd *vfs.Dentry, opts vfs.OpenOptions) (*vfs.FileDescription, error) { - fd, err := kernfs.NewGenericDirectoryFD(rp.Mount(), vfsd, &i.OrderedChildren, &i.locks, &opts) + fd, err := kernfs.NewGenericDirectoryFD(rp.Mount(), vfsd, &i.OrderedChildren, &i.locks, &opts, kernfs.GenericDirectoryFDOptions{ + SeekEnd: kernfs.SeekEndZero, + }) if err != nil { return nil, err } @@ -224,6 +229,11 @@ func (i *tasksInode) Stat(ctx context.Context, vsfs *vfs.Filesystem, opts vfs.St return stat, nil } +// DecRef implements kernfs.Inode.DecRef. +func (i *tasksInode) DecRef(context.Context) { + i.tasksInodeRefs.DecRef(i.Destroy) +} + // staticFileSetStat implements a special static file that allows inode // attributes to be set. This is to support /proc files that are readonly, but // allow attributes to be set. diff --git a/pkg/sentry/fsimpl/proc/tasks_files.go b/pkg/sentry/fsimpl/proc/tasks_files.go index 7d8983aa5..459a8e52e 100644 --- a/pkg/sentry/fsimpl/proc/tasks_files.go +++ b/pkg/sentry/fsimpl/proc/tasks_files.go @@ -32,6 +32,7 @@ import ( ) type selfSymlink struct { + implStatFS kernfs.InodeAttrs kernfs.InodeNoopRefCount kernfs.InodeSymlink @@ -50,7 +51,7 @@ func (fs *filesystem) newSelfSymlink(creds *auth.Credentials, ino uint64, pidns return d } -func (s *selfSymlink) Readlink(ctx context.Context) (string, error) { +func (s *selfSymlink) Readlink(ctx context.Context, _ *vfs.Mount) (string, error) { t := kernel.TaskFromContext(ctx) if t == nil { // Who is reading this link? @@ -63,17 +64,18 @@ func (s *selfSymlink) Readlink(ctx context.Context) (string, error) { return strconv.FormatUint(uint64(tgid), 10), nil } -func (s *selfSymlink) Getlink(ctx context.Context, _ *vfs.Mount) (vfs.VirtualDentry, string, error) { - target, err := s.Readlink(ctx) +func (s *selfSymlink) Getlink(ctx context.Context, mnt *vfs.Mount) (vfs.VirtualDentry, string, error) { + target, err := s.Readlink(ctx, mnt) return vfs.VirtualDentry{}, target, err } -// SetStat implements Inode.SetStat not allowing inode attributes to be changed. +// SetStat implements kernfs.Inode.SetStat not allowing inode attributes to be changed. func (*selfSymlink) SetStat(context.Context, *vfs.Filesystem, *auth.Credentials, vfs.SetStatOptions) error { return syserror.EPERM } type threadSelfSymlink struct { + implStatFS kernfs.InodeAttrs kernfs.InodeNoopRefCount kernfs.InodeSymlink @@ -92,7 +94,7 @@ func (fs *filesystem) newThreadSelfSymlink(creds *auth.Credentials, ino uint64, return d } -func (s *threadSelfSymlink) Readlink(ctx context.Context) (string, error) { +func (s *threadSelfSymlink) Readlink(ctx context.Context, _ *vfs.Mount) (string, error) { t := kernel.TaskFromContext(ctx) if t == nil { // Who is reading this link? @@ -106,12 +108,12 @@ func (s *threadSelfSymlink) Readlink(ctx context.Context) (string, error) { return fmt.Sprintf("%d/task/%d", tgid, tid), nil } -func (s *threadSelfSymlink) Getlink(ctx context.Context, _ *vfs.Mount) (vfs.VirtualDentry, string, error) { - target, err := s.Readlink(ctx) +func (s *threadSelfSymlink) Getlink(ctx context.Context, mnt *vfs.Mount) (vfs.VirtualDentry, string, error) { + target, err := s.Readlink(ctx, mnt) return vfs.VirtualDentry{}, target, err } -// SetStat implements Inode.SetStat not allowing inode attributes to be changed. +// SetStat implements kernfs.Inode.SetStat not allowing inode attributes to be changed. func (*threadSelfSymlink) SetStat(context.Context, *vfs.Filesystem, *auth.Credentials, vfs.SetStatOptions) error { return syserror.EPERM } @@ -123,7 +125,7 @@ type dynamicBytesFileSetAttr struct { kernfs.DynamicBytesFile } -// SetStat implements Inode.SetStat. +// SetStat implements kernfs.Inode.SetStat. func (d *dynamicBytesFileSetAttr) SetStat(ctx context.Context, fs *vfs.Filesystem, creds *auth.Credentials, opts vfs.SetStatOptions) error { return d.DynamicBytesFile.InodeAttrs.SetStat(ctx, fs, creds, opts) } diff --git a/pkg/sentry/fsimpl/proc/tasks_sys.go b/pkg/sentry/fsimpl/proc/tasks_sys.go index 6768aa880..a3ffbb15e 100644 --- a/pkg/sentry/fsimpl/proc/tasks_sys.go +++ b/pkg/sentry/fsimpl/proc/tasks_sys.go @@ -25,21 +25,29 @@ import ( "gvisor.dev/gvisor/pkg/sentry/kernel" "gvisor.dev/gvisor/pkg/sentry/kernel/auth" "gvisor.dev/gvisor/pkg/sentry/vfs" + "gvisor.dev/gvisor/pkg/sync" "gvisor.dev/gvisor/pkg/syserror" "gvisor.dev/gvisor/pkg/tcpip/network/ipv4" "gvisor.dev/gvisor/pkg/usermem" ) +type tcpMemDir int + +const ( + tcpRMem tcpMemDir = iota + tcpWMem +) + // newSysDir returns the dentry corresponding to /proc/sys directory. func (fs *filesystem) newSysDir(root *auth.Credentials, k *kernel.Kernel) *kernfs.Dentry { - return kernfs.NewStaticDir(root, linux.UNNAMED_MAJOR, fs.devMinor, fs.NextIno(), 0555, map[string]*kernfs.Dentry{ - "kernel": kernfs.NewStaticDir(root, linux.UNNAMED_MAJOR, fs.devMinor, fs.NextIno(), 0555, map[string]*kernfs.Dentry{ + return newStaticDir(root, linux.UNNAMED_MAJOR, fs.devMinor, fs.NextIno(), 0555, map[string]*kernfs.Dentry{ + "kernel": newStaticDir(root, linux.UNNAMED_MAJOR, fs.devMinor, fs.NextIno(), 0555, map[string]*kernfs.Dentry{ "hostname": fs.newDentry(root, fs.NextIno(), 0444, &hostnameData{}), "shmall": fs.newDentry(root, fs.NextIno(), 0444, shmData(linux.SHMALL)), "shmmax": fs.newDentry(root, fs.NextIno(), 0444, shmData(linux.SHMMAX)), "shmmni": fs.newDentry(root, fs.NextIno(), 0444, shmData(linux.SHMMNI)), }), - "vm": kernfs.NewStaticDir(root, linux.UNNAMED_MAJOR, fs.devMinor, fs.NextIno(), 0555, map[string]*kernfs.Dentry{ + "vm": newStaticDir(root, linux.UNNAMED_MAJOR, fs.devMinor, fs.NextIno(), 0555, map[string]*kernfs.Dentry{ "mmap_min_addr": fs.newDentry(root, fs.NextIno(), 0444, &mmapMinAddrData{k: k}), "overcommit_memory": fs.newDentry(root, fs.NextIno(), 0444, newStaticFile("0\n")), }), @@ -55,9 +63,12 @@ func (fs *filesystem) newSysNetDir(root *auth.Credentials, k *kernel.Kernel) *ke // network namespace of the calling process. if stack := k.RootNetworkNamespace().Stack(); stack != nil { contents = map[string]*kernfs.Dentry{ - "ipv4": kernfs.NewStaticDir(root, linux.UNNAMED_MAJOR, fs.devMinor, fs.NextIno(), 0555, map[string]*kernfs.Dentry{ + "ipv4": newStaticDir(root, linux.UNNAMED_MAJOR, fs.devMinor, fs.NextIno(), 0555, map[string]*kernfs.Dentry{ "tcp_recovery": fs.newDentry(root, fs.NextIno(), 0644, &tcpRecoveryData{stack: stack}), + "tcp_rmem": fs.newDentry(root, fs.NextIno(), 0644, &tcpMemData{stack: stack, dir: tcpRMem}), "tcp_sack": fs.newDentry(root, fs.NextIno(), 0644, &tcpSackData{stack: stack}), + "tcp_wmem": fs.newDentry(root, fs.NextIno(), 0644, &tcpMemData{stack: stack, dir: tcpWMem}), + "ip_forward": fs.newDentry(root, fs.NextIno(), 0444, &ipForwarding{stack: stack}), // The following files are simple stubs until they are implemented in // netstack, most of these files are configuration related. We use the @@ -100,7 +111,7 @@ func (fs *filesystem) newSysNetDir(root *auth.Credentials, k *kernel.Kernel) *ke "tcp_syn_retries": fs.newDentry(root, fs.NextIno(), 0444, newStaticFile("3")), "tcp_timestamps": fs.newDentry(root, fs.NextIno(), 0444, newStaticFile("1")), }), - "core": kernfs.NewStaticDir(root, linux.UNNAMED_MAJOR, fs.devMinor, fs.NextIno(), 0555, map[string]*kernfs.Dentry{ + "core": newStaticDir(root, linux.UNNAMED_MAJOR, fs.devMinor, fs.NextIno(), 0555, map[string]*kernfs.Dentry{ "default_qdisc": fs.newDentry(root, fs.NextIno(), 0444, newStaticFile("pfifo_fast")), "message_burst": fs.newDentry(root, fs.NextIno(), 0444, newStaticFile("10")), "message_cost": fs.newDentry(root, fs.NextIno(), 0444, newStaticFile("5")), @@ -114,7 +125,7 @@ func (fs *filesystem) newSysNetDir(root *auth.Credentials, k *kernel.Kernel) *ke } } - return kernfs.NewStaticDir(root, linux.UNNAMED_MAJOR, fs.devMinor, fs.NextIno(), 0555, contents) + return newStaticDir(root, linux.UNNAMED_MAJOR, fs.devMinor, fs.NextIno(), 0555, contents) } // mmapMinAddrData implements vfs.DynamicBytesSource for @@ -165,7 +176,7 @@ type tcpSackData struct { var _ vfs.WritableDynamicBytesSource = (*tcpSackData)(nil) -// Generate implements vfs.DynamicBytesSource. +// Generate implements vfs.DynamicBytesSource.Generate. func (d *tcpSackData) Generate(ctx context.Context, buf *bytes.Buffer) error { if d.enabled == nil { sack, err := d.stack.TCPSACKEnabled() @@ -182,10 +193,11 @@ func (d *tcpSackData) Generate(ctx context.Context, buf *bytes.Buffer) error { // Tough luck. val = "1\n" } - buf.WriteString(val) - return nil + _, err := buf.WriteString(val) + return err } +// Write implements vfs.WritableDynamicBytesSource.Write. func (d *tcpSackData) Write(ctx context.Context, src usermem.IOSequence, offset int64) (int64, error) { if offset != 0 { // No need to handle partial writes thus far. @@ -201,7 +213,7 @@ func (d *tcpSackData) Write(ctx context.Context, src usermem.IOSequence, offset var v int32 n, err := usermem.CopyInt32StringInVec(ctx, src.IO, src.Addrs, &v, src.Opts) if err != nil { - return n, err + return 0, err } if d.enabled == nil { d.enabled = new(bool) @@ -222,17 +234,18 @@ type tcpRecoveryData struct { var _ vfs.WritableDynamicBytesSource = (*tcpRecoveryData)(nil) -// Generate implements vfs.DynamicBytesSource. +// Generate implements vfs.DynamicBytesSource.Generate. func (d *tcpRecoveryData) Generate(ctx context.Context, buf *bytes.Buffer) error { recovery, err := d.stack.TCPRecovery() if err != nil { return err } - buf.WriteString(fmt.Sprintf("%d\n", recovery)) - return nil + _, err = buf.WriteString(fmt.Sprintf("%d\n", recovery)) + return err } +// Write implements vfs.WritableDynamicBytesSource.Write. func (d *tcpRecoveryData) Write(ctx context.Context, src usermem.IOSequence, offset int64) (int64, error) { if offset != 0 { // No need to handle partial writes thus far. @@ -256,6 +269,94 @@ func (d *tcpRecoveryData) Write(ctx context.Context, src usermem.IOSequence, off return n, nil } +// tcpMemData implements vfs.WritableDynamicBytesSource for +// /proc/sys/net/ipv4/tcp_rmem and /proc/sys/net/ipv4/tcp_wmem. +// +// +stateify savable +type tcpMemData struct { + kernfs.DynamicBytesFile + + dir tcpMemDir + stack inet.Stack `state:"wait"` + + // mu protects against concurrent reads/writes to FDs based on the dentry + // backing this byte source. + mu sync.Mutex `state:"nosave"` +} + +var _ vfs.WritableDynamicBytesSource = (*tcpMemData)(nil) + +// Generate implements vfs.DynamicBytesSource.Generate. +func (d *tcpMemData) Generate(ctx context.Context, buf *bytes.Buffer) error { + d.mu.Lock() + defer d.mu.Unlock() + + size, err := d.readSizeLocked() + if err != nil { + return err + } + _, err = buf.WriteString(fmt.Sprintf("%d\t%d\t%d\n", size.Min, size.Default, size.Max)) + return err +} + +// Write implements vfs.WritableDynamicBytesSource.Write. +func (d *tcpMemData) Write(ctx context.Context, src usermem.IOSequence, offset int64) (int64, error) { + if offset != 0 { + // No need to handle partial writes thus far. + return 0, syserror.EINVAL + } + if src.NumBytes() == 0 { + return 0, nil + } + d.mu.Lock() + defer d.mu.Unlock() + + // Limit the amount of memory allocated. + src = src.TakeFirst(usermem.PageSize - 1) + size, err := d.readSizeLocked() + if err != nil { + return 0, err + } + buf := []int32{int32(size.Min), int32(size.Default), int32(size.Max)} + n, err := usermem.CopyInt32StringsInVec(ctx, src.IO, src.Addrs, buf, src.Opts) + if err != nil { + return 0, err + } + newSize := inet.TCPBufferSize{ + Min: int(buf[0]), + Default: int(buf[1]), + Max: int(buf[2]), + } + if err := d.writeSizeLocked(newSize); err != nil { + return 0, err + } + return n, nil +} + +// Precondition: d.mu must be locked. +func (d *tcpMemData) readSizeLocked() (inet.TCPBufferSize, error) { + switch d.dir { + case tcpRMem: + return d.stack.TCPReceiveBufferSize() + case tcpWMem: + return d.stack.TCPSendBufferSize() + default: + panic(fmt.Sprintf("unknown tcpMemFile type: %v", d.dir)) + } +} + +// Precondition: d.mu must be locked. +func (d *tcpMemData) writeSizeLocked(size inet.TCPBufferSize) error { + switch d.dir { + case tcpRMem: + return d.stack.SetTCPReceiveBufferSize(size) + case tcpWMem: + return d.stack.SetTCPSendBufferSize(size) + default: + panic(fmt.Sprintf("unknown tcpMemFile type: %v", d.dir)) + } +} + // ipForwarding implements vfs.WritableDynamicBytesSource for // /proc/sys/net/ipv4/ip_forwarding. // diff --git a/pkg/sentry/fsimpl/proc/tasks_sys_test.go b/pkg/sentry/fsimpl/proc/tasks_sys_test.go index 1abf56da2..6cee22823 100644 --- a/pkg/sentry/fsimpl/proc/tasks_sys_test.go +++ b/pkg/sentry/fsimpl/proc/tasks_sys_test.go @@ -137,12 +137,12 @@ func TestConfigureIPForwarding(t *testing.T) { // Write the values. src := usermem.BytesIOSequence([]byte(c.str)) if n, err := file.Write(ctx, src, 0); n != int64(len(c.str)) || err != nil { - t.Errorf("file.Write(ctx, nil, %v, 0) = (%d, %v); wanted (%d, nil)", c.str, n, err, len(c.str)) + t.Errorf("file.Write(ctx, nil, %q, 0) = (%d, %v); want (%d, nil)", c.str, n, err, len(c.str)) } // Read the values from the stack and check them. - if s.IPForwarding != c.final { - t.Errorf("s.IPForwarding = %v; wanted %v", s.IPForwarding, c.final) + if got, want := s.IPForwarding, c.final; got != want { + t.Errorf("s.IPForwarding incorrect; got: %v, want: %v", got, want) } }) } diff --git a/pkg/sentry/fsimpl/proc/tasks_test.go b/pkg/sentry/fsimpl/proc/tasks_test.go index 3c9297dee..f693f9060 100644 --- a/pkg/sentry/fsimpl/proc/tasks_test.go +++ b/pkg/sentry/fsimpl/proc/tasks_test.go @@ -104,7 +104,7 @@ func setup(t *testing.T) *testutil.System { AllowUserMount: true, }) - mntns, err := k.VFS().NewMountNamespace(ctx, creds, "", tmpfs.Name, &vfs.GetFilesystemOptions{}) + mntns, err := k.VFS().NewMountNamespace(ctx, creds, "", tmpfs.Name, &vfs.MountOptions{}) if err != nil { t.Fatalf("NewMountNamespace(): %v", err) } @@ -132,7 +132,7 @@ func setup(t *testing.T) *testutil.System { }, }, } - if err := k.VFS().MountAt(ctx, creds, "", pop, Name, mntOpts); err != nil { + if _, err := k.VFS().MountAt(ctx, creds, "", pop, Name, mntOpts); err != nil { t.Fatalf("MountAt(/proc): %v", err) } return testutil.NewSystem(ctx, t, k.VFS(), mntns) diff --git a/pkg/sentry/fsimpl/signalfd/signalfd.go b/pkg/sentry/fsimpl/signalfd/signalfd.go index 6297e1df4..3c02af8c9 100644 --- a/pkg/sentry/fsimpl/signalfd/signalfd.go +++ b/pkg/sentry/fsimpl/signalfd/signalfd.go @@ -26,7 +26,7 @@ import ( "gvisor.dev/gvisor/pkg/waiter" ) -// SignalFileDescription implements FileDescriptionImpl for signal fds. +// SignalFileDescription implements vfs.FileDescriptionImpl for signal fds. type SignalFileDescription struct { vfsfd vfs.FileDescription vfs.FileDescriptionDefaultImpl @@ -83,7 +83,7 @@ func (sfd *SignalFileDescription) SetMask(mask linux.SignalSet) { sfd.mask = mask } -// Read implements FileDescriptionImpl.Read. +// Read implements vfs.FileDescriptionImpl.Read. func (sfd *SignalFileDescription) Read(ctx context.Context, dst usermem.IOSequence, _ vfs.ReadOptions) (int64, error) { // Attempt to dequeue relevant signals. info, err := sfd.target.Sigtimedwait(sfd.Mask(), 0) @@ -132,5 +132,5 @@ func (sfd *SignalFileDescription) EventUnregister(entry *waiter.Entry) { sfd.target.SignalUnregister(entry) } -// Release implements FileDescriptionImpl.Release() +// Release implements vfs.FileDescriptionImpl.Release. func (sfd *SignalFileDescription) Release(context.Context) {} diff --git a/pkg/sentry/fsimpl/sockfs/sockfs.go b/pkg/sentry/fsimpl/sockfs/sockfs.go index c61818ff6..80b41aa9e 100644 --- a/pkg/sentry/fsimpl/sockfs/sockfs.go +++ b/pkg/sentry/fsimpl/sockfs/sockfs.go @@ -30,12 +30,12 @@ import ( // filesystemType implements vfs.FilesystemType. type filesystemType struct{} -// GetFilesystem implements FilesystemType.GetFilesystem. +// GetFilesystem implements vfs.FilesystemType.GetFilesystem. func (fsType filesystemType) GetFilesystem(_ context.Context, vfsObj *vfs.VirtualFilesystem, _ *auth.Credentials, _ string, _ vfs.GetFilesystemOptions) (*vfs.Filesystem, *vfs.Dentry, error) { panic("sockfs.filesystemType.GetFilesystem should never be called") } -// Name implements FilesystemType.Name. +// Name implements vfs.FilesystemType.Name. // // Note that registering sockfs is unnecessary, except for the fact that it // will not show up under /proc/filesystems as a result. This is a very minor @@ -81,10 +81,10 @@ func (fs *filesystem) PrependPath(ctx context.Context, vfsroot, vd vfs.VirtualDe // inode implements kernfs.Inode. type inode struct { - kernfs.InodeNotDirectory - kernfs.InodeNotSymlink kernfs.InodeAttrs kernfs.InodeNoopRefCount + kernfs.InodeNotDirectory + kernfs.InodeNotSymlink } // Open implements kernfs.Inode.Open. @@ -92,6 +92,11 @@ func (i *inode) Open(ctx context.Context, rp *vfs.ResolvingPath, vfsd *vfs.Dentr return nil, syserror.ENXIO } +// StatFS implements kernfs.Inode.StatFS. +func (i *inode) StatFS(ctx context.Context, fs *vfs.Filesystem) (linux.Statfs, error) { + return vfs.GenericStatFS(linux.SOCKFS_MAGIC), nil +} + // NewDentry constructs and returns a sockfs dentry. // // Preconditions: mnt.Filesystem() must have been returned by NewFilesystem(). diff --git a/pkg/sentry/fsimpl/sys/BUILD b/pkg/sentry/fsimpl/sys/BUILD index 1b548ccd4..906cd52cb 100644 --- a/pkg/sentry/fsimpl/sys/BUILD +++ b/pkg/sentry/fsimpl/sys/BUILD @@ -1,21 +1,41 @@ load("//tools:defs.bzl", "go_library", "go_test") +load("//tools/go_generics:defs.bzl", "go_template_instance") licenses(["notice"]) +go_template_instance( + name = "dir_refs", + out = "dir_refs.go", + package = "sys", + prefix = "dir", + template = "//pkg/refs_vfs2:refs_template", + types = { + "T": "dir", + }, +) + go_library( name = "sys", srcs = [ + "dir_refs.go", + "kcov.go", "sys.go", ], visibility = ["//pkg/sentry:internal"], deps = [ "//pkg/abi/linux", "//pkg/context", + "//pkg/coverage", + "//pkg/log", + "//pkg/refs", + "//pkg/sentry/arch", "//pkg/sentry/fsimpl/kernfs", "//pkg/sentry/kernel", "//pkg/sentry/kernel/auth", + "//pkg/sentry/memmap", "//pkg/sentry/vfs", "//pkg/syserror", + "//pkg/usermem", ], ) diff --git a/pkg/sentry/fsimpl/sys/kcov.go b/pkg/sentry/fsimpl/sys/kcov.go new file mode 100644 index 000000000..73f3d3309 --- /dev/null +++ b/pkg/sentry/fsimpl/sys/kcov.go @@ -0,0 +1,117 @@ +// Copyright 2020 The gVisor Authors. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package sys + +import ( + "gvisor.dev/gvisor/pkg/abi/linux" + "gvisor.dev/gvisor/pkg/context" + "gvisor.dev/gvisor/pkg/sentry/arch" + "gvisor.dev/gvisor/pkg/sentry/fsimpl/kernfs" + "gvisor.dev/gvisor/pkg/sentry/kernel" + "gvisor.dev/gvisor/pkg/sentry/kernel/auth" + "gvisor.dev/gvisor/pkg/sentry/memmap" + "gvisor.dev/gvisor/pkg/sentry/vfs" + "gvisor.dev/gvisor/pkg/syserror" + "gvisor.dev/gvisor/pkg/usermem" +) + +func (fs *filesystem) newKcovFile(ctx context.Context, creds *auth.Credentials) *kernfs.Dentry { + k := &kcovInode{} + k.InodeAttrs.Init(creds, 0, 0, fs.NextIno(), linux.S_IFREG|0600) + d := &kernfs.Dentry{} + d.Init(k) + return d +} + +// kcovInode implements kernfs.Inode. +type kcovInode struct { + kernfs.InodeAttrs + kernfs.InodeNoopRefCount + kernfs.InodeNotDirectory + kernfs.InodeNotSymlink + implStatFS +} + +func (i *kcovInode) Open(ctx context.Context, rp *vfs.ResolvingPath, vfsd *vfs.Dentry, opts vfs.OpenOptions) (*vfs.FileDescription, error) { + k := kernel.KernelFromContext(ctx) + if k == nil { + panic("KernelFromContext returned nil") + } + fd := &kcovFD{ + inode: i, + kcov: k.NewKcov(), + } + + if err := fd.vfsfd.Init(fd, opts.Flags, rp.Mount(), vfsd, &vfs.FileDescriptionOptions{ + DenyPRead: true, + DenyPWrite: true, + }); err != nil { + return nil, err + } + return &fd.vfsfd, nil +} + +type kcovFD struct { + vfs.FileDescriptionDefaultImpl + vfs.NoLockFD + + vfsfd vfs.FileDescription + inode *kcovInode + kcov *kernel.Kcov +} + +// Ioctl implements vfs.FileDescriptionImpl.Ioctl. +func (fd *kcovFD) Ioctl(ctx context.Context, uio usermem.IO, args arch.SyscallArguments) (uintptr, error) { + cmd := uint32(args[1].Int()) + arg := args[2].Uint64() + switch uint32(cmd) { + case linux.KCOV_INIT_TRACE: + return 0, fd.kcov.InitTrace(arg) + case linux.KCOV_ENABLE: + return 0, fd.kcov.EnableTrace(ctx, uint8(arg)) + case linux.KCOV_DISABLE: + if arg != 0 { + // This arg is unused; it should be 0. + return 0, syserror.EINVAL + } + return 0, fd.kcov.DisableTrace(ctx) + default: + return 0, syserror.ENOTTY + } +} + +// ConfigureMmap implements vfs.FileDescriptionImpl.ConfigureMmap. +func (fd *kcovFD) ConfigureMMap(ctx context.Context, opts *memmap.MMapOpts) error { + return fd.kcov.ConfigureMMap(ctx, opts) +} + +// Release implements vfs.FileDescriptionImpl.Release. +func (fd *kcovFD) Release(ctx context.Context) { + // kcov instances have reference counts in Linux, but this seems sufficient + // for our purposes. + fd.kcov.Reset() +} + +// SetStat implements vfs.FileDescriptionImpl.SetStat. +func (fd *kcovFD) SetStat(ctx context.Context, opts vfs.SetStatOptions) error { + creds := auth.CredentialsFromContext(ctx) + fs := fd.vfsfd.VirtualDentry().Mount().Filesystem() + return fd.inode.SetStat(ctx, fs, creds, opts) +} + +// Stat implements vfs.FileDescriptionImpl.Stat. +func (fd *kcovFD) Stat(ctx context.Context, opts vfs.StatOptions) (linux.Statx, error) { + return fd.inode.Stat(ctx, fd.vfsfd.Mount().Filesystem(), opts) +} diff --git a/pkg/sentry/fsimpl/sys/sys.go b/pkg/sentry/fsimpl/sys/sys.go index 0401726b6..39952d2d0 100644 --- a/pkg/sentry/fsimpl/sys/sys.go +++ b/pkg/sentry/fsimpl/sys/sys.go @@ -21,6 +21,7 @@ import ( "gvisor.dev/gvisor/pkg/abi/linux" "gvisor.dev/gvisor/pkg/context" + "gvisor.dev/gvisor/pkg/coverage" "gvisor.dev/gvisor/pkg/sentry/fsimpl/kernfs" "gvisor.dev/gvisor/pkg/sentry/kernel" "gvisor.dev/gvisor/pkg/sentry/kernel/auth" @@ -73,7 +74,7 @@ func (fsType FilesystemType) GetFilesystem(ctx context.Context, vfsObj *vfs.Virt }), "firmware": fs.newDir(creds, defaultSysDirMode, nil), "fs": fs.newDir(creds, defaultSysDirMode, nil), - "kernel": fs.newDir(creds, defaultSysDirMode, nil), + "kernel": kernelDir(ctx, fs, creds), "module": fs.newDir(creds, defaultSysDirMode, nil), "power": fs.newDir(creds, defaultSysDirMode, nil), }) @@ -94,6 +95,21 @@ func cpuDir(ctx context.Context, fs *filesystem, creds *auth.Credentials) *kernf return fs.newDir(creds, defaultSysDirMode, children) } +func kernelDir(ctx context.Context, fs *filesystem, creds *auth.Credentials) *kernfs.Dentry { + // If kcov is available, set up /sys/kernel/debug/kcov. Technically, debugfs + // should be mounted at debug/, but for our purposes, it is sufficient to + // keep it in sys. + var children map[string]*kernfs.Dentry + if coverage.KcovAvailable() { + children = map[string]*kernfs.Dentry{ + "debug": fs.newDir(creds, linux.FileMode(0700), map[string]*kernfs.Dentry{ + "kcov": fs.newKcovFile(ctx, creds), + }), + } + } + return fs.newDir(creds, defaultSysDirMode, children) +} + // Release implements vfs.FilesystemImpl.Release. func (fs *filesystem) Release(ctx context.Context) { fs.Filesystem.VFSFilesystem().VirtualFilesystem().PutAnonBlockDevMinor(fs.devMinor) @@ -102,6 +118,7 @@ func (fs *filesystem) Release(ctx context.Context) { // dir implements kernfs.Inode. type dir struct { + dirRefs kernfs.InodeAttrs kernfs.InodeNoDynamicLookup kernfs.InodeNotSymlink @@ -117,6 +134,7 @@ func (fs *filesystem) newDir(creds *auth.Credentials, mode linux.FileMode, conte d := &dir{} d.InodeAttrs.Init(creds, linux.UNNAMED_MAJOR, fs.devMinor, fs.NextIno(), linux.ModeDirectory|0755) d.OrderedChildren.Init(kernfs.OrderedChildrenOptions{}) + d.EnableLeakCheck() d.dentry.Init(d) d.IncLinks(d.OrderedChildren.Populate(&d.dentry, contents)) @@ -124,23 +142,37 @@ func (fs *filesystem) newDir(creds *auth.Credentials, mode linux.FileMode, conte return &d.dentry } -// SetStat implements Inode.SetStat not allowing inode attributes to be changed. +// SetStat implements kernfs.Inode.SetStat not allowing inode attributes to be changed. func (*dir) SetStat(context.Context, *vfs.Filesystem, *auth.Credentials, vfs.SetStatOptions) error { return syserror.EPERM } // Open implements kernfs.Inode.Open. func (d *dir) Open(ctx context.Context, rp *vfs.ResolvingPath, vfsd *vfs.Dentry, opts vfs.OpenOptions) (*vfs.FileDescription, error) { - fd, err := kernfs.NewGenericDirectoryFD(rp.Mount(), vfsd, &d.OrderedChildren, &d.locks, &opts) + fd, err := kernfs.NewGenericDirectoryFD(rp.Mount(), vfsd, &d.OrderedChildren, &d.locks, &opts, kernfs.GenericDirectoryFDOptions{ + SeekEnd: kernfs.SeekEndStaticEntries, + }) if err != nil { return nil, err } return fd.VFSFileDescription(), nil } +// DecRef implements kernfs.Inode.DecRef. +func (d *dir) DecRef(context.Context) { + d.dirRefs.DecRef(d.Destroy) +} + +// StatFS implements kernfs.Inode.StatFS. +func (d *dir) StatFS(ctx context.Context, fs *vfs.Filesystem) (linux.Statfs, error) { + return vfs.GenericStatFS(linux.SYSFS_MAGIC), nil +} + // cpuFile implements kernfs.Inode. type cpuFile struct { + implStatFS kernfs.DynamicBytesFile + maxCores uint } @@ -157,3 +189,10 @@ func (fs *filesystem) newCPUFile(creds *auth.Credentials, maxCores uint, mode li d.Init(c) return d } + +type implStatFS struct{} + +// StatFS implements kernfs.Inode.StatFS. +func (*implStatFS) StatFS(context.Context, *vfs.Filesystem) (linux.Statfs, error) { + return vfs.GenericStatFS(linux.SYSFS_MAGIC), nil +} diff --git a/pkg/sentry/fsimpl/sys/sys_test.go b/pkg/sentry/fsimpl/sys/sys_test.go index 9fd38b295..0a0d914cc 100644 --- a/pkg/sentry/fsimpl/sys/sys_test.go +++ b/pkg/sentry/fsimpl/sys/sys_test.go @@ -38,7 +38,7 @@ func newTestSystem(t *testing.T) *testutil.System { AllowUserMount: true, }) - mns, err := k.VFS().NewMountNamespace(ctx, creds, "", sys.Name, &vfs.GetFilesystemOptions{}) + mns, err := k.VFS().NewMountNamespace(ctx, creds, "", sys.Name, &vfs.MountOptions{}) if err != nil { t.Fatalf("Failed to create new mount namespace: %v", err) } diff --git a/pkg/sentry/fsimpl/timerfd/timerfd.go b/pkg/sentry/fsimpl/timerfd/timerfd.go index 86beaa0a8..ac8a4e3bb 100644 --- a/pkg/sentry/fsimpl/timerfd/timerfd.go +++ b/pkg/sentry/fsimpl/timerfd/timerfd.go @@ -26,7 +26,7 @@ import ( "gvisor.dev/gvisor/pkg/waiter" ) -// TimerFileDescription implements FileDescriptionImpl for timer fds. It also +// TimerFileDescription implements vfs.FileDescriptionImpl for timer fds. It also // implements ktime.TimerListener. type TimerFileDescription struct { vfsfd vfs.FileDescription @@ -62,7 +62,7 @@ func New(ctx context.Context, vfsObj *vfs.VirtualFilesystem, clock ktime.Clock, return &tfd.vfsfd, nil } -// Read implements FileDescriptionImpl.Read. +// Read implements vfs.FileDescriptionImpl.Read. func (tfd *TimerFileDescription) Read(ctx context.Context, dst usermem.IOSequence, opts vfs.ReadOptions) (int64, error) { const sizeofUint64 = 8 if dst.NumBytes() < sizeofUint64 { @@ -128,7 +128,7 @@ func (tfd *TimerFileDescription) ResumeTimer() { tfd.timer.Resume() } -// Release implements FileDescriptionImpl.Release() +// Release implements vfs.FileDescriptionImpl.Release. func (tfd *TimerFileDescription) Release(context.Context) { tfd.timer.Destroy() } diff --git a/pkg/sentry/fsimpl/tmpfs/benchmark_test.go b/pkg/sentry/fsimpl/tmpfs/benchmark_test.go index d263147c2..5209a17af 100644 --- a/pkg/sentry/fsimpl/tmpfs/benchmark_test.go +++ b/pkg/sentry/fsimpl/tmpfs/benchmark_test.go @@ -182,7 +182,7 @@ func BenchmarkVFS2TmpfsStat(b *testing.B) { vfsObj.MustRegisterFilesystemType("tmpfs", tmpfs.FilesystemType{}, &vfs.RegisterFilesystemTypeOptions{ AllowUserMount: true, }) - mntns, err := vfsObj.NewMountNamespace(ctx, creds, "", "tmpfs", &vfs.GetFilesystemOptions{}) + mntns, err := vfsObj.NewMountNamespace(ctx, creds, "", "tmpfs", &vfs.MountOptions{}) if err != nil { b.Fatalf("failed to create tmpfs root mount: %v", err) } @@ -376,7 +376,7 @@ func BenchmarkVFS2TmpfsMountStat(b *testing.B) { vfsObj.MustRegisterFilesystemType("tmpfs", tmpfs.FilesystemType{}, &vfs.RegisterFilesystemTypeOptions{ AllowUserMount: true, }) - mntns, err := vfsObj.NewMountNamespace(ctx, creds, "", "tmpfs", &vfs.GetFilesystemOptions{}) + mntns, err := vfsObj.NewMountNamespace(ctx, creds, "", "tmpfs", &vfs.MountOptions{}) if err != nil { b.Fatalf("failed to create tmpfs root mount: %v", err) } @@ -405,7 +405,7 @@ func BenchmarkVFS2TmpfsMountStat(b *testing.B) { } defer mountPoint.DecRef(ctx) // Create and mount the submount. - if err := vfsObj.MountAt(ctx, creds, "", &pop, "tmpfs", &vfs.MountOptions{}); err != nil { + if _, err := vfsObj.MountAt(ctx, creds, "", &pop, "tmpfs", &vfs.MountOptions{}); err != nil { b.Fatalf("failed to mount tmpfs submount: %v", err) } filePathBuilder.WriteString(mountPointName) diff --git a/pkg/sentry/fsimpl/tmpfs/directory.go b/pkg/sentry/fsimpl/tmpfs/directory.go index 78b4fc5be..070c75e68 100644 --- a/pkg/sentry/fsimpl/tmpfs/directory.go +++ b/pkg/sentry/fsimpl/tmpfs/directory.go @@ -57,8 +57,9 @@ func (fs *filesystem) newDirectory(kuid auth.KUID, kgid auth.KGID, mode linux.Fi return dir } -// Preconditions: filesystem.mu must be locked for writing. dir must not -// already contain a child with the given name. +// Preconditions: +// * filesystem.mu must be locked for writing. +// * dir must not already contain a child with the given name. func (dir *directory) insertChildLocked(child *dentry, name string) { child.parent = &dir.dentry child.name = name diff --git a/pkg/sentry/fsimpl/tmpfs/filesystem.go b/pkg/sentry/fsimpl/tmpfs/filesystem.go index cb8b2d944..1362c1602 100644 --- a/pkg/sentry/fsimpl/tmpfs/filesystem.go +++ b/pkg/sentry/fsimpl/tmpfs/filesystem.go @@ -25,7 +25,6 @@ import ( "gvisor.dev/gvisor/pkg/sentry/socket/unix/transport" "gvisor.dev/gvisor/pkg/sentry/vfs" "gvisor.dev/gvisor/pkg/syserror" - "gvisor.dev/gvisor/pkg/usermem" ) // Sync implements vfs.FilesystemImpl.Sync. @@ -39,7 +38,9 @@ func (fs *filesystem) Sync(ctx context.Context) error { // // stepLocked is loosely analogous to fs/namei.c:walk_component(). // -// Preconditions: filesystem.mu must be locked. !rp.Done(). +// Preconditions: +// * filesystem.mu must be locked. +// * !rp.Done(). func stepLocked(ctx context.Context, rp *vfs.ResolvingPath, d *dentry) (*dentry, error) { dir, ok := d.inode.impl.(*directory) if !ok { @@ -97,7 +98,9 @@ afterSymlink: // walkParentDirLocked is loosely analogous to Linux's // fs/namei.c:path_parentat(). // -// Preconditions: filesystem.mu must be locked. !rp.Done(). +// Preconditions: +// * filesystem.mu must be locked. +// * !rp.Done(). func walkParentDirLocked(ctx context.Context, rp *vfs.ResolvingPath, d *dentry) (*directory, error) { for !rp.Final() { next, err := stepLocked(ctx, rp, d) @@ -139,8 +142,9 @@ func resolveLocked(ctx context.Context, rp *vfs.ResolvingPath) (*dentry, error) // doCreateAt is loosely analogous to a conjunction of Linux's // fs/namei.c:filename_create() and done_path_create(). // -// Preconditions: !rp.Done(). For the final path component in rp, -// !rp.ShouldFollowSymlink(). +// Preconditions: +// * !rp.Done(). +// * For the final path component in rp, !rp.ShouldFollowSymlink(). func (fs *filesystem) doCreateAt(ctx context.Context, rp *vfs.ResolvingPath, dir bool, create func(parentDir *directory, name string) error) error { fs.mu.Lock() defer fs.mu.Unlock() @@ -307,18 +311,28 @@ func (fs *filesystem) OpenAt(ctx context.Context, rp *vfs.ResolvingPath, opts vf // don't need fs.mu for writing. if opts.Flags&linux.O_CREAT == 0 { fs.mu.RLock() - defer fs.mu.RUnlock() d, err := resolveLocked(ctx, rp) if err != nil { + fs.mu.RUnlock() return nil, err } + d.IncRef() + defer d.DecRef(ctx) + fs.mu.RUnlock() return d.open(ctx, rp, &opts, false /* afterCreate */) } mustCreate := opts.Flags&linux.O_EXCL != 0 start := rp.Start().Impl().(*dentry) fs.mu.Lock() - defer fs.mu.Unlock() + unlocked := false + unlock := func() { + if !unlocked { + fs.mu.Unlock() + unlocked = true + } + } + defer unlock() if rp.Done() { // Reject attempts to open mount root directory with O_CREAT. if rp.MustBeDir() { @@ -327,6 +341,9 @@ func (fs *filesystem) OpenAt(ctx context.Context, rp *vfs.ResolvingPath, opts vf if mustCreate { return nil, syserror.EEXIST } + start.IncRef() + defer start.DecRef(ctx) + unlock() return start.open(ctx, rp, &opts, false /* afterCreate */) } afterTrailingSymlink: @@ -364,6 +381,7 @@ afterTrailingSymlink: creds := rp.Credentials() child := fs.newDentry(fs.newRegularFile(creds.EffectiveKUID, creds.EffectiveKGID, opts.Mode)) parentDir.insertChildLocked(child, name) + unlock() fd, err := child.open(ctx, rp, &opts, true) if err != nil { return nil, err @@ -392,9 +410,14 @@ afterTrailingSymlink: if rp.MustBeDir() && !child.inode.isDir() { return nil, syserror.ENOTDIR } + child.IncRef() + defer child.DecRef(ctx) + unlock() return child.open(ctx, rp, &opts, false) } +// Preconditions: The caller must hold no locks (since opening pipes may block +// indefinitely). func (d *dentry) open(ctx context.Context, rp *vfs.ResolvingPath, opts *vfs.OpenOptions, afterCreate bool) (*vfs.FileDescription, error) { ats := vfs.AccessTypesForOpenFlags(opts) if !afterCreate { @@ -682,16 +705,7 @@ func (fs *filesystem) StatFSAt(ctx context.Context, rp *vfs.ResolvingPath) (linu if _, err := resolveLocked(ctx, rp); err != nil { return linux.Statfs{}, err } - statfs := linux.Statfs{ - Type: linux.TMPFS_MAGIC, - BlockSize: usermem.PageSize, - FragmentSize: usermem.PageSize, - NameLength: linux.NAME_MAX, - // TODO(b/29637826): Allow configuring a tmpfs size and enforce it. - Blocks: 0, - BlocksFree: 0, - } - return statfs, nil + return globalStatfs, nil } // SymlinkAt implements vfs.FilesystemImpl.SymlinkAt. @@ -756,7 +770,7 @@ func (fs *filesystem) UnlinkAt(ctx context.Context, rp *vfs.ResolvingPath) error return nil } -// BoundEndpointAt implements FilesystemImpl.BoundEndpointAt. +// BoundEndpointAt implements vfs.FilesystemImpl.BoundEndpointAt. func (fs *filesystem) BoundEndpointAt(ctx context.Context, rp *vfs.ResolvingPath, opts vfs.BoundEndpointOptions) (transport.BoundEndpoint, error) { fs.mu.RLock() defer fs.mu.RUnlock() @@ -769,43 +783,46 @@ func (fs *filesystem) BoundEndpointAt(ctx context.Context, rp *vfs.ResolvingPath } switch impl := d.inode.impl.(type) { case *socketFile: + if impl.ep == nil { + return nil, syserror.ECONNREFUSED + } return impl.ep, nil default: return nil, syserror.ECONNREFUSED } } -// ListxattrAt implements vfs.FilesystemImpl.ListxattrAt. -func (fs *filesystem) ListxattrAt(ctx context.Context, rp *vfs.ResolvingPath, size uint64) ([]string, error) { +// ListXattrAt implements vfs.FilesystemImpl.ListXattrAt. +func (fs *filesystem) ListXattrAt(ctx context.Context, rp *vfs.ResolvingPath, size uint64) ([]string, error) { fs.mu.RLock() defer fs.mu.RUnlock() d, err := resolveLocked(ctx, rp) if err != nil { return nil, err } - return d.inode.listxattr(size) + return d.inode.listXattr(size) } -// GetxattrAt implements vfs.FilesystemImpl.GetxattrAt. -func (fs *filesystem) GetxattrAt(ctx context.Context, rp *vfs.ResolvingPath, opts vfs.GetxattrOptions) (string, error) { +// GetXattrAt implements vfs.FilesystemImpl.GetXattrAt. +func (fs *filesystem) GetXattrAt(ctx context.Context, rp *vfs.ResolvingPath, opts vfs.GetXattrOptions) (string, error) { fs.mu.RLock() defer fs.mu.RUnlock() d, err := resolveLocked(ctx, rp) if err != nil { return "", err } - return d.inode.getxattr(rp.Credentials(), &opts) + return d.inode.getXattr(rp.Credentials(), &opts) } -// SetxattrAt implements vfs.FilesystemImpl.SetxattrAt. -func (fs *filesystem) SetxattrAt(ctx context.Context, rp *vfs.ResolvingPath, opts vfs.SetxattrOptions) error { +// SetXattrAt implements vfs.FilesystemImpl.SetXattrAt. +func (fs *filesystem) SetXattrAt(ctx context.Context, rp *vfs.ResolvingPath, opts vfs.SetXattrOptions) error { fs.mu.RLock() d, err := resolveLocked(ctx, rp) if err != nil { fs.mu.RUnlock() return err } - if err := d.inode.setxattr(rp.Credentials(), &opts); err != nil { + if err := d.inode.setXattr(rp.Credentials(), &opts); err != nil { fs.mu.RUnlock() return err } @@ -815,15 +832,15 @@ func (fs *filesystem) SetxattrAt(ctx context.Context, rp *vfs.ResolvingPath, opt return nil } -// RemovexattrAt implements vfs.FilesystemImpl.RemovexattrAt. -func (fs *filesystem) RemovexattrAt(ctx context.Context, rp *vfs.ResolvingPath, name string) error { +// RemoveXattrAt implements vfs.FilesystemImpl.RemoveXattrAt. +func (fs *filesystem) RemoveXattrAt(ctx context.Context, rp *vfs.ResolvingPath, name string) error { fs.mu.RLock() d, err := resolveLocked(ctx, rp) if err != nil { fs.mu.RUnlock() return err } - if err := d.inode.removexattr(rp.Credentials(), name); err != nil { + if err := d.inode.removeXattr(rp.Credentials(), name); err != nil { fs.mu.RUnlock() return err } @@ -848,8 +865,16 @@ func (fs *filesystem) PrependPath(ctx context.Context, vfsroot, vd vfs.VirtualDe } if d.parent == nil { if d.name != "" { - // This must be an anonymous memfd file. + // This file must have been created by + // newUnlinkedRegularFileDescription(). In Linux, + // mm/shmem.c:__shmem_file_setup() => + // fs/file_table.c:alloc_file_pseudo() sets the created + // dentry's dentry_operations to anon_ops, for which d_dname == + // simple_dname. fs/d_path.c:simple_dname() defines the + // dentry's pathname to be its name, prefixed with "/" and + // suffixed with " (deleted)". b.PrependComponent("/" + d.name) + b.AppendString(" (deleted)") return vfs.PrependPathSyntheticError{} } return vfs.PrependPathAtNonMountRootError{} diff --git a/pkg/sentry/fsimpl/tmpfs/named_pipe.go b/pkg/sentry/fsimpl/tmpfs/named_pipe.go index 739350cf0..5b0471ff4 100644 --- a/pkg/sentry/fsimpl/tmpfs/named_pipe.go +++ b/pkg/sentry/fsimpl/tmpfs/named_pipe.go @@ -28,8 +28,8 @@ type namedPipe struct { } // Preconditions: -// * fs.mu must be locked. -// * rp.Mount().CheckBeginWrite() has been called successfully. +// * fs.mu must be locked. +// * rp.Mount().CheckBeginWrite() has been called successfully. func (fs *filesystem) newNamedPipe(kuid auth.KUID, kgid auth.KGID, mode linux.FileMode) *inode { file := &namedPipe{pipe: pipe.NewVFSPipe(true /* isNamed */, pipe.DefaultPipeSize, usermem.PageSize)} file.inode.init(file, fs, kuid, kgid, linux.S_IFIFO|mode) diff --git a/pkg/sentry/fsimpl/tmpfs/pipe_test.go b/pkg/sentry/fsimpl/tmpfs/pipe_test.go index ec2701d8b..be29a2363 100644 --- a/pkg/sentry/fsimpl/tmpfs/pipe_test.go +++ b/pkg/sentry/fsimpl/tmpfs/pipe_test.go @@ -158,7 +158,7 @@ func setup(t *testing.T) (context.Context, *auth.Credentials, *vfs.VirtualFilesy vfsObj.MustRegisterFilesystemType("tmpfs", FilesystemType{}, &vfs.RegisterFilesystemTypeOptions{ AllowUserMount: true, }) - mntns, err := vfsObj.NewMountNamespace(ctx, creds, "", "tmpfs", &vfs.GetFilesystemOptions{}) + mntns, err := vfsObj.NewMountNamespace(ctx, creds, "", "tmpfs", &vfs.MountOptions{}) if err != nil { t.Fatalf("failed to create tmpfs root mount: %v", err) } diff --git a/pkg/sentry/fsimpl/tmpfs/regular_file.go b/pkg/sentry/fsimpl/tmpfs/regular_file.go index 0710b65db..b8699d064 100644 --- a/pkg/sentry/fsimpl/tmpfs/regular_file.go +++ b/pkg/sentry/fsimpl/tmpfs/regular_file.go @@ -42,6 +42,10 @@ type regularFile struct { // memFile is a platform.File used to allocate pages to this regularFile. memFile *pgalloc.MemoryFile + // memoryUsageKind is the memory accounting category under which pages backing + // this regularFile's contents are accounted. + memoryUsageKind usage.MemoryKind + // mapsMu protects mappings. mapsMu sync.Mutex `state:"nosave"` @@ -86,14 +90,75 @@ type regularFile struct { func (fs *filesystem) newRegularFile(kuid auth.KUID, kgid auth.KGID, mode linux.FileMode) *inode { file := ®ularFile{ - memFile: fs.memFile, - seals: linux.F_SEAL_SEAL, + memFile: fs.memFile, + memoryUsageKind: usage.Tmpfs, + seals: linux.F_SEAL_SEAL, } file.inode.init(file, fs, kuid, kgid, linux.S_IFREG|mode) file.inode.nlink = 1 // from parent directory return &file.inode } +// newUnlinkedRegularFileDescription creates a regular file on the tmpfs +// filesystem represented by mount and returns an FD representing that file. +// The new file is not reachable by path traversal from any other file. +// +// newUnlinkedRegularFileDescription is analogous to Linux's +// mm/shmem.c:__shmem_file_setup(). +// +// Preconditions: mount must be a tmpfs mount. +func newUnlinkedRegularFileDescription(ctx context.Context, creds *auth.Credentials, mount *vfs.Mount, name string) (*regularFileFD, error) { + fs, ok := mount.Filesystem().Impl().(*filesystem) + if !ok { + panic("tmpfs.newUnlinkedRegularFileDescription() called with non-tmpfs mount") + } + + inode := fs.newRegularFile(creds.EffectiveKUID, creds.EffectiveKGID, 0777) + d := fs.newDentry(inode) + defer d.DecRef(ctx) + d.name = name + + fd := ®ularFileFD{} + fd.Init(&inode.locks) + flags := uint32(linux.O_RDWR) + if err := fd.vfsfd.Init(fd, flags, mount, &d.vfsd, &vfs.FileDescriptionOptions{}); err != nil { + return nil, err + } + return fd, nil +} + +// NewZeroFile creates a new regular file and file description as for +// mmap(MAP_SHARED | MAP_ANONYMOUS). The file has the given size and is +// initially (implicitly) filled with zeroes. +// +// Preconditions: mount must be a tmpfs mount. +func NewZeroFile(ctx context.Context, creds *auth.Credentials, mount *vfs.Mount, size uint64) (*vfs.FileDescription, error) { + // Compare mm/shmem.c:shmem_zero_setup(). + fd, err := newUnlinkedRegularFileDescription(ctx, creds, mount, "dev/zero") + if err != nil { + return nil, err + } + rf := fd.inode().impl.(*regularFile) + rf.memoryUsageKind = usage.Anonymous + rf.size = size + return &fd.vfsfd, err +} + +// NewMemfd creates a new regular file and file description as for +// memfd_create. +// +// Preconditions: mount must be a tmpfs mount. +func NewMemfd(ctx context.Context, creds *auth.Credentials, mount *vfs.Mount, allowSeals bool, name string) (*vfs.FileDescription, error) { + fd, err := newUnlinkedRegularFileDescription(ctx, creds, mount, name) + if err != nil { + return nil, err + } + if allowSeals { + fd.inode().impl.(*regularFile).seals = 0 + } + return &fd.vfsfd, nil +} + // truncate grows or shrinks the file to the given size. It returns true if the // file size was updated. func (rf *regularFile) truncate(newSize uint64) (bool, error) { @@ -226,7 +291,7 @@ func (rf *regularFile) Translate(ctx context.Context, required, optional memmap. optional.End = pgend } - cerr := rf.data.Fill(ctx, required, optional, rf.memFile, usage.Tmpfs, func(_ context.Context, dsts safemem.BlockSeq, _ uint64) (uint64, error) { + cerr := rf.data.Fill(ctx, required, optional, rf.memFile, rf.memoryUsageKind, func(_ context.Context, dsts safemem.BlockSeq, _ uint64) (uint64, error) { // Newly-allocated pages are zeroed, so we don't need to do anything. return dsts.NumBytes(), nil }) @@ -575,7 +640,7 @@ func (rw *regularFileReadWriter) WriteFromBlocks(srcs safemem.BlockSeq) (uint64, case gap.Ok(): // Allocate memory for the write. gapMR := gap.Range().Intersect(pgMR) - fr, err := rw.file.memFile.Allocate(gapMR.Length(), usage.Tmpfs) + fr, err := rw.file.memFile.Allocate(gapMR.Length(), rw.file.memoryUsageKind) if err != nil { retErr = err goto exitLoop diff --git a/pkg/sentry/fsimpl/tmpfs/tmpfs.go b/pkg/sentry/fsimpl/tmpfs/tmpfs.go index de2af6d01..4658e1533 100644 --- a/pkg/sentry/fsimpl/tmpfs/tmpfs.go +++ b/pkg/sentry/fsimpl/tmpfs/tmpfs.go @@ -201,6 +201,25 @@ func (fs *filesystem) Release(ctx context.Context) { fs.vfsfs.VirtualFilesystem().PutAnonBlockDevMinor(fs.devMinor) } +// immutable +var globalStatfs = linux.Statfs{ + Type: linux.TMPFS_MAGIC, + BlockSize: usermem.PageSize, + FragmentSize: usermem.PageSize, + NameLength: linux.NAME_MAX, + + // tmpfs currently does not support configurable size limits. In Linux, + // such a tmpfs mount will return f_blocks == f_bfree == f_bavail == 0 from + // statfs(2). However, many applications treat this as having a size limit + // of 0. To work around this, claim to have a very large but non-zero size, + // chosen to ensure that BlockSize * Blocks does not overflow int64 (which + // applications may also handle incorrectly). + // TODO(b/29637826): allow configuring a tmpfs size and enforce it. + Blocks: math.MaxInt64 / usermem.PageSize, + BlocksFree: math.MaxInt64 / usermem.PageSize, + BlocksAvailable: math.MaxInt64 / usermem.PageSize, +} + // dentry implements vfs.DentryImpl. type dentry struct { vfsd vfs.Dentry @@ -340,8 +359,10 @@ func (i *inode) init(impl interface{}, fs *filesystem, kuid auth.KUID, kgid auth // incLinksLocked increments i's link count. // -// Preconditions: filesystem.mu must be locked for writing. i.nlink != 0. -// i.nlink < maxLinks. +// Preconditions: +// * filesystem.mu must be locked for writing. +// * i.nlink != 0. +// * i.nlink < maxLinks. func (i *inode) incLinksLocked() { if i.nlink == 0 { panic("tmpfs.inode.incLinksLocked() called with no existing links") @@ -355,7 +376,9 @@ func (i *inode) incLinksLocked() { // decLinksLocked decrements i's link count. If the link count reaches 0, we // remove a reference on i as well. // -// Preconditions: filesystem.mu must be locked for writing. i.nlink != 0. +// Preconditions: +// * filesystem.mu must be locked for writing. +// * i.nlink != 0. func (i *inode) decLinksLocked(ctx context.Context) { if i.nlink == 0 { panic("tmpfs.inode.decLinksLocked() called with no existing links") @@ -594,62 +617,53 @@ func (i *inode) touchCMtime() { i.mu.Unlock() } -// Preconditions: The caller has called vfs.Mount.CheckBeginWrite() and holds -// inode.mu. +// Preconditions: +// * The caller has called vfs.Mount.CheckBeginWrite(). +// * inode.mu must be locked. func (i *inode) touchCMtimeLocked() { now := i.fs.clock.Now().Nanoseconds() atomic.StoreInt64(&i.mtime, now) atomic.StoreInt64(&i.ctime, now) } -func (i *inode) listxattr(size uint64) ([]string, error) { - return i.xattrs.Listxattr(size) +func (i *inode) listXattr(size uint64) ([]string, error) { + return i.xattrs.ListXattr(size) } -func (i *inode) getxattr(creds *auth.Credentials, opts *vfs.GetxattrOptions) (string, error) { - if err := i.checkPermissions(creds, vfs.MayRead); err != nil { +func (i *inode) getXattr(creds *auth.Credentials, opts *vfs.GetXattrOptions) (string, error) { + if err := i.checkXattrPermissions(creds, opts.Name, vfs.MayRead); err != nil { return "", err } - if !strings.HasPrefix(opts.Name, linux.XATTR_USER_PREFIX) { - return "", syserror.EOPNOTSUPP - } - if !i.userXattrSupported() { - return "", syserror.ENODATA - } - return i.xattrs.Getxattr(opts) + return i.xattrs.GetXattr(opts) } -func (i *inode) setxattr(creds *auth.Credentials, opts *vfs.SetxattrOptions) error { - if err := i.checkPermissions(creds, vfs.MayWrite); err != nil { +func (i *inode) setXattr(creds *auth.Credentials, opts *vfs.SetXattrOptions) error { + if err := i.checkXattrPermissions(creds, opts.Name, vfs.MayWrite); err != nil { return err } - if !strings.HasPrefix(opts.Name, linux.XATTR_USER_PREFIX) { - return syserror.EOPNOTSUPP - } - if !i.userXattrSupported() { - return syserror.EPERM - } - return i.xattrs.Setxattr(opts) + return i.xattrs.SetXattr(opts) } -func (i *inode) removexattr(creds *auth.Credentials, name string) error { - if err := i.checkPermissions(creds, vfs.MayWrite); err != nil { +func (i *inode) removeXattr(creds *auth.Credentials, name string) error { + if err := i.checkXattrPermissions(creds, name, vfs.MayWrite); err != nil { return err } - if !strings.HasPrefix(name, linux.XATTR_USER_PREFIX) { + return i.xattrs.RemoveXattr(name) +} + +func (i *inode) checkXattrPermissions(creds *auth.Credentials, name string, ats vfs.AccessTypes) error { + // We currently only support extended attributes in the user.* and + // trusted.* namespaces. See b/148380782. + if !strings.HasPrefix(name, linux.XATTR_USER_PREFIX) && !strings.HasPrefix(name, linux.XATTR_TRUSTED_PREFIX) { return syserror.EOPNOTSUPP } - if !i.userXattrSupported() { - return syserror.EPERM + mode := linux.FileMode(atomic.LoadUint32(&i.mode)) + kuid := auth.KUID(atomic.LoadUint32(&i.uid)) + kgid := auth.KGID(atomic.LoadUint32(&i.gid)) + if err := vfs.GenericCheckPermissions(creds, ats, mode, kuid, kgid); err != nil { + return err } - return i.xattrs.Removexattr(name) -} - -// Extended attributes in the user.* namespace are only supported for regular -// files and directories. -func (i *inode) userXattrSupported() bool { - filetype := linux.S_IFMT & atomic.LoadUint32(&i.mode) - return filetype == linux.S_IFREG || filetype == linux.S_IFDIR + return vfs.CheckXattrPermissions(creds, ats, mode, kuid, name) } // fileDescription is embedded by tmpfs implementations of @@ -693,20 +707,25 @@ func (fd *fileDescription) SetStat(ctx context.Context, opts vfs.SetStatOptions) return nil } -// Listxattr implements vfs.FileDescriptionImpl.Listxattr. -func (fd *fileDescription) Listxattr(ctx context.Context, size uint64) ([]string, error) { - return fd.inode().listxattr(size) +// StatFS implements vfs.FileDescriptionImpl.StatFS. +func (fd *fileDescription) StatFS(ctx context.Context) (linux.Statfs, error) { + return globalStatfs, nil +} + +// ListXattr implements vfs.FileDescriptionImpl.ListXattr. +func (fd *fileDescription) ListXattr(ctx context.Context, size uint64) ([]string, error) { + return fd.inode().listXattr(size) } -// Getxattr implements vfs.FileDescriptionImpl.Getxattr. -func (fd *fileDescription) Getxattr(ctx context.Context, opts vfs.GetxattrOptions) (string, error) { - return fd.inode().getxattr(auth.CredentialsFromContext(ctx), &opts) +// GetXattr implements vfs.FileDescriptionImpl.GetXattr. +func (fd *fileDescription) GetXattr(ctx context.Context, opts vfs.GetXattrOptions) (string, error) { + return fd.inode().getXattr(auth.CredentialsFromContext(ctx), &opts) } -// Setxattr implements vfs.FileDescriptionImpl.Setxattr. -func (fd *fileDescription) Setxattr(ctx context.Context, opts vfs.SetxattrOptions) error { +// SetXattr implements vfs.FileDescriptionImpl.SetXattr. +func (fd *fileDescription) SetXattr(ctx context.Context, opts vfs.SetXattrOptions) error { d := fd.dentry() - if err := d.inode.setxattr(auth.CredentialsFromContext(ctx), &opts); err != nil { + if err := d.inode.setXattr(auth.CredentialsFromContext(ctx), &opts); err != nil { return err } @@ -715,10 +734,10 @@ func (fd *fileDescription) Setxattr(ctx context.Context, opts vfs.SetxattrOption return nil } -// Removexattr implements vfs.FileDescriptionImpl.Removexattr. -func (fd *fileDescription) Removexattr(ctx context.Context, name string) error { +// RemoveXattr implements vfs.FileDescriptionImpl.RemoveXattr. +func (fd *fileDescription) RemoveXattr(ctx context.Context, name string) error { d := fd.dentry() - if err := d.inode.removexattr(auth.CredentialsFromContext(ctx), name); err != nil { + if err := d.inode.removeXattr(auth.CredentialsFromContext(ctx), name); err != nil { return err } @@ -727,37 +746,6 @@ func (fd *fileDescription) Removexattr(ctx context.Context, name string) error { return nil } -// NewMemfd creates a new tmpfs regular file and file description that can back -// an anonymous fd created by memfd_create. -func NewMemfd(ctx context.Context, creds *auth.Credentials, mount *vfs.Mount, allowSeals bool, name string) (*vfs.FileDescription, error) { - fs, ok := mount.Filesystem().Impl().(*filesystem) - if !ok { - panic("NewMemfd() called with non-tmpfs mount") - } - - // Per Linux, mm/shmem.c:__shmem_file_setup(), memfd inodes are set up with - // S_IRWXUGO. - inode := fs.newRegularFile(creds.EffectiveKUID, creds.EffectiveKGID, 0777) - rf := inode.impl.(*regularFile) - if allowSeals { - rf.seals = 0 - } - - d := fs.newDentry(inode) - defer d.DecRef(ctx) - d.name = name - - // Per Linux, mm/shmem.c:__shmem_file_setup(), memfd files are set up with - // FMODE_READ | FMODE_WRITE. - var fd regularFileFD - fd.Init(&inode.locks) - flags := uint32(linux.O_RDWR) - if err := fd.vfsfd.Init(&fd, flags, mount, &d.vfsd, &vfs.FileDescriptionOptions{}); err != nil { - return nil, err - } - return &fd.vfsfd, nil -} - // LockPOSIX implements vfs.FileDescriptionImpl.LockPOSIX. func (fd *fileDescription) LockPOSIX(ctx context.Context, uid fslock.UniqueID, t fslock.LockType, start, length uint64, whence int16, block fslock.Blocker) error { return fd.Locks().LockPOSIX(ctx, &fd.vfsfd, uid, t, start, length, whence, block) diff --git a/pkg/sentry/fsimpl/tmpfs/tmpfs_test.go b/pkg/sentry/fsimpl/tmpfs/tmpfs_test.go index 6f3e3ae6f..99c8e3c0f 100644 --- a/pkg/sentry/fsimpl/tmpfs/tmpfs_test.go +++ b/pkg/sentry/fsimpl/tmpfs/tmpfs_test.go @@ -41,7 +41,7 @@ func newTmpfsRoot(ctx context.Context) (*vfs.VirtualFilesystem, vfs.VirtualDentr vfsObj.MustRegisterFilesystemType("tmpfs", FilesystemType{}, &vfs.RegisterFilesystemTypeOptions{ AllowUserMount: true, }) - mntns, err := vfsObj.NewMountNamespace(ctx, creds, "", "tmpfs", &vfs.GetFilesystemOptions{}) + mntns, err := vfsObj.NewMountNamespace(ctx, creds, "", "tmpfs", &vfs.MountOptions{}) if err != nil { return nil, vfs.VirtualDentry{}, nil, fmt.Errorf("failed to create tmpfs root mount: %v", err) } diff --git a/pkg/sentry/fsimpl/verity/BUILD b/pkg/sentry/fsimpl/verity/BUILD index 28d2a4bcb..bc8e38431 100644 --- a/pkg/sentry/fsimpl/verity/BUILD +++ b/pkg/sentry/fsimpl/verity/BUILD @@ -13,11 +13,16 @@ go_library( "//pkg/abi/linux", "//pkg/context", "//pkg/fspath", + "//pkg/marshal/primitive", + "//pkg/merkletree", + "//pkg/sentry/arch", "//pkg/sentry/fs/lock", + "//pkg/sentry/kernel", "//pkg/sentry/kernel/auth", "//pkg/sentry/socket/unix/transport", "//pkg/sentry/vfs", "//pkg/sync", "//pkg/syserror", + "//pkg/usermem", ], ) diff --git a/pkg/sentry/fsimpl/verity/filesystem.go b/pkg/sentry/fsimpl/verity/filesystem.go index 78c6074bd..26b117ca4 100644 --- a/pkg/sentry/fsimpl/verity/filesystem.go +++ b/pkg/sentry/fsimpl/verity/filesystem.go @@ -15,9 +15,16 @@ package verity import ( + "bytes" + "fmt" + "io" + "strconv" + "strings" + "gvisor.dev/gvisor/pkg/abi/linux" "gvisor.dev/gvisor/pkg/context" "gvisor.dev/gvisor/pkg/fspath" + "gvisor.dev/gvisor/pkg/merkletree" "gvisor.dev/gvisor/pkg/sentry/kernel/auth" "gvisor.dev/gvisor/pkg/sentry/socket/unix/transport" "gvisor.dev/gvisor/pkg/sentry/vfs" @@ -91,10 +98,350 @@ func (fs *filesystem) renameMuUnlockAndCheckDrop(ctx context.Context, ds **[]*de putDentrySlice(*ds) } -// resolveLocked resolves rp to an existing file. -func (fs *filesystem) resolveLocked(ctx context.Context, rp *vfs.ResolvingPath, ds **[]*dentry) (*dentry, error) { - // TODO(b/159261227): Implement resolveLocked. - return nil, nil +// stepLocked resolves rp.Component() to an existing file, starting from the +// given directory. +// +// Dentries which may have a reference count of zero, and which therefore +// should be dropped once traversal is complete, are appended to ds. +// +// Preconditions: fs.renameMu must be locked. d.dirMu must be locked. +// !rp.Done(). +func (fs *filesystem) stepLocked(ctx context.Context, rp *vfs.ResolvingPath, d *dentry, mayFollowSymlinks bool, ds **[]*dentry) (*dentry, error) { + if !d.isDir() { + return nil, syserror.ENOTDIR + } + + if err := d.checkPermissions(rp.Credentials(), vfs.MayExec); err != nil { + return nil, err + } + +afterSymlink: + name := rp.Component() + if name == "." { + rp.Advance() + return d, nil + } + if name == ".." { + if isRoot, err := rp.CheckRoot(ctx, &d.vfsd); err != nil { + return nil, err + } else if isRoot || d.parent == nil { + rp.Advance() + return d, nil + } + if err := rp.CheckMount(ctx, &d.parent.vfsd); err != nil { + return nil, err + } + rp.Advance() + return d.parent, nil + } + child, err := fs.getChildLocked(ctx, d, name, ds) + if err != nil { + return nil, err + } + if err := rp.CheckMount(ctx, &child.vfsd); err != nil { + return nil, err + } + if child.isSymlink() && mayFollowSymlinks && rp.ShouldFollowSymlink() { + target, err := child.readlink(ctx) + if err != nil { + return nil, err + } + if err := rp.HandleSymlink(target); err != nil { + return nil, err + } + goto afterSymlink // don't check the current directory again + } + rp.Advance() + return child, nil +} + +// verifyChild verifies the root hash of child against the already verified +// root hash of the parent to ensure the child is expected. verifyChild +// triggers a sentry panic if unexpected modifications to the file system are +// detected. In noCrashOnVerificationFailure mode it returns a syserror +// instead. +// Preconditions: fs.renameMu must be locked. d.dirMu must be locked. +// TODO(b/166474175): Investigate all possible errors returned in this +// function, and make sure we differentiate all errors that indicate unexpected +// modifications to the file system from the ones that are not harmful. +func (fs *filesystem) verifyChild(ctx context.Context, parent *dentry, child *dentry) (*dentry, error) { + vfsObj := fs.vfsfs.VirtualFilesystem() + + // Get the path to the child dentry. This is only used to provide path + // information in failure case. + childPath, err := vfsObj.PathnameWithDeleted(ctx, child.fs.rootDentry.lowerVD, child.lowerVD) + if err != nil { + return nil, err + } + + verityMu.RLock() + defer verityMu.RUnlock() + // Read the offset of the child from the extended attributes of the + // corresponding Merkle tree file. + // This is the offset of the root hash for child in its parent's Merkle + // tree file. + off, err := vfsObj.GetXattrAt(ctx, fs.creds, &vfs.PathOperation{ + Root: child.lowerMerkleVD, + Start: child.lowerMerkleVD, + }, &vfs.GetXattrOptions{ + Name: merkleOffsetInParentXattr, + Size: sizeOfStringInt32, + }) + + // The Merkle tree file for the child should have been created and + // contains the expected xattrs. If the file or the xattr does not + // exist, it indicates unexpected modifications to the file system. + if err == syserror.ENOENT || err == syserror.ENODATA { + return nil, alertIntegrityViolation(err, fmt.Sprintf("Failed to get xattr %s for %s: %v", merkleOffsetInParentXattr, childPath, err)) + } + if err != nil { + return nil, err + } + // The offset xattr should be an integer. If it's not, it indicates + // unexpected modifications to the file system. + offset, err := strconv.Atoi(off) + if err != nil { + return nil, alertIntegrityViolation(err, fmt.Sprintf("Failed to convert xattr %s for %s to int: %v", merkleOffsetInParentXattr, childPath, err)) + } + + // Open parent Merkle tree file to read and verify child's root hash. + parentMerkleFD, err := vfsObj.OpenAt(ctx, fs.creds, &vfs.PathOperation{ + Root: parent.lowerMerkleVD, + Start: parent.lowerMerkleVD, + }, &vfs.OpenOptions{ + Flags: linux.O_RDONLY, + }) + + // The parent Merkle tree file should have been created. If it's + // missing, it indicates an unexpected modification to the file system. + if err == syserror.ENOENT { + return nil, alertIntegrityViolation(err, fmt.Sprintf("Failed to open parent Merkle file for %s: %v", childPath, err)) + } + if err != nil { + return nil, err + } + + // dataSize is the size of raw data for the Merkle tree. For a file, + // dataSize is the size of the whole file. For a directory, dataSize is + // the size of all its children's root hashes. + dataSize, err := parentMerkleFD.GetXattr(ctx, &vfs.GetXattrOptions{ + Name: merkleSizeXattr, + Size: sizeOfStringInt32, + }) + + // The Merkle tree file for the child should have been created and + // contains the expected xattrs. If the file or the xattr does not + // exist, it indicates unexpected modifications to the file system. + if err == syserror.ENOENT || err == syserror.ENODATA { + return nil, alertIntegrityViolation(err, fmt.Sprintf("Failed to get xattr %s for %s: %v", merkleSizeXattr, childPath, err)) + } + if err != nil { + return nil, err + } + + // The dataSize xattr should be an integer. If it's not, it indicates + // unexpected modifications to the file system. + parentSize, err := strconv.Atoi(dataSize) + if err != nil { + return nil, alertIntegrityViolation(syserror.EINVAL, fmt.Sprintf("Failed to convert xattr %s for %s to int: %v", merkleSizeXattr, childPath, err)) + } + + fdReader := vfs.FileReadWriteSeeker{ + FD: parentMerkleFD, + Ctx: ctx, + } + + // Since we are verifying against a directory Merkle tree, buf should + // contain the root hash of the children in the parent Merkle tree when + // Verify returns with success. + var buf bytes.Buffer + if _, err := merkletree.Verify(&buf, &fdReader, &fdReader, int64(parentSize), int64(offset), int64(merkletree.DigestSize()), parent.rootHash, true /* dataAndTreeInSameFile */); err != nil && err != io.EOF { + return nil, alertIntegrityViolation(syserror.EIO, fmt.Sprintf("Verification for %s failed: %v", childPath, err)) + } + + // Cache child root hash when it's verified the first time. + if len(child.rootHash) == 0 { + child.rootHash = buf.Bytes() + } + return child, nil +} + +// Preconditions: fs.renameMu must be locked. d.dirMu must be locked. +func (fs *filesystem) getChildLocked(ctx context.Context, parent *dentry, name string, ds **[]*dentry) (*dentry, error) { + if child, ok := parent.children[name]; ok { + // If enabling verification on files/directories is not allowed + // during runtime, all cached children are already verified. If + // runtime enable is allowed and the parent directory is + // enabled, we should verify the child root hash here because + // it may be cached before enabled. + if fs.allowRuntimeEnable && len(parent.rootHash) != 0 { + if _, err := fs.verifyChild(ctx, parent, child); err != nil { + return nil, err + } + } + return child, nil + } + child, err := fs.lookupAndVerifyLocked(ctx, parent, name) + if err != nil { + return nil, err + } + if parent.children == nil { + parent.children = make(map[string]*dentry) + } + parent.children[name] = child + // child's refcount is initially 0, so it may be dropped after traversal. + *ds = appendDentry(*ds, child) + return child, nil +} + +// Preconditions: fs.renameMu must be locked. parent.dirMu must be locked. +func (fs *filesystem) lookupAndVerifyLocked(ctx context.Context, parent *dentry, name string) (*dentry, error) { + vfsObj := fs.vfsfs.VirtualFilesystem() + + childFilename := fspath.Parse(name) + childVD, childErr := vfsObj.GetDentryAt(ctx, fs.creds, &vfs.PathOperation{ + Root: parent.lowerVD, + Start: parent.lowerVD, + Path: childFilename, + }, &vfs.GetDentryOptions{}) + + // We will handle ENOENT separately, as it may indicate unexpected + // modifications to the file system, and may cause a sentry panic. + if childErr != nil && childErr != syserror.ENOENT { + return nil, childErr + } + + // The dentry needs to be cleaned up if any error occurs. IncRef will be + // called if a verity child dentry is successfully created. + if childErr == nil { + defer childVD.DecRef(ctx) + } + + childMerkleFilename := merklePrefix + name + childMerkleVD, childMerkleErr := vfsObj.GetDentryAt(ctx, fs.creds, &vfs.PathOperation{ + Root: parent.lowerVD, + Start: parent.lowerVD, + Path: fspath.Parse(childMerkleFilename), + }, &vfs.GetDentryOptions{}) + + // We will handle ENOENT separately, as it may indicate unexpected + // modifications to the file system, and may cause a sentry panic. + if childMerkleErr != nil && childMerkleErr != syserror.ENOENT { + return nil, childMerkleErr + } + + // The dentry needs to be cleaned up if any error occurs. IncRef will be + // called if a verity child dentry is successfully created. + if childMerkleErr == nil { + defer childMerkleVD.DecRef(ctx) + } + + // Get the path to the parent dentry. This is only used to provide path + // information in failure case. + parentPath, err := vfsObj.PathnameWithDeleted(ctx, parent.fs.rootDentry.lowerVD, parent.lowerVD) + if err != nil { + return nil, err + } + + // TODO(b/166474175): Investigate all possible errors of childErr and + // childMerkleErr, and make sure we differentiate all errors that + // indicate unexpected modifications to the file system from the ones + // that are not harmful. + if childErr == syserror.ENOENT && childMerkleErr == nil { + // Failed to get child file/directory dentry. However the + // corresponding Merkle tree is found. This indicates an + // unexpected modification to the file system that + // removed/renamed the child. + return nil, alertIntegrityViolation(childErr, fmt.Sprintf("Target file %s is expected but missing", parentPath+"/"+name)) + } else if childErr == nil && childMerkleErr == syserror.ENOENT { + // If in allowRuntimeEnable mode, and the Merkle tree file is + // not created yet, we create an empty Merkle tree file, so that + // if the file is enabled through ioctl, we have the Merkle tree + // file open and ready to use. + // This may cause empty and unused Merkle tree files in + // allowRuntimeEnable mode, if they are never enabled. This + // does not affect verification, as we rely on cached root hash + // to decide whether to perform verification, not the existence + // of the Merkle tree file. Also, those Merkle tree files are + // always hidden and cannot be accessed by verity fs users. + if fs.allowRuntimeEnable { + childMerkleFD, err := vfsObj.OpenAt(ctx, fs.creds, &vfs.PathOperation{ + Root: parent.lowerVD, + Start: parent.lowerVD, + Path: fspath.Parse(childMerkleFilename), + }, &vfs.OpenOptions{ + Flags: linux.O_RDWR | linux.O_CREAT, + Mode: 0644, + }) + if err != nil { + return nil, err + } + childMerkleFD.DecRef(ctx) + childMerkleVD, err = vfsObj.GetDentryAt(ctx, fs.creds, &vfs.PathOperation{ + Root: parent.lowerVD, + Start: parent.lowerVD, + Path: fspath.Parse(childMerkleFilename), + }, &vfs.GetDentryOptions{}) + if err != nil { + return nil, err + } + } else { + // If runtime enable is not allowed. This indicates an + // unexpected modification to the file system that + // removed/renamed the Merkle tree file. + return nil, alertIntegrityViolation(childMerkleErr, fmt.Sprintf("Expected Merkle file for target %s but none found", parentPath+"/"+name)) + } + } else if childErr == syserror.ENOENT && childMerkleErr == syserror.ENOENT { + // Both the child and the corresponding Merkle tree are missing. + // This could be an unexpected modification or due to incorrect + // parameter. + // TODO(b/167752508): Investigate possible ways to differentiate + // cases that both files are deleted from cases that they never + // exist in the file system. + return nil, alertIntegrityViolation(childErr, fmt.Sprintf("Failed to find file %s", parentPath+"/"+name)) + } + + mask := uint32(linux.STATX_TYPE | linux.STATX_MODE | linux.STATX_UID | linux.STATX_GID) + stat, err := vfsObj.StatAt(ctx, fs.creds, &vfs.PathOperation{ + Root: childVD, + Start: childVD, + }, &vfs.StatOptions{ + Mask: mask, + }) + if err != nil { + return nil, err + } + + child := fs.newDentry() + child.lowerVD = childVD + child.lowerMerkleVD = childMerkleVD + + // Increase the reference for both childVD and childMerkleVD as they are + // held by child. If this function fails and the child is destroyed, the + // references will be decreased in destroyLocked. + childVD.IncRef() + childMerkleVD.IncRef() + + parent.IncRef() + child.parent = parent + child.name = name + + // TODO(b/162788573): Verify child metadata. + child.mode = uint32(stat.Mode) + child.uid = stat.UID + child.gid = stat.GID + + // Verify child root hash. This should always be performed unless in + // allowRuntimeEnable mode and the parent directory hasn't been enabled + // yet. + if !(fs.allowRuntimeEnable && len(parent.rootHash) == 0) { + if _, err := fs.verifyChild(ctx, parent, child); err != nil { + child.destroyLocked(ctx) + return nil, err + } + } + + return child, nil } // walkParentDirLocked resolves all but the last path component of rp to an @@ -104,8 +451,39 @@ func (fs *filesystem) resolveLocked(ctx context.Context, rp *vfs.ResolvingPath, // // Preconditions: fs.renameMu must be locked. !rp.Done(). func (fs *filesystem) walkParentDirLocked(ctx context.Context, rp *vfs.ResolvingPath, d *dentry, ds **[]*dentry) (*dentry, error) { - // TODO(b/159261227): Implement walkParentDirLocked. - return nil, nil + for !rp.Final() { + d.dirMu.Lock() + next, err := fs.stepLocked(ctx, rp, d, true /* mayFollowSymlinks */, ds) + d.dirMu.Unlock() + if err != nil { + return nil, err + } + d = next + } + if !d.isDir() { + return nil, syserror.ENOTDIR + } + return d, nil +} + +// resolveLocked resolves rp to an existing file. +// +// Preconditions: fs.renameMu must be locked. +func (fs *filesystem) resolveLocked(ctx context.Context, rp *vfs.ResolvingPath, ds **[]*dentry) (*dentry, error) { + d := rp.Start().Impl().(*dentry) + for !rp.Done() { + d.dirMu.Lock() + next, err := fs.stepLocked(ctx, rp, d, true /* mayFollowSymlinks */, ds) + d.dirMu.Unlock() + if err != nil { + return nil, err + } + d = next + } + if rp.MustBeDir() && !d.isDir() { + return nil, syserror.ENOTDIR + } + return d, nil } // AccessAt implements vfs.Filesystem.Impl.AccessAt. @@ -179,8 +557,181 @@ func (fs *filesystem) MknodAt(ctx context.Context, rp *vfs.ResolvingPath, opts v // OpenAt implements vfs.FilesystemImpl.OpenAt. func (fs *filesystem) OpenAt(ctx context.Context, rp *vfs.ResolvingPath, opts vfs.OpenOptions) (*vfs.FileDescription, error) { - //TODO(b/159261227): Implement OpenAt. - return nil, nil + // Verity fs is read-only. + if opts.Flags&(linux.O_WRONLY|linux.O_CREAT) != 0 { + return nil, syserror.EROFS + } + + var ds *[]*dentry + fs.renameMu.RLock() + defer fs.renameMuRUnlockAndCheckDrop(ctx, &ds) + + start := rp.Start().Impl().(*dentry) + if rp.Done() { + return start.openLocked(ctx, rp, &opts) + } + +afterTrailingSymlink: + parent, err := fs.walkParentDirLocked(ctx, rp, start, &ds) + if err != nil { + return nil, err + } + + // Check for search permission in the parent directory. + if err := parent.checkPermissions(rp.Credentials(), vfs.MayExec); err != nil { + return nil, err + } + + // Open existing child or follow symlink. + parent.dirMu.Lock() + child, err := fs.stepLocked(ctx, rp, parent, false /*mayFollowSymlinks*/, &ds) + parent.dirMu.Unlock() + if err != nil { + return nil, err + } + if child.isSymlink() && rp.ShouldFollowSymlink() { + target, err := child.readlink(ctx) + if err != nil { + return nil, err + } + if err := rp.HandleSymlink(target); err != nil { + return nil, err + } + start = parent + goto afterTrailingSymlink + } + return child.openLocked(ctx, rp, &opts) +} + +// Preconditions: fs.renameMu must be locked. +func (d *dentry) openLocked(ctx context.Context, rp *vfs.ResolvingPath, opts *vfs.OpenOptions) (*vfs.FileDescription, error) { + // Users should not open the Merkle tree files. Those are for verity fs + // use only. + if strings.Contains(d.name, merklePrefix) { + return nil, syserror.EPERM + } + ats := vfs.AccessTypesForOpenFlags(opts) + if err := d.checkPermissions(rp.Credentials(), ats); err != nil { + return nil, err + } + + // Verity fs is read-only. + if ats&vfs.MayWrite != 0 { + return nil, syserror.EROFS + } + + // Get the path to the target file. This is only used to provide path + // information in failure case. + path, err := d.fs.vfsfs.VirtualFilesystem().PathnameWithDeleted(ctx, d.fs.rootDentry.lowerVD, d.lowerVD) + if err != nil { + return nil, err + } + + // Open the file in the underlying file system. + lowerFD, err := rp.VirtualFilesystem().OpenAt(ctx, d.fs.creds, &vfs.PathOperation{ + Root: d.lowerVD, + Start: d.lowerVD, + }, opts) + + // The file should exist, as we succeeded in finding its dentry. If it's + // missing, it indicates an unexpected modification to the file system. + if err != nil { + if err == syserror.ENOENT { + return nil, alertIntegrityViolation(err, fmt.Sprintf("File %s expected but not found", path)) + } + return nil, err + } + + // lowerFD needs to be cleaned up if any error occurs. IncRef will be + // called if a verity FD is successfully created. + defer lowerFD.DecRef(ctx) + + // Open the Merkle tree file corresponding to the current file/directory + // to be used later for verifying Read/Walk. + merkleReader, err := rp.VirtualFilesystem().OpenAt(ctx, d.fs.creds, &vfs.PathOperation{ + Root: d.lowerMerkleVD, + Start: d.lowerMerkleVD, + }, &vfs.OpenOptions{ + Flags: linux.O_RDONLY, + }) + + // The Merkle tree file should exist, as we succeeded in finding its + // dentry. If it's missing, it indicates an unexpected modification to + // the file system. + if err != nil { + if err == syserror.ENOENT { + return nil, alertIntegrityViolation(err, fmt.Sprintf("Merkle file for %s expected but not found", path)) + } + return nil, err + } + + // merkleReader needs to be cleaned up if any error occurs. IncRef will + // be called if a verity FD is successfully created. + defer merkleReader.DecRef(ctx) + + lowerFlags := lowerFD.StatusFlags() + lowerFDOpts := lowerFD.Options() + var merkleWriter *vfs.FileDescription + var parentMerkleWriter *vfs.FileDescription + + // Only open the Merkle tree files for write if in allowRuntimeEnable + // mode. + if d.fs.allowRuntimeEnable { + merkleWriter, err = rp.VirtualFilesystem().OpenAt(ctx, d.fs.creds, &vfs.PathOperation{ + Root: d.lowerMerkleVD, + Start: d.lowerMerkleVD, + }, &vfs.OpenOptions{ + Flags: linux.O_WRONLY | linux.O_APPEND, + }) + if err != nil { + if err == syserror.ENOENT { + return nil, alertIntegrityViolation(err, fmt.Sprintf("Merkle file for %s expected but not found", path)) + } + return nil, err + } + // merkleWriter is cleaned up if any error occurs. IncRef will + // be called if a verity FD is created successfully. + defer merkleWriter.DecRef(ctx) + + parentMerkleWriter, err = rp.VirtualFilesystem().OpenAt(ctx, d.fs.creds, &vfs.PathOperation{ + Root: d.parent.lowerMerkleVD, + Start: d.parent.lowerMerkleVD, + }, &vfs.OpenOptions{ + Flags: linux.O_WRONLY | linux.O_APPEND, + }) + if err != nil { + if err == syserror.ENOENT { + parentPath, _ := d.fs.vfsfs.VirtualFilesystem().PathnameWithDeleted(ctx, d.fs.rootDentry.lowerVD, d.parent.lowerVD) + return nil, alertIntegrityViolation(err, fmt.Sprintf("Merkle file for %s expected but not found", parentPath)) + } + return nil, err + } + // parentMerkleWriter is cleaned up if any error occurs. IncRef + // will be called if a verity FD is created successfully. + defer parentMerkleWriter.DecRef(ctx) + } + + fd := &fileDescription{ + d: d, + lowerFD: lowerFD, + merkleReader: merkleReader, + merkleWriter: merkleWriter, + parentMerkleWriter: parentMerkleWriter, + isDir: d.isDir(), + } + + if err := fd.vfsfd.Init(fd, lowerFlags, rp.Mount(), &d.vfsd, &lowerFDOpts); err != nil { + return nil, err + } + lowerFD.IncRef() + merkleReader.IncRef() + if merkleWriter != nil { + merkleWriter.IncRef() + } + if parentMerkleWriter != nil { + parentMerkleWriter.IncRef() + } + return &fd.vfsfd, err } // ReadlinkAt implements vfs.FilesystemImpl.ReadlinkAt. @@ -256,7 +807,7 @@ func (fs *filesystem) UnlinkAt(ctx context.Context, rp *vfs.ResolvingPath) error return syserror.EROFS } -// BoundEndpointAt implements FilesystemImpl.BoundEndpointAt. +// BoundEndpointAt implements vfs.FilesystemImpl.BoundEndpointAt. func (fs *filesystem) BoundEndpointAt(ctx context.Context, rp *vfs.ResolvingPath, opts vfs.BoundEndpointOptions) (transport.BoundEndpoint, error) { var ds *[]*dentry fs.renameMu.RLock() @@ -267,8 +818,8 @@ func (fs *filesystem) BoundEndpointAt(ctx context.Context, rp *vfs.ResolvingPath return nil, syserror.ECONNREFUSED } -// ListxattrAt implements vfs.FilesystemImpl.ListxattrAt. -func (fs *filesystem) ListxattrAt(ctx context.Context, rp *vfs.ResolvingPath, size uint64) ([]string, error) { +// ListXattrAt implements vfs.FilesystemImpl.ListXattrAt. +func (fs *filesystem) ListXattrAt(ctx context.Context, rp *vfs.ResolvingPath, size uint64) ([]string, error) { var ds *[]*dentry fs.renameMu.RLock() defer fs.renameMuRUnlockAndCheckDrop(ctx, &ds) @@ -277,14 +828,14 @@ func (fs *filesystem) ListxattrAt(ctx context.Context, rp *vfs.ResolvingPath, si return nil, err } lowerVD := d.lowerVD - return fs.vfsfs.VirtualFilesystem().ListxattrAt(ctx, d.fs.creds, &vfs.PathOperation{ + return fs.vfsfs.VirtualFilesystem().ListXattrAt(ctx, d.fs.creds, &vfs.PathOperation{ Root: lowerVD, Start: lowerVD, }, size) } -// GetxattrAt implements vfs.FilesystemImpl.GetxattrAt. -func (fs *filesystem) GetxattrAt(ctx context.Context, rp *vfs.ResolvingPath, opts vfs.GetxattrOptions) (string, error) { +// GetXattrAt implements vfs.FilesystemImpl.GetXattrAt. +func (fs *filesystem) GetXattrAt(ctx context.Context, rp *vfs.ResolvingPath, opts vfs.GetXattrOptions) (string, error) { var ds *[]*dentry fs.renameMu.RLock() defer fs.renameMuRUnlockAndCheckDrop(ctx, &ds) @@ -293,20 +844,20 @@ func (fs *filesystem) GetxattrAt(ctx context.Context, rp *vfs.ResolvingPath, opt return "", err } lowerVD := d.lowerVD - return fs.vfsfs.VirtualFilesystem().GetxattrAt(ctx, d.fs.creds, &vfs.PathOperation{ + return fs.vfsfs.VirtualFilesystem().GetXattrAt(ctx, d.fs.creds, &vfs.PathOperation{ Root: lowerVD, Start: lowerVD, }, &opts) } -// SetxattrAt implements vfs.FilesystemImpl.SetxattrAt. -func (fs *filesystem) SetxattrAt(ctx context.Context, rp *vfs.ResolvingPath, opts vfs.SetxattrOptions) error { +// SetXattrAt implements vfs.FilesystemImpl.SetXattrAt. +func (fs *filesystem) SetXattrAt(ctx context.Context, rp *vfs.ResolvingPath, opts vfs.SetXattrOptions) error { // Verity file system is read-only. return syserror.EROFS } -// RemovexattrAt implements vfs.FilesystemImpl.RemovexattrAt. -func (fs *filesystem) RemovexattrAt(ctx context.Context, rp *vfs.ResolvingPath, name string) error { +// RemoveXattrAt implements vfs.FilesystemImpl.RemoveXattrAt. +func (fs *filesystem) RemoveXattrAt(ctx context.Context, rp *vfs.ResolvingPath, name string) error { // Verity file system is read-only. return syserror.EROFS } diff --git a/pkg/sentry/fsimpl/verity/verity.go b/pkg/sentry/fsimpl/verity/verity.go index cb29d33a5..9182df317 100644 --- a/pkg/sentry/fsimpl/verity/verity.go +++ b/pkg/sentry/fsimpl/verity/verity.go @@ -22,24 +22,56 @@ package verity import ( + "fmt" + "strconv" "sync/atomic" "gvisor.dev/gvisor/pkg/abi/linux" "gvisor.dev/gvisor/pkg/context" + "gvisor.dev/gvisor/pkg/fspath" + "gvisor.dev/gvisor/pkg/marshal/primitive" + "gvisor.dev/gvisor/pkg/merkletree" + "gvisor.dev/gvisor/pkg/sentry/arch" fslock "gvisor.dev/gvisor/pkg/sentry/fs/lock" + "gvisor.dev/gvisor/pkg/sentry/kernel" "gvisor.dev/gvisor/pkg/sentry/kernel/auth" "gvisor.dev/gvisor/pkg/sentry/vfs" "gvisor.dev/gvisor/pkg/sync" "gvisor.dev/gvisor/pkg/syserror" + "gvisor.dev/gvisor/pkg/usermem" ) // Name is the default filesystem name. const Name = "verity" -// testOnlyDebugging allows verity file system to return error instead of -// crashing the application when a malicious action is detected. This should -// only be set for tests. -var testOnlyDebugging bool +// merklePrefix is the prefix of the Merkle tree files. For example, the Merkle +// tree file for "/foo" is "/.merkle.verity.foo". +const merklePrefix = ".merkle.verity." + +// merkleoffsetInParentXattr is the extended attribute name specifying the +// offset of child root hash in its parent's Merkle tree. +const merkleOffsetInParentXattr = "user.merkle.offset" + +// merkleSizeXattr is the extended attribute name specifying the size of data +// hashed by the corresponding Merkle tree. For a file, it's the size of the +// whole file. For a directory, it's the size of all its children's root hashes. +const merkleSizeXattr = "user.merkle.size" + +// sizeOfStringInt32 is the size for a 32 bit integer stored as string in +// extended attributes. The maximum value of a 32 bit integer is 10 digits. +const sizeOfStringInt32 = 10 + +// noCrashOnVerificationFailure indicates whether the sandbox should panic +// whenever verification fails. If true, an error is returned instead of +// panicking. This should only be set for tests. +// TOOD(b/165661693): Decide whether to panic or return error based on this +// flag. +var noCrashOnVerificationFailure bool + +// verityMu synchronizes enabling verity files, protects files or directories +// from being enabled by different threads simultaneously. It also ensures that +// verity does not access files that are being enabled. +var verityMu sync.RWMutex // FilesystemType implements vfs.FilesystemType. type FilesystemType struct{} @@ -93,10 +125,10 @@ type InternalFilesystemOptions struct { // system wrapped by verity file system. LowerGetFSOptions vfs.GetFilesystemOptions - // TestOnlyDebugging allows verity file system to return error instead - // of crashing the application when a malicious action is detected. This - // should only be set for tests. - TestOnlyDebugging bool + // NoCrashOnVerificationFailure indicates whether the sandbox should + // panic whenever verification fails. If true, an error is returned + // instead of panicking. This should only be set for tests. + NoCrashOnVerificationFailure bool } // Name implements vfs.FilesystemType.Name. @@ -104,10 +136,120 @@ func (FilesystemType) Name() string { return Name } +// alertIntegrityViolation alerts a violation of integrity, which usually means +// unexpected modification to the file system is detected. In +// noCrashOnVerificationFailure mode, it returns an error, otherwise it panic. +func alertIntegrityViolation(err error, msg string) error { + if noCrashOnVerificationFailure { + return err + } + panic(msg) +} + // GetFilesystem implements vfs.FilesystemType.GetFilesystem. func (fstype FilesystemType) GetFilesystem(ctx context.Context, vfsObj *vfs.VirtualFilesystem, creds *auth.Credentials, source string, opts vfs.GetFilesystemOptions) (*vfs.Filesystem, *vfs.Dentry, error) { - //TODO(b/159261227): Implement GetFilesystem. - return nil, nil, nil + iopts, ok := opts.InternalData.(InternalFilesystemOptions) + if !ok { + ctx.Warningf("verity.FilesystemType.GetFilesystem: missing verity configs") + return nil, nil, syserror.EINVAL + } + noCrashOnVerificationFailure = iopts.NoCrashOnVerificationFailure + + // Mount the lower file system. The lower file system is wrapped inside + // verity, and should not be exposed or connected. + mopts := &vfs.MountOptions{ + GetFilesystemOptions: iopts.LowerGetFSOptions, + } + mnt, err := vfsObj.MountDisconnected(ctx, creds, "", iopts.LowerName, mopts) + if err != nil { + return nil, nil, err + } + + fs := &filesystem{ + creds: creds.Fork(), + lowerMount: mnt, + allowRuntimeEnable: iopts.AllowRuntimeEnable, + } + fs.vfsfs.Init(vfsObj, &fstype, fs) + + // Construct the root dentry. + d := fs.newDentry() + d.refs = 1 + lowerVD := vfs.MakeVirtualDentry(mnt, mnt.Root()) + lowerVD.IncRef() + d.lowerVD = lowerVD + + rootMerkleName := merklePrefix + iopts.RootMerkleFileName + + lowerMerkleVD, err := vfsObj.GetDentryAt(ctx, fs.creds, &vfs.PathOperation{ + Root: lowerVD, + Start: lowerVD, + Path: fspath.Parse(rootMerkleName), + }, &vfs.GetDentryOptions{}) + + // If runtime enable is allowed, the root merkle tree may be absent. We + // should create the tree file. + if err == syserror.ENOENT && fs.allowRuntimeEnable { + lowerMerkleFD, err := vfsObj.OpenAt(ctx, fs.creds, &vfs.PathOperation{ + Root: lowerVD, + Start: lowerVD, + Path: fspath.Parse(rootMerkleName), + }, &vfs.OpenOptions{ + Flags: linux.O_RDWR | linux.O_CREAT, + Mode: 0644, + }) + if err != nil { + fs.vfsfs.DecRef(ctx) + d.DecRef(ctx) + return nil, nil, err + } + lowerMerkleFD.DecRef(ctx) + lowerMerkleVD, err = vfsObj.GetDentryAt(ctx, fs.creds, &vfs.PathOperation{ + Root: lowerVD, + Start: lowerVD, + Path: fspath.Parse(rootMerkleName), + }, &vfs.GetDentryOptions{}) + if err != nil { + fs.vfsfs.DecRef(ctx) + d.DecRef(ctx) + return nil, nil, err + } + } else if err != nil { + // Failed to get dentry for the root Merkle file. This + // indicates an unexpected modification that removed/renamed + // the root Merkle file, or it's never generated. + fs.vfsfs.DecRef(ctx) + d.DecRef(ctx) + return nil, nil, alertIntegrityViolation(err, "Failed to find root Merkle file") + } + d.lowerMerkleVD = lowerMerkleVD + + // Get metadata from the underlying file system. + const statMask = linux.STATX_TYPE | linux.STATX_MODE | linux.STATX_UID | linux.STATX_GID + stat, err := vfsObj.StatAt(ctx, creds, &vfs.PathOperation{ + Root: lowerVD, + Start: lowerVD, + }, &vfs.StatOptions{ + Mask: statMask, + }) + if err != nil { + fs.vfsfs.DecRef(ctx) + d.DecRef(ctx) + return nil, nil, err + } + + // TODO(b/162788573): Verify Metadata. + d.mode = uint32(stat.Mode) + d.uid = stat.UID + d.gid = stat.GID + + d.rootHash = make([]byte, len(iopts.RootHash)) + copy(d.rootHash, iopts.RootHash) + d.vfsd.Init(d) + + fs.rootDentry = d + + return &fs.vfsfs, &d.vfsd, nil } // Release implements vfs.FilesystemImpl.Release. @@ -344,6 +486,194 @@ func (fd *fileDescription) SetStat(ctx context.Context, opts vfs.SetStatOptions) return syserror.EPERM } +// generateMerkle generates a Merkle tree file for fd. If fd points to a file +// /foo/bar, a Merkle tree file /foo/.merkle.verity.bar is generated. The root +// hash of the generated Merkle tree and the data size is returned. +// If fd points to a regular file, the data is the content of the file. If fd +// points to a directory, the data is all root hahes of its children, written +// to the Merkle tree file. +func (fd *fileDescription) generateMerkle(ctx context.Context) ([]byte, uint64, error) { + fdReader := vfs.FileReadWriteSeeker{ + FD: fd.lowerFD, + Ctx: ctx, + } + merkleReader := vfs.FileReadWriteSeeker{ + FD: fd.merkleReader, + Ctx: ctx, + } + merkleWriter := vfs.FileReadWriteSeeker{ + FD: fd.merkleWriter, + Ctx: ctx, + } + var rootHash []byte + var dataSize uint64 + + switch atomic.LoadUint32(&fd.d.mode) & linux.S_IFMT { + case linux.S_IFREG: + // For a regular file, generate a Merkle tree based on its + // content. + var err error + stat, err := fd.lowerFD.Stat(ctx, vfs.StatOptions{}) + if err != nil { + return nil, 0, err + } + dataSize = stat.Size + + rootHash, err = merkletree.Generate(&fdReader, int64(dataSize), &merkleReader, &merkleWriter, false /* dataAndTreeInSameFile */) + if err != nil { + return nil, 0, err + } + case linux.S_IFDIR: + // For a directory, generate a Merkle tree based on the root + // hashes of its children that has already been written to the + // Merkle tree file. + merkleStat, err := fd.merkleReader.Stat(ctx, vfs.StatOptions{}) + if err != nil { + return nil, 0, err + } + dataSize = merkleStat.Size + + rootHash, err = merkletree.Generate(&merkleReader, int64(dataSize), &merkleReader, &merkleWriter, true /* dataAndTreeInSameFile */) + if err != nil { + return nil, 0, err + } + default: + // TODO(b/167728857): Investigate whether and how we should + // enable other types of file. + return nil, 0, syserror.EINVAL + } + return rootHash, dataSize, nil +} + +// enableVerity enables verity features on fd by generating a Merkle tree file +// and stores its root hash in its parent directory's Merkle tree. +func (fd *fileDescription) enableVerity(ctx context.Context, uio usermem.IO, args arch.SyscallArguments) (uintptr, error) { + if !fd.d.fs.allowRuntimeEnable { + return 0, syserror.EPERM + } + + // Lock to prevent other threads performing enable or access the file + // while it's being enabled. + verityMu.Lock() + defer verityMu.Unlock() + + if fd.lowerFD == nil || fd.merkleReader == nil || fd.merkleWriter == nil || fd.parentMerkleWriter == nil { + return 0, alertIntegrityViolation(syserror.EIO, "Unexpected verity fd: missing expected underlying fds") + } + + rootHash, dataSize, err := fd.generateMerkle(ctx) + if err != nil { + return 0, err + } + + stat, err := fd.parentMerkleWriter.Stat(ctx, vfs.StatOptions{}) + if err != nil { + return 0, err + } + + // Write the root hash of fd to the parent directory's Merkle tree + // file, as it should be part of the parent Merkle tree data. + // parentMerkleWriter is open with O_APPEND, so it should write + // directly to the end of the file. + if _, err = fd.parentMerkleWriter.Write(ctx, usermem.BytesIOSequence(rootHash), vfs.WriteOptions{}); err != nil { + return 0, err + } + + // Record the offset of the root hash of fd in parent directory's + // Merkle tree file. + if err := fd.merkleWriter.SetXattr(ctx, &vfs.SetXattrOptions{ + Name: merkleOffsetInParentXattr, + Value: strconv.Itoa(int(stat.Size)), + }); err != nil { + return 0, err + } + + // Record the size of the data being hashed for fd. + if err := fd.merkleWriter.SetXattr(ctx, &vfs.SetXattrOptions{ + Name: merkleSizeXattr, + Value: strconv.Itoa(int(dataSize)), + }); err != nil { + return 0, err + } + fd.d.rootHash = append(fd.d.rootHash, rootHash...) + return 0, nil +} + +func (fd *fileDescription) getFlags(ctx context.Context, uio usermem.IO, args arch.SyscallArguments) (uintptr, error) { + f := int32(0) + + // All enabled files should store a root hash. This flag is not settable + // via FS_IOC_SETFLAGS. + if len(fd.d.rootHash) != 0 { + f |= linux.FS_VERITY_FL + } + + t := kernel.TaskFromContext(ctx) + addr := args[2].Pointer() + _, err := primitive.CopyInt32Out(t, addr, f) + return 0, err +} + +// Ioctl implements vfs.FileDescriptionImpl.Ioctl. +func (fd *fileDescription) Ioctl(ctx context.Context, uio usermem.IO, args arch.SyscallArguments) (uintptr, error) { + switch cmd := args[1].Uint(); cmd { + case linux.FS_IOC_ENABLE_VERITY: + return fd.enableVerity(ctx, uio, args) + case linux.FS_IOC_GETFLAGS: + return fd.getFlags(ctx, uio, args) + default: + return fd.lowerFD.Ioctl(ctx, uio, args) + } +} + +// PRead implements vfs.FileDescriptionImpl.PRead. +func (fd *fileDescription) PRead(ctx context.Context, dst usermem.IOSequence, offset int64, opts vfs.ReadOptions) (int64, error) { + // No need to verify if the file is not enabled yet in + // allowRuntimeEnable mode. + if fd.d.fs.allowRuntimeEnable && len(fd.d.rootHash) == 0 { + return fd.lowerFD.PRead(ctx, dst, offset, opts) + } + + // dataSize is the size of the whole file. + dataSize, err := fd.merkleReader.GetXattr(ctx, &vfs.GetXattrOptions{ + Name: merkleSizeXattr, + Size: sizeOfStringInt32, + }) + + // The Merkle tree file for the child should have been created and + // contains the expected xattrs. If the xattr does not exist, it + // indicates unexpected modifications to the file system. + if err == syserror.ENODATA { + return 0, alertIntegrityViolation(err, fmt.Sprintf("Failed to get xattr %s: %v", merkleSizeXattr, err)) + } + if err != nil { + return 0, err + } + + // The dataSize xattr should be an integer. If it's not, it indicates + // unexpected modifications to the file system. + size, err := strconv.Atoi(dataSize) + if err != nil { + return 0, alertIntegrityViolation(err, fmt.Sprintf("Failed to convert xattr %s to int: %v", merkleSizeXattr, err)) + } + + dataReader := vfs.FileReadWriteSeeker{ + FD: fd.lowerFD, + Ctx: ctx, + } + + merkleReader := vfs.FileReadWriteSeeker{ + FD: fd.merkleReader, + Ctx: ctx, + } + + n, err := merkletree.Verify(dst.Writer(ctx), &dataReader, &merkleReader, int64(size), offset, dst.NumBytes(), fd.d.rootHash, false /* dataAndTreeInSameFile */) + if err != nil { + return 0, alertIntegrityViolation(syserror.EINVAL, fmt.Sprintf("Verification failed: %v", err)) + } + return n, err +} + // LockPOSIX implements vfs.FileDescriptionImpl.LockPOSIX. func (fd *fileDescription) LockPOSIX(ctx context.Context, uid fslock.UniqueID, t fslock.LockType, start, length uint64, whence int16, block fslock.Blocker) error { return fd.Locks().LockPOSIX(ctx, &fd.vfsfd, uid, t, start, length, whence, block) diff --git a/pkg/sentry/kernel/BUILD b/pkg/sentry/kernel/BUILD index 5416a310d..a43c549f1 100644 --- a/pkg/sentry/kernel/BUILD +++ b/pkg/sentry/kernel/BUILD @@ -74,6 +74,50 @@ go_template_instance( }, ) +go_template_instance( + name = "fd_table_refs", + out = "fd_table_refs.go", + package = "kernel", + prefix = "FDTable", + template = "//pkg/refs_vfs2:refs_template", + types = { + "T": "FDTable", + }, +) + +go_template_instance( + name = "fs_context_refs", + out = "fs_context_refs.go", + package = "kernel", + prefix = "FSContext", + template = "//pkg/refs_vfs2:refs_template", + types = { + "T": "FSContext", + }, +) + +go_template_instance( + name = "process_group_refs", + out = "process_group_refs.go", + package = "kernel", + prefix = "ProcessGroup", + template = "//pkg/refs_vfs2:refs_template", + types = { + "T": "ProcessGroup", + }, +) + +go_template_instance( + name = "session_refs", + out = "session_refs.go", + package = "kernel", + prefix = "Session", + template = "//pkg/refs_vfs2:refs_template", + types = { + "T": "Session", + }, +) + proto_library( name = "uncaught_signal", srcs = ["uncaught_signal.proto"], @@ -88,9 +132,13 @@ go_library( "aio.go", "context.go", "fd_table.go", + "fd_table_refs.go", "fd_table_unsafe.go", "fs_context.go", + "fs_context_refs.go", "ipc_namespace.go", + "kcov.go", + "kcov_unsafe.go", "kernel.go", "kernel_opts.go", "kernel_state.go", @@ -99,6 +147,7 @@ go_library( "pending_signals_state.go", "posixtimer.go", "process_group_list.go", + "process_group_refs.go", "ptrace.go", "ptrace_amd64.go", "ptrace_arm64.go", @@ -106,6 +155,7 @@ go_library( "seccomp.go", "seqatomic_taskgoroutineschedinfo_unsafe.go", "session_list.go", + "session_refs.go", "sessions.go", "signal.go", "signal_handlers.go", @@ -147,6 +197,7 @@ go_library( "gvisor.dev/gvisor/pkg/sentry/device", "gvisor.dev/gvisor/pkg/tcpip", ], + marshal = True, visibility = ["//:sandbox"], deps = [ ":uncaught_signal_go_proto", @@ -157,10 +208,13 @@ go_library( "//pkg/bits", "//pkg/bpf", "//pkg/context", + "//pkg/coverage", "//pkg/cpuid", "//pkg/eventchannel", "//pkg/fspath", "//pkg/log", + "//pkg/marshal", + "//pkg/marshal/primitive", "//pkg/metric", "//pkg/refs", "//pkg/refs_vfs2", @@ -210,7 +264,6 @@ go_library( "//pkg/tcpip/stack", "//pkg/usermem", "//pkg/waiter", - "//tools/go_marshal/marshal", ], ) diff --git a/pkg/sentry/kernel/auth/BUILD b/pkg/sentry/kernel/auth/BUILD index 2bc49483a..869e49ebc 100644 --- a/pkg/sentry/kernel/auth/BUILD +++ b/pkg/sentry/kernel/auth/BUILD @@ -57,6 +57,7 @@ go_library( "id_map_set.go", "user_namespace.go", ], + marshal = True, visibility = ["//pkg/sentry:internal"], deps = [ "//pkg/abi/linux", diff --git a/pkg/sentry/kernel/auth/context.go b/pkg/sentry/kernel/auth/context.go index ef5723127..c08d47787 100644 --- a/pkg/sentry/kernel/auth/context.go +++ b/pkg/sentry/kernel/auth/context.go @@ -34,3 +34,23 @@ func CredentialsFromContext(ctx context.Context) *Credentials { } return NewAnonymousCredentials() } + +// ContextWithCredentials returns a copy of ctx carrying creds. +func ContextWithCredentials(ctx context.Context, creds *Credentials) context.Context { + return &authContext{ctx, creds} +} + +type authContext struct { + context.Context + creds *Credentials +} + +// Value implements context.Context. +func (ac *authContext) Value(key interface{}) interface{} { + switch key { + case CtxCredentials: + return ac.creds + default: + return ac.Context.Value(key) + } +} diff --git a/pkg/sentry/kernel/auth/id.go b/pkg/sentry/kernel/auth/id.go index 0a58ba17c..4c32ee703 100644 --- a/pkg/sentry/kernel/auth/id.go +++ b/pkg/sentry/kernel/auth/id.go @@ -19,9 +19,13 @@ import ( ) // UID is a user ID in an unspecified user namespace. +// +// +marshal type UID uint32 // GID is a group ID in an unspecified user namespace. +// +// +marshal slice:GIDSlice type GID uint32 // In the root user namespace, user/group IDs have a 1-to-1 relationship with diff --git a/pkg/sentry/kernel/fd_table.go b/pkg/sentry/kernel/fd_table.go index ce53af69b..0ec7344cd 100644 --- a/pkg/sentry/kernel/fd_table.go +++ b/pkg/sentry/kernel/fd_table.go @@ -23,7 +23,6 @@ import ( "gvisor.dev/gvisor/pkg/abi/linux" "gvisor.dev/gvisor/pkg/context" - "gvisor.dev/gvisor/pkg/refs" "gvisor.dev/gvisor/pkg/sentry/fs" "gvisor.dev/gvisor/pkg/sentry/fs/lock" "gvisor.dev/gvisor/pkg/sentry/limits" @@ -78,7 +77,8 @@ type descriptor struct { // // +stateify savable type FDTable struct { - refs.AtomicRefCount + FDTableRefs + k *Kernel // mu protects below. @@ -111,8 +111,11 @@ func (f *FDTable) saveDescriptorTable() map[int32]descriptor { func (f *FDTable) loadDescriptorTable(m map[int32]descriptor) { ctx := context.Background() f.init() // Initialize table. + f.used = 0 for fd, d := range m { - f.setAll(fd, d.file, d.fileVFS2, d.flags) + if file, fileVFS2 := f.setAll(ctx, fd, d.file, d.fileVFS2, d.flags); file != nil || fileVFS2 != nil { + panic("VFS1 or VFS2 files set") + } // Note that we do _not_ need to acquire a extra table reference here. The // table reference will already be accounted for in the file, so we drop the @@ -127,7 +130,7 @@ func (f *FDTable) loadDescriptorTable(m map[int32]descriptor) { } // drop drops the table reference. -func (f *FDTable) drop(file *fs.File) { +func (f *FDTable) drop(ctx context.Context, file *fs.File) { // Release locks. file.Dirent.Inode.LockCtx.Posix.UnlockRegion(f, lock.LockRange{0, lock.LockEOF}) @@ -145,14 +148,13 @@ func (f *FDTable) drop(file *fs.File) { d.InotifyEvent(ev, 0) // Drop the table reference. - file.DecRef(context.Background()) + file.DecRef(ctx) } // dropVFS2 drops the table reference. -func (f *FDTable) dropVFS2(file *vfs.FileDescription) { +func (f *FDTable) dropVFS2(ctx context.Context, file *vfs.FileDescription) { // Release any POSIX lock possibly held by the FDTable. Range {0, 0} means the // entire file. - ctx := context.Background() err := file.UnlockPOSIX(ctx, f, 0, 0, linux.SEEK_SET) if err != nil && err != syserror.ENOLCK { panic(fmt.Sprintf("UnlockPOSIX failed: %v", err)) @@ -176,22 +178,15 @@ func (k *Kernel) NewFDTable() *FDTable { return f } -// destroy removes all of the file descriptors from the map. -func (f *FDTable) destroy(ctx context.Context) { - f.RemoveIf(ctx, func(*fs.File, *vfs.FileDescription, FDFlags) bool { - return true - }) -} - -// DecRef implements RefCounter.DecRef with destructor f.destroy. +// DecRef implements RefCounter.DecRef. +// +// If f reaches zero references, all of its file descriptors are removed. func (f *FDTable) DecRef(ctx context.Context) { - f.DecRefWithDestructor(ctx, f.destroy) -} - -// Size returns the number of file descriptor slots currently allocated. -func (f *FDTable) Size() int { - size := atomic.LoadInt32(&f.used) - return int(size) + f.FDTableRefs.DecRef(func() { + f.RemoveIf(ctx, func(*fs.File, *vfs.FileDescription, FDFlags) bool { + return true + }) + }) } // forEach iterates over all non-nil files in sorted order. @@ -280,7 +275,6 @@ func (f *FDTable) NewFDs(ctx context.Context, fd int32, files []*fs.File, flags } f.mu.Lock() - defer f.mu.Unlock() // From f.next to find available fd. if fd < f.next { @@ -290,15 +284,25 @@ func (f *FDTable) NewFDs(ctx context.Context, fd int32, files []*fs.File, flags // Install all entries. for i := fd; i < end && len(fds) < len(files); i++ { if d, _, _ := f.get(i); d == nil { - f.set(i, files[len(fds)], flags) // Set the descriptor. - fds = append(fds, i) // Record the file descriptor. + // Set the descriptor. + f.set(ctx, i, files[len(fds)], flags) + fds = append(fds, i) // Record the file descriptor. } } // Failure? Unwind existing FDs. if len(fds) < len(files) { for _, i := range fds { - f.set(i, nil, FDFlags{}) // Zap entry. + f.set(ctx, i, nil, FDFlags{}) + } + f.mu.Unlock() + + // Drop the reference taken by the call to f.set() that + // originally installed the file. Don't call f.drop() + // (generating inotify events, etc.) since the file should + // appear to have never been inserted into f. + for _, file := range files[:len(fds)] { + file.DecRef(ctx) } return nil, syscall.EMFILE } @@ -308,6 +312,7 @@ func (f *FDTable) NewFDs(ctx context.Context, fd int32, files []*fs.File, flags f.next = fds[len(fds)-1] + 1 } + f.mu.Unlock() return fds, nil } @@ -335,7 +340,6 @@ func (f *FDTable) NewFDsVFS2(ctx context.Context, fd int32, files []*vfs.FileDes } f.mu.Lock() - defer f.mu.Unlock() // From f.next to find available fd. if fd < f.next { @@ -345,15 +349,25 @@ func (f *FDTable) NewFDsVFS2(ctx context.Context, fd int32, files []*vfs.FileDes // Install all entries. for i := fd; i < end && len(fds) < len(files); i++ { if d, _, _ := f.getVFS2(i); d == nil { - f.setVFS2(i, files[len(fds)], flags) // Set the descriptor. - fds = append(fds, i) // Record the file descriptor. + // Set the descriptor. + f.setVFS2(ctx, i, files[len(fds)], flags) + fds = append(fds, i) // Record the file descriptor. } } // Failure? Unwind existing FDs. if len(fds) < len(files) { for _, i := range fds { - f.setVFS2(i, nil, FDFlags{}) // Zap entry. + f.setVFS2(ctx, i, nil, FDFlags{}) + } + f.mu.Unlock() + + // Drop the reference taken by the call to f.setVFS2() that + // originally installed the file. Don't call f.dropVFS2() + // (generating inotify events, etc.) since the file should + // appear to have never been inserted into f. + for _, file := range files[:len(fds)] { + file.DecRef(ctx) } return nil, syscall.EMFILE } @@ -363,6 +377,7 @@ func (f *FDTable) NewFDsVFS2(ctx context.Context, fd int32, files []*vfs.FileDes f.next = fds[len(fds)-1] + 1 } + f.mu.Unlock() return fds, nil } @@ -398,7 +413,7 @@ func (f *FDTable) NewFDVFS2(ctx context.Context, minfd int32, file *vfs.FileDesc } for fd < end { if d, _, _ := f.getVFS2(fd); d == nil { - f.setVFS2(fd, file, flags) + f.setVFS2(ctx, fd, file, flags) if fd == f.next { // Update next search start position. f.next = fd + 1 @@ -414,40 +429,55 @@ func (f *FDTable) NewFDVFS2(ctx context.Context, minfd int32, file *vfs.FileDesc // reference for that FD, the ref count for that existing reference is // decremented. func (f *FDTable) NewFDAt(ctx context.Context, fd int32, file *fs.File, flags FDFlags) error { - return f.newFDAt(ctx, fd, file, nil, flags) + df, _, err := f.newFDAt(ctx, fd, file, nil, flags) + if err != nil { + return err + } + if df != nil { + f.drop(ctx, df) + } + return nil } // NewFDAtVFS2 sets the file reference for the given FD. If there is an active // reference for that FD, the ref count for that existing reference is // decremented. func (f *FDTable) NewFDAtVFS2(ctx context.Context, fd int32, file *vfs.FileDescription, flags FDFlags) error { - return f.newFDAt(ctx, fd, nil, file, flags) + _, dfVFS2, err := f.newFDAt(ctx, fd, nil, file, flags) + if err != nil { + return err + } + if dfVFS2 != nil { + f.dropVFS2(ctx, dfVFS2) + } + return nil } -func (f *FDTable) newFDAt(ctx context.Context, fd int32, file *fs.File, fileVFS2 *vfs.FileDescription, flags FDFlags) error { +func (f *FDTable) newFDAt(ctx context.Context, fd int32, file *fs.File, fileVFS2 *vfs.FileDescription, flags FDFlags) (*fs.File, *vfs.FileDescription, error) { if fd < 0 { // Don't accept negative FDs. - return syscall.EBADF + return nil, nil, syscall.EBADF } // Check the limit for the provided file. if limitSet := limits.FromContext(ctx); limitSet != nil { if lim := limitSet.Get(limits.NumberOfFiles); lim.Cur != limits.Infinity && uint64(fd) >= lim.Cur { - return syscall.EMFILE + return nil, nil, syscall.EMFILE } } // Install the entry. f.mu.Lock() defer f.mu.Unlock() - f.setAll(fd, file, fileVFS2, flags) - return nil + + df, dfVFS2 := f.setAll(ctx, fd, file, fileVFS2, flags) + return df, dfVFS2, nil } // SetFlags sets the flags for the given file descriptor. // // True is returned iff flags were changed. -func (f *FDTable) SetFlags(fd int32, flags FDFlags) error { +func (f *FDTable) SetFlags(ctx context.Context, fd int32, flags FDFlags) error { if fd < 0 { // Don't accept negative FDs. return syscall.EBADF @@ -463,14 +493,14 @@ func (f *FDTable) SetFlags(fd int32, flags FDFlags) error { } // Update the flags. - f.set(fd, file, flags) + f.set(ctx, fd, file, flags) return nil } // SetFlagsVFS2 sets the flags for the given file descriptor. // // True is returned iff flags were changed. -func (f *FDTable) SetFlagsVFS2(fd int32, flags FDFlags) error { +func (f *FDTable) SetFlagsVFS2(ctx context.Context, fd int32, flags FDFlags) error { if fd < 0 { // Don't accept negative FDs. return syscall.EBADF @@ -486,7 +516,7 @@ func (f *FDTable) SetFlagsVFS2(fd int32, flags FDFlags) error { } // Update the flags. - f.setVFS2(fd, file, flags) + f.setVFS2(ctx, fd, file, flags) return nil } @@ -552,30 +582,6 @@ func (f *FDTable) GetFDs(ctx context.Context) []int32 { return fds } -// GetRefs returns a stable slice of references to all files and bumps the -// reference count on each. The caller must use DecRef on each reference when -// they're done using the slice. -func (f *FDTable) GetRefs(ctx context.Context) []*fs.File { - files := make([]*fs.File, 0, f.Size()) - f.forEach(ctx, func(_ int32, file *fs.File, _ *vfs.FileDescription, _ FDFlags) { - file.IncRef() // Acquire a reference for caller. - files = append(files, file) - }) - return files -} - -// GetRefsVFS2 returns a stable slice of references to all files and bumps the -// reference count on each. The caller must use DecRef on each reference when -// they're done using the slice. -func (f *FDTable) GetRefsVFS2(ctx context.Context) []*vfs.FileDescription { - files := make([]*vfs.FileDescription, 0, f.Size()) - f.forEach(ctx, func(_ int32, _ *fs.File, file *vfs.FileDescription, _ FDFlags) { - file.IncRef() // Acquire a reference for caller. - files = append(files, file) - }) - return files -} - // Fork returns an independent FDTable. func (f *FDTable) Fork(ctx context.Context) *FDTable { clone := f.k.NewFDTable() @@ -583,11 +589,8 @@ func (f *FDTable) Fork(ctx context.Context) *FDTable { f.forEach(ctx, func(fd int32, file *fs.File, fileVFS2 *vfs.FileDescription, flags FDFlags) { // The set function here will acquire an appropriate table // reference for the clone. We don't need anything else. - switch { - case file != nil: - clone.set(fd, file, flags) - case fileVFS2 != nil: - clone.setVFS2(fd, fileVFS2, flags) + if df, dfVFS2 := clone.setAll(ctx, fd, file, fileVFS2, flags); df != nil || dfVFS2 != nil { + panic("VFS1 or VFS2 files set") } }) return clone @@ -596,13 +599,12 @@ func (f *FDTable) Fork(ctx context.Context) *FDTable { // Remove removes an FD from and returns a non-file iff successful. // // N.B. Callers are required to use DecRef when they are done. -func (f *FDTable) Remove(fd int32) (*fs.File, *vfs.FileDescription) { +func (f *FDTable) Remove(ctx context.Context, fd int32) (*fs.File, *vfs.FileDescription) { if fd < 0 { return nil, nil } f.mu.Lock() - defer f.mu.Unlock() // Update current available position. if fd < f.next { @@ -618,24 +620,51 @@ func (f *FDTable) Remove(fd int32) (*fs.File, *vfs.FileDescription) { case orig2 != nil: orig2.IncRef() } + if orig != nil || orig2 != nil { - f.setAll(fd, nil, nil, FDFlags{}) // Zap entry. + orig, orig2 = f.setAll(ctx, fd, nil, nil, FDFlags{}) // Zap entry. } + f.mu.Unlock() + + if orig != nil { + f.drop(ctx, orig) + } + if orig2 != nil { + f.dropVFS2(ctx, orig2) + } + return orig, orig2 } // RemoveIf removes all FDs where cond is true. func (f *FDTable) RemoveIf(ctx context.Context, cond func(*fs.File, *vfs.FileDescription, FDFlags) bool) { - f.mu.Lock() - defer f.mu.Unlock() + // TODO(gvisor.dev/issue/1624): Remove fs.File slice. + var files []*fs.File + var filesVFS2 []*vfs.FileDescription + f.mu.Lock() f.forEach(ctx, func(fd int32, file *fs.File, fileVFS2 *vfs.FileDescription, flags FDFlags) { if cond(file, fileVFS2, flags) { - f.set(fd, nil, FDFlags{}) // Clear from table. + df, dfVFS2 := f.setAll(ctx, fd, nil, nil, FDFlags{}) // Clear from table. + if df != nil { + files = append(files, df) + } + if dfVFS2 != nil { + filesVFS2 = append(filesVFS2, dfVFS2) + } // Update current available position. if fd < f.next { f.next = fd } } }) + f.mu.Unlock() + + for _, file := range files { + f.drop(ctx, file) + } + + for _, file := range filesVFS2 { + f.dropVFS2(ctx, file) + } } diff --git a/pkg/sentry/kernel/fd_table_test.go b/pkg/sentry/kernel/fd_table_test.go index e3f30ba2a..bf5460083 100644 --- a/pkg/sentry/kernel/fd_table_test.go +++ b/pkg/sentry/kernel/fd_table_test.go @@ -72,7 +72,7 @@ func TestFDTableMany(t *testing.T) { } i := int32(2) - fdTable.Remove(i) + fdTable.Remove(ctx, i) if fds, err := fdTable.NewFDs(ctx, 0, []*fs.File{file}, FDFlags{}); err != nil || fds[0] != i { t.Fatalf("Allocated %v FDs but wanted to allocate %v: %v", i, maxFD, err) } @@ -93,7 +93,7 @@ func TestFDTableOverLimit(t *testing.T) { t.Fatalf("fdTable.NewFDs(maxFD-3, {f,f,f}): got %v, wanted nil", err) } else { for _, fd := range fds { - fdTable.Remove(fd) + fdTable.Remove(ctx, fd) } } @@ -150,13 +150,13 @@ func TestFDTable(t *testing.T) { t.Fatalf("fdTable.Get(2): got a %v, wanted nil", ref) } - ref, _ := fdTable.Remove(1) + ref, _ := fdTable.Remove(ctx, 1) if ref == nil { t.Fatalf("fdTable.Remove(1) for an existing FD: failed, want success") } ref.DecRef(ctx) - if ref, _ := fdTable.Remove(1); ref != nil { + if ref, _ := fdTable.Remove(ctx, 1); ref != nil { t.Fatalf("r.Remove(1) for a removed FD: got success, want failure") } }) diff --git a/pkg/sentry/kernel/fd_table_unsafe.go b/pkg/sentry/kernel/fd_table_unsafe.go index 7fd97dc53..da79e6627 100644 --- a/pkg/sentry/kernel/fd_table_unsafe.go +++ b/pkg/sentry/kernel/fd_table_unsafe.go @@ -18,6 +18,7 @@ import ( "sync/atomic" "unsafe" + "gvisor.dev/gvisor/pkg/context" "gvisor.dev/gvisor/pkg/sentry/fs" "gvisor.dev/gvisor/pkg/sentry/vfs" ) @@ -31,6 +32,8 @@ type descriptorTable struct { } // init initializes the table. +// +// TODO(gvisor.dev/1486): Enable leak check for FDTable. func (f *FDTable) init() { var slice []unsafe.Pointer // Empty slice. atomic.StorePointer(&f.slice, unsafe.Pointer(&slice)) @@ -76,33 +79,37 @@ func (f *FDTable) getAll(fd int32) (*fs.File, *vfs.FileDescription, FDFlags, boo return d.file, d.fileVFS2, d.flags, true } -// set sets an entry. -// -// This handles accounting changes, as well as acquiring and releasing the -// reference needed by the table iff the file is different. +// CurrentMaxFDs returns the number of file descriptors that may be stored in f +// without reallocation. +func (f *FDTable) CurrentMaxFDs() int { + slice := *(*[]unsafe.Pointer)(atomic.LoadPointer(&f.slice)) + return len(slice) +} + +// set sets an entry for VFS1, refer to setAll(). // // Precondition: mu must be held. -func (f *FDTable) set(fd int32, file *fs.File, flags FDFlags) { - f.setAll(fd, file, nil, flags) +func (f *FDTable) set(ctx context.Context, fd int32, file *fs.File, flags FDFlags) *fs.File { + dropFile, _ := f.setAll(ctx, fd, file, nil, flags) + return dropFile } -// setVFS2 sets an entry. -// -// This handles accounting changes, as well as acquiring and releasing the -// reference needed by the table iff the file is different. +// setVFS2 sets an entry for VFS2, refer to setAll(). // // Precondition: mu must be held. -func (f *FDTable) setVFS2(fd int32, file *vfs.FileDescription, flags FDFlags) { - f.setAll(fd, nil, file, flags) +func (f *FDTable) setVFS2(ctx context.Context, fd int32, file *vfs.FileDescription, flags FDFlags) *vfs.FileDescription { + _, dropFile := f.setAll(ctx, fd, nil, file, flags) + return dropFile } -// setAll sets an entry. -// -// This handles accounting changes, as well as acquiring and releasing the -// reference needed by the table iff the file is different. +// setAll sets the file description referred to by fd to file/fileVFS2. If +// file/fileVFS2 are non-nil, it takes a reference on them. If setAll replaces +// an existing file description, it returns it with the FDTable's reference +// transferred to the caller, which must call f.drop/dropVFS2() on the returned +// file after unlocking f.mu. // // Precondition: mu must be held. -func (f *FDTable) setAll(fd int32, file *fs.File, fileVFS2 *vfs.FileDescription, flags FDFlags) { +func (f *FDTable) setAll(ctx context.Context, fd int32, file *fs.File, fileVFS2 *vfs.FileDescription, flags FDFlags) (*fs.File, *vfs.FileDescription) { if file != nil && fileVFS2 != nil { panic("VFS1 and VFS2 files set") } @@ -145,25 +152,25 @@ func (f *FDTable) setAll(fd int32, file *fs.File, fileVFS2 *vfs.FileDescription, } } - // Drop the table reference. + // Adjust used. + switch { + case orig == nil && desc != nil: + atomic.AddInt32(&f.used, 1) + case orig != nil && desc == nil: + atomic.AddInt32(&f.used, -1) + } + if orig != nil { switch { case orig.file != nil: if desc == nil || desc.file != orig.file { - f.drop(orig.file) + return orig.file, nil } case orig.fileVFS2 != nil: if desc == nil || desc.fileVFS2 != orig.fileVFS2 { - f.dropVFS2(orig.fileVFS2) + return nil, orig.fileVFS2 } } } - - // Adjust used. - switch { - case orig == nil && desc != nil: - atomic.AddInt32(&f.used, 1) - case orig != nil && desc == nil: - atomic.AddInt32(&f.used, -1) - } + return nil, nil } diff --git a/pkg/sentry/kernel/fs_context.go b/pkg/sentry/kernel/fs_context.go index 8f2d36d5a..d46d1e1c1 100644 --- a/pkg/sentry/kernel/fs_context.go +++ b/pkg/sentry/kernel/fs_context.go @@ -18,7 +18,6 @@ import ( "fmt" "gvisor.dev/gvisor/pkg/context" - "gvisor.dev/gvisor/pkg/refs" "gvisor.dev/gvisor/pkg/sentry/fs" "gvisor.dev/gvisor/pkg/sentry/vfs" "gvisor.dev/gvisor/pkg/sync" @@ -30,7 +29,7 @@ import ( // // +stateify savable type FSContext struct { - refs.AtomicRefCount + FSContextRefs // mu protects below. mu sync.Mutex `state:"nosave"` @@ -64,7 +63,7 @@ func newFSContext(root, cwd *fs.Dirent, umask uint) *FSContext { cwd: cwd, umask: umask, } - f.EnableLeakCheck("kernel.FSContext") + f.EnableLeakCheck() return &f } @@ -77,54 +76,56 @@ func NewFSContextVFS2(root, cwd vfs.VirtualDentry, umask uint) *FSContext { cwdVFS2: cwd, umask: umask, } - f.EnableLeakCheck("kernel.FSContext") + f.EnableLeakCheck() return &f } -// destroy is the destructor for an FSContext. +// DecRef implements RefCounter.DecRef. // -// This will call DecRef on both root and cwd Dirents. If either call to -// DecRef returns an error, then it will be propagated. If both calls to -// DecRef return an error, then the one from root.DecRef will be propagated. +// When f reaches zero references, DecRef will be called on both root and cwd +// Dirents. // // Note that there may still be calls to WorkingDirectory() or RootDirectory() // (that return nil). This is because valid references may still be held via // proc files or other mechanisms. -func (f *FSContext) destroy(ctx context.Context) { - // Hold f.mu so that we don't race with RootDirectory() and - // WorkingDirectory(). - f.mu.Lock() - defer f.mu.Unlock() - - if VFS2Enabled { - f.rootVFS2.DecRef(ctx) - f.rootVFS2 = vfs.VirtualDentry{} - f.cwdVFS2.DecRef(ctx) - f.cwdVFS2 = vfs.VirtualDentry{} - } else { - f.root.DecRef(ctx) - f.root = nil - f.cwd.DecRef(ctx) - f.cwd = nil - } -} - -// DecRef implements RefCounter.DecRef with destructor f.destroy. func (f *FSContext) DecRef(ctx context.Context) { - f.DecRefWithDestructor(ctx, f.destroy) + f.FSContextRefs.DecRef(func() { + // Hold f.mu so that we don't race with RootDirectory() and + // WorkingDirectory(). + f.mu.Lock() + defer f.mu.Unlock() + + if VFS2Enabled { + f.rootVFS2.DecRef(ctx) + f.rootVFS2 = vfs.VirtualDentry{} + f.cwdVFS2.DecRef(ctx) + f.cwdVFS2 = vfs.VirtualDentry{} + } else { + f.root.DecRef(ctx) + f.root = nil + f.cwd.DecRef(ctx) + f.cwd = nil + } + }) } // Fork forks this FSContext. // -// This is not a valid call after destroy. +// This is not a valid call after f is destroyed. func (f *FSContext) Fork() *FSContext { f.mu.Lock() defer f.mu.Unlock() if VFS2Enabled { + if !f.cwdVFS2.Ok() { + panic("FSContext.Fork() called after destroy") + } f.cwdVFS2.IncRef() f.rootVFS2.IncRef() } else { + if f.cwd == nil { + panic("FSContext.Fork() called after destroy") + } f.cwd.IncRef() f.root.IncRef() } @@ -140,8 +141,8 @@ func (f *FSContext) Fork() *FSContext { // WorkingDirectory returns the current working directory. // -// This will return nil if called after destroy(), otherwise it will return a -// Dirent with a reference taken. +// This will return nil if called after f is destroyed, otherwise it will return +// a Dirent with a reference taken. func (f *FSContext) WorkingDirectory() *fs.Dirent { f.mu.Lock() defer f.mu.Unlock() @@ -152,8 +153,8 @@ func (f *FSContext) WorkingDirectory() *fs.Dirent { // WorkingDirectoryVFS2 returns the current working directory. // -// This will return nil if called after destroy(), otherwise it will return a -// Dirent with a reference taken. +// This will return nil if called after f is destroyed, otherwise it will return +// a Dirent with a reference taken. func (f *FSContext) WorkingDirectoryVFS2() vfs.VirtualDentry { f.mu.Lock() defer f.mu.Unlock() @@ -165,7 +166,7 @@ func (f *FSContext) WorkingDirectoryVFS2() vfs.VirtualDentry { // SetWorkingDirectory sets the current working directory. // This will take an extra reference on the Dirent. // -// This is not a valid call after destroy. +// This is not a valid call after f is destroyed. func (f *FSContext) SetWorkingDirectory(ctx context.Context, d *fs.Dirent) { if d == nil { panic("FSContext.SetWorkingDirectory called with nil dirent") @@ -187,11 +188,15 @@ func (f *FSContext) SetWorkingDirectory(ctx context.Context, d *fs.Dirent) { // SetWorkingDirectoryVFS2 sets the current working directory. // This will take an extra reference on the VirtualDentry. // -// This is not a valid call after destroy. +// This is not a valid call after f is destroyed. func (f *FSContext) SetWorkingDirectoryVFS2(ctx context.Context, d vfs.VirtualDentry) { f.mu.Lock() defer f.mu.Unlock() + if !f.cwdVFS2.Ok() { + panic(fmt.Sprintf("FSContext.SetWorkingDirectoryVFS2(%v)) called after destroy", d)) + } + old := f.cwdVFS2 f.cwdVFS2 = d d.IncRef() @@ -200,8 +205,8 @@ func (f *FSContext) SetWorkingDirectoryVFS2(ctx context.Context, d vfs.VirtualDe // RootDirectory returns the current filesystem root. // -// This will return nil if called after destroy(), otherwise it will return a -// Dirent with a reference taken. +// This will return nil if called after f is destroyed, otherwise it will return +// a Dirent with a reference taken. func (f *FSContext) RootDirectory() *fs.Dirent { f.mu.Lock() defer f.mu.Unlock() @@ -213,8 +218,8 @@ func (f *FSContext) RootDirectory() *fs.Dirent { // RootDirectoryVFS2 returns the current filesystem root. // -// This will return nil if called after destroy(), otherwise it will return a -// Dirent with a reference taken. +// This will return nil if called after f is destroyed, otherwise it will return +// a Dirent with a reference taken. func (f *FSContext) RootDirectoryVFS2() vfs.VirtualDentry { f.mu.Lock() defer f.mu.Unlock() @@ -226,7 +231,7 @@ func (f *FSContext) RootDirectoryVFS2() vfs.VirtualDentry { // SetRootDirectory sets the root directory. // This will take an extra reference on the Dirent. // -// This is not a valid call after free. +// This is not a valid call after f is destroyed. func (f *FSContext) SetRootDirectory(ctx context.Context, d *fs.Dirent) { if d == nil { panic("FSContext.SetRootDirectory called with nil dirent") @@ -247,7 +252,7 @@ func (f *FSContext) SetRootDirectory(ctx context.Context, d *fs.Dirent) { // SetRootDirectoryVFS2 sets the root directory. It takes a reference on vd. // -// This is not a valid call after free. +// This is not a valid call after f is destroyed. func (f *FSContext) SetRootDirectoryVFS2(ctx context.Context, vd vfs.VirtualDentry) { if !vd.Ok() { panic("FSContext.SetRootDirectoryVFS2 called with zero-value VirtualDentry") diff --git a/pkg/sentry/kernel/kcov.go b/pkg/sentry/kernel/kcov.go new file mode 100644 index 000000000..aad63aa99 --- /dev/null +++ b/pkg/sentry/kernel/kcov.go @@ -0,0 +1,321 @@ +// Copyright 2020 The gVisor Authors. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package kernel + +import ( + "fmt" + "io" + "sync" + + "gvisor.dev/gvisor/pkg/abi/linux" + "gvisor.dev/gvisor/pkg/context" + "gvisor.dev/gvisor/pkg/coverage" + "gvisor.dev/gvisor/pkg/safemem" + "gvisor.dev/gvisor/pkg/sentry/memmap" + "gvisor.dev/gvisor/pkg/sentry/mm" + "gvisor.dev/gvisor/pkg/sentry/pgalloc" + "gvisor.dev/gvisor/pkg/sentry/usage" + "gvisor.dev/gvisor/pkg/syserror" + "gvisor.dev/gvisor/pkg/usermem" +) + +// kcovAreaSizeMax is the maximum number of uint64 entries allowed in the kcov +// area. On Linux, the maximum is INT_MAX / 8. +const kcovAreaSizeMax = 10 * 1024 * 1024 + +// Kcov provides kernel coverage data to userspace through a memory-mapped +// region, as kcov does in Linux. +// +// To give the illusion that the data is always up to date, we update the shared +// memory every time before we return to userspace. +type Kcov struct { + // mfp provides application memory. It is immutable after creation. + mfp pgalloc.MemoryFileProvider + + // mu protects all of the fields below. + mu sync.RWMutex + + // mode is the current kcov mode. + mode uint8 + + // size is the size of the mapping through which the kernel conveys coverage + // information to userspace. + size uint64 + + // owningTask is the task that currently owns coverage data on the system. The + // interface for kcov essentially requires that coverage is only going to a + // single task. Note that kcov should only generate coverage data for the + // owning task, but we currently generate global coverage. + owningTask *Task + + // count is a locally cached version of the first uint64 in the kcov data, + // which is the number of subsequent entries representing PCs. + // + // It is used with kcovInode.countBlock(), to copy in/out the first element of + // the actual data in an efficient manner, avoid boilerplate, and prevent + // accidental garbage escapes by the temporary counts. + count uint64 + + mappable *mm.SpecialMappable +} + +// NewKcov creates and returns a Kcov instance. +func (k *Kernel) NewKcov() *Kcov { + return &Kcov{ + mfp: k, + } +} + +var coveragePool = sync.Pool{ + New: func() interface{} { + return make([]byte, 0) + }, +} + +// TaskWork implements TaskWorker.TaskWork. +func (kcov *Kcov) TaskWork(t *Task) { + kcov.mu.Lock() + defer kcov.mu.Unlock() + + rw := &kcovReadWriter{ + mf: kcov.mfp.MemoryFile(), + fr: kcov.mappable.FileRange(), + } + + // Read in the PC count. + if _, err := safemem.ReadFullToBlocks(rw, kcov.countBlock()); err != nil { + panic(fmt.Sprintf("Internal error reading count from kcov area: %v", err)) + } + + rw.off = 8 * (1 + kcov.count) + n := coverage.ConsumeCoverageData(&kcovIOWriter{rw}) + + // Update the pc count, based on the number of entries written. Note that if + // we reached the end of the kcov area, we may not have written everything in + // output. + kcov.count += uint64(n / 8) + rw.off = 0 + if _, err := safemem.WriteFullFromBlocks(rw, kcov.countBlock()); err != nil { + panic(fmt.Sprintf("Internal error writing count to kcov area: %v", err)) + } + + // Re-register for future work. + t.RegisterWork(kcov) +} + +// InitTrace performs the KCOV_INIT_TRACE ioctl. +func (kcov *Kcov) InitTrace(size uint64) error { + kcov.mu.Lock() + defer kcov.mu.Unlock() + + if kcov.mode != linux.KCOV_MODE_DISABLED { + return syserror.EBUSY + } + + // To simplify all the logic around mapping, we require that the length of the + // shared region is a multiple of the system page size. + if (8*size)&(usermem.PageSize-1) != 0 { + return syserror.EINVAL + } + + // We need space for at least two uint64s to hold current position and a + // single PC. + if size < 2 || size > kcovAreaSizeMax { + return syserror.EINVAL + } + + kcov.size = size + kcov.mode = linux.KCOV_MODE_INIT + return nil +} + +// EnableTrace performs the KCOV_ENABLE_TRACE ioctl. +func (kcov *Kcov) EnableTrace(ctx context.Context, traceMode uint8) error { + t := TaskFromContext(ctx) + if t == nil { + panic("kcovInode.EnableTrace() cannot be used outside of a task goroutine") + } + + kcov.mu.Lock() + defer kcov.mu.Unlock() + + // KCOV_ENABLE must be preceded by KCOV_INIT_TRACE and an mmap call. + if kcov.mode != linux.KCOV_MODE_INIT || kcov.mappable == nil { + return syserror.EINVAL + } + + switch traceMode { + case linux.KCOV_TRACE_PC: + kcov.mode = traceMode + case linux.KCOV_TRACE_CMP: + // We do not support KCOV_MODE_TRACE_CMP. + return syserror.ENOTSUP + default: + return syserror.EINVAL + } + + if kcov.owningTask != nil && kcov.owningTask != t { + return syserror.EBUSY + } + + kcov.owningTask = t + t.RegisterWork(kcov) + + // Clear existing coverage data; the task expects to read only coverage data + // from the time it is activated. + coverage.ClearCoverageData() + return nil +} + +// DisableTrace performs the KCOV_DISABLE_TRACE ioctl. +func (kcov *Kcov) DisableTrace(ctx context.Context) error { + kcov.mu.Lock() + defer kcov.mu.Unlock() + + t := TaskFromContext(ctx) + if t == nil { + panic("kcovInode.EnableTrace() cannot be used outside of a task goroutine") + } + + if t != kcov.owningTask { + return syserror.EINVAL + } + kcov.owningTask = nil + kcov.mode = linux.KCOV_MODE_INIT + kcov.resetLocked() + return nil +} + +// Reset is called when the owning task exits. +func (kcov *Kcov) Reset() { + kcov.mu.Lock() + kcov.resetLocked() + kcov.mu.Unlock() +} + +// The kcov instance is reset when the owning task exits or when tracing is +// disabled. +func (kcov *Kcov) resetLocked() { + kcov.owningTask = nil + if kcov.mappable != nil { + kcov.mappable = nil + } +} + +// ConfigureMMap is called by the vfs.FileDescription for this kcov instance to +// implement vfs.FileDescription.ConfigureMMap. +func (kcov *Kcov) ConfigureMMap(ctx context.Context, opts *memmap.MMapOpts) error { + kcov.mu.Lock() + defer kcov.mu.Unlock() + + if kcov.mode != linux.KCOV_MODE_INIT { + return syserror.EINVAL + } + + if kcov.mappable == nil { + // Set up the kcov area. + fr, err := kcov.mfp.MemoryFile().Allocate(kcov.size*8, usage.Anonymous) + if err != nil { + return err + } + + // Get the thread id for the mmap name. + t := TaskFromContext(ctx) + if t == nil { + panic("ThreadFromContext returned nil") + } + // For convenience, a special mappable is used here. Note that these mappings + // will look different under /proc/[pid]/maps than they do on Linux. + kcov.mappable = mm.NewSpecialMappable(fmt.Sprintf("[kcov:%d]", t.ThreadID()), kcov.mfp, fr) + } + opts.Mappable = kcov.mappable + opts.MappingIdentity = kcov.mappable + return nil +} + +// kcovReadWriter implements safemem.Reader and safemem.Writer. +type kcovReadWriter struct { + off uint64 + mf *pgalloc.MemoryFile + fr memmap.FileRange +} + +// ReadToBlocks implements safemem.Reader.ReadToBlocks. +func (rw *kcovReadWriter) ReadToBlocks(dsts safemem.BlockSeq) (uint64, error) { + if dsts.IsEmpty() { + return 0, nil + } + + // Limit the read to the kcov range and check for overflow. + if rw.fr.Length() <= rw.off { + return 0, io.EOF + } + start := rw.fr.Start + rw.off + end := rw.fr.Start + rw.fr.Length() + if rend := start + dsts.NumBytes(); rend < end { + end = rend + } + + // Get internal mappings. + bs, err := rw.mf.MapInternal(memmap.FileRange{start, end}, usermem.Read) + if err != nil { + return 0, err + } + + // Copy from internal mappings. + n, err := safemem.CopySeq(dsts, bs) + rw.off += n + return n, err +} + +// WriteFromBlocks implements safemem.Writer.WriteFromBlocks. +func (rw *kcovReadWriter) WriteFromBlocks(srcs safemem.BlockSeq) (uint64, error) { + if srcs.IsEmpty() { + return 0, nil + } + + // Limit the write to the kcov area and check for overflow. + if rw.fr.Length() <= rw.off { + return 0, io.EOF + } + start := rw.fr.Start + rw.off + end := rw.fr.Start + rw.fr.Length() + if wend := start + srcs.NumBytes(); wend < end { + end = wend + } + + // Get internal mapping. + bs, err := rw.mf.MapInternal(memmap.FileRange{start, end}, usermem.Write) + if err != nil { + return 0, err + } + + // Copy to internal mapping. + n, err := safemem.CopySeq(bs, srcs) + rw.off += n + return n, err +} + +// kcovIOWriter implements io.Writer as a basic wrapper over kcovReadWriter. +type kcovIOWriter struct { + rw *kcovReadWriter +} + +// Write implements io.Writer.Write. +func (w *kcovIOWriter) Write(p []byte) (int, error) { + bs := safemem.BlockSeqOf(safemem.BlockFromSafeSlice(p)) + n, err := safemem.WriteFullFromBlocks(w.rw, bs) + return int(n), err +} diff --git a/pkg/sentry/kernel/kcov_unsafe.go b/pkg/sentry/kernel/kcov_unsafe.go new file mode 100644 index 000000000..6f64022eb --- /dev/null +++ b/pkg/sentry/kernel/kcov_unsafe.go @@ -0,0 +1,28 @@ +// Copyright 2020 The gVisor Authors. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package kernel + +import ( + "unsafe" + + "gvisor.dev/gvisor/pkg/safemem" +) + +// countBlock provides a safemem.BlockSeq for k.count. +// +// Like k.count, the block returned is protected by k.mu. +func (k *Kcov) countBlock() safemem.BlockSeq { + return safemem.BlockSeqOf(safemem.BlockFromSafePointer(unsafe.Pointer(&k.count), int(unsafe.Sizeof(k.count)))) +} diff --git a/pkg/sentry/kernel/kernel.go b/pkg/sentry/kernel/kernel.go index 1028d13c6..22f9bb006 100644 --- a/pkg/sentry/kernel/kernel.go +++ b/pkg/sentry/kernel/kernel.go @@ -248,7 +248,7 @@ type Kernel struct { // SpecialOpts contains special kernel options. SpecialOpts - // VFS keeps the filesystem state used across the kernel. + // vfs keeps the filesystem state used across the kernel. vfs vfs.VirtualFilesystem // hostMount is the Mount used for file descriptors that were imported @@ -888,17 +888,18 @@ func (k *Kernel) CreateProcess(args CreateProcessArgs) (*ThreadGroup, ThreadID, opener fsbridge.Lookup fsContext *FSContext mntns *fs.MountNamespace + mntnsVFS2 *vfs.MountNamespace ) if VFS2Enabled { - mntnsVFS2 := args.MountNamespaceVFS2 + mntnsVFS2 = args.MountNamespaceVFS2 if mntnsVFS2 == nil { // MountNamespaceVFS2 adds a reference to the namespace, which is // transferred to the new process. mntnsVFS2 = k.globalInit.Leader().MountNamespaceVFS2() } // Get the root directory from the MountNamespace. - root := args.MountNamespaceVFS2.Root() + root := mntnsVFS2.Root() // The call to newFSContext below will take a reference on root, so we // don't need to hold this one. defer root.DecRef(ctx) @@ -1008,7 +1009,7 @@ func (k *Kernel) CreateProcess(args CreateProcessArgs) (*ThreadGroup, ThreadID, UTSNamespace: args.UTSNamespace, IPCNamespace: args.IPCNamespace, AbstractSocketNamespace: args.AbstractSocketNamespace, - MountNamespaceVFS2: args.MountNamespaceVFS2, + MountNamespaceVFS2: mntnsVFS2, ContainerID: args.ContainerID, } t, err := k.tasks.NewTask(config) @@ -1067,8 +1068,9 @@ func (k *Kernel) Start() error { // pauseTimeLocked pauses all Timers and Timekeeper updates. // -// Preconditions: Any task goroutines running in k must be stopped. k.extMu -// must be locked. +// Preconditions: +// * Any task goroutines running in k must be stopped. +// * k.extMu must be locked. func (k *Kernel) pauseTimeLocked(ctx context.Context) { // k.cpuClockTicker may be nil since Kernel.SaveTo() may be called before // Kernel.Start(). @@ -1111,8 +1113,9 @@ func (k *Kernel) pauseTimeLocked(ctx context.Context) { // pauseTimeLocked has not been previously called, resumeTimeLocked has no // effect. // -// Preconditions: Any task goroutines running in k must be stopped. k.extMu -// must be locked. +// Preconditions: +// * Any task goroutines running in k must be stopped. +// * k.extMu must be locked. func (k *Kernel) resumeTimeLocked(ctx context.Context) { if k.cpuClockTicker != nil { k.cpuClockTicker.Resume() diff --git a/pkg/sentry/kernel/pipe/pipe.go b/pkg/sentry/kernel/pipe/pipe.go index 297e8f28f..c410c96aa 100644 --- a/pkg/sentry/kernel/pipe/pipe.go +++ b/pkg/sentry/kernel/pipe/pipe.go @@ -200,17 +200,17 @@ type readOps struct { // // Precondition: this pipe must have readers. func (p *Pipe) read(ctx context.Context, ops readOps) (int64, error) { - // Don't block for a zero-length read even if the pipe is empty. - if ops.left() == 0 { - return 0, nil - } - p.mu.Lock() defer p.mu.Unlock() return p.readLocked(ctx, ops) } func (p *Pipe) readLocked(ctx context.Context, ops readOps) (int64, error) { + // Don't block for a zero-length read even if the pipe is empty. + if ops.left() == 0 { + return 0, nil + } + // Is the pipe empty? if p.view.Size() == 0 { if !p.HasWriters() { @@ -388,6 +388,10 @@ func (p *Pipe) rwReadiness() waiter.EventMask { func (p *Pipe) queued() int64 { p.mu.Lock() defer p.mu.Unlock() + return p.queuedLocked() +} + +func (p *Pipe) queuedLocked() int64 { return p.view.Size() } diff --git a/pkg/sentry/kernel/pipe/vfs.go b/pkg/sentry/kernel/pipe/vfs.go index 28f998e45..f61039f5b 100644 --- a/pkg/sentry/kernel/pipe/vfs.go +++ b/pkg/sentry/kernel/pipe/vfs.go @@ -67,6 +67,11 @@ func (vp *VFSPipe) ReaderWriterPair(mnt *vfs.Mount, vfsd *vfs.Dentry, statusFlag return vp.newFD(mnt, vfsd, linux.O_RDONLY|statusFlags, locks), vp.newFD(mnt, vfsd, linux.O_WRONLY|statusFlags, locks) } +// Allocate implements vfs.FileDescriptionImpl.Allocate. +func (*VFSPipe) Allocate(context.Context, uint64, uint64, uint64) error { + return syserror.ESPIPE +} + // Open opens the pipe represented by vp. func (vp *VFSPipe) Open(ctx context.Context, mnt *vfs.Mount, vfsd *vfs.Dentry, statusFlags uint32, locks *vfs.FileLocks) (*vfs.FileDescription, error) { vp.mu.Lock() @@ -244,19 +249,57 @@ func (fd *VFSPipeFD) SetPipeSize(size int64) (int64, error) { return fd.pipe.SetFifoSize(size) } -// IOSequence returns a useremm.IOSequence that reads up to count bytes from, -// or writes up to count bytes to, fd. -func (fd *VFSPipeFD) IOSequence(count int64) usermem.IOSequence { - return usermem.IOSequence{ +// SpliceToNonPipe performs a splice operation from fd to a non-pipe file. +func (fd *VFSPipeFD) SpliceToNonPipe(ctx context.Context, out *vfs.FileDescription, off, count int64) (int64, error) { + fd.pipe.mu.Lock() + defer fd.pipe.mu.Unlock() + + // Cap the sequence at number of bytes actually available. + v := fd.pipe.queuedLocked() + if v < count { + count = v + } + src := usermem.IOSequence{ IO: fd, Addrs: usermem.AddrRangeSeqOf(usermem.AddrRange{0, usermem.Addr(count)}), } + + var ( + n int64 + err error + ) + if off == -1 { + n, err = out.Write(ctx, src, vfs.WriteOptions{}) + } else { + n, err = out.PWrite(ctx, src, off, vfs.WriteOptions{}) + } + if n > 0 { + fd.pipe.view.TrimFront(n) + } + return n, err +} + +// SpliceFromNonPipe performs a splice operation from a non-pipe file to fd. +func (fd *VFSPipeFD) SpliceFromNonPipe(ctx context.Context, in *vfs.FileDescription, off, count int64) (int64, error) { + fd.pipe.mu.Lock() + defer fd.pipe.mu.Unlock() + + dst := usermem.IOSequence{ + IO: fd, + Addrs: usermem.AddrRangeSeqOf(usermem.AddrRange{0, usermem.Addr(count)}), + } + + if off == -1 { + return in.Read(ctx, dst, vfs.ReadOptions{}) + } + return in.PRead(ctx, dst, off, vfs.ReadOptions{}) } -// CopyIn implements usermem.IO.CopyIn. +// CopyIn implements usermem.IO.CopyIn. Note that it is the caller's +// responsibility to trim fd.pipe.view after the read is completed. func (fd *VFSPipeFD) CopyIn(ctx context.Context, addr usermem.Addr, dst []byte, opts usermem.IOOpts) (int, error) { origCount := int64(len(dst)) - n, err := fd.pipe.read(ctx, readOps{ + n, err := fd.pipe.readLocked(ctx, readOps{ left: func() int64 { return int64(len(dst)) }, @@ -265,7 +308,6 @@ func (fd *VFSPipeFD) CopyIn(ctx context.Context, addr usermem.Addr, dst []byte, }, read: func(view *buffer.View) (int64, error) { n, err := view.ReadAt(dst, 0) - view.TrimFront(int64(n)) return int64(n), err }, }) @@ -281,7 +323,7 @@ func (fd *VFSPipeFD) CopyIn(ctx context.Context, addr usermem.Addr, dst []byte, // CopyOut implements usermem.IO.CopyOut. func (fd *VFSPipeFD) CopyOut(ctx context.Context, addr usermem.Addr, src []byte, opts usermem.IOOpts) (int, error) { origCount := int64(len(src)) - n, err := fd.pipe.write(ctx, writeOps{ + n, err := fd.pipe.writeLocked(ctx, writeOps{ left: func() int64 { return int64(len(src)) }, @@ -305,7 +347,7 @@ func (fd *VFSPipeFD) CopyOut(ctx context.Context, addr usermem.Addr, src []byte, // ZeroOut implements usermem.IO.ZeroOut. func (fd *VFSPipeFD) ZeroOut(ctx context.Context, addr usermem.Addr, toZero int64, opts usermem.IOOpts) (int64, error) { origCount := toZero - n, err := fd.pipe.write(ctx, writeOps{ + n, err := fd.pipe.writeLocked(ctx, writeOps{ left: func() int64 { return toZero }, @@ -326,14 +368,15 @@ func (fd *VFSPipeFD) ZeroOut(ctx context.Context, addr usermem.Addr, toZero int6 return n, err } -// CopyInTo implements usermem.IO.CopyInTo. +// CopyInTo implements usermem.IO.CopyInTo. Note that it is the caller's +// responsibility to trim fd.pipe.view after the read is completed. func (fd *VFSPipeFD) CopyInTo(ctx context.Context, ars usermem.AddrRangeSeq, dst safemem.Writer, opts usermem.IOOpts) (int64, error) { count := ars.NumBytes() if count == 0 { return 0, nil } origCount := count - n, err := fd.pipe.read(ctx, readOps{ + n, err := fd.pipe.readLocked(ctx, readOps{ left: func() int64 { return count }, @@ -342,7 +385,6 @@ func (fd *VFSPipeFD) CopyInTo(ctx context.Context, ars usermem.AddrRangeSeq, dst }, read: func(view *buffer.View) (int64, error) { n, err := view.ReadToSafememWriter(dst, uint64(count)) - view.TrimFront(int64(n)) return int64(n), err }, }) @@ -362,7 +404,7 @@ func (fd *VFSPipeFD) CopyOutFrom(ctx context.Context, ars usermem.AddrRangeSeq, return 0, nil } origCount := count - n, err := fd.pipe.write(ctx, writeOps{ + n, err := fd.pipe.writeLocked(ctx, writeOps{ left: func() int64 { return count }, diff --git a/pkg/sentry/kernel/ptrace.go b/pkg/sentry/kernel/ptrace.go index 619b0cb7c..1145faf13 100644 --- a/pkg/sentry/kernel/ptrace.go +++ b/pkg/sentry/kernel/ptrace.go @@ -18,6 +18,7 @@ import ( "fmt" "gvisor.dev/gvisor/pkg/abi/linux" + "gvisor.dev/gvisor/pkg/marshal/primitive" "gvisor.dev/gvisor/pkg/sentry/arch" "gvisor.dev/gvisor/pkg/sentry/mm" "gvisor.dev/gvisor/pkg/syserror" @@ -224,8 +225,9 @@ func (s *ptraceStop) Killable() bool { // beginPtraceStopLocked does not signal t's tracer or wake it if it is // waiting. // -// Preconditions: The TaskSet mutex must be locked. The caller must be running -// on the task goroutine. +// Preconditions: +// * The TaskSet mutex must be locked. +// * The caller must be running on the task goroutine. func (t *Task) beginPtraceStopLocked() bool { t.tg.signalHandlers.mu.Lock() defer t.tg.signalHandlers.mu.Unlock() @@ -270,8 +272,9 @@ func (t *Task) ptraceTrapLocked(code int32) { // ptraceStop, temporarily preventing it from being removed by a concurrent // Task.Kill, and returns true. Otherwise it returns false. // -// Preconditions: The TaskSet mutex must be locked. The caller must be running -// on the task goroutine of t's tracer. +// Preconditions: +// * The TaskSet mutex must be locked. +// * The caller must be running on the task goroutine of t's tracer. func (t *Task) ptraceFreeze() bool { t.tg.signalHandlers.mu.Lock() defer t.tg.signalHandlers.mu.Unlock() @@ -301,8 +304,9 @@ func (t *Task) ptraceUnfreeze() { t.ptraceUnfreezeLocked() } -// Preconditions: t must be in a frozen ptraceStop. t's signal mutex must be -// locked. +// Preconditions: +// * t must be in a frozen ptraceStop. +// * t's signal mutex must be locked. func (t *Task) ptraceUnfreezeLocked() { // Do this even if the task has been killed to ensure a panic if t.stop is // nil or not a ptraceStop. @@ -497,8 +501,9 @@ func (t *Task) forgetTracerLocked() { // ptraceSignalLocked is called after signal dequeueing to check if t should // enter ptrace signal-delivery-stop. // -// Preconditions: The signal mutex must be locked. The caller must be running -// on the task goroutine. +// Preconditions: +// * The signal mutex must be locked. +// * The caller must be running on the task goroutine. func (t *Task) ptraceSignalLocked(info *arch.SignalInfo) bool { if linux.Signal(info.Signo) == linux.SIGKILL { return false @@ -828,8 +833,9 @@ func (t *Task) ptraceInterrupt(target *Task) error { return nil } -// Preconditions: The TaskSet mutex must be locked for writing. t must have a -// tracer. +// Preconditions: +// * The TaskSet mutex must be locked for writing. +// * t must have a tracer. func (t *Task) ptraceSetOptionsLocked(opts uintptr) error { const valid = uintptr(linux.PTRACE_O_EXITKILL | linux.PTRACE_O_TRACESYSGOOD | @@ -994,18 +1000,15 @@ func (t *Task) Ptrace(req int64, pid ThreadID, addr, data usermem.Addr) error { // at the address specified by the data parameter, and the return value // is the error flag." - ptrace(2) word := t.Arch().Native(0) - if _, err := usermem.CopyObjectIn(t, target.MemoryManager(), addr, word, usermem.IOOpts{ - IgnorePermissions: true, - }); err != nil { + if _, err := word.CopyIn(target.AsCopyContext(usermem.IOOpts{IgnorePermissions: true}), addr); err != nil { return err } - _, err := t.CopyOut(data, word) + _, err := word.CopyOut(t, data) return err case linux.PTRACE_POKETEXT, linux.PTRACE_POKEDATA: - _, err := usermem.CopyObjectOut(t, target.MemoryManager(), addr, t.Arch().Native(uintptr(data)), usermem.IOOpts{ - IgnorePermissions: true, - }) + word := t.Arch().Native(uintptr(data)) + _, err := word.CopyOut(target.AsCopyContext(usermem.IOOpts{IgnorePermissions: true}), addr) return err case linux.PTRACE_GETREGSET: @@ -1073,12 +1076,12 @@ func (t *Task) Ptrace(req int64, pid ThreadID, addr, data usermem.Addr) error { if target.ptraceSiginfo == nil { return syserror.EINVAL } - _, err := t.CopyOut(data, target.ptraceSiginfo) + _, err := target.ptraceSiginfo.CopyOut(t, data) return err case linux.PTRACE_SETSIGINFO: var info arch.SignalInfo - if _, err := t.CopyIn(data, &info); err != nil { + if _, err := info.CopyIn(t, data); err != nil { return err } t.tg.pidns.owner.mu.RLock() @@ -1093,7 +1096,8 @@ func (t *Task) Ptrace(req int64, pid ThreadID, addr, data usermem.Addr) error { if addr != linux.SignalSetSize { return syserror.EINVAL } - _, err := t.CopyOut(data, target.SignalMask()) + mask := target.SignalMask() + _, err := mask.CopyOut(t, data) return err case linux.PTRACE_SETSIGMASK: @@ -1101,7 +1105,7 @@ func (t *Task) Ptrace(req int64, pid ThreadID, addr, data usermem.Addr) error { return syserror.EINVAL } var mask linux.SignalSet - if _, err := t.CopyIn(data, &mask); err != nil { + if _, err := mask.CopyIn(t, data); err != nil { return err } // The target's task goroutine is stopped, so this is safe: @@ -1116,7 +1120,7 @@ func (t *Task) Ptrace(req int64, pid ThreadID, addr, data usermem.Addr) error { case linux.PTRACE_GETEVENTMSG: t.tg.pidns.owner.mu.RLock() defer t.tg.pidns.owner.mu.RUnlock() - _, err := t.CopyOut(usermem.Addr(data), target.ptraceEventMsg) + _, err := primitive.CopyUint64Out(t, usermem.Addr(data), target.ptraceEventMsg) return err // PEEKSIGINFO is unimplemented but seems to have no users anywhere. diff --git a/pkg/sentry/kernel/ptrace_amd64.go b/pkg/sentry/kernel/ptrace_amd64.go index cef1276ec..609ad3941 100644 --- a/pkg/sentry/kernel/ptrace_amd64.go +++ b/pkg/sentry/kernel/ptrace_amd64.go @@ -30,7 +30,7 @@ func (t *Task) ptraceArch(target *Task, req int64, addr, data usermem.Addr) erro if err != nil { return err } - _, err = t.CopyOut(data, n) + _, err = n.CopyOut(t, data) return err case linux.PTRACE_POKEUSR: // aka PTRACE_POKEUSER diff --git a/pkg/sentry/kernel/rseq.go b/pkg/sentry/kernel/rseq.go index 18416643b..2a9023fdf 100644 --- a/pkg/sentry/kernel/rseq.go +++ b/pkg/sentry/kernel/rseq.go @@ -173,8 +173,10 @@ func (t *Task) OldRSeqCPUAddr() usermem.Addr { // SetOldRSeqCPUAddr replaces the address that old rseq will keep updated with // t's CPU number. // -// Preconditions: t.RSeqAvailable() == true. The caller must be running on the -// task goroutine. t's AddressSpace must be active. +// Preconditions: +// * t.RSeqAvailable() == true. +// * The caller must be running on the task goroutine. +// * t's AddressSpace must be active. func (t *Task) SetOldRSeqCPUAddr(addr usermem.Addr) error { t.oldRSeqCPUAddr = addr @@ -189,8 +191,9 @@ func (t *Task) SetOldRSeqCPUAddr(addr usermem.Addr) error { return nil } -// Preconditions: The caller must be running on the task goroutine. t's -// AddressSpace must be active. +// Preconditions: +// * The caller must be running on the task goroutine. +// * t's AddressSpace must be active. func (t *Task) rseqUpdateCPU() error { if t.rseqAddr == 0 && t.oldRSeqCPUAddr == 0 { t.rseqCPU = -1 @@ -209,8 +212,9 @@ func (t *Task) rseqUpdateCPU() error { return oerr } -// Preconditions: The caller must be running on the task goroutine. t's -// AddressSpace must be active. +// Preconditions: +// * The caller must be running on the task goroutine. +// * t's AddressSpace must be active. func (t *Task) oldRSeqCopyOutCPU() error { if t.oldRSeqCPUAddr == 0 { return nil @@ -222,8 +226,9 @@ func (t *Task) oldRSeqCopyOutCPU() error { return err } -// Preconditions: The caller must be running on the task goroutine. t's -// AddressSpace must be active. +// Preconditions: +// * The caller must be running on the task goroutine. +// * t's AddressSpace must be active. func (t *Task) rseqCopyOutCPU() error { if t.rseqAddr == 0 { return nil @@ -240,8 +245,9 @@ func (t *Task) rseqCopyOutCPU() error { return err } -// Preconditions: The caller must be running on the task goroutine. t's -// AddressSpace must be active. +// Preconditions: +// * The caller must be running on the task goroutine. +// * t's AddressSpace must be active. func (t *Task) rseqClearCPU() error { buf := t.CopyScratchBuffer(8) // CPUIDStart and CPUID are the first two fields in linux.RSeq. @@ -269,8 +275,9 @@ func (t *Task) rseqClearCPU() error { // // See kernel/rseq.c:rseq_ip_fixup for reference. // -// Preconditions: The caller must be running on the task goroutine. t's -// AddressSpace must be active. +// Preconditions: +// * The caller must be running on the task goroutine. +// * t's AddressSpace must be active. func (t *Task) rseqAddrInterrupt() { if t.rseqAddr == 0 { return diff --git a/pkg/sentry/kernel/sessions.go b/pkg/sentry/kernel/sessions.go index 5c4c622c2..df5c8421b 100644 --- a/pkg/sentry/kernel/sessions.go +++ b/pkg/sentry/kernel/sessions.go @@ -16,8 +16,6 @@ package kernel import ( "gvisor.dev/gvisor/pkg/abi/linux" - "gvisor.dev/gvisor/pkg/context" - "gvisor.dev/gvisor/pkg/refs" "gvisor.dev/gvisor/pkg/sentry/arch" "gvisor.dev/gvisor/pkg/syserror" ) @@ -32,7 +30,7 @@ type ProcessGroupID ThreadID // // +stateify savable type Session struct { - refs refs.AtomicRefCount + SessionRefs // leader is the originator of the Session. // @@ -62,16 +60,11 @@ type Session struct { sessionEntry } -// incRef grabs a reference. -func (s *Session) incRef() { - s.refs.IncRef() -} - -// decRef drops a reference. +// DecRef drops a reference. // // Precondition: callers must hold TaskSet.mu for writing. -func (s *Session) decRef() { - s.refs.DecRefWithDestructor(nil, func(context.Context) { +func (s *Session) DecRef() { + s.SessionRefs.DecRef(func() { // Remove translations from the leader. for ns := s.leader.pidns; ns != nil; ns = ns.parent { id := ns.sids[s] @@ -88,7 +81,7 @@ func (s *Session) decRef() { // // +stateify savable type ProcessGroup struct { - refs refs.AtomicRefCount // not exported. + refs ProcessGroupRefs // originator is the originator of the group. // @@ -163,7 +156,7 @@ func (pg *ProcessGroup) decRefWithParent(parentPG *ProcessGroup) { } alive := true - pg.refs.DecRefWithDestructor(nil, func(context.Context) { + pg.refs.DecRef(func() { alive = false // don't bother with handleOrphan. // Remove translations from the originator. @@ -175,7 +168,7 @@ func (pg *ProcessGroup) decRefWithParent(parentPG *ProcessGroup) { // Remove the list of process groups. pg.session.processGroups.Remove(pg) - pg.session.decRef() + pg.session.DecRef() }) if alive { pg.handleOrphan() @@ -302,7 +295,7 @@ func (tg *ThreadGroup) createSession() error { id: SessionID(id), leader: tg, } - s.refs.EnableLeakCheck("kernel.Session") + s.EnableLeakCheck() // Create a new ProcessGroup, belonging to that Session. // This also has a single reference (assigned below). @@ -316,7 +309,7 @@ func (tg *ThreadGroup) createSession() error { session: s, ancestors: 0, } - pg.refs.EnableLeakCheck("kernel.ProcessGroup") + pg.refs.EnableLeakCheck() // Tie them and return the result. s.processGroups.PushBack(pg) @@ -396,13 +389,13 @@ func (tg *ThreadGroup) CreateProcessGroup() error { // // We manually adjust the ancestors if the parent is in the same // session. - tg.processGroup.session.incRef() + tg.processGroup.session.IncRef() pg := ProcessGroup{ id: ProcessGroupID(id), originator: tg, session: tg.processGroup.session, } - pg.refs.EnableLeakCheck("kernel.ProcessGroup") + pg.refs.EnableLeakCheck() if tg.leader.parent != nil && tg.leader.parent.tg.processGroup.session == pg.session { pg.ancestors++ diff --git a/pkg/sentry/kernel/shm/BUILD b/pkg/sentry/kernel/shm/BUILD index c211fc8d0..b7e4b480d 100644 --- a/pkg/sentry/kernel/shm/BUILD +++ b/pkg/sentry/kernel/shm/BUILD @@ -1,12 +1,25 @@ load("//tools:defs.bzl", "go_library") +load("//tools/go_generics:defs.bzl", "go_template_instance") package(licenses = ["notice"]) +go_template_instance( + name = "shm_refs", + out = "shm_refs.go", + package = "shm", + prefix = "Shm", + template = "//pkg/refs_vfs2:refs_template", + types = { + "T": "Shm", + }, +) + go_library( name = "shm", srcs = [ "device.go", "shm.go", + "shm_refs.go", ], visibility = ["//pkg/sentry:internal"], deps = [ diff --git a/pkg/sentry/kernel/shm/shm.go b/pkg/sentry/kernel/shm/shm.go index 13ec7afe0..00c03585e 100644 --- a/pkg/sentry/kernel/shm/shm.go +++ b/pkg/sentry/kernel/shm/shm.go @@ -39,7 +39,6 @@ import ( "gvisor.dev/gvisor/pkg/abi/linux" "gvisor.dev/gvisor/pkg/context" "gvisor.dev/gvisor/pkg/log" - "gvisor.dev/gvisor/pkg/refs" "gvisor.dev/gvisor/pkg/sentry/fs" "gvisor.dev/gvisor/pkg/sentry/kernel/auth" ktime "gvisor.dev/gvisor/pkg/sentry/kernel/time" @@ -252,7 +251,7 @@ func (r *Registry) newShm(ctx context.Context, pid int32, key Key, creator fs.Fi creatorPID: pid, changeTime: ktime.NowFromContext(ctx), } - shm.EnableLeakCheck("kernel.Shm") + shm.EnableLeakCheck() // Find the next available ID. for id := r.lastIDUsed + 1; id != r.lastIDUsed; id++ { @@ -337,14 +336,14 @@ func (r *Registry) remove(s *Shm) { // // +stateify savable type Shm struct { - // AtomicRefCount tracks the number of references to this segment. + // ShmRefs tracks the number of references to this segment. // // A segment holds a reference to itself until it is marked for // destruction. // // In addition to direct users, the MemoryManager will hold references // via MappingIdentity. - refs.AtomicRefCount + ShmRefs mfp pgalloc.MemoryFileProvider @@ -428,11 +427,14 @@ func (s *Shm) InodeID() uint64 { return uint64(s.ID) } -// DecRef overrides refs.RefCount.DecRef with a destructor. +// DecRef drops a reference on s. // // Precondition: Caller must not hold s.mu. func (s *Shm) DecRef(ctx context.Context) { - s.DecRefWithDestructor(ctx, s.destroy) + s.ShmRefs.DecRef(func() { + s.mfp.MemoryFile().DecRef(s.fr) + s.registry.remove(s) + }) } // Msync implements memmap.MappingIdentity.Msync. Msync is a no-op for shm @@ -642,11 +644,6 @@ func (s *Shm) Set(ctx context.Context, ds *linux.ShmidDS) error { return nil } -func (s *Shm) destroy(context.Context) { - s.mfp.MemoryFile().DecRef(s.fr) - s.registry.remove(s) -} - // MarkDestroyed marks a segment for destruction. The segment is actually // destroyed once it has no references. MarkDestroyed may be called multiple // times, and is safe to call after a segment has already been destroyed. See diff --git a/pkg/sentry/kernel/syscalls.go b/pkg/sentry/kernel/syscalls.go index 413111faf..332bdb8e8 100644 --- a/pkg/sentry/kernel/syscalls.go +++ b/pkg/sentry/kernel/syscalls.go @@ -348,6 +348,16 @@ func (s *SyscallTable) LookupName(sysno uintptr) string { return fmt.Sprintf("sys_%d", sysno) // Unlikely. } +// LookupNo looks up a syscall number by name. +func (s *SyscallTable) LookupNo(name string) (uintptr, error) { + for i, syscall := range s.Table { + if syscall.Name == name { + return uintptr(i), nil + } + } + return 0, fmt.Errorf("syscall %q not found", name) +} + // LookupEmulate looks up an emulation syscall number. func (s *SyscallTable) LookupEmulate(addr usermem.Addr) (uintptr, bool) { sysno, ok := s.Emulate[addr] diff --git a/pkg/sentry/kernel/task.go b/pkg/sentry/kernel/task.go index 5aee699e7..a436610c9 100644 --- a/pkg/sentry/kernel/task.go +++ b/pkg/sentry/kernel/task.go @@ -574,6 +574,11 @@ type Task struct { // // startTime is protected by mu. startTime ktime.Time + + // kcov is the kcov instance providing code coverage owned by this task. + // + // kcov is exclusive to the task goroutine. + kcov *Kcov } func (t *Task) savePtraceTracer() *Task { @@ -903,3 +908,16 @@ func (t *Task) UID() uint32 { func (t *Task) GID() uint32 { return uint32(t.Credentials().EffectiveKGID) } + +// SetKcov sets the kcov instance associated with t. +func (t *Task) SetKcov(k *Kcov) { + t.kcov = k +} + +// ResetKcov clears the kcov instance associated with t. +func (t *Task) ResetKcov() { + if t.kcov != nil { + t.kcov.Reset() + t.kcov = nil + } +} diff --git a/pkg/sentry/kernel/task_clone.go b/pkg/sentry/kernel/task_clone.go index 9d7a9128f..fce1064a7 100644 --- a/pkg/sentry/kernel/task_clone.go +++ b/pkg/sentry/kernel/task_clone.go @@ -341,12 +341,12 @@ func (t *Task) Clone(opts *CloneOptions) (ThreadID, *SyscallControl, error) { nt.SetClearTID(opts.ChildTID) } if opts.ChildSetTID { - // Can't use Task.CopyOut, which assumes AddressSpaceActive. - usermem.CopyObjectOut(t, nt.MemoryManager(), opts.ChildTID, nt.ThreadID(), usermem.IOOpts{}) + ctid := nt.ThreadID() + ctid.CopyOut(nt.AsCopyContext(usermem.IOOpts{AddressSpaceActive: false}), opts.ChildTID) } ntid := t.tg.pidns.IDOfTask(nt) if opts.ParentSetTID { - t.CopyOut(opts.ParentTID, ntid) + ntid.CopyOut(t, opts.ParentTID) } kind := ptraceCloneKindClone diff --git a/pkg/sentry/kernel/task_exec.go b/pkg/sentry/kernel/task_exec.go index 5e4fb3e3a..412d471d3 100644 --- a/pkg/sentry/kernel/task_exec.go +++ b/pkg/sentry/kernel/task_exec.go @@ -237,9 +237,10 @@ func (r *runSyscallAfterExecStop) execute(t *Task) taskRunState { // promoteLocked makes t the leader of its thread group. If t is already the // thread group leader, promoteLocked is a no-op. // -// Preconditions: All other tasks in t's thread group, including the existing -// leader (if it is not t), have reached TaskExitZombie. The TaskSet mutex must -// be locked for writing. +// Preconditions: +// * All other tasks in t's thread group, including the existing leader (if it +// is not t), have reached TaskExitZombie. +// * The TaskSet mutex must be locked for writing. func (t *Task) promoteLocked() { oldLeader := t.tg.leader if t == oldLeader { diff --git a/pkg/sentry/kernel/task_exit.go b/pkg/sentry/kernel/task_exit.go index c165d6cb1..b400a8b41 100644 --- a/pkg/sentry/kernel/task_exit.go +++ b/pkg/sentry/kernel/task_exit.go @@ -239,6 +239,8 @@ func (*runExitMain) execute(t *Task) taskRunState { t.traceExitEvent() lastExiter := t.exitThreadGroup() + t.ResetKcov() + // If the task has a cleartid, and the thread group wasn't killed by a // signal, handle that before releasing the MM. if t.cleartid != 0 { @@ -246,7 +248,8 @@ func (*runExitMain) execute(t *Task) taskRunState { signaled := t.tg.exiting && t.tg.exitStatus.Signaled() t.tg.signalHandlers.mu.Unlock() if !signaled { - if _, err := t.CopyOut(t.cleartid, ThreadID(0)); err == nil { + zero := ThreadID(0) + if _, err := zero.CopyOut(t, t.cleartid); err == nil { t.Futex().Wake(t, t.cleartid, false, ^uint32(0), 1) } // If the CopyOut fails, there's nothing we can do. diff --git a/pkg/sentry/kernel/task_futex.go b/pkg/sentry/kernel/task_futex.go index 4b535c949..c80391475 100644 --- a/pkg/sentry/kernel/task_futex.go +++ b/pkg/sentry/kernel/task_futex.go @@ -16,6 +16,7 @@ package kernel import ( "gvisor.dev/gvisor/pkg/abi/linux" + "gvisor.dev/gvisor/pkg/marshal/primitive" "gvisor.dev/gvisor/pkg/sentry/kernel/futex" "gvisor.dev/gvisor/pkg/usermem" ) @@ -87,7 +88,7 @@ func (t *Task) exitRobustList() { return } - next := rl.List + next := primitive.Uint64(rl.List) done := 0 var pendingLockAddr usermem.Addr if rl.ListOpPending != 0 { @@ -99,12 +100,12 @@ func (t *Task) exitRobustList() { // We traverse to the next element of the list before we // actually wake anything. This prevents the race where waking // this futex causes a modification of the list. - thisLockAddr := usermem.Addr(next + rl.FutexOffset) + thisLockAddr := usermem.Addr(uint64(next) + rl.FutexOffset) // Try to decode the next element in the list before waking the // current futex. But don't check the error until after we've // woken the current futex. Linux does it in this order too - _, nextErr := t.CopyIn(usermem.Addr(next), &next) + _, nextErr := next.CopyIn(t, usermem.Addr(next)) // Wakeup the current futex if it's not pending. if thisLockAddr != pendingLockAddr { diff --git a/pkg/sentry/kernel/task_run.go b/pkg/sentry/kernel/task_run.go index abaf29216..8dc3fec90 100644 --- a/pkg/sentry/kernel/task_run.go +++ b/pkg/sentry/kernel/task_run.go @@ -26,6 +26,7 @@ import ( ktime "gvisor.dev/gvisor/pkg/sentry/kernel/time" "gvisor.dev/gvisor/pkg/sentry/memmap" "gvisor.dev/gvisor/pkg/sentry/platform" + "gvisor.dev/gvisor/pkg/syserror" "gvisor.dev/gvisor/pkg/usermem" ) @@ -140,7 +141,7 @@ func (*runApp) handleCPUIDInstruction(t *Task) error { region := trace.StartRegion(t.traceContext, cpuidRegion) expected := arch.CPUIDInstruction[:] found := make([]byte, len(expected)) - _, err := t.CopyIn(usermem.Addr(t.Arch().IP()), &found) + _, err := t.CopyInBytes(usermem.Addr(t.Arch().IP()), found) if err == nil && bytes.Equal(expected, found) { // Skip the cpuid instruction. t.Arch().CPUIDEmulate(t) @@ -189,8 +190,8 @@ func (app *runApp) execute(t *Task) taskRunState { // a pending signal, causing another interruption, but that signal should // not interact with the interrupted syscall.) if t.haveSyscallReturn { - if sre, ok := SyscallRestartErrnoFromReturn(t.Arch().Return()); ok { - if sre == ERESTART_RESTARTBLOCK { + if sre, ok := syserror.SyscallRestartErrnoFromReturn(t.Arch().Return()); ok { + if sre == syserror.ERESTART_RESTARTBLOCK { t.Debugf("Restarting syscall %d with restart block after errno %d: not interrupted by handled signal", t.Arch().SyscallNo(), sre) t.Arch().RestartSyscallWithRestartBlock() } else { diff --git a/pkg/sentry/kernel/task_sched.go b/pkg/sentry/kernel/task_sched.go index 09366b60c..52c55d13d 100644 --- a/pkg/sentry/kernel/task_sched.go +++ b/pkg/sentry/kernel/task_sched.go @@ -133,9 +133,10 @@ func (t *Task) accountTaskGoroutineEnter(state TaskGoroutineState) { } } -// Preconditions: The caller must be running on the task goroutine, and leaving -// a state indicated by a previous call to -// t.accountTaskGoroutineEnter(state). +// Preconditions: +// * The caller must be running on the task goroutine +// * The caller must be leaving a state indicated by a previous call to +// t.accountTaskGoroutineEnter(state). func (t *Task) accountTaskGoroutineLeave(state TaskGoroutineState) { if state != TaskGoroutineRunningApp { // Task is unblocking/continuing. @@ -191,8 +192,8 @@ func (tg *ThreadGroup) CPUStats() usage.CPUStats { return tg.cpuStatsAtLocked(tg.leader.k.CPUClockNow()) } -// Preconditions: As for TaskGoroutineSchedInfo.userTicksAt. The TaskSet mutex -// must be locked. +// Preconditions: Same as TaskGoroutineSchedInfo.userTicksAt, plus: +// * The TaskSet mutex must be locked. func (tg *ThreadGroup) cpuStatsAtLocked(now uint64) usage.CPUStats { stats := tg.exitedCPUStats // Account for live tasks. diff --git a/pkg/sentry/kernel/task_signals.go b/pkg/sentry/kernel/task_signals.go index cff2a8365..feaa38596 100644 --- a/pkg/sentry/kernel/task_signals.go +++ b/pkg/sentry/kernel/task_signals.go @@ -159,7 +159,7 @@ func (t *Task) deliverSignal(info *arch.SignalInfo, act arch.SignalAct) taskRunS sigact := computeAction(linux.Signal(info.Signo), act) if t.haveSyscallReturn { - if sre, ok := SyscallRestartErrnoFromReturn(t.Arch().Return()); ok { + if sre, ok := syserror.SyscallRestartErrnoFromReturn(t.Arch().Return()); ok { // Signals that are ignored, cause a thread group stop, or // terminate the thread group do not interact with interrupted // syscalls; in Linux terms, they are never returned to the signal @@ -168,11 +168,11 @@ func (t *Task) deliverSignal(info *arch.SignalInfo, act arch.SignalAct) taskRunS // signal that is actually handled (by userspace). if sigact == SignalActionHandler { switch { - case sre == ERESTARTNOHAND: + case sre == syserror.ERESTARTNOHAND: fallthrough - case sre == ERESTART_RESTARTBLOCK: + case sre == syserror.ERESTART_RESTARTBLOCK: fallthrough - case (sre == ERESTARTSYS && !act.IsRestart()): + case (sre == syserror.ERESTARTSYS && !act.IsRestart()): t.Debugf("Not restarting syscall %d after errno %d: interrupted by signal %d", t.Arch().SyscallNo(), sre, info.Signo) t.Arch().SetReturn(uintptr(-ExtractErrno(syserror.EINTR, -1))) default: @@ -319,8 +319,9 @@ func (t *Task) SignalReturn(rt bool) (*SyscallControl, error) { // Sigtimedwait implements the semantics of sigtimedwait(2). // -// Preconditions: The caller must be running on the task goroutine. t.exitState -// < TaskExitZombie. +// Preconditions: +// * The caller must be running on the task goroutine. +// * t.exitState < TaskExitZombie. func (t *Task) Sigtimedwait(set linux.SignalSet, timeout time.Duration) (*arch.SignalInfo, error) { // set is the set of signals we're interested in; invert it to get the set // of signals to block. @@ -584,8 +585,9 @@ func (t *Task) SignalMask() linux.SignalSet { // SetSignalMask sets t's signal mask. // -// Preconditions: SetSignalMask can only be called by the task goroutine. -// t.exitState < TaskExitZombie. +// Preconditions: +// * The caller must be running on the task goroutine. +// * t.exitState < TaskExitZombie. func (t *Task) SetSignalMask(mask linux.SignalSet) { // By precondition, t prevents t.tg from completing an execve and mutating // t.tg.signalHandlers, so we can skip the TaskSet mutex. @@ -631,7 +633,7 @@ func (t *Task) setSignalMaskLocked(mask linux.SignalSet) { // SetSavedSignalMask sets the saved signal mask (see Task.savedSignalMask's // comment). // -// Preconditions: SetSavedSignalMask can only be called by the task goroutine. +// Preconditions: The caller must be running on the task goroutine. func (t *Task) SetSavedSignalMask(mask linux.SignalSet) { t.savedSignalMask = mask t.haveSavedSignalMask = true diff --git a/pkg/sentry/kernel/task_stop.go b/pkg/sentry/kernel/task_stop.go index 296735d32..a35948a5f 100644 --- a/pkg/sentry/kernel/task_stop.go +++ b/pkg/sentry/kernel/task_stop.go @@ -99,8 +99,9 @@ type TaskStop interface { // beginInternalStop indicates the start of an internal stop that applies to t. // -// Preconditions: The task must not already be in an internal stop (i.e. t.stop -// == nil). The caller must be running on the task goroutine. +// Preconditions: +// * The caller must be running on the task goroutine. +// * The task must not already be in an internal stop (i.e. t.stop == nil). func (t *Task) beginInternalStop(s TaskStop) { t.tg.pidns.owner.mu.RLock() defer t.tg.pidns.owner.mu.RUnlock() @@ -109,8 +110,8 @@ func (t *Task) beginInternalStop(s TaskStop) { t.beginInternalStopLocked(s) } -// Preconditions: The signal mutex must be locked. All preconditions for -// Task.beginInternalStop also apply. +// Preconditions: Same as beginInternalStop, plus: +// * The signal mutex must be locked. func (t *Task) beginInternalStopLocked(s TaskStop) { if t.stop != nil { panic(fmt.Sprintf("Attempting to enter internal stop %#v when already in internal stop %#v", s, t.stop)) @@ -128,8 +129,9 @@ func (t *Task) beginInternalStopLocked(s TaskStop) { // t.stop, which is why there is no endInternalStop that locks the signal mutex // for you. // -// Preconditions: The signal mutex must be locked. The task must be in an -// internal stop (i.e. t.stop != nil). +// Preconditions: +// * The signal mutex must be locked. +// * The task must be in an internal stop (i.e. t.stop != nil). func (t *Task) endInternalStopLocked() { if t.stop == nil { panic("Attempting to leave non-existent internal stop") diff --git a/pkg/sentry/kernel/task_syscall.go b/pkg/sentry/kernel/task_syscall.go index a5903b0b5..0141459e7 100644 --- a/pkg/sentry/kernel/task_syscall.go +++ b/pkg/sentry/kernel/task_syscall.go @@ -22,6 +22,7 @@ import ( "gvisor.dev/gvisor/pkg/abi/linux" "gvisor.dev/gvisor/pkg/bits" + "gvisor.dev/gvisor/pkg/marshal" "gvisor.dev/gvisor/pkg/metric" "gvisor.dev/gvisor/pkg/sentry/arch" "gvisor.dev/gvisor/pkg/sentry/memmap" @@ -29,75 +30,8 @@ import ( "gvisor.dev/gvisor/pkg/usermem" ) -// SyscallRestartErrno represents a ERESTART* errno defined in the Linux's kernel -// include/linux/errno.h. These errnos are never returned to userspace -// directly, but are used to communicate the expected behavior of an -// interrupted syscall from the syscall to signal handling. -type SyscallRestartErrno int - -// These numeric values are significant because ptrace syscall exit tracing can -// observe them. -// -// For all of the following errnos, if the syscall is not interrupted by a -// signal delivered to a user handler, the syscall is restarted. -const ( - // ERESTARTSYS is returned by an interrupted syscall to indicate that it - // should be converted to EINTR if interrupted by a signal delivered to a - // user handler without SA_RESTART set, and restarted otherwise. - ERESTARTSYS = SyscallRestartErrno(512) - - // ERESTARTNOINTR is returned by an interrupted syscall to indicate that it - // should always be restarted. - ERESTARTNOINTR = SyscallRestartErrno(513) - - // ERESTARTNOHAND is returned by an interrupted syscall to indicate that it - // should be converted to EINTR if interrupted by a signal delivered to a - // user handler, and restarted otherwise. - ERESTARTNOHAND = SyscallRestartErrno(514) - - // ERESTART_RESTARTBLOCK is returned by an interrupted syscall to indicate - // that it should be restarted using a custom function. The interrupted - // syscall must register a custom restart function by calling - // Task.SetRestartSyscallFn. - ERESTART_RESTARTBLOCK = SyscallRestartErrno(516) -) - var vsyscallCount = metric.MustCreateNewUint64Metric("/kernel/vsyscall_count", false /* sync */, "Number of times vsyscalls were invoked by the application") -// Error implements error.Error. -func (e SyscallRestartErrno) Error() string { - // Descriptions are borrowed from strace. - switch e { - case ERESTARTSYS: - return "to be restarted if SA_RESTART is set" - case ERESTARTNOINTR: - return "to be restarted" - case ERESTARTNOHAND: - return "to be restarted if no handler" - case ERESTART_RESTARTBLOCK: - return "interrupted by signal" - default: - return "(unknown interrupt error)" - } -} - -// SyscallRestartErrnoFromReturn returns the SyscallRestartErrno represented by -// rv, the value in a syscall return register. -func SyscallRestartErrnoFromReturn(rv uintptr) (SyscallRestartErrno, bool) { - switch int(rv) { - case -int(ERESTARTSYS): - return ERESTARTSYS, true - case -int(ERESTARTNOINTR): - return ERESTARTNOINTR, true - case -int(ERESTARTNOHAND): - return ERESTARTNOHAND, true - case -int(ERESTART_RESTARTBLOCK): - return ERESTART_RESTARTBLOCK, true - default: - return 0, false - } -} - // SyscallRestartBlock represents the restart block for a syscall restartable // with a custom function. It encapsulates the state required to restart a // syscall across a S/R. @@ -354,7 +288,7 @@ func (t *Task) doVsyscall(addr usermem.Addr, sysno uintptr) taskRunState { // Grab the caller up front, to make sure there's a sensible stack. caller := t.Arch().Native(uintptr(0)) - if _, err := t.CopyIn(usermem.Addr(t.Arch().Stack()), caller); err != nil { + if _, err := caller.CopyIn(t, usermem.Addr(t.Arch().Stack())); err != nil { t.Debugf("vsyscall %d: error reading return address from stack: %v", sysno, err) t.forceSignal(linux.SIGSEGV, false /* unconditional */) t.SendSignal(SignalInfoPriv(linux.SIGSEGV)) @@ -390,7 +324,7 @@ func (t *Task) doVsyscall(addr usermem.Addr, sysno uintptr) taskRunState { type runVsyscallAfterPtraceEventSeccomp struct { addr usermem.Addr sysno uintptr - caller interface{} + caller marshal.Marshallable } func (r *runVsyscallAfterPtraceEventSeccomp) execute(t *Task) taskRunState { @@ -413,7 +347,7 @@ func (r *runVsyscallAfterPtraceEventSeccomp) execute(t *Task) taskRunState { return t.doVsyscallInvoke(sysno, t.Arch().SyscallArgs(), r.caller) } -func (t *Task) doVsyscallInvoke(sysno uintptr, args arch.SyscallArguments, caller interface{}) taskRunState { +func (t *Task) doVsyscallInvoke(sysno uintptr, args arch.SyscallArguments, caller marshal.Marshallable) taskRunState { rval, ctrl, err := t.executeSyscall(sysno, args) if ctrl != nil { t.Debugf("vsyscall %d, caller %x: syscall control: %v", sysno, t.Arch().Value(caller), ctrl) @@ -447,7 +381,7 @@ func ExtractErrno(err error, sysno int) int { return 0 case syscall.Errno: return int(err) - case SyscallRestartErrno: + case syserror.SyscallRestartErrno: return int(err) case *memmap.BusError: // Bus errors may generate SIGBUS, but for syscalls they still diff --git a/pkg/sentry/kernel/task_usermem.go b/pkg/sentry/kernel/task_usermem.go index b02044ad2..ce134bf54 100644 --- a/pkg/sentry/kernel/task_usermem.go +++ b/pkg/sentry/kernel/task_usermem.go @@ -18,6 +18,7 @@ import ( "math" "gvisor.dev/gvisor/pkg/abi/linux" + "gvisor.dev/gvisor/pkg/marshal" "gvisor.dev/gvisor/pkg/syserror" "gvisor.dev/gvisor/pkg/usermem" ) @@ -43,17 +44,6 @@ func (t *Task) Deactivate() { } } -// CopyIn copies a fixed-size value or slice of fixed-size values in from the -// task's memory. The copy will fail with syscall.EFAULT if it traverses user -// memory that is unmapped or not readable by the user. -// -// This Task's AddressSpace must be active. -func (t *Task) CopyIn(addr usermem.Addr, dst interface{}) (int, error) { - return usermem.CopyObjectIn(t, t.MemoryManager(), addr, dst, usermem.IOOpts{ - AddressSpaceActive: true, - }) -} - // CopyInBytes is a fast version of CopyIn if the caller can serialize the // data without reflection and pass in a byte slice. // @@ -64,17 +54,6 @@ func (t *Task) CopyInBytes(addr usermem.Addr, dst []byte) (int, error) { }) } -// CopyOut copies a fixed-size value or slice of fixed-size values out to the -// task's memory. The copy will fail with syscall.EFAULT if it traverses user -// memory that is unmapped or not writeable by the user. -// -// This Task's AddressSpace must be active. -func (t *Task) CopyOut(addr usermem.Addr, src interface{}) (int, error) { - return usermem.CopyObjectOut(t, t.MemoryManager(), addr, src, usermem.IOOpts{ - AddressSpaceActive: true, - }) -} - // CopyOutBytes is a fast version of CopyOut if the caller can serialize the // data without reflection and pass in a byte slice. // @@ -114,7 +93,7 @@ func (t *Task) CopyInVector(addr usermem.Addr, maxElemSize, maxTotalSize int) ([ var v []string for { argAddr := t.Arch().Native(0) - if _, err := t.CopyIn(addr, argAddr); err != nil { + if _, err := argAddr.CopyIn(t, addr); err != nil { return v, err } if t.Arch().Value(argAddr) == 0 { @@ -143,8 +122,9 @@ func (t *Task) CopyInVector(addr usermem.Addr, maxElemSize, maxTotalSize int) ([ // CopyOutIovecs converts src to an array of struct iovecs and copies it to the // memory mapped at addr. // -// Preconditions: As for usermem.IO.CopyOut. The caller must be running on the -// task goroutine. t's AddressSpace must be active. +// Preconditions: Same as usermem.IO.CopyOut, plus: +// * The caller must be running on the task goroutine. +// * t's AddressSpace must be active. func (t *Task) CopyOutIovecs(addr usermem.Addr, src usermem.AddrRangeSeq) error { switch t.Arch().Width() { case 8: @@ -191,8 +171,9 @@ func (t *Task) CopyOutIovecs(addr usermem.Addr, src usermem.AddrRangeSeq) error // combined length of all AddrRanges would otherwise exceed this amount, ranges // beyond MAX_RW_COUNT are silently truncated. // -// Preconditions: As for usermem.IO.CopyIn. The caller must be running on the -// task goroutine. t's AddressSpace must be active. +// Preconditions: Same as usermem.IO.CopyIn, plus: +// * The caller must be running on the task goroutine. +// * t's AddressSpace must be active. func (t *Task) CopyInIovecs(addr usermem.Addr, numIovecs int) (usermem.AddrRangeSeq, error) { if numIovecs == 0 { return usermem.AddrRangeSeq{}, nil @@ -284,7 +265,7 @@ func (t *Task) SingleIOSequence(addr usermem.Addr, length int, opts usermem.IOOp // // IovecsIOSequence is analogous to Linux's lib/iov_iter.c:import_iovec(). // -// Preconditions: As for Task.CopyInIovecs. +// Preconditions: Same as Task.CopyInIovecs. func (t *Task) IovecsIOSequence(addr usermem.Addr, iovcnt int, opts usermem.IOOpts) (usermem.IOSequence, error) { if iovcnt < 0 || iovcnt > linux.UIO_MAXIOV { return usermem.IOSequence{}, syserror.EINVAL @@ -299,3 +280,30 @@ func (t *Task) IovecsIOSequence(addr usermem.Addr, iovcnt int, opts usermem.IOOp Opts: opts, }, nil } + +// copyContext implements marshal.CopyContext. It wraps a task to allow copying +// memory to and from the task memory with custom usermem.IOOpts. +type copyContext struct { + *Task + opts usermem.IOOpts +} + +// AsCopyContext wraps the task and returns it as CopyContext. +func (t *Task) AsCopyContext(opts usermem.IOOpts) marshal.CopyContext { + return ©Context{t, opts} +} + +// CopyInString copies a string in from the task's memory. +func (t *copyContext) CopyInString(addr usermem.Addr, maxLen int) (string, error) { + return usermem.CopyStringIn(t, t.MemoryManager(), addr, maxLen, t.opts) +} + +// CopyInBytes copies task memory into dst from an IO context. +func (t *copyContext) CopyInBytes(addr usermem.Addr, dst []byte) (int, error) { + return t.MemoryManager().CopyIn(t, addr, dst, t.opts) +} + +// CopyOutBytes copies src into task memoryfrom an IO context. +func (t *copyContext) CopyOutBytes(addr usermem.Addr, src []byte) (int, error) { + return t.MemoryManager().CopyOut(t, addr, src, t.opts) +} diff --git a/pkg/sentry/kernel/threads.go b/pkg/sentry/kernel/threads.go index 872e1a82d..5ae5906e8 100644 --- a/pkg/sentry/kernel/threads.go +++ b/pkg/sentry/kernel/threads.go @@ -36,6 +36,8 @@ import ( const TasksLimit = (1 << 16) // ThreadID is a generic thread identifier. +// +// +marshal type ThreadID int32 // String returns a decimal representation of the ThreadID. diff --git a/pkg/sentry/kernel/time/time.go b/pkg/sentry/kernel/time/time.go index e959700f2..f61a8e164 100644 --- a/pkg/sentry/kernel/time/time.go +++ b/pkg/sentry/kernel/time/time.go @@ -616,8 +616,10 @@ func (t *Timer) Swap(s Setting) (Time, Setting) { // Timer's Clock) at which the Setting was changed. Setting s.Enabled to true // starts the timer, while setting s.Enabled to false stops it. // -// Preconditions: The Timer must not be paused. f cannot call any Timer methods -// since it is called with the Timer mutex locked. +// Preconditions: +// * The Timer must not be paused. +// * f cannot call any Timer methods since it is called with the Timer mutex +// locked. func (t *Timer) SwapAnd(s Setting, f func()) (Time, Setting) { now := t.clock.Now() t.mu.Lock() diff --git a/pkg/sentry/kernel/vdso.go b/pkg/sentry/kernel/vdso.go index 290c32466..e44a139b3 100644 --- a/pkg/sentry/kernel/vdso.go +++ b/pkg/sentry/kernel/vdso.go @@ -73,13 +73,10 @@ type VDSOParamPage struct { // NewVDSOParamPage returns a VDSOParamPage. // // Preconditions: -// // * fr is a single page allocated from mfp.MemoryFile(). VDSOParamPage does // not take ownership of fr; it must remain allocated for the lifetime of the // VDSOParamPage. -// // * VDSOParamPage must be the only writer to fr. -// // * mfp.MemoryFile().MapInternal(fr) must return a single safemem.Block. func NewVDSOParamPage(mfp pgalloc.MemoryFileProvider, fr memmap.FileRange) *VDSOParamPage { return &VDSOParamPage{mfp: mfp, fr: fr} diff --git a/pkg/sentry/limits/context.go b/pkg/sentry/limits/context.go index 77e1fe217..0bade6e57 100644 --- a/pkg/sentry/limits/context.go +++ b/pkg/sentry/limits/context.go @@ -33,3 +33,12 @@ func FromContext(ctx context.Context) *LimitSet { } return nil } + +// FromContextOrDie returns FromContext(ctx) if the latter is not nil. +// Otherwise, panic is triggered. +func FromContextOrDie(ctx context.Context) *LimitSet { + if v := ctx.Value(CtxLimits); v != nil { + return v.(*LimitSet) + } + panic("failed to create limit set from context") +} diff --git a/pkg/sentry/loader/elf.go b/pkg/sentry/loader/elf.go index 20dd1cc21..d4610ec3b 100644 --- a/pkg/sentry/loader/elf.go +++ b/pkg/sentry/loader/elf.go @@ -402,8 +402,7 @@ type loadedELF struct { // // It does not load the ELF interpreter, or return any auxv entries. // -// Preconditions: -// * f is an ELF file +// Preconditions: f is an ELF file. func loadParsedELF(ctx context.Context, m *mm.MemoryManager, f fsbridge.File, info elfInfo, sharedLoadOffset usermem.Addr) (loadedELF, error) { first := true var start, end usermem.Addr @@ -571,8 +570,8 @@ func loadParsedELF(ctx context.Context, m *mm.MemoryManager, f fsbridge.File, in // It does not load the ELF interpreter, or return any auxv entries. // // Preconditions: -// * f is an ELF file -// * f is the first ELF loaded into m +// * f is an ELF file. +// * f is the first ELF loaded into m. func loadInitialELF(ctx context.Context, m *mm.MemoryManager, fs *cpuid.FeatureSet, f fsbridge.File) (loadedELF, arch.Context, error) { info, err := parseHeader(ctx, f) if err != nil { @@ -609,8 +608,7 @@ func loadInitialELF(ctx context.Context, m *mm.MemoryManager, fs *cpuid.FeatureS // // It does not return any auxv entries. // -// Preconditions: -// * f is an ELF file +// Preconditions: f is an ELF file. func loadInterpreterELF(ctx context.Context, m *mm.MemoryManager, f fsbridge.File, initial loadedELF) (loadedELF, error) { info, err := parseHeader(ctx, f) if err != nil { @@ -640,8 +638,7 @@ func loadInterpreterELF(ctx context.Context, m *mm.MemoryManager, f fsbridge.Fil // If loadELF returns ErrSwitchFile it should be called again with the returned // path and argv. // -// Preconditions: -// * args.File is an ELF file +// Preconditions: args.File is an ELF file. func loadELF(ctx context.Context, args LoadArgs) (loadedELF, arch.Context, error) { bin, ac, err := loadInitialELF(ctx, args.MemoryManager, args.Features, args.File) if err != nil { diff --git a/pkg/sentry/loader/loader.go b/pkg/sentry/loader/loader.go index 8d6802ea3..15c88aa7c 100644 --- a/pkg/sentry/loader/loader.go +++ b/pkg/sentry/loader/loader.go @@ -215,8 +215,8 @@ func loadExecutable(ctx context.Context, args LoadArgs) (loadedELF, arch.Context // path and argv. // // Preconditions: -// * The Task MemoryManager is empty. -// * Load is called on the Task goroutine. +// * The Task MemoryManager is empty. +// * Load is called on the Task goroutine. func Load(ctx context.Context, args LoadArgs, extraAuxv []arch.AuxEntry, vdso *VDSO) (abi.OS, arch.Context, string, *syserr.Error) { // Load the executable itself. loaded, ac, file, newArgv, err := loadExecutable(ctx, args) diff --git a/pkg/sentry/memmap/mapping_set.go b/pkg/sentry/memmap/mapping_set.go index d609c1ae0..457ed87f8 100644 --- a/pkg/sentry/memmap/mapping_set.go +++ b/pkg/sentry/memmap/mapping_set.go @@ -177,7 +177,7 @@ func subsetMapping(wholeRange, subsetRange MappableRange, ms MappingSpace, addr // AddMapping adds the given mapping and returns the set of MappableRanges that // previously had no mappings. // -// Preconditions: As for Mappable.AddMapping. +// Preconditions: Same as Mappable.AddMapping. func (s *MappingSet) AddMapping(ms MappingSpace, ar usermem.AddrRange, offset uint64, writable bool) []MappableRange { mr := MappableRange{offset, offset + uint64(ar.Length())} var mapped []MappableRange @@ -204,7 +204,7 @@ func (s *MappingSet) AddMapping(ms MappingSpace, ar usermem.AddrRange, offset ui // RemoveMapping removes the given mapping and returns the set of // MappableRanges that now have no mappings. // -// Preconditions: As for Mappable.RemoveMapping. +// Preconditions: Same as Mappable.RemoveMapping. func (s *MappingSet) RemoveMapping(ms MappingSpace, ar usermem.AddrRange, offset uint64, writable bool) []MappableRange { mr := MappableRange{offset, offset + uint64(ar.Length())} var unmapped []MappableRange diff --git a/pkg/sentry/memmap/memmap.go b/pkg/sentry/memmap/memmap.go index 65d83096f..a44fa2b95 100644 --- a/pkg/sentry/memmap/memmap.go +++ b/pkg/sentry/memmap/memmap.go @@ -28,9 +28,9 @@ import ( // // See mm/mm.go for Mappable's place in the lock order. // -// Preconditions: For all Mappable methods, usermem.AddrRanges and -// MappableRanges must be non-empty (Length() != 0), and usermem.Addrs and -// Mappable offsets must be page-aligned. +// All Mappable methods have the following preconditions: +// * usermem.AddrRanges and MappableRanges must be non-empty (Length() != 0). +// * usermem.Addrs and Mappable offsets must be page-aligned. type Mappable interface { // AddMapping notifies the Mappable of a mapping from addresses ar in ms to // offsets [offset, offset+ar.Length()) in this Mappable. @@ -48,8 +48,10 @@ type Mappable interface { // addresses ar in ms to offsets [offset, offset+ar.Length()) in this // Mappable. // - // Preconditions: offset+ar.Length() does not overflow. The removed mapping - // must exist. writable must match the corresponding call to AddMapping. + // Preconditions: + // * offset+ar.Length() does not overflow. + // * The removed mapping must exist. writable must match the + // corresponding call to AddMapping. RemoveMapping(ctx context.Context, ms MappingSpace, ar usermem.AddrRange, offset uint64, writable bool) // CopyMapping notifies the Mappable of an attempt to copy a mapping in ms @@ -60,9 +62,10 @@ type Mappable interface { // CopyMapping is only called when a mapping is copied within a given // MappingSpace; it is analogous to Linux's vm_operations_struct::mremap. // - // Preconditions: offset+srcAR.Length() and offset+dstAR.Length() do not - // overflow. The mapping at srcAR must exist. writable must match the - // corresponding call to AddMapping. + // Preconditions: + // * offset+srcAR.Length() and offset+dstAR.Length() do not overflow. + // * The mapping at srcAR must exist. writable must match the + // corresponding call to AddMapping. CopyMapping(ctx context.Context, ms MappingSpace, srcAR, dstAR usermem.AddrRange, offset uint64, writable bool) error // Translate returns the Mappable's current mappings for at least the range @@ -77,11 +80,14 @@ type Mappable interface { // reference is held on all pages in a File that may be the result // of a valid Translation. // - // Preconditions: required.Length() > 0. optional.IsSupersetOf(required). - // required and optional must be page-aligned. The caller must have - // established a mapping for all of the queried offsets via a previous call - // to AddMapping. The caller is responsible for ensuring that calls to - // Translate synchronize with invalidation. + // Preconditions: + // * required.Length() > 0. + // * optional.IsSupersetOf(required). + // * required and optional must be page-aligned. + // * The caller must have established a mapping for all of the queried + // offsets via a previous call to AddMapping. + // * The caller is responsible for ensuring that calls to Translate + // synchronize with invalidation. // // Postconditions: See CheckTranslateResult. Translate(ctx context.Context, required, optional MappableRange, at usermem.AccessType) ([]Translation, error) @@ -118,7 +124,7 @@ func (t Translation) FileRange() FileRange { // CheckTranslateResult returns an error if (ts, terr) does not satisfy all // postconditions for Mappable.Translate(required, optional, at). // -// Preconditions: As for Mappable.Translate. +// Preconditions: Same as Mappable.Translate. func CheckTranslateResult(required, optional MappableRange, at usermem.AccessType, ts []Translation, terr error) error { // Verify that the inputs to Mappable.Translate were valid. if !required.WellFormed() || required.Length() <= 0 { @@ -214,7 +220,9 @@ type MappingSpace interface { // Invalidate must not take any locks preceding mm.MemoryManager.activeMu // in the lock order. // - // Preconditions: ar.Length() != 0. ar must be page-aligned. + // Preconditions: + // * ar.Length() != 0. + // * ar must be page-aligned. Invalidate(ar usermem.AddrRange, opts InvalidateOpts) } @@ -375,16 +383,20 @@ type File interface { // IncRef increments the reference count on all pages in fr. // - // Preconditions: fr.Start and fr.End must be page-aligned. fr.Length() > - // 0. At least one reference must be held on all pages in fr. (The File - // interface does not provide a way to acquire an initial reference; - // implementors may define mechanisms for doing so.) + // Preconditions: + // * fr.Start and fr.End must be page-aligned. + // * fr.Length() > 0. + // * At least one reference must be held on all pages in fr. (The File + // interface does not provide a way to acquire an initial reference; + // implementors may define mechanisms for doing so.) IncRef(fr FileRange) // DecRef decrements the reference count on all pages in fr. // - // Preconditions: fr.Start and fr.End must be page-aligned. fr.Length() > - // 0. At least one reference must be held on all pages in fr. + // Preconditions: + // * fr.Start and fr.End must be page-aligned. + // * fr.Length() > 0. + // * At least one reference must be held on all pages in fr. DecRef(fr FileRange) // MapInternal returns a mapping of the given file offsets in the invoking @@ -392,8 +404,9 @@ type File interface { // // Note that fr.Start and fr.End need not be page-aligned. // - // Preconditions: fr.Length() > 0. At least one reference must be held on - // all pages in fr. + // Preconditions: + // * fr.Length() > 0. + // * At least one reference must be held on all pages in fr. // // Postconditions: The returned mapping is valid as long as at least one // reference is held on the mapped pages. diff --git a/pkg/sentry/mm/BUILD b/pkg/sentry/mm/BUILD index f9d0837a1..b4a47ccca 100644 --- a/pkg/sentry/mm/BUILD +++ b/pkg/sentry/mm/BUILD @@ -73,12 +73,35 @@ go_template_instance( }, ) +go_template_instance( + name = "aio_mappable_refs", + out = "aio_mappable_refs.go", + package = "mm", + prefix = "aioMappable", + template = "//pkg/refs_vfs2:refs_template", + types = { + "T": "aioMappable", + }, +) + +go_template_instance( + name = "special_mappable_refs", + out = "special_mappable_refs.go", + package = "mm", + prefix = "SpecialMappable", + template = "//pkg/refs_vfs2:refs_template", + types = { + "T": "SpecialMappable", + }, +) + go_library( name = "mm", srcs = [ "address_space.go", "aio_context.go", "aio_context_state.go", + "aio_mappable_refs.go", "debug.go", "file_refcount_set.go", "io.go", @@ -92,6 +115,7 @@ go_library( "save_restore.go", "shm.go", "special_mappable.go", + "special_mappable_refs.go", "syscalls.go", "vma.go", "vma_set.go", diff --git a/pkg/sentry/mm/address_space.go b/pkg/sentry/mm/address_space.go index 5c667117c..a93e76c75 100644 --- a/pkg/sentry/mm/address_space.go +++ b/pkg/sentry/mm/address_space.go @@ -166,8 +166,12 @@ func (mm *MemoryManager) Deactivate() { // mapASLocked maps addresses in ar into mm.as. If precommit is true, mappings // for all addresses in ar should be precommitted. // -// Preconditions: mm.activeMu must be locked. mm.as != nil. ar.Length() != 0. -// ar must be page-aligned. pseg == mm.pmas.LowerBoundSegment(ar.Start). +// Preconditions: +// * mm.activeMu must be locked. +// * mm.as != nil. +// * ar.Length() != 0. +// * ar must be page-aligned. +// * pseg == mm.pmas.LowerBoundSegment(ar.Start). func (mm *MemoryManager) mapASLocked(pseg pmaIterator, ar usermem.AddrRange, precommit bool) error { // By default, map entire pmas at a time, under the assumption that there // is no cost to mapping more of a pma than necessary. diff --git a/pkg/sentry/mm/aio_context.go b/pkg/sentry/mm/aio_context.go index 16fea53c4..7bf48cb2c 100644 --- a/pkg/sentry/mm/aio_context.go +++ b/pkg/sentry/mm/aio_context.go @@ -17,7 +17,6 @@ package mm import ( "gvisor.dev/gvisor/pkg/abi/linux" "gvisor.dev/gvisor/pkg/context" - "gvisor.dev/gvisor/pkg/refs" "gvisor.dev/gvisor/pkg/sentry/memmap" "gvisor.dev/gvisor/pkg/sentry/pgalloc" "gvisor.dev/gvisor/pkg/sentry/usage" @@ -239,7 +238,7 @@ func (ctx *AIOContext) Drain() { // // +stateify savable type aioMappable struct { - refs.AtomicRefCount + aioMappableRefs mfp pgalloc.MemoryFileProvider fr memmap.FileRange @@ -253,13 +252,13 @@ func newAIOMappable(mfp pgalloc.MemoryFileProvider) (*aioMappable, error) { return nil, err } m := aioMappable{mfp: mfp, fr: fr} - m.EnableLeakCheck("mm.aioMappable") + m.EnableLeakCheck() return &m, nil } // DecRef implements refs.RefCounter.DecRef. func (m *aioMappable) DecRef(ctx context.Context) { - m.AtomicRefCount.DecRefWithDestructor(ctx, func(context.Context) { + m.aioMappableRefs.DecRef(func() { m.mfp.MemoryFile().DecRef(m.fr) }) } diff --git a/pkg/sentry/mm/io.go b/pkg/sentry/mm/io.go index fa776f9c6..a8ac48080 100644 --- a/pkg/sentry/mm/io.go +++ b/pkg/sentry/mm/io.go @@ -441,7 +441,10 @@ func (mm *MemoryManager) LoadUint32(ctx context.Context, addr usermem.Addr, opts // handleASIOFault handles a page fault at address addr for an AddressSpaceIO // operation spanning ioar. // -// Preconditions: mm.as != nil. ioar.Length() != 0. ioar.Contains(addr). +// Preconditions: +// * mm.as != nil. +// * ioar.Length() != 0. +// * ioar.Contains(addr). func (mm *MemoryManager) handleASIOFault(ctx context.Context, addr usermem.Addr, ioar usermem.AddrRange, at usermem.AccessType) error { // Try to map all remaining pages in the I/O operation. This RoundUp can't // overflow because otherwise it would have been caught by CheckIORange. @@ -629,7 +632,9 @@ func (mm *MemoryManager) withVecInternalMappings(ctx context.Context, ars userme // at most address end on AddrRange arsit.Head(). It is used in vector I/O paths to // truncate usermem.AddrRangeSeq when errors occur. // -// Preconditions: !arsit.IsEmpty(). end <= arsit.Head().End. +// Preconditions: +// * !arsit.IsEmpty(). +// * end <= arsit.Head().End. func truncatedAddrRangeSeq(ars, arsit usermem.AddrRangeSeq, end usermem.Addr) usermem.AddrRangeSeq { ar := arsit.Head() if end <= ar.Start { diff --git a/pkg/sentry/mm/mm.go b/pkg/sentry/mm/mm.go index 3e85964e4..8c9f11cce 100644 --- a/pkg/sentry/mm/mm.go +++ b/pkg/sentry/mm/mm.go @@ -242,7 +242,7 @@ type MemoryManager struct { // +stateify savable type vma struct { // mappable is the virtual memory object mapped by this vma. If mappable is - // nil, the vma represents a private anonymous mapping. + // nil, the vma represents an anonymous mapping. mappable memmap.Mappable // off is the offset into mappable at which this vma begins. If mappable is diff --git a/pkg/sentry/mm/mm_test.go b/pkg/sentry/mm/mm_test.go index fdc308542..acac3d357 100644 --- a/pkg/sentry/mm/mm_test.go +++ b/pkg/sentry/mm/mm_test.go @@ -51,7 +51,8 @@ func TestUsageASUpdates(t *testing.T) { defer mm.DecUsers(ctx) addr, err := mm.MMap(ctx, memmap.MMapOpts{ - Length: 2 * usermem.PageSize, + Length: 2 * usermem.PageSize, + Private: true, }) if err != nil { t.Fatalf("MMap got err %v want nil", err) diff --git a/pkg/sentry/mm/pma.go b/pkg/sentry/mm/pma.go index 930ec895f..30facebf7 100644 --- a/pkg/sentry/mm/pma.go +++ b/pkg/sentry/mm/pma.go @@ -31,7 +31,9 @@ import ( // iterator to the pma containing ar.Start. Otherwise it returns a terminal // iterator. // -// Preconditions: mm.activeMu must be locked. ar.Length() != 0. +// Preconditions: +// * mm.activeMu must be locked. +// * ar.Length() != 0. func (mm *MemoryManager) existingPMAsLocked(ar usermem.AddrRange, at usermem.AccessType, ignorePermissions bool, needInternalMappings bool) pmaIterator { if checkInvariants { if !ar.WellFormed() || ar.Length() <= 0 { @@ -89,10 +91,13 @@ func (mm *MemoryManager) existingVecPMAsLocked(ars usermem.AddrRangeSeq, at user // // - An error that is non-nil if pmas exist for only a subset of ar. // -// Preconditions: mm.mappingMu must be locked. mm.activeMu must be locked for -// writing. ar.Length() != 0. vseg.Range().Contains(ar.Start). vmas must exist -// for all addresses in ar, and support accesses of type at (i.e. permission -// checks must have been performed against vmas). +// Preconditions: +// * mm.mappingMu must be locked. +// * mm.activeMu must be locked for writing. +// * ar.Length() != 0. +// * vseg.Range().Contains(ar.Start). +// * vmas must exist for all addresses in ar, and support accesses of type at +// (i.e. permission checks must have been performed against vmas). func (mm *MemoryManager) getPMAsLocked(ctx context.Context, vseg vmaIterator, ar usermem.AddrRange, at usermem.AccessType) (pmaIterator, pmaGapIterator, error) { if checkInvariants { if !ar.WellFormed() || ar.Length() <= 0 { @@ -135,9 +140,11 @@ func (mm *MemoryManager) getPMAsLocked(ctx context.Context, vseg vmaIterator, ar // exist. If this is not equal to ars, it returns a non-nil error explaining // why. // -// Preconditions: mm.mappingMu must be locked. mm.activeMu must be locked for -// writing. vmas must exist for all addresses in ars, and support accesses of -// type at (i.e. permission checks must have been performed against vmas). +// Preconditions: +// * mm.mappingMu must be locked. +// * mm.activeMu must be locked for writing. +// * vmas must exist for all addresses in ars, and support accesses of type at +// (i.e. permission checks must have been performed against vmas). func (mm *MemoryManager) getVecPMAsLocked(ctx context.Context, ars usermem.AddrRangeSeq, at usermem.AccessType) (usermem.AddrRangeSeq, error) { for arsit := ars; !arsit.IsEmpty(); arsit = arsit.Tail() { ar := arsit.Head() @@ -518,8 +525,10 @@ func privateAligned(ar usermem.AddrRange) usermem.AddrRange { // the memory it maps, isPMACopyOnWriteLocked will take ownership of the memory // and update the pma to indicate that it does not require copy-on-write. // -// Preconditions: vseg.Range().IsSupersetOf(pseg.Range()). mm.mappingMu must be -// locked. mm.activeMu must be locked for writing. +// Preconditions: +// * vseg.Range().IsSupersetOf(pseg.Range()). +// * mm.mappingMu must be locked. +// * mm.activeMu must be locked for writing. func (mm *MemoryManager) isPMACopyOnWriteLocked(vseg vmaIterator, pseg pmaIterator) bool { pma := pseg.ValuePtr() if !pma.needCOW { @@ -568,8 +577,10 @@ func (mm *MemoryManager) Invalidate(ar usermem.AddrRange, opts memmap.Invalidate // invalidateLocked removes pmas and AddressSpace mappings of those pmas for // addresses in ar. // -// Preconditions: mm.activeMu must be locked for writing. ar.Length() != 0. ar -// must be page-aligned. +// Preconditions: +// * mm.activeMu must be locked for writing. +// * ar.Length() != 0. +// * ar must be page-aligned. func (mm *MemoryManager) invalidateLocked(ar usermem.AddrRange, invalidatePrivate, invalidateShared bool) { if checkInvariants { if !ar.WellFormed() || ar.Length() <= 0 || !ar.IsPageAligned() { @@ -613,7 +624,9 @@ func (mm *MemoryManager) invalidateLocked(ar usermem.AddrRange, invalidatePrivat // most I/O. It should only be used in contexts that would use get_user_pages() // in the Linux kernel. // -// Preconditions: ar.Length() != 0. ar must be page-aligned. +// Preconditions: +// * ar.Length() != 0. +// * ar must be page-aligned. func (mm *MemoryManager) Pin(ctx context.Context, ar usermem.AddrRange, at usermem.AccessType, ignorePermissions bool) ([]PinnedRange, error) { if checkInvariants { if !ar.WellFormed() || ar.Length() <= 0 || !ar.IsPageAligned() { @@ -693,9 +706,13 @@ func Unpin(prs []PinnedRange) { // movePMAsLocked moves all pmas in oldAR to newAR. // -// Preconditions: mm.activeMu must be locked for writing. oldAR.Length() != 0. -// oldAR.Length() <= newAR.Length(). !oldAR.Overlaps(newAR). -// mm.pmas.IsEmptyRange(newAR). oldAR and newAR must be page-aligned. +// Preconditions: +// * mm.activeMu must be locked for writing. +// * oldAR.Length() != 0. +// * oldAR.Length() <= newAR.Length(). +// * !oldAR.Overlaps(newAR). +// * mm.pmas.IsEmptyRange(newAR). +// * oldAR and newAR must be page-aligned. func (mm *MemoryManager) movePMAsLocked(oldAR, newAR usermem.AddrRange) { if checkInvariants { if !oldAR.WellFormed() || oldAR.Length() <= 0 || !oldAR.IsPageAligned() { @@ -751,9 +768,11 @@ func (mm *MemoryManager) movePMAsLocked(oldAR, newAR usermem.AddrRange) { // - An error that is non-nil if internal mappings exist for only a subset of // ar. // -// Preconditions: mm.activeMu must be locked for writing. -// pseg.Range().Contains(ar.Start). pmas must exist for all addresses in ar. -// ar.Length() != 0. +// Preconditions: +// * mm.activeMu must be locked for writing. +// * pseg.Range().Contains(ar.Start). +// * pmas must exist for all addresses in ar. +// * ar.Length() != 0. // // Postconditions: getPMAInternalMappingsLocked does not invalidate iterators // into mm.pmas. @@ -783,8 +802,9 @@ func (mm *MemoryManager) getPMAInternalMappingsLocked(pseg pmaIterator, ar userm // internal mappings exist. If this is not equal to ars, it returns a non-nil // error explaining why. // -// Preconditions: mm.activeMu must be locked for writing. pmas must exist for -// all addresses in ar. +// Preconditions: +// * mm.activeMu must be locked for writing. +// * pmas must exist for all addresses in ar. // // Postconditions: getVecPMAInternalMappingsLocked does not invalidate iterators // into mm.pmas. @@ -803,9 +823,12 @@ func (mm *MemoryManager) getVecPMAInternalMappingsLocked(ars usermem.AddrRangeSe // internalMappingsLocked returns internal mappings for addresses in ar. // -// Preconditions: mm.activeMu must be locked. Internal mappings must have been -// previously established for all addresses in ar. ar.Length() != 0. -// pseg.Range().Contains(ar.Start). +// Preconditions: +// * mm.activeMu must be locked. +// * Internal mappings must have been previously established for all addresses +// in ar. +// * ar.Length() != 0. +// * pseg.Range().Contains(ar.Start). func (mm *MemoryManager) internalMappingsLocked(pseg pmaIterator, ar usermem.AddrRange) safemem.BlockSeq { if checkInvariants { if !ar.WellFormed() || ar.Length() <= 0 { @@ -839,8 +862,10 @@ func (mm *MemoryManager) internalMappingsLocked(pseg pmaIterator, ar usermem.Add // vecInternalMappingsLocked returns internal mappings for addresses in ars. // -// Preconditions: mm.activeMu must be locked. Internal mappings must have been -// previously established for all addresses in ars. +// Preconditions: +// * mm.activeMu must be locked. +// * Internal mappings must have been previously established for all addresses +// in ars. func (mm *MemoryManager) vecInternalMappingsLocked(ars usermem.AddrRangeSeq) safemem.BlockSeq { var ims []safemem.Block for ; !ars.IsEmpty(); ars = ars.Tail() { @@ -969,7 +994,9 @@ func (pmaSetFunctions) Split(ar usermem.AddrRange, p pma, split usermem.Addr) (p // findOrSeekPrevUpperBoundPMA returns mm.pmas.UpperBoundSegment(addr), but may do // so by scanning linearly backward from pgap. // -// Preconditions: mm.activeMu must be locked. addr <= pgap.Start(). +// Preconditions: +// * mm.activeMu must be locked. +// * addr <= pgap.Start(). func (mm *MemoryManager) findOrSeekPrevUpperBoundPMA(addr usermem.Addr, pgap pmaGapIterator) pmaIterator { if checkInvariants { if !pgap.Ok() { @@ -1015,7 +1042,9 @@ func (pseg pmaIterator) fileRange() memmap.FileRange { return pseg.fileRangeOf(pseg.Range()) } -// Preconditions: pseg.Range().IsSupersetOf(ar). ar.Length != 0. +// Preconditions: +// * pseg.Range().IsSupersetOf(ar). +// * ar.Length != 0. func (pseg pmaIterator) fileRangeOf(ar usermem.AddrRange) memmap.FileRange { if checkInvariants { if !pseg.Ok() { diff --git a/pkg/sentry/mm/special_mappable.go b/pkg/sentry/mm/special_mappable.go index 4cdb52eb6..2dbe5b751 100644 --- a/pkg/sentry/mm/special_mappable.go +++ b/pkg/sentry/mm/special_mappable.go @@ -16,7 +16,6 @@ package mm import ( "gvisor.dev/gvisor/pkg/context" - "gvisor.dev/gvisor/pkg/refs" "gvisor.dev/gvisor/pkg/sentry/memmap" "gvisor.dev/gvisor/pkg/sentry/pgalloc" "gvisor.dev/gvisor/pkg/sentry/usage" @@ -31,7 +30,7 @@ import ( // // +stateify savable type SpecialMappable struct { - refs.AtomicRefCount + SpecialMappableRefs mfp pgalloc.MemoryFileProvider fr memmap.FileRange @@ -45,13 +44,13 @@ type SpecialMappable struct { // Preconditions: fr.Length() != 0. func NewSpecialMappable(name string, mfp pgalloc.MemoryFileProvider, fr memmap.FileRange) *SpecialMappable { m := SpecialMappable{mfp: mfp, fr: fr, name: name} - m.EnableLeakCheck("mm.SpecialMappable") + m.EnableLeakCheck() return &m } // DecRef implements refs.RefCounter.DecRef. func (m *SpecialMappable) DecRef(ctx context.Context) { - m.AtomicRefCount.DecRefWithDestructor(ctx, func(context.Context) { + m.SpecialMappableRefs.DecRef(func() { m.mfp.MemoryFile().DecRef(m.fr) }) } @@ -137,9 +136,12 @@ func (m *SpecialMappable) Length() uint64 { // NewSharedAnonMappable returns a SpecialMappable that implements the // semantics of mmap(MAP_SHARED|MAP_ANONYMOUS) and mappings of /dev/zero. // -// TODO(jamieliu): The use of SpecialMappable is a lazy code reuse hack. Linux -// uses an ephemeral file created by mm/shmem.c:shmem_zero_setup(); we should -// do the same to get non-zero device and inode IDs. +// TODO(gvisor.dev/issue/1624): Linux uses an ephemeral file created by +// mm/shmem.c:shmem_zero_setup(), and VFS2 does something analogous. VFS1 uses +// a SpecialMappable instead, incorrectly getting device and inode IDs of zero +// and causing memory for shared anonymous mappings to be allocated up-front +// instead of on first touch; this is to avoid exacerbating the fs.MountSource +// leak (b/143656263). Delete this function along with VFS1. func NewSharedAnonMappable(length uint64, mfp pgalloc.MemoryFileProvider) (*SpecialMappable, error) { if length == 0 { return nil, syserror.EINVAL diff --git a/pkg/sentry/mm/syscalls.go b/pkg/sentry/mm/syscalls.go index e74d4e1c1..a2555ba1a 100644 --- a/pkg/sentry/mm/syscalls.go +++ b/pkg/sentry/mm/syscalls.go @@ -24,7 +24,6 @@ import ( "gvisor.dev/gvisor/pkg/sentry/kernel/futex" "gvisor.dev/gvisor/pkg/sentry/limits" "gvisor.dev/gvisor/pkg/sentry/memmap" - "gvisor.dev/gvisor/pkg/sentry/pgalloc" "gvisor.dev/gvisor/pkg/syserror" "gvisor.dev/gvisor/pkg/usermem" ) @@ -93,18 +92,6 @@ func (mm *MemoryManager) MMap(ctx context.Context, opts memmap.MMapOpts) (userme } } else { opts.Offset = 0 - if !opts.Private { - if opts.MappingIdentity != nil { - return 0, syserror.EINVAL - } - m, err := NewSharedAnonMappable(opts.Length, pgalloc.MemoryFileProviderFromContext(ctx)) - if err != nil { - return 0, err - } - defer m.DecRef(ctx) - opts.MappingIdentity = m - opts.Mappable = m - } } if opts.Addr.RoundDown() != opts.Addr { @@ -166,7 +153,9 @@ func (mm *MemoryManager) MMap(ctx context.Context, opts memmap.MMapOpts) (userme // populateVMA obtains pmas for addresses in ar in the given vma, and maps them // into mm.as if it is active. // -// Preconditions: mm.mappingMu must be locked. vseg.Range().IsSupersetOf(ar). +// Preconditions: +// * mm.mappingMu must be locked. +// * vseg.Range().IsSupersetOf(ar). func (mm *MemoryManager) populateVMA(ctx context.Context, vseg vmaIterator, ar usermem.AddrRange, precommit bool) { if !vseg.ValuePtr().effectivePerms.Any() { // Linux doesn't populate inaccessible pages. See @@ -208,8 +197,9 @@ func (mm *MemoryManager) populateVMA(ctx context.Context, vseg vmaIterator, ar u // preferable to populateVMA since it unlocks mm.mappingMu before performing // expensive operations that don't require it to be locked. // -// Preconditions: mm.mappingMu must be locked for writing. -// vseg.Range().IsSupersetOf(ar). +// Preconditions: +// * mm.mappingMu must be locked for writing. +// * vseg.Range().IsSupersetOf(ar). // // Postconditions: mm.mappingMu will be unlocked. func (mm *MemoryManager) populateVMAAndUnlock(ctx context.Context, vseg vmaIterator, ar usermem.AddrRange, precommit bool) { diff --git a/pkg/sentry/mm/vma.go b/pkg/sentry/mm/vma.go index c4e1989ed..f769d8294 100644 --- a/pkg/sentry/mm/vma.go +++ b/pkg/sentry/mm/vma.go @@ -27,8 +27,9 @@ import ( "gvisor.dev/gvisor/pkg/usermem" ) -// Preconditions: mm.mappingMu must be locked for writing. opts must be valid -// as defined by the checks in MMap. +// Preconditions: +// * mm.mappingMu must be locked for writing. +// * opts must be valid as defined by the checks in MMap. func (mm *MemoryManager) createVMALocked(ctx context.Context, opts memmap.MMapOpts) (vmaIterator, usermem.AddrRange, error) { if opts.MaxPerms != opts.MaxPerms.Effective() { panic(fmt.Sprintf("Non-effective MaxPerms %s cannot be enforced", opts.MaxPerms)) @@ -260,8 +261,9 @@ func (mm *MemoryManager) mlockedBytesRangeLocked(ar usermem.AddrRange) uint64 { // // - An error that is non-nil if vmas exist for only a subset of ar. // -// Preconditions: mm.mappingMu must be locked for reading; it may be -// temporarily unlocked. ar.Length() != 0. +// Preconditions: +// * mm.mappingMu must be locked for reading; it may be temporarily unlocked. +// * ar.Length() != 0. func (mm *MemoryManager) getVMAsLocked(ctx context.Context, ar usermem.AddrRange, at usermem.AccessType, ignorePermissions bool) (vmaIterator, vmaGapIterator, error) { if checkInvariants { if !ar.WellFormed() || ar.Length() <= 0 { @@ -342,8 +344,10 @@ const guardBytes = 256 * usermem.PageSize // unmapLocked unmaps all addresses in ar and returns the resulting gap in // mm.vmas. // -// Preconditions: mm.mappingMu must be locked for writing. ar.Length() != 0. -// ar must be page-aligned. +// Preconditions: +// * mm.mappingMu must be locked for writing. +// * ar.Length() != 0. +// * ar must be page-aligned. func (mm *MemoryManager) unmapLocked(ctx context.Context, ar usermem.AddrRange) vmaGapIterator { if checkInvariants { if !ar.WellFormed() || ar.Length() <= 0 || !ar.IsPageAligned() { @@ -361,8 +365,10 @@ func (mm *MemoryManager) unmapLocked(ctx context.Context, ar usermem.AddrRange) // gap in mm.vmas. It does not remove pmas or AddressSpace mappings; clients // must do so before calling removeVMAsLocked. // -// Preconditions: mm.mappingMu must be locked for writing. ar.Length() != 0. ar -// must be page-aligned. +// Preconditions: +// * mm.mappingMu must be locked for writing. +// * ar.Length() != 0. +// * ar must be page-aligned. func (mm *MemoryManager) removeVMAsLocked(ctx context.Context, ar usermem.AddrRange) vmaGapIterator { if checkInvariants { if !ar.WellFormed() || ar.Length() <= 0 || !ar.IsPageAligned() { @@ -467,7 +473,9 @@ func (vmaSetFunctions) Split(ar usermem.AddrRange, v vma, split usermem.Addr) (v return v, v2 } -// Preconditions: vseg.ValuePtr().mappable != nil. vseg.Range().Contains(addr). +// Preconditions: +// * vseg.ValuePtr().mappable != nil. +// * vseg.Range().Contains(addr). func (vseg vmaIterator) mappableOffsetAt(addr usermem.Addr) uint64 { if checkInvariants { if !vseg.Ok() { @@ -491,8 +499,10 @@ func (vseg vmaIterator) mappableRange() memmap.MappableRange { return vseg.mappableRangeOf(vseg.Range()) } -// Preconditions: vseg.ValuePtr().mappable != nil. -// vseg.Range().IsSupersetOf(ar). ar.Length() != 0. +// Preconditions: +// * vseg.ValuePtr().mappable != nil. +// * vseg.Range().IsSupersetOf(ar). +// * ar.Length() != 0. func (vseg vmaIterator) mappableRangeOf(ar usermem.AddrRange) memmap.MappableRange { if checkInvariants { if !vseg.Ok() { @@ -514,8 +524,10 @@ func (vseg vmaIterator) mappableRangeOf(ar usermem.AddrRange) memmap.MappableRan return memmap.MappableRange{vma.off + uint64(ar.Start-vstart), vma.off + uint64(ar.End-vstart)} } -// Preconditions: vseg.ValuePtr().mappable != nil. -// vseg.mappableRange().IsSupersetOf(mr). mr.Length() != 0. +// Preconditions: +// * vseg.ValuePtr().mappable != nil. +// * vseg.mappableRange().IsSupersetOf(mr). +// * mr.Length() != 0. func (vseg vmaIterator) addrRangeOf(mr memmap.MappableRange) usermem.AddrRange { if checkInvariants { if !vseg.Ok() { @@ -540,7 +552,9 @@ func (vseg vmaIterator) addrRangeOf(mr memmap.MappableRange) usermem.AddrRange { // seekNextLowerBound returns mm.vmas.LowerBoundSegment(addr), but does so by // scanning linearly forward from vseg. // -// Preconditions: mm.mappingMu must be locked. addr >= vseg.Start(). +// Preconditions: +// * mm.mappingMu must be locked. +// * addr >= vseg.Start(). func (vseg vmaIterator) seekNextLowerBound(addr usermem.Addr) vmaIterator { if checkInvariants { if !vseg.Ok() { diff --git a/pkg/sentry/pgalloc/pgalloc.go b/pkg/sentry/pgalloc/pgalloc.go index 46d3be58c..626d1eaa4 100644 --- a/pkg/sentry/pgalloc/pgalloc.go +++ b/pkg/sentry/pgalloc/pgalloc.go @@ -507,7 +507,9 @@ func findAvailableRange(usage *usageSet, fileSize int64, length, alignment uint6 // nearest page. If this is shorter than length bytes due to an error returned // by r.ReadToBlocks(), it returns that error. // -// Preconditions: length > 0. length must be page-aligned. +// Preconditions: +// * length > 0. +// * length must be page-aligned. func (f *MemoryFile) AllocateAndFill(length uint64, kind usage.MemoryKind, r safemem.Reader) (memmap.FileRange, error) { fr, err := f.Allocate(length, kind) if err != nil { @@ -1167,8 +1169,10 @@ func (f *MemoryFile) startEvictionsLocked() bool { return startedAny } -// Preconditions: info == f.evictable[user]. !info.evicting. f.mu must be -// locked. +// Preconditions: +// * info == f.evictable[user]. +// * !info.evicting. +// * f.mu must be locked. func (f *MemoryFile) startEvictionGoroutineLocked(user EvictableMemoryUser, info *evictableMemoryUserInfo) { info.evicting = true f.evictionWG.Add(1) diff --git a/pkg/sentry/platform/interrupt/interrupt.go b/pkg/sentry/platform/interrupt/interrupt.go index 57be41647..9dfac3eae 100644 --- a/pkg/sentry/platform/interrupt/interrupt.go +++ b/pkg/sentry/platform/interrupt/interrupt.go @@ -54,8 +54,9 @@ type Forwarder struct { // } // defer f.Disable() // -// Preconditions: r must not be nil. f must not already be forwarding -// interrupts to a Receiver. +// Preconditions: +// * r must not be nil. +// * f must not already be forwarding interrupts to a Receiver. func (f *Forwarder) Enable(r Receiver) bool { if r == nil { panic("nil Receiver") diff --git a/pkg/sentry/platform/kvm/bluepill_fault.go b/pkg/sentry/platform/kvm/bluepill_fault.go index e34f46aeb..a182e4f22 100644 --- a/pkg/sentry/platform/kvm/bluepill_fault.go +++ b/pkg/sentry/platform/kvm/bluepill_fault.go @@ -98,6 +98,10 @@ func handleBluepillFault(m *machine, physical uintptr, phyRegions []physicalRegi } errno := m.setMemoryRegion(int(slot), physicalStart, length, virtualStart, flags) if errno == 0 { + // Store the physical address in the slot. This is used to + // avoid calls to handleBluepillFault in the future (see + // machine.mapPhysical). + atomic.StoreUintptr(&m.usedSlots[slot], physical) // Successfully added region; we can increment nextSlot and // allow another set to proceed here. atomic.StoreUint32(&m.nextSlot, slot+1) diff --git a/pkg/sentry/platform/kvm/bluepill_unsafe.go b/pkg/sentry/platform/kvm/bluepill_unsafe.go index bf357de1a..979be5d89 100644 --- a/pkg/sentry/platform/kvm/bluepill_unsafe.go +++ b/pkg/sentry/platform/kvm/bluepill_unsafe.go @@ -13,7 +13,7 @@ // limitations under the License. // +build go1.12 -// +build !go1.16 +// +build !go1.17 // Check go:linkname function signatures when updating Go version. diff --git a/pkg/sentry/platform/kvm/kvm_const.go b/pkg/sentry/platform/kvm/kvm_const.go index 3bf918446..5c4b18899 100644 --- a/pkg/sentry/platform/kvm/kvm_const.go +++ b/pkg/sentry/platform/kvm/kvm_const.go @@ -56,6 +56,7 @@ const ( // KVM capability options. const ( + _KVM_CAP_MAX_MEMSLOTS = 0x0a _KVM_CAP_MAX_VCPUS = 0x42 _KVM_CAP_ARM_VM_IPA_SIZE = 0xa5 _KVM_CAP_VCPU_EVENTS = 0x29 @@ -64,6 +65,7 @@ const ( // KVM limits. const ( + _KVM_NR_MEMSLOTS = 0x100 _KVM_NR_VCPUS = 0xff _KVM_NR_INTERRUPTS = 0x100 _KVM_NR_CPUID_ENTRIES = 0x100 diff --git a/pkg/sentry/platform/kvm/machine.go b/pkg/sentry/platform/kvm/machine.go index 6c54712d1..372a4cbd7 100644 --- a/pkg/sentry/platform/kvm/machine.go +++ b/pkg/sentry/platform/kvm/machine.go @@ -43,9 +43,6 @@ type machine struct { // kernel is the set of global structures. kernel ring0.Kernel - // mappingCache is used for mapPhysical. - mappingCache sync.Map - // mu protects vCPUs. mu sync.RWMutex @@ -63,6 +60,12 @@ type machine struct { // maxVCPUs is the maximum number of vCPUs supported by the machine. maxVCPUs int + // maxSlots is the maximum number of memory slots supported by the machine. + maxSlots int + + // usedSlots is the set of used physical addresses (sorted). + usedSlots []uintptr + // nextID is the next vCPU ID. nextID uint32 } @@ -184,6 +187,7 @@ func newMachine(vm int) (*machine, error) { PageTables: pagetables.New(newAllocator()), }) + // Pull the maximum vCPUs. maxVCPUs, _, errno := syscall.RawSyscall(syscall.SYS_IOCTL, uintptr(m.fd), _KVM_CHECK_EXTENSION, _KVM_CAP_MAX_VCPUS) if errno != 0 { m.maxVCPUs = _KVM_NR_VCPUS @@ -191,11 +195,19 @@ func newMachine(vm int) (*machine, error) { m.maxVCPUs = int(maxVCPUs) } log.Debugf("The maximum number of vCPUs is %d.", m.maxVCPUs) - - // Create the vCPUs map/slices. m.vCPUsByTID = make(map[uint64]*vCPU) m.vCPUsByID = make([]*vCPU, m.maxVCPUs) + // Pull the maximum slots. + maxSlots, _, errno := syscall.RawSyscall(syscall.SYS_IOCTL, uintptr(m.fd), _KVM_CHECK_EXTENSION, _KVM_CAP_MAX_MEMSLOTS) + if errno != 0 { + m.maxSlots = _KVM_NR_MEMSLOTS + } else { + m.maxSlots = int(maxSlots) + } + log.Debugf("The maximum number of slots is %d.", m.maxSlots) + m.usedSlots = make([]uintptr, m.maxSlots) + // Apply the physical mappings. Note that these mappings may point to // guest physical addresses that are not actually available. These // physical pages are mapped on demand, see kernel_unsafe.go. @@ -272,6 +284,20 @@ func newMachine(vm int) (*machine, error) { return m, nil } +// hasSlot returns true iff the given address is mapped. +// +// This must be done via a linear scan. +// +//go:nosplit +func (m *machine) hasSlot(physical uintptr) bool { + for i := 0; i < len(m.usedSlots); i++ { + if p := atomic.LoadUintptr(&m.usedSlots[i]); p == physical { + return true + } + } + return false +} + // mapPhysical checks for the mapping of a physical range, and installs one if // not available. This attempts to be efficient for calls in the hot path. // @@ -286,8 +312,8 @@ func (m *machine) mapPhysical(physical, length uintptr, phyRegions []physicalReg panic("mapPhysical on unknown physical address") } - if _, ok := m.mappingCache.LoadOrStore(physicalStart, true); !ok { - // Not present in the cache; requires setting the slot. + // Is this already mapped? Check the usedSlots. + if !m.hasSlot(physicalStart) { if _, ok := handleBluepillFault(m, physical, phyRegions, flags); !ok { panic("handleBluepillFault failed") } diff --git a/pkg/sentry/platform/kvm/machine_arm64_unsafe.go b/pkg/sentry/platform/kvm/machine_arm64_unsafe.go index 905712076..537419657 100644 --- a/pkg/sentry/platform/kvm/machine_arm64_unsafe.go +++ b/pkg/sentry/platform/kvm/machine_arm64_unsafe.go @@ -79,7 +79,7 @@ func (c *vCPU) initArchState() error { } // tcr_el1 - data = _TCR_TXSZ_VA48 | _TCR_CACHE_FLAGS | _TCR_SHARED | _TCR_TG_FLAGS | _TCR_ASID16 | _TCR_IPS_40BITS + data = _TCR_TXSZ_VA48 | _TCR_CACHE_FLAGS | _TCR_SHARED | _TCR_TG_FLAGS | _TCR_ASID16 | _TCR_IPS_40BITS | _TCR_A1 reg.id = _KVM_ARM64_REGS_TCR_EL1 if err := c.setOneRegister(®); err != nil { return err @@ -103,7 +103,7 @@ func (c *vCPU) initArchState() error { c.SetTtbr0Kvm(uintptr(data)) // ttbr1_el1 - data = c.machine.kernel.PageTables.TTBR1_EL1(false, 0) + data = c.machine.kernel.PageTables.TTBR1_EL1(false, 1) reg.id = _KVM_ARM64_REGS_TTBR1_EL1 if err := c.setOneRegister(®); err != nil { diff --git a/pkg/sentry/platform/kvm/machine_unsafe.go b/pkg/sentry/platform/kvm/machine_unsafe.go index 9f86f6a7a..607c82156 100644 --- a/pkg/sentry/platform/kvm/machine_unsafe.go +++ b/pkg/sentry/platform/kvm/machine_unsafe.go @@ -13,7 +13,7 @@ // limitations under the License. // +build go1.12 -// +build !go1.16 +// +build !go1.17 // Check go:linkname function signatures when updating Go version. diff --git a/pkg/sentry/platform/kvm/virtual_map.go b/pkg/sentry/platform/kvm/virtual_map.go index c8897d34f..4dcdbf8a7 100644 --- a/pkg/sentry/platform/kvm/virtual_map.go +++ b/pkg/sentry/platform/kvm/virtual_map.go @@ -34,7 +34,7 @@ type virtualRegion struct { } // mapsLine matches a single line from /proc/PID/maps. -var mapsLine = regexp.MustCompile("([0-9a-f]+)-([0-9a-f]+) ([r-][w-][x-][sp]) ([0-9a-f]+) [0-9a-f]{2}:[0-9a-f]{2,} [0-9]+\\s+(.*)") +var mapsLine = regexp.MustCompile("([0-9a-f]+)-([0-9a-f]+) ([r-][w-][x-][sp]) ([0-9a-f]+) [0-9a-f]{2,3}:[0-9a-f]{2,} [0-9]+\\s+(.*)") // excludeRegion returns true if these regions should be excluded from the // physical map. Virtual regions need to be excluded if get_user_pages will diff --git a/pkg/sentry/platform/platform.go b/pkg/sentry/platform/platform.go index ba031516a..530e779b0 100644 --- a/pkg/sentry/platform/platform.go +++ b/pkg/sentry/platform/platform.go @@ -245,14 +245,19 @@ type AddressSpace interface { // physical memory) to the mapping. The precommit flag is advisory and // implementations may choose to ignore it. // - // Preconditions: addr and fr must be page-aligned. fr.Length() > 0. - // at.Any() == true. At least one reference must be held on all pages in - // fr, and must continue to be held as long as pages are mapped. + // Preconditions: + // * addr and fr must be page-aligned. + // * fr.Length() > 0. + // * at.Any() == true. + // * At least one reference must be held on all pages in fr, and must + // continue to be held as long as pages are mapped. MapFile(addr usermem.Addr, f memmap.File, fr memmap.FileRange, at usermem.AccessType, precommit bool) error // Unmap unmaps the given range. // - // Preconditions: addr is page-aligned. length > 0. + // Preconditions: + // * addr is page-aligned. + // * length > 0. Unmap(addr usermem.Addr, length uint64) // Release releases this address space. After releasing, a new AddressSpace diff --git a/pkg/sentry/platform/ptrace/BUILD b/pkg/sentry/platform/ptrace/BUILD index e04165fbf..fc43cc3c0 100644 --- a/pkg/sentry/platform/ptrace/BUILD +++ b/pkg/sentry/platform/ptrace/BUILD @@ -30,7 +30,6 @@ go_library( "//pkg/safecopy", "//pkg/seccomp", "//pkg/sentry/arch", - "//pkg/sentry/hostcpu", "//pkg/sentry/memmap", "//pkg/sentry/platform", "//pkg/sentry/platform/interrupt", diff --git a/pkg/sentry/platform/ptrace/filters.go b/pkg/sentry/platform/ptrace/filters.go index 1e07cfd0d..b0970e356 100644 --- a/pkg/sentry/platform/ptrace/filters.go +++ b/pkg/sentry/platform/ptrace/filters.go @@ -24,10 +24,9 @@ import ( // SyscallFilters returns syscalls made exclusively by the ptrace platform. func (*PTrace) SyscallFilters() seccomp.SyscallRules { return seccomp.SyscallRules{ - unix.SYS_GETCPU: {}, - unix.SYS_SCHED_SETAFFINITY: {}, - syscall.SYS_PTRACE: {}, - syscall.SYS_TGKILL: {}, - syscall.SYS_WAIT4: {}, + unix.SYS_GETCPU: {}, + syscall.SYS_PTRACE: {}, + syscall.SYS_TGKILL: {}, + syscall.SYS_WAIT4: {}, } } diff --git a/pkg/sentry/platform/ptrace/subprocess.go b/pkg/sentry/platform/ptrace/subprocess.go index e1d54d8a2..812ab80ef 100644 --- a/pkg/sentry/platform/ptrace/subprocess.go +++ b/pkg/sentry/platform/ptrace/subprocess.go @@ -518,11 +518,6 @@ func (s *subprocess) switchToApp(c *context, ac arch.Context) bool { } defer c.interrupt.Disable() - // Ensure that the CPU set is bound appropriately; this makes the - // emulation below several times faster, presumably by avoiding - // interprocessor wakeups and by simplifying the schedule. - t.bind() - // Set registers. if err := t.setRegs(regs); err != nil { panic(fmt.Sprintf("ptrace set regs (%+v) failed: %v", regs, err)) diff --git a/pkg/sentry/platform/ptrace/subprocess_amd64.go b/pkg/sentry/platform/ptrace/subprocess_amd64.go index 84b699f0d..020bbda79 100644 --- a/pkg/sentry/platform/ptrace/subprocess_amd64.go +++ b/pkg/sentry/platform/ptrace/subprocess_amd64.go @@ -201,7 +201,7 @@ func appendArchSeccompRules(rules []seccomp.RuleSet, defaultAction linux.BPFActi seccomp.RuleSet{ Rules: seccomp.SyscallRules{ syscall.SYS_ARCH_PRCTL: []seccomp.Rule{ - {seccomp.AllowValue(linux.ARCH_SET_CPUID), seccomp.AllowValue(0)}, + {seccomp.EqualTo(linux.ARCH_SET_CPUID), seccomp.EqualTo(0)}, }, }, Action: linux.SECCOMP_RET_ALLOW, diff --git a/pkg/sentry/platform/ptrace/subprocess_linux.go b/pkg/sentry/platform/ptrace/subprocess_linux.go index 2ce528601..8548853da 100644 --- a/pkg/sentry/platform/ptrace/subprocess_linux.go +++ b/pkg/sentry/platform/ptrace/subprocess_linux.go @@ -80,9 +80,9 @@ func attachedThread(flags uintptr, defaultAction linux.BPFAction) (*thread, erro Rules: seccomp.SyscallRules{ syscall.SYS_CLONE: []seccomp.Rule{ // Allow creation of new subprocesses (used by the master). - {seccomp.AllowValue(syscall.CLONE_FILES | syscall.SIGKILL)}, + {seccomp.EqualTo(syscall.CLONE_FILES | syscall.SIGKILL)}, // Allow creation of new threads within a single address space (used by addresss spaces). - {seccomp.AllowValue( + {seccomp.EqualTo( syscall.CLONE_FILES | syscall.CLONE_FS | syscall.CLONE_SIGHAND | @@ -97,14 +97,14 @@ func attachedThread(flags uintptr, defaultAction linux.BPFAction) (*thread, erro // For the stub prctl dance (all). syscall.SYS_PRCTL: []seccomp.Rule{ - {seccomp.AllowValue(syscall.PR_SET_PDEATHSIG), seccomp.AllowValue(syscall.SIGKILL)}, + {seccomp.EqualTo(syscall.PR_SET_PDEATHSIG), seccomp.EqualTo(syscall.SIGKILL)}, }, syscall.SYS_GETPPID: {}, // For the stub to stop itself (all). syscall.SYS_GETPID: {}, syscall.SYS_KILL: []seccomp.Rule{ - {seccomp.AllowAny{}, seccomp.AllowValue(syscall.SIGSTOP)}, + {seccomp.MatchAny{}, seccomp.EqualTo(syscall.SIGSTOP)}, }, // Injected to support the address space operations. @@ -115,7 +115,7 @@ func attachedThread(flags uintptr, defaultAction linux.BPFAction) (*thread, erro }) } rules = appendArchSeccompRules(rules, defaultAction) - instrs, err := seccomp.BuildProgram(rules, defaultAction) + instrs, err := seccomp.BuildProgram(rules, defaultAction, defaultAction) if err != nil { return nil, err } diff --git a/pkg/sentry/platform/ptrace/subprocess_linux_unsafe.go b/pkg/sentry/platform/ptrace/subprocess_linux_unsafe.go index 245b20722..533e45497 100644 --- a/pkg/sentry/platform/ptrace/subprocess_linux_unsafe.go +++ b/pkg/sentry/platform/ptrace/subprocess_linux_unsafe.go @@ -18,29 +18,12 @@ package ptrace import ( - "sync/atomic" "syscall" "unsafe" - "golang.org/x/sys/unix" "gvisor.dev/gvisor/pkg/abi/linux" - "gvisor.dev/gvisor/pkg/sentry/hostcpu" - "gvisor.dev/gvisor/pkg/sync" ) -// maskPool contains reusable CPU masks for setting affinity. Unfortunately, -// runtime.NumCPU doesn't actually record the number of CPUs on the system, it -// just records the number of CPUs available in the scheduler affinity set at -// startup. This may a) change over time and b) gives a number far lower than -// the maximum indexable CPU. To prevent lots of allocation in the hot path, we -// use a pool to store large masks that we can reuse during bind. -var maskPool = sync.Pool{ - New: func() interface{} { - const maxCPUs = 1024 // Not a hard limit; see below. - return make([]uintptr, maxCPUs/64) - }, -} - // unmaskAllSignals unmasks all signals on the current thread. // //go:nosplit @@ -49,47 +32,3 @@ func unmaskAllSignals() syscall.Errno { _, _, errno := syscall.RawSyscall6(syscall.SYS_RT_SIGPROCMASK, linux.SIG_SETMASK, uintptr(unsafe.Pointer(&set)), 0, linux.SignalSetSize, 0, 0) return errno } - -// setCPU sets the CPU affinity. -func (t *thread) setCPU(cpu uint32) error { - mask := maskPool.Get().([]uintptr) - n := int(cpu / 64) - v := uintptr(1 << uintptr(cpu%64)) - if n >= len(mask) { - // See maskPool note above. We've actually exceeded the number - // of available cores. Grow the mask and return it. - mask = make([]uintptr, n+1) - } - mask[n] |= v - if _, _, errno := syscall.RawSyscall( - unix.SYS_SCHED_SETAFFINITY, - uintptr(t.tid), - uintptr(len(mask)*8), - uintptr(unsafe.Pointer(&mask[0]))); errno != 0 { - return errno - } - mask[n] &^= v - maskPool.Put(mask) - return nil -} - -// bind attempts to ensure that the thread is on the same CPU as the current -// thread. This provides no guarantees as it is fundamentally a racy operation: -// CPU sets may change and we may be rescheduled in the middle of this -// operation. As a result, no failures are reported. -// -// Precondition: the current runtime thread should be locked. -func (t *thread) bind() { - currentCPU := hostcpu.GetCPU() - - if oldCPU := atomic.SwapUint32(&t.cpu, currentCPU); oldCPU != currentCPU { - // Set the affinity on the thread and save the CPU for next - // round; we don't expect CPUs to bounce around too frequently. - // - // (It's worth noting that we could move CPUs between this point - // and when the tracee finishes executing. But that would be - // roughly the status quo anyways -- we're just maximizing our - // chances of colocation, not guaranteeing it.) - t.setCPU(currentCPU) - } -} diff --git a/pkg/sentry/platform/ptrace/subprocess_unsafe.go b/pkg/sentry/platform/ptrace/subprocess_unsafe.go index 0bee995e4..7ee20d89a 100644 --- a/pkg/sentry/platform/ptrace/subprocess_unsafe.go +++ b/pkg/sentry/platform/ptrace/subprocess_unsafe.go @@ -13,7 +13,7 @@ // limitations under the License. // +build go1.12 -// +build !go1.16 +// +build !go1.17 // Check go:linkname function signatures when updating Go version. diff --git a/pkg/sentry/platform/ring0/defs_arm64.go b/pkg/sentry/platform/ring0/defs_arm64.go index 0e2ab716c..508236e46 100644 --- a/pkg/sentry/platform/ring0/defs_arm64.go +++ b/pkg/sentry/platform/ring0/defs_arm64.go @@ -77,6 +77,9 @@ type CPUArchState struct { // lazyVFP is the value of cpacr_el1. lazyVFP uintptr + + // appASID is the asid value of guest application. + appASID uintptr } // ErrorCode returns the last error code. diff --git a/pkg/sentry/platform/ring0/entry_arm64.s b/pkg/sentry/platform/ring0/entry_arm64.s index 9d29b7168..5f63cbd45 100644 --- a/pkg/sentry/platform/ring0/entry_arm64.s +++ b/pkg/sentry/platform/ring0/entry_arm64.s @@ -27,7 +27,9 @@ // ERET returns using the ELR and SPSR for the current exception level. #define ERET() \ - WORD $0xd69f03e0 + WORD $0xd69f03e0; \ + DSB $7; \ + ISB $15; // RSV_REG is a register that holds el1 information temporarily. #define RSV_REG R18_PLATFORM @@ -300,17 +302,23 @@ // SWITCH_TO_APP_PAGETABLE sets a new pagetable for a container application. #define SWITCH_TO_APP_PAGETABLE(from) \ - MOVD CPU_TTBR0_APP(from), RSV_REG; \ - WORD $0xd5182012; \ // MSR R18, TTBR0_EL1 + MRS TTBR1_EL1, R0; \ + MOVD CPU_APP_ASID(from), R1; \ + BFI $48, R1, $16, R0; \ + MSR R0, TTBR1_EL1; \ // set the ASID in TTBR1_EL1 (since TCR.A1 is set) ISB $15; \ - DSB $15; + MOVD CPU_TTBR0_APP(from), RSV_REG; \ + MSR RSV_REG, TTBR0_EL1; // SWITCH_TO_KVM_PAGETABLE sets the kvm pagetable. #define SWITCH_TO_KVM_PAGETABLE(from) \ - MOVD CPU_TTBR0_KVM(from), RSV_REG; \ - WORD $0xd5182012; \ // MSR R18, TTBR0_EL1 + MRS TTBR1_EL1, R0; \ + MOVD $1, R1; \ + BFI $48, R1, $16, R0; \ + MSR R0, TTBR1_EL1; \ ISB $15; \ - DSB $15; + MOVD CPU_TTBR0_KVM(from), RSV_REG; \ + MSR RSV_REG, TTBR0_EL1; #define VFP_ENABLE \ MOVD $FPEN_ENABLE, R0; \ @@ -326,23 +334,20 @@ #define KERNEL_ENTRY_FROM_EL0 \ SUB $16, RSP, RSP; \ // step1, save r18, r9 into kernel temporary stack. STP (RSV_REG, RSV_REG_APP), 16*0(RSP); \ - WORD $0xd538d092; \ //MRS TPIDR_EL1, R18, step2, switch user pagetable. - SWITCH_TO_KVM_PAGETABLE(RSV_REG); \ - WORD $0xd538d092; \ //MRS TPIDR_EL1, R18 - MOVD CPU_APP_ADDR(RSV_REG), RSV_REG_APP; \ // step3, load app context pointer. - REGISTERS_SAVE(RSV_REG_APP, 0); \ // step4, save app context. + WORD $0xd538d092; \ // MRS TPIDR_EL1, R18 + MOVD CPU_APP_ADDR(RSV_REG), RSV_REG_APP; \ // step2, load app context pointer. + REGISTERS_SAVE(RSV_REG_APP, 0); \ // step3, save app context. MOVD RSV_REG_APP, R20; \ LDP 16*0(RSP), (RSV_REG, RSV_REG_APP); \ ADD $16, RSP, RSP; \ MOVD RSV_REG, PTRACE_R18(R20); \ MOVD RSV_REG_APP, PTRACE_R9(R20); \ - MOVD R20, RSV_REG_APP; \ WORD $0xd5384003; \ // MRS SPSR_EL1, R3 - MOVD R3, PTRACE_PSTATE(RSV_REG_APP); \ + MOVD R3, PTRACE_PSTATE(R20); \ MRS ELR_EL1, R3; \ - MOVD R3, PTRACE_PC(RSV_REG_APP); \ + MOVD R3, PTRACE_PC(R20); \ WORD $0xd5384103; \ // MRS SP_EL0, R3 - MOVD R3, PTRACE_SP(RSV_REG_APP); + MOVD R3, PTRACE_SP(R20); // KERNEL_ENTRY_FROM_EL1 is the entry code of the vcpu from el1 to el1. #define KERNEL_ENTRY_FROM_EL1 \ @@ -357,6 +362,13 @@ MOVD R4, CPU_REGISTERS+PTRACE_SP(RSV_REG); \ LOAD_KERNEL_STACK(RSV_REG); // Load the temporary stack. +// storeAppASID writes the application's asid value. +TEXT ·storeAppASID(SB),NOSPLIT,$0-8 + MOVD asid+0(FP), R1 + MRS TPIDR_EL1, RSV_REG + MOVD R1, CPU_APP_ASID(RSV_REG) + RET + // Halt halts execution. TEXT ·Halt(SB),NOSPLIT,$0 // Clear bluepill. @@ -414,7 +426,7 @@ TEXT ·Current(SB),NOSPLIT,$0-8 MOVD R8, ret+0(FP) RET -#define STACK_FRAME_SIZE 16 +#define STACK_FRAME_SIZE 32 // kernelExitToEl0 is the entrypoint for application in guest_el0. // Prepare the vcpu environment for container application. @@ -458,15 +470,16 @@ TEXT ·kernelExitToEl0(SB),NOSPLIT,$0 SUB $STACK_FRAME_SIZE, RSP, RSP STP (RSV_REG, RSV_REG_APP), 16*0(RSP) + STP (R0, R1), 16*1(RSP) WORD $0xd538d092 //MRS TPIDR_EL1, R18 SWITCH_TO_APP_PAGETABLE(RSV_REG) + LDP 16*1(RSP), (R0, R1) LDP 16*0(RSP), (RSV_REG, RSV_REG_APP) ADD $STACK_FRAME_SIZE, RSP, RSP - ISB $15 ERET() // kernelExitToEl1 is the entrypoint for sentry in guest_el1. @@ -482,6 +495,9 @@ TEXT ·kernelExitToEl1(SB),NOSPLIT,$0 MOVD CPU_REGISTERS+PTRACE_SP(RSV_REG), R1 MOVD R1, RSP + SWITCH_TO_KVM_PAGETABLE(RSV_REG) + MRS TPIDR_EL1, RSV_REG + REGISTERS_LOAD(RSV_REG, CPU_REGISTERS) MOVD CPU_REGISTERS+PTRACE_R9(RSV_REG), RSV_REG_APP diff --git a/pkg/sentry/platform/ring0/kernel_arm64.go b/pkg/sentry/platform/ring0/kernel_arm64.go index d0afa1aaa..14774c5db 100644 --- a/pkg/sentry/platform/ring0/kernel_arm64.go +++ b/pkg/sentry/platform/ring0/kernel_arm64.go @@ -53,6 +53,11 @@ func IsCanonical(addr uint64) bool { //go:nosplit func (c *CPU) SwitchToUser(switchOpts SwitchOpts) (vector Vector) { + storeAppASID(uintptr(switchOpts.UserASID)) + if switchOpts.Flush { + FlushTlbAll() + } + regs := switchOpts.Registers regs.Pstate &= ^uint64(PsrFlagsClear) diff --git a/pkg/sentry/platform/ring0/lib_arm64.go b/pkg/sentry/platform/ring0/lib_arm64.go index 00e52c8af..2f1abcb0f 100644 --- a/pkg/sentry/platform/ring0/lib_arm64.go +++ b/pkg/sentry/platform/ring0/lib_arm64.go @@ -16,6 +16,15 @@ package ring0 +// storeAppASID writes the application's asid value. +func storeAppASID(asid uintptr) + +// LocalFlushTlbAll same as FlushTlbAll, but only applies to the calling CPU. +func LocalFlushTlbAll() + +// FlushTlbAll flush all tlb. +func FlushTlbAll() + // CPACREL1 returns the value of the CPACR_EL1 register. func CPACREL1() (value uintptr) diff --git a/pkg/sentry/platform/ring0/lib_arm64.s b/pkg/sentry/platform/ring0/lib_arm64.s index 86bfbe46f..8aabf7d0e 100644 --- a/pkg/sentry/platform/ring0/lib_arm64.s +++ b/pkg/sentry/platform/ring0/lib_arm64.s @@ -15,6 +15,20 @@ #include "funcdata.h" #include "textflag.h" +TEXT ·LocalFlushTlbAll(SB),NOSPLIT,$0 + DSB $6 // dsb(nshst) + WORD $0xd508871f // __tlbi(vmalle1) + DSB $7 // dsb(nsh) + ISB $15 + RET + +TEXT ·FlushTlbAll(SB),NOSPLIT,$0 + DSB $10 // dsb(ishst) + WORD $0xd508831f // __tlbi(vmalle1is) + DSB $11 // dsb(ish) + ISB $15 + RET + TEXT ·GetTLS(SB),NOSPLIT,$0-8 MRS TPIDR_EL0, R1 MOVD R1, ret+0(FP) diff --git a/pkg/sentry/platform/ring0/offsets_arm64.go b/pkg/sentry/platform/ring0/offsets_arm64.go index f3de962f0..1d86b4bcf 100644 --- a/pkg/sentry/platform/ring0/offsets_arm64.go +++ b/pkg/sentry/platform/ring0/offsets_arm64.go @@ -41,6 +41,7 @@ func Emit(w io.Writer) { fmt.Fprintf(w, "#define CPU_VECTOR_CODE 0x%02x\n", reflect.ValueOf(&c.vecCode).Pointer()-reflect.ValueOf(c).Pointer()) fmt.Fprintf(w, "#define CPU_APP_ADDR 0x%02x\n", reflect.ValueOf(&c.appAddr).Pointer()-reflect.ValueOf(c).Pointer()) fmt.Fprintf(w, "#define CPU_LAZY_VFP 0x%02x\n", reflect.ValueOf(&c.lazyVFP).Pointer()-reflect.ValueOf(c).Pointer()) + fmt.Fprintf(w, "#define CPU_APP_ASID 0x%02x\n", reflect.ValueOf(&c.appASID).Pointer()-reflect.ValueOf(c).Pointer()) fmt.Fprintf(w, "\n// Bits.\n") fmt.Fprintf(w, "#define _KERNEL_FLAGS 0x%02x\n", KernelFlagsSet) diff --git a/pkg/sentry/socket/BUILD b/pkg/sentry/socket/BUILD index c0fd3425b..a3f775d15 100644 --- a/pkg/sentry/socket/BUILD +++ b/pkg/sentry/socket/BUILD @@ -10,6 +10,7 @@ go_library( "//pkg/abi/linux", "//pkg/binary", "//pkg/context", + "//pkg/marshal", "//pkg/sentry/device", "//pkg/sentry/fs", "//pkg/sentry/fs/fsutil", @@ -20,6 +21,5 @@ go_library( "//pkg/syserr", "//pkg/tcpip", "//pkg/usermem", - "//tools/go_marshal/marshal", ], ) diff --git a/pkg/sentry/socket/hostinet/BUILD b/pkg/sentry/socket/hostinet/BUILD index 8448ea401..b6ebe29d6 100644 --- a/pkg/sentry/socket/hostinet/BUILD +++ b/pkg/sentry/socket/hostinet/BUILD @@ -21,6 +21,8 @@ go_library( "//pkg/context", "//pkg/fdnotifier", "//pkg/log", + "//pkg/marshal", + "//pkg/marshal/primitive", "//pkg/safemem", "//pkg/sentry/arch", "//pkg/sentry/device", @@ -43,8 +45,6 @@ go_library( "//pkg/tcpip/stack", "//pkg/usermem", "//pkg/waiter", - "//tools/go_marshal/marshal", - "//tools/go_marshal/primitive", "@org_golang_x_sys//unix:go_default_library", ], ) diff --git a/pkg/sentry/socket/hostinet/socket.go b/pkg/sentry/socket/hostinet/socket.go index 242e6bf76..7d3c4a01c 100644 --- a/pkg/sentry/socket/hostinet/socket.go +++ b/pkg/sentry/socket/hostinet/socket.go @@ -24,6 +24,8 @@ import ( "gvisor.dev/gvisor/pkg/context" "gvisor.dev/gvisor/pkg/fdnotifier" "gvisor.dev/gvisor/pkg/log" + "gvisor.dev/gvisor/pkg/marshal" + "gvisor.dev/gvisor/pkg/marshal/primitive" "gvisor.dev/gvisor/pkg/safemem" "gvisor.dev/gvisor/pkg/sentry/arch" "gvisor.dev/gvisor/pkg/sentry/fs" @@ -36,8 +38,6 @@ import ( "gvisor.dev/gvisor/pkg/syserror" "gvisor.dev/gvisor/pkg/usermem" "gvisor.dev/gvisor/pkg/waiter" - "gvisor.dev/gvisor/tools/go_marshal/marshal" - "gvisor.dev/gvisor/tools/go_marshal/primitive" ) const ( diff --git a/pkg/sentry/socket/hostinet/socket_vfs2.go b/pkg/sentry/socket/hostinet/socket_vfs2.go index 8a1d52ebf..97bc6027f 100644 --- a/pkg/sentry/socket/hostinet/socket_vfs2.go +++ b/pkg/sentry/socket/hostinet/socket_vfs2.go @@ -97,11 +97,6 @@ func (s *socketVFS2) Ioctl(ctx context.Context, uio usermem.IO, args arch.Syscal return ioctl(ctx, s.fd, uio, args) } -// Allocate implements vfs.FileDescriptionImpl.Allocate. -func (s *socketVFS2) Allocate(ctx context.Context, mode, offset, length uint64) error { - return syserror.ENODEV -} - // PRead implements vfs.FileDescriptionImpl.PRead. func (s *socketVFS2) PRead(ctx context.Context, dst usermem.IOSequence, offset int64, opts vfs.ReadOptions) (int64, error) { return 0, syserror.ESPIPE diff --git a/pkg/sentry/socket/hostinet/stack.go b/pkg/sentry/socket/hostinet/stack.go index 3d3fabb30..faa61160e 100644 --- a/pkg/sentry/socket/hostinet/stack.go +++ b/pkg/sentry/socket/hostinet/stack.go @@ -123,18 +123,11 @@ func (s *Stack) Configure() error { s.netSNMPFile = f } - s.ipv4Forwarding = false - if ipForwarding, err := ioutil.ReadFile("/proc/sys/net/ipv4/ip_forward"); err == nil { - s.ipv4Forwarding = strings.TrimSpace(string(ipForwarding)) != "0" + s.ipv6Forwarding = false + if ipForwarding, err := ioutil.ReadFile("/proc/sys/net/ipv6/conf/all/forwarding"); err == nil { + s.ipv6Forwarding = strings.TrimSpace(string(ipForwarding)) != "0" } else { - log.Warningf("Failed to read if IPv4 forwarding is enabled, setting to false") - } - - s.ipv4Forwarding = false - if ipForwarding, err := ioutil.ReadFile("/proc/sys/net/ipv4/ip_forward"); err == nil { - s.ipv4Forwarding = strings.TrimSpace(string(ipForwarding)) != "0" - } else { - log.Warningf("Failed to read if IPv4 forwarding is enabled, setting to false") + log.Warningf("Failed to read if ipv6 forwarding is enabled, setting to false") } return nil diff --git a/pkg/sentry/socket/netfilter/BUILD b/pkg/sentry/socket/netfilter/BUILD index 721094bbf..8aea0200f 100644 --- a/pkg/sentry/socket/netfilter/BUILD +++ b/pkg/sentry/socket/netfilter/BUILD @@ -6,6 +6,8 @@ go_library( name = "netfilter", srcs = [ "extensions.go", + "ipv4.go", + "ipv6.go", "netfilter.go", "owner_matcher.go", "targets.go", diff --git a/pkg/sentry/socket/netfilter/ipv4.go b/pkg/sentry/socket/netfilter/ipv4.go new file mode 100644 index 000000000..e4c55a100 --- /dev/null +++ b/pkg/sentry/socket/netfilter/ipv4.go @@ -0,0 +1,260 @@ +// Copyright 2020 The gVisor Authors. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package netfilter + +import ( + "bytes" + "fmt" + + "gvisor.dev/gvisor/pkg/abi/linux" + "gvisor.dev/gvisor/pkg/binary" + "gvisor.dev/gvisor/pkg/syserr" + "gvisor.dev/gvisor/pkg/tcpip" + "gvisor.dev/gvisor/pkg/tcpip/header" + "gvisor.dev/gvisor/pkg/tcpip/stack" + "gvisor.dev/gvisor/pkg/usermem" +) + +// emptyIPv4Filter is for comparison with a rule's filters to determine whether +// it is also empty. It is immutable. +var emptyIPv4Filter = stack.IPHeaderFilter{ + Dst: "\x00\x00\x00\x00", + DstMask: "\x00\x00\x00\x00", + Src: "\x00\x00\x00\x00", + SrcMask: "\x00\x00\x00\x00", +} + +// convertNetstackToBinary4 converts the iptables as stored in netstack to the +// format expected by the iptables tool. Linux stores each table as a binary +// blob that can only be traversed by parsing a little data, reading some +// offsets, jumping to those offsets, parsing again, etc. +func convertNetstackToBinary4(stk *stack.Stack, tablename linux.TableName) (linux.KernelIPTGetEntries, linux.IPTGetinfo, error) { + // The table name has to fit in the struct. + if linux.XT_TABLE_MAXNAMELEN < len(tablename) { + return linux.KernelIPTGetEntries{}, linux.IPTGetinfo{}, fmt.Errorf("table name %q too long", tablename) + } + + table, ok := stk.IPTables().GetTable(tablename.String(), false) + if !ok { + return linux.KernelIPTGetEntries{}, linux.IPTGetinfo{}, fmt.Errorf("couldn't find table %q", tablename) + } + + // Setup the info struct. + entries, info := getEntries4(table, tablename) + return entries, info, nil +} + +func getEntries4(table stack.Table, tablename linux.TableName) (linux.KernelIPTGetEntries, linux.IPTGetinfo) { + var info linux.IPTGetinfo + var entries linux.KernelIPTGetEntries + copy(info.Name[:], tablename[:]) + copy(entries.Name[:], info.Name[:]) + info.ValidHooks = table.ValidHooks() + + for ruleIdx, rule := range table.Rules { + nflog("convert to binary: current offset: %d", entries.Size) + + setHooksAndUnderflow(&info, table, entries.Size, ruleIdx) + // Each rule corresponds to an entry. + entry := linux.KernelIPTEntry{ + Entry: linux.IPTEntry{ + IP: linux.IPTIP{ + Protocol: uint16(rule.Filter.Protocol), + }, + NextOffset: linux.SizeOfIPTEntry, + TargetOffset: linux.SizeOfIPTEntry, + }, + } + copy(entry.Entry.IP.Dst[:], rule.Filter.Dst) + copy(entry.Entry.IP.DstMask[:], rule.Filter.DstMask) + copy(entry.Entry.IP.Src[:], rule.Filter.Src) + copy(entry.Entry.IP.SrcMask[:], rule.Filter.SrcMask) + copy(entry.Entry.IP.OutputInterface[:], rule.Filter.OutputInterface) + copy(entry.Entry.IP.OutputInterfaceMask[:], rule.Filter.OutputInterfaceMask) + if rule.Filter.DstInvert { + entry.Entry.IP.InverseFlags |= linux.IPT_INV_DSTIP + } + if rule.Filter.SrcInvert { + entry.Entry.IP.InverseFlags |= linux.IPT_INV_SRCIP + } + if rule.Filter.OutputInterfaceInvert { + entry.Entry.IP.InverseFlags |= linux.IPT_INV_VIA_OUT + } + + for _, matcher := range rule.Matchers { + // Serialize the matcher and add it to the + // entry. + serialized := marshalMatcher(matcher) + nflog("convert to binary: matcher serialized as: %v", serialized) + if len(serialized)%8 != 0 { + panic(fmt.Sprintf("matcher %T is not 64-bit aligned", matcher)) + } + entry.Elems = append(entry.Elems, serialized...) + entry.Entry.NextOffset += uint16(len(serialized)) + entry.Entry.TargetOffset += uint16(len(serialized)) + } + + // Serialize and append the target. + serialized := marshalTarget(rule.Target) + if len(serialized)%8 != 0 { + panic(fmt.Sprintf("target %T is not 64-bit aligned", rule.Target)) + } + entry.Elems = append(entry.Elems, serialized...) + entry.Entry.NextOffset += uint16(len(serialized)) + + nflog("convert to binary: adding entry: %+v", entry) + + entries.Size += uint32(entry.Entry.NextOffset) + entries.Entrytable = append(entries.Entrytable, entry) + info.NumEntries++ + } + + info.Size = entries.Size + nflog("convert to binary: finished with an marshalled size of %d", info.Size) + return entries, info +} + +func modifyEntries4(stk *stack.Stack, optVal []byte, replace *linux.IPTReplace, table *stack.Table) (map[uint32]int, *syserr.Error) { + nflog("set entries: setting entries in table %q", replace.Name.String()) + + // Convert input into a list of rules and their offsets. + var offset uint32 + // offsets maps rule byte offsets to their position in table.Rules. + offsets := map[uint32]int{} + for entryIdx := uint32(0); entryIdx < replace.NumEntries; entryIdx++ { + nflog("set entries: processing entry at offset %d", offset) + + // Get the struct ipt_entry. + if len(optVal) < linux.SizeOfIPTEntry { + nflog("optVal has insufficient size for entry %d", len(optVal)) + return nil, syserr.ErrInvalidArgument + } + var entry linux.IPTEntry + buf := optVal[:linux.SizeOfIPTEntry] + binary.Unmarshal(buf, usermem.ByteOrder, &entry) + initialOptValLen := len(optVal) + optVal = optVal[linux.SizeOfIPTEntry:] + + if entry.TargetOffset < linux.SizeOfIPTEntry { + nflog("entry has too-small target offset %d", entry.TargetOffset) + return nil, syserr.ErrInvalidArgument + } + + // TODO(gvisor.dev/issue/170): We should support more IPTIP + // filtering fields. + filter, err := filterFromIPTIP(entry.IP) + if err != nil { + nflog("bad iptip: %v", err) + return nil, syserr.ErrInvalidArgument + } + + // TODO(gvisor.dev/issue/170): Matchers and targets can specify + // that they only work for certain protocols, hooks, tables. + // Get matchers. + matchersSize := entry.TargetOffset - linux.SizeOfIPTEntry + if len(optVal) < int(matchersSize) { + nflog("entry doesn't have enough room for its matchers (only %d bytes remain)", len(optVal)) + return nil, syserr.ErrInvalidArgument + } + matchers, err := parseMatchers(filter, optVal[:matchersSize]) + if err != nil { + nflog("failed to parse matchers: %v", err) + return nil, syserr.ErrInvalidArgument + } + optVal = optVal[matchersSize:] + + // Get the target of the rule. + targetSize := entry.NextOffset - entry.TargetOffset + if len(optVal) < int(targetSize) { + nflog("entry doesn't have enough room for its target (only %d bytes remain)", len(optVal)) + return nil, syserr.ErrInvalidArgument + } + target, err := parseTarget(filter, optVal[:targetSize]) + if err != nil { + nflog("failed to parse target: %v", err) + return nil, syserr.ErrInvalidArgument + } + optVal = optVal[targetSize:] + + table.Rules = append(table.Rules, stack.Rule{ + Filter: filter, + Target: target, + Matchers: matchers, + }) + offsets[offset] = int(entryIdx) + offset += uint32(entry.NextOffset) + + if initialOptValLen-len(optVal) != int(entry.NextOffset) { + nflog("entry NextOffset is %d, but entry took up %d bytes", entry.NextOffset, initialOptValLen-len(optVal)) + return nil, syserr.ErrInvalidArgument + } + } + return offsets, nil +} + +func filterFromIPTIP(iptip linux.IPTIP) (stack.IPHeaderFilter, error) { + if containsUnsupportedFields4(iptip) { + return stack.IPHeaderFilter{}, fmt.Errorf("unsupported fields in struct iptip: %+v", iptip) + } + if len(iptip.Dst) != header.IPv4AddressSize || len(iptip.DstMask) != header.IPv4AddressSize { + return stack.IPHeaderFilter{}, fmt.Errorf("incorrect length of destination (%d) and/or destination mask (%d) fields", len(iptip.Dst), len(iptip.DstMask)) + } + if len(iptip.Src) != header.IPv4AddressSize || len(iptip.SrcMask) != header.IPv4AddressSize { + return stack.IPHeaderFilter{}, fmt.Errorf("incorrect length of source (%d) and/or source mask (%d) fields", len(iptip.Src), len(iptip.SrcMask)) + } + + n := bytes.IndexByte([]byte(iptip.OutputInterface[:]), 0) + if n == -1 { + n = len(iptip.OutputInterface) + } + ifname := string(iptip.OutputInterface[:n]) + + n = bytes.IndexByte([]byte(iptip.OutputInterfaceMask[:]), 0) + if n == -1 { + n = len(iptip.OutputInterfaceMask) + } + ifnameMask := string(iptip.OutputInterfaceMask[:n]) + + return stack.IPHeaderFilter{ + Protocol: tcpip.TransportProtocolNumber(iptip.Protocol), + // A Protocol value of 0 indicates all protocols match. + CheckProtocol: iptip.Protocol != 0, + Dst: tcpip.Address(iptip.Dst[:]), + DstMask: tcpip.Address(iptip.DstMask[:]), + DstInvert: iptip.InverseFlags&linux.IPT_INV_DSTIP != 0, + Src: tcpip.Address(iptip.Src[:]), + SrcMask: tcpip.Address(iptip.SrcMask[:]), + SrcInvert: iptip.InverseFlags&linux.IPT_INV_SRCIP != 0, + OutputInterface: ifname, + OutputInterfaceMask: ifnameMask, + OutputInterfaceInvert: iptip.InverseFlags&linux.IPT_INV_VIA_OUT != 0, + }, nil +} + +func containsUnsupportedFields4(iptip linux.IPTIP) bool { + // The following features are supported: + // - Protocol + // - Dst and DstMask + // - Src and SrcMask + // - The inverse destination IP check flag + // - OutputInterface, OutputInterfaceMask and its inverse. + var emptyInterface = [linux.IFNAMSIZ]byte{} + // Disable any supported inverse flags. + inverseMask := uint8(linux.IPT_INV_DSTIP) | uint8(linux.IPT_INV_SRCIP) | uint8(linux.IPT_INV_VIA_OUT) + return iptip.InputInterface != emptyInterface || + iptip.InputInterfaceMask != emptyInterface || + iptip.Flags != 0 || + iptip.InverseFlags&^inverseMask != 0 +} diff --git a/pkg/sentry/socket/netfilter/ipv6.go b/pkg/sentry/socket/netfilter/ipv6.go new file mode 100644 index 000000000..3b2c1becd --- /dev/null +++ b/pkg/sentry/socket/netfilter/ipv6.go @@ -0,0 +1,265 @@ +// Copyright 2020 The gVisor Authors. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package netfilter + +import ( + "bytes" + "fmt" + + "gvisor.dev/gvisor/pkg/abi/linux" + "gvisor.dev/gvisor/pkg/binary" + "gvisor.dev/gvisor/pkg/syserr" + "gvisor.dev/gvisor/pkg/tcpip" + "gvisor.dev/gvisor/pkg/tcpip/header" + "gvisor.dev/gvisor/pkg/tcpip/stack" + "gvisor.dev/gvisor/pkg/usermem" +) + +// emptyIPv6Filter is for comparison with a rule's filters to determine whether +// it is also empty. It is immutable. +var emptyIPv6Filter = stack.IPHeaderFilter{ + Dst: "\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00", + DstMask: "\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00", + Src: "\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00", + SrcMask: "\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00", +} + +// convertNetstackToBinary6 converts the ip6tables as stored in netstack to the +// format expected by the iptables tool. Linux stores each table as a binary +// blob that can only be traversed by parsing a little data, reading some +// offsets, jumping to those offsets, parsing again, etc. +func convertNetstackToBinary6(stk *stack.Stack, tablename linux.TableName) (linux.KernelIP6TGetEntries, linux.IPTGetinfo, error) { + // The table name has to fit in the struct. + if linux.XT_TABLE_MAXNAMELEN < len(tablename) { + return linux.KernelIP6TGetEntries{}, linux.IPTGetinfo{}, fmt.Errorf("table name %q too long", tablename) + } + + table, ok := stk.IPTables().GetTable(tablename.String(), true) + if !ok { + return linux.KernelIP6TGetEntries{}, linux.IPTGetinfo{}, fmt.Errorf("couldn't find table %q", tablename) + } + + // Setup the info struct, which is the same in IPv4 and IPv6. + entries, info := getEntries6(table, tablename) + return entries, info, nil +} + +func getEntries6(table stack.Table, tablename linux.TableName) (linux.KernelIP6TGetEntries, linux.IPTGetinfo) { + var info linux.IPTGetinfo + var entries linux.KernelIP6TGetEntries + copy(info.Name[:], tablename[:]) + copy(entries.Name[:], info.Name[:]) + info.ValidHooks = table.ValidHooks() + + for ruleIdx, rule := range table.Rules { + nflog("convert to binary: current offset: %d", entries.Size) + + setHooksAndUnderflow(&info, table, entries.Size, ruleIdx) + // Each rule corresponds to an entry. + entry := linux.KernelIP6TEntry{ + Entry: linux.IP6TEntry{ + IPv6: linux.IP6TIP{ + Protocol: uint16(rule.Filter.Protocol), + }, + NextOffset: linux.SizeOfIP6TEntry, + TargetOffset: linux.SizeOfIP6TEntry, + }, + } + copy(entry.Entry.IPv6.Dst[:], rule.Filter.Dst) + copy(entry.Entry.IPv6.DstMask[:], rule.Filter.DstMask) + copy(entry.Entry.IPv6.Src[:], rule.Filter.Src) + copy(entry.Entry.IPv6.SrcMask[:], rule.Filter.SrcMask) + copy(entry.Entry.IPv6.OutputInterface[:], rule.Filter.OutputInterface) + copy(entry.Entry.IPv6.OutputInterfaceMask[:], rule.Filter.OutputInterfaceMask) + if rule.Filter.DstInvert { + entry.Entry.IPv6.InverseFlags |= linux.IP6T_INV_DSTIP + } + if rule.Filter.SrcInvert { + entry.Entry.IPv6.InverseFlags |= linux.IP6T_INV_SRCIP + } + if rule.Filter.OutputInterfaceInvert { + entry.Entry.IPv6.InverseFlags |= linux.IP6T_INV_VIA_OUT + } + if rule.Filter.CheckProtocol { + entry.Entry.IPv6.Flags |= linux.IP6T_F_PROTO + } + + for _, matcher := range rule.Matchers { + // Serialize the matcher and add it to the + // entry. + serialized := marshalMatcher(matcher) + nflog("convert to binary: matcher serialized as: %v", serialized) + if len(serialized)%8 != 0 { + panic(fmt.Sprintf("matcher %T is not 64-bit aligned", matcher)) + } + entry.Elems = append(entry.Elems, serialized...) + entry.Entry.NextOffset += uint16(len(serialized)) + entry.Entry.TargetOffset += uint16(len(serialized)) + } + + // Serialize and append the target. + serialized := marshalTarget(rule.Target) + if len(serialized)%8 != 0 { + panic(fmt.Sprintf("target %T is not 64-bit aligned", rule.Target)) + } + entry.Elems = append(entry.Elems, serialized...) + entry.Entry.NextOffset += uint16(len(serialized)) + + nflog("convert to binary: adding entry: %+v", entry) + + entries.Size += uint32(entry.Entry.NextOffset) + entries.Entrytable = append(entries.Entrytable, entry) + info.NumEntries++ + } + + info.Size = entries.Size + nflog("convert to binary: finished with an marshalled size of %d", info.Size) + return entries, info +} + +func modifyEntries6(stk *stack.Stack, optVal []byte, replace *linux.IPTReplace, table *stack.Table) (map[uint32]int, *syserr.Error) { + nflog("set entries: setting entries in table %q", replace.Name.String()) + + // Convert input into a list of rules and their offsets. + var offset uint32 + // offsets maps rule byte offsets to their position in table.Rules. + offsets := map[uint32]int{} + for entryIdx := uint32(0); entryIdx < replace.NumEntries; entryIdx++ { + nflog("set entries: processing entry at offset %d", offset) + + // Get the struct ipt_entry. + if len(optVal) < linux.SizeOfIP6TEntry { + nflog("optVal has insufficient size for entry %d", len(optVal)) + return nil, syserr.ErrInvalidArgument + } + var entry linux.IP6TEntry + buf := optVal[:linux.SizeOfIP6TEntry] + binary.Unmarshal(buf, usermem.ByteOrder, &entry) + initialOptValLen := len(optVal) + optVal = optVal[linux.SizeOfIP6TEntry:] + + if entry.TargetOffset < linux.SizeOfIP6TEntry { + nflog("entry has too-small target offset %d", entry.TargetOffset) + return nil, syserr.ErrInvalidArgument + } + + // TODO(gvisor.dev/issue/170): We should support more IPTIP + // filtering fields. + filter, err := filterFromIP6TIP(entry.IPv6) + if err != nil { + nflog("bad iptip: %v", err) + return nil, syserr.ErrInvalidArgument + } + + // TODO(gvisor.dev/issue/170): Matchers and targets can specify + // that they only work for certain protocols, hooks, tables. + // Get matchers. + matchersSize := entry.TargetOffset - linux.SizeOfIP6TEntry + if len(optVal) < int(matchersSize) { + nflog("entry doesn't have enough room for its matchers (only %d bytes remain)", len(optVal)) + return nil, syserr.ErrInvalidArgument + } + matchers, err := parseMatchers(filter, optVal[:matchersSize]) + if err != nil { + nflog("failed to parse matchers: %v", err) + return nil, syserr.ErrInvalidArgument + } + optVal = optVal[matchersSize:] + + // Get the target of the rule. + targetSize := entry.NextOffset - entry.TargetOffset + if len(optVal) < int(targetSize) { + nflog("entry doesn't have enough room for its target (only %d bytes remain)", len(optVal)) + return nil, syserr.ErrInvalidArgument + } + target, err := parseTarget(filter, optVal[:targetSize]) + if err != nil { + nflog("failed to parse target: %v", err) + return nil, syserr.ErrInvalidArgument + } + optVal = optVal[targetSize:] + + table.Rules = append(table.Rules, stack.Rule{ + Filter: filter, + Target: target, + Matchers: matchers, + }) + offsets[offset] = int(entryIdx) + offset += uint32(entry.NextOffset) + + if initialOptValLen-len(optVal) != int(entry.NextOffset) { + nflog("entry NextOffset is %d, but entry took up %d bytes", entry.NextOffset, initialOptValLen-len(optVal)) + return nil, syserr.ErrInvalidArgument + } + } + return offsets, nil +} + +func filterFromIP6TIP(iptip linux.IP6TIP) (stack.IPHeaderFilter, error) { + if containsUnsupportedFields6(iptip) { + return stack.IPHeaderFilter{}, fmt.Errorf("unsupported fields in struct iptip: %+v", iptip) + } + if len(iptip.Dst) != header.IPv6AddressSize || len(iptip.DstMask) != header.IPv6AddressSize { + return stack.IPHeaderFilter{}, fmt.Errorf("incorrect length of destination (%d) and/or destination mask (%d) fields", len(iptip.Dst), len(iptip.DstMask)) + } + if len(iptip.Src) != header.IPv6AddressSize || len(iptip.SrcMask) != header.IPv6AddressSize { + return stack.IPHeaderFilter{}, fmt.Errorf("incorrect length of source (%d) and/or source mask (%d) fields", len(iptip.Src), len(iptip.SrcMask)) + } + + n := bytes.IndexByte([]byte(iptip.OutputInterface[:]), 0) + if n == -1 { + n = len(iptip.OutputInterface) + } + ifname := string(iptip.OutputInterface[:n]) + + n = bytes.IndexByte([]byte(iptip.OutputInterfaceMask[:]), 0) + if n == -1 { + n = len(iptip.OutputInterfaceMask) + } + ifnameMask := string(iptip.OutputInterfaceMask[:n]) + + return stack.IPHeaderFilter{ + Protocol: tcpip.TransportProtocolNumber(iptip.Protocol), + // In ip6tables a flag controls whether to check the protocol. + CheckProtocol: iptip.Flags&linux.IP6T_F_PROTO != 0, + Dst: tcpip.Address(iptip.Dst[:]), + DstMask: tcpip.Address(iptip.DstMask[:]), + DstInvert: iptip.InverseFlags&linux.IP6T_INV_DSTIP != 0, + Src: tcpip.Address(iptip.Src[:]), + SrcMask: tcpip.Address(iptip.SrcMask[:]), + SrcInvert: iptip.InverseFlags&linux.IP6T_INV_SRCIP != 0, + OutputInterface: ifname, + OutputInterfaceMask: ifnameMask, + OutputInterfaceInvert: iptip.InverseFlags&linux.IP6T_INV_VIA_OUT != 0, + }, nil +} + +func containsUnsupportedFields6(iptip linux.IP6TIP) bool { + // The following features are supported: + // - Protocol + // - Dst and DstMask + // - Src and SrcMask + // - The inverse destination IP check flag + // - OutputInterface, OutputInterfaceMask and its inverse. + var emptyInterface = [linux.IFNAMSIZ]byte{} + flagMask := uint8(linux.IP6T_F_PROTO) + // Disable any supported inverse flags. + inverseMask := uint8(linux.IP6T_INV_DSTIP) | uint8(linux.IP6T_INV_SRCIP) | uint8(linux.IP6T_INV_VIA_OUT) + return iptip.InputInterface != emptyInterface || + iptip.InputInterfaceMask != emptyInterface || + iptip.Flags&^flagMask != 0 || + iptip.InverseFlags&^inverseMask != 0 || + iptip.TOS != 0 +} diff --git a/pkg/sentry/socket/netfilter/netfilter.go b/pkg/sentry/socket/netfilter/netfilter.go index e91b0624c..871ea80ee 100644 --- a/pkg/sentry/socket/netfilter/netfilter.go +++ b/pkg/sentry/socket/netfilter/netfilter.go @@ -17,7 +17,6 @@ package netfilter import ( - "bytes" "errors" "fmt" @@ -26,8 +25,6 @@ import ( "gvisor.dev/gvisor/pkg/log" "gvisor.dev/gvisor/pkg/sentry/kernel" "gvisor.dev/gvisor/pkg/syserr" - "gvisor.dev/gvisor/pkg/tcpip" - "gvisor.dev/gvisor/pkg/tcpip/header" "gvisor.dev/gvisor/pkg/tcpip/stack" "gvisor.dev/gvisor/pkg/usermem" ) @@ -37,15 +34,6 @@ import ( // developing iptables, but can pollute sentry logs otherwise. const enableLogging = false -// emptyFilter is for comparison with a rule's filters to determine whether it -// is also empty. It is immutable. -var emptyFilter = stack.IPHeaderFilter{ - Dst: "\x00\x00\x00\x00", - DstMask: "\x00\x00\x00\x00", - Src: "\x00\x00\x00\x00", - SrcMask: "\x00\x00\x00\x00", -} - // nflog logs messages related to the writing and reading of iptables. func nflog(format string, args ...interface{}) { if enableLogging && log.IsLogging(log.Debug) { @@ -54,14 +42,19 @@ func nflog(format string, args ...interface{}) { } // GetInfo returns information about iptables. -func GetInfo(t *kernel.Task, stack *stack.Stack, outPtr usermem.Addr) (linux.IPTGetinfo, *syserr.Error) { +func GetInfo(t *kernel.Task, stack *stack.Stack, outPtr usermem.Addr, ipv6 bool) (linux.IPTGetinfo, *syserr.Error) { // Read in the struct and table name. var info linux.IPTGetinfo if _, err := info.CopyIn(t, outPtr); err != nil { return linux.IPTGetinfo{}, syserr.FromError(err) } - _, info, err := convertNetstackToBinary(stack, info.Name) + var err error + if ipv6 { + _, info, err = convertNetstackToBinary6(stack, info.Name) + } else { + _, info, err = convertNetstackToBinary4(stack, info.Name) + } if err != nil { nflog("couldn't convert iptables: %v", err) return linux.IPTGetinfo{}, syserr.ErrInvalidArgument @@ -71,8 +64,8 @@ func GetInfo(t *kernel.Task, stack *stack.Stack, outPtr usermem.Addr) (linux.IPT return info, nil } -// GetEntries returns netstack's iptables rules encoded for the iptables tool. -func GetEntries(t *kernel.Task, stack *stack.Stack, outPtr usermem.Addr, outLen int) (linux.KernelIPTGetEntries, *syserr.Error) { +// GetEntries4 returns netstack's iptables rules. +func GetEntries4(t *kernel.Task, stack *stack.Stack, outPtr usermem.Addr, outLen int) (linux.KernelIPTGetEntries, *syserr.Error) { // Read in the struct and table name. var userEntries linux.IPTGetEntries if _, err := userEntries.CopyIn(t, outPtr); err != nil { @@ -82,7 +75,7 @@ func GetEntries(t *kernel.Task, stack *stack.Stack, outPtr usermem.Addr, outLen // Convert netstack's iptables rules to something that the iptables // tool can understand. - entries, _, err := convertNetstackToBinary(stack, userEntries.Name) + entries, _, err := convertNetstackToBinary4(stack, userEntries.Name) if err != nil { nflog("couldn't read entries: %v", err) return linux.KernelIPTGetEntries{}, syserr.ErrInvalidArgument @@ -95,112 +88,53 @@ func GetEntries(t *kernel.Task, stack *stack.Stack, outPtr usermem.Addr, outLen return entries, nil } -// convertNetstackToBinary converts the iptables as stored in netstack to the -// format expected by the iptables tool. Linux stores each table as a binary -// blob that can only be traversed by parsing a bit, reading some offsets, -// jumping to those offsets, parsing again, etc. -func convertNetstackToBinary(stack *stack.Stack, tablename linux.TableName) (linux.KernelIPTGetEntries, linux.IPTGetinfo, error) { - table, ok := stack.IPTables().GetTable(tablename.String()) - if !ok { - return linux.KernelIPTGetEntries{}, linux.IPTGetinfo{}, fmt.Errorf("couldn't find table %q", tablename) +// GetEntries6 returns netstack's ip6tables rules. +func GetEntries6(t *kernel.Task, stack *stack.Stack, outPtr usermem.Addr, outLen int) (linux.KernelIP6TGetEntries, *syserr.Error) { + // Read in the struct and table name. IPv4 and IPv6 utilize structs + // with the same layout. + var userEntries linux.IPTGetEntries + if _, err := userEntries.CopyIn(t, outPtr); err != nil { + nflog("couldn't copy in entries %q", userEntries.Name) + return linux.KernelIP6TGetEntries{}, syserr.FromError(err) } - var entries linux.KernelIPTGetEntries - var info linux.IPTGetinfo - info.ValidHooks = table.ValidHooks() - - // The table name has to fit in the struct. - if linux.XT_TABLE_MAXNAMELEN < len(tablename) { - return linux.KernelIPTGetEntries{}, linux.IPTGetinfo{}, fmt.Errorf("table name %q too long", tablename) + // Convert netstack's iptables rules to something that the iptables + // tool can understand. + entries, _, err := convertNetstackToBinary6(stack, userEntries.Name) + if err != nil { + nflog("couldn't read entries: %v", err) + return linux.KernelIP6TGetEntries{}, syserr.ErrInvalidArgument + } + if binary.Size(entries) > uintptr(outLen) { + nflog("insufficient GetEntries output size: %d", uintptr(outLen)) + return linux.KernelIP6TGetEntries{}, syserr.ErrInvalidArgument } - copy(info.Name[:], tablename[:]) - copy(entries.Name[:], tablename[:]) - - for ruleIdx, rule := range table.Rules { - nflog("convert to binary: current offset: %d", entries.Size) - - // Is this a chain entry point? - for hook, hookRuleIdx := range table.BuiltinChains { - if hookRuleIdx == ruleIdx { - nflog("convert to binary: found hook %d at offset %d", hook, entries.Size) - info.HookEntry[hook] = entries.Size - } - } - // Is this a chain underflow point? - for underflow, underflowRuleIdx := range table.Underflows { - if underflowRuleIdx == ruleIdx { - nflog("convert to binary: found underflow %d at offset %d", underflow, entries.Size) - info.Underflow[underflow] = entries.Size - } - } - // Each rule corresponds to an entry. - entry := linux.KernelIPTEntry{ - Entry: linux.IPTEntry{ - IP: linux.IPTIP{ - Protocol: uint16(rule.Filter.Protocol), - }, - NextOffset: linux.SizeOfIPTEntry, - TargetOffset: linux.SizeOfIPTEntry, - }, - } - copy(entry.Entry.IP.Dst[:], rule.Filter.Dst) - copy(entry.Entry.IP.DstMask[:], rule.Filter.DstMask) - copy(entry.Entry.IP.Src[:], rule.Filter.Src) - copy(entry.Entry.IP.SrcMask[:], rule.Filter.SrcMask) - copy(entry.Entry.IP.OutputInterface[:], rule.Filter.OutputInterface) - copy(entry.Entry.IP.OutputInterfaceMask[:], rule.Filter.OutputInterfaceMask) - if rule.Filter.DstInvert { - entry.Entry.IP.InverseFlags |= linux.IPT_INV_DSTIP - } - if rule.Filter.SrcInvert { - entry.Entry.IP.InverseFlags |= linux.IPT_INV_SRCIP - } - if rule.Filter.OutputInterfaceInvert { - entry.Entry.IP.InverseFlags |= linux.IPT_INV_VIA_OUT - } + return entries, nil +} - for _, matcher := range rule.Matchers { - // Serialize the matcher and add it to the - // entry. - serialized := marshalMatcher(matcher) - nflog("convert to binary: matcher serialized as: %v", serialized) - if len(serialized)%8 != 0 { - panic(fmt.Sprintf("matcher %T is not 64-bit aligned", matcher)) - } - entry.Elems = append(entry.Elems, serialized...) - entry.Entry.NextOffset += uint16(len(serialized)) - entry.Entry.TargetOffset += uint16(len(serialized)) +// setHooksAndUnderflow checks whether the rule at ruleIdx is a hook entrypoint +// or underflow, in which case it fills in info.HookEntry and info.Underflows. +func setHooksAndUnderflow(info *linux.IPTGetinfo, table stack.Table, offset uint32, ruleIdx int) { + // Is this a chain entry point? + for hook, hookRuleIdx := range table.BuiltinChains { + if hookRuleIdx == ruleIdx { + nflog("convert to binary: found hook %d at offset %d", hook, offset) + info.HookEntry[hook] = offset } - - // Serialize and append the target. - serialized := marshalTarget(rule.Target) - if len(serialized)%8 != 0 { - panic(fmt.Sprintf("target %T is not 64-bit aligned", rule.Target)) + } + // Is this a chain underflow point? + for underflow, underflowRuleIdx := range table.Underflows { + if underflowRuleIdx == ruleIdx { + nflog("convert to binary: found underflow %d at offset %d", underflow, offset) + info.Underflow[underflow] = offset } - entry.Elems = append(entry.Elems, serialized...) - entry.Entry.NextOffset += uint16(len(serialized)) - - nflog("convert to binary: adding entry: %+v", entry) - - entries.Size += uint32(entry.Entry.NextOffset) - entries.Entrytable = append(entries.Entrytable, entry) - info.NumEntries++ } - - nflog("convert to binary: finished with an marshalled size of %d", info.Size) - info.Size = entries.Size - return entries, info, nil } // SetEntries sets iptables rules for a single table. See // net/ipv4/netfilter/ip_tables.c:translate_table for reference. -func SetEntries(stk *stack.Stack, optVal []byte) *syserr.Error { - // Get the basic rules data (struct ipt_replace). - if len(optVal) < linux.SizeOfIPTReplace { - nflog("optVal has insufficient size for replace %d", len(optVal)) - return syserr.ErrInvalidArgument - } +func SetEntries(stk *stack.Stack, optVal []byte, ipv6 bool) *syserr.Error { var replace linux.IPTReplace replaceBuf := optVal[:linux.SizeOfIPTReplace] optVal = optVal[linux.SizeOfIPTReplace:] @@ -212,85 +146,25 @@ func SetEntries(stk *stack.Stack, optVal []byte) *syserr.Error { case stack.FilterTable: table = stack.EmptyFilterTable() case stack.NATTable: + if ipv6 { + nflog("IPv6 redirection not yet supported (gvisor.dev/issue/3549)") + return syserr.ErrInvalidArgument + } table = stack.EmptyNATTable() default: nflog("we don't yet support writing to the %q table (gvisor.dev/issue/170)", replace.Name.String()) return syserr.ErrInvalidArgument } - nflog("set entries: setting entries in table %q", replace.Name.String()) - - // Convert input into a list of rules and their offsets. - var offset uint32 - // offsets maps rule byte offsets to their position in table.Rules. - offsets := map[uint32]int{} - for entryIdx := uint32(0); entryIdx < replace.NumEntries; entryIdx++ { - nflog("set entries: processing entry at offset %d", offset) - - // Get the struct ipt_entry. - if len(optVal) < linux.SizeOfIPTEntry { - nflog("optVal has insufficient size for entry %d", len(optVal)) - return syserr.ErrInvalidArgument - } - var entry linux.IPTEntry - buf := optVal[:linux.SizeOfIPTEntry] - binary.Unmarshal(buf, usermem.ByteOrder, &entry) - initialOptValLen := len(optVal) - optVal = optVal[linux.SizeOfIPTEntry:] - - if entry.TargetOffset < linux.SizeOfIPTEntry { - nflog("entry has too-small target offset %d", entry.TargetOffset) - return syserr.ErrInvalidArgument - } - - // TODO(gvisor.dev/issue/170): We should support more IPTIP - // filtering fields. - filter, err := filterFromIPTIP(entry.IP) - if err != nil { - nflog("bad iptip: %v", err) - return syserr.ErrInvalidArgument - } - - // TODO(gvisor.dev/issue/170): Matchers and targets can specify - // that they only work for certain protocols, hooks, tables. - // Get matchers. - matchersSize := entry.TargetOffset - linux.SizeOfIPTEntry - if len(optVal) < int(matchersSize) { - nflog("entry doesn't have enough room for its matchers (only %d bytes remain)", len(optVal)) - return syserr.ErrInvalidArgument - } - matchers, err := parseMatchers(filter, optVal[:matchersSize]) - if err != nil { - nflog("failed to parse matchers: %v", err) - return syserr.ErrInvalidArgument - } - optVal = optVal[matchersSize:] - - // Get the target of the rule. - targetSize := entry.NextOffset - entry.TargetOffset - if len(optVal) < int(targetSize) { - nflog("entry doesn't have enough room for its target (only %d bytes remain)", len(optVal)) - return syserr.ErrInvalidArgument - } - target, err := parseTarget(filter, optVal[:targetSize]) - if err != nil { - nflog("failed to parse target: %v", err) - return syserr.ErrInvalidArgument - } - optVal = optVal[targetSize:] - - table.Rules = append(table.Rules, stack.Rule{ - Filter: filter, - Target: target, - Matchers: matchers, - }) - offsets[offset] = int(entryIdx) - offset += uint32(entry.NextOffset) - - if initialOptValLen-len(optVal) != int(entry.NextOffset) { - nflog("entry NextOffset is %d, but entry took up %d bytes", entry.NextOffset, initialOptValLen-len(optVal)) - return syserr.ErrInvalidArgument - } + var err *syserr.Error + var offsets map[uint32]int + if ipv6 { + offsets, err = modifyEntries6(stk, optVal, &replace, &table) + } else { + offsets, err = modifyEntries4(stk, optVal, &replace, &table) + } + if err != nil { + return err } // Go through the list of supported hooks for this table and, for each @@ -305,7 +179,7 @@ func SetEntries(stk *stack.Stack, optVal []byte) *syserr.Error { table.BuiltinChains[hk] = ruleIdx } if offset == replace.Underflow[hook] { - if !validUnderflow(table.Rules[ruleIdx]) { + if !validUnderflow(table.Rules[ruleIdx], ipv6) { nflog("underflow for hook %d isn't an unconditional ACCEPT or DROP", ruleIdx) return syserr.ErrInvalidArgument } @@ -323,7 +197,7 @@ func SetEntries(stk *stack.Stack, optVal []byte) *syserr.Error { } } - // Add the user chains. + // Check the user chains. for ruleIdx, rule := range table.Rules { if _, ok := rule.Target.(stack.UserChainTarget); !ok { continue @@ -370,7 +244,7 @@ func SetEntries(stk *stack.Stack, optVal []byte) *syserr.Error { if ruleIdx == stack.HookUnset { continue } - if !isUnconditionalAccept(table.Rules[ruleIdx]) { + if !isUnconditionalAccept(table.Rules[ruleIdx], ipv6) { nflog("hook %d is unsupported.", hook) return syserr.ErrInvalidArgument } @@ -382,7 +256,8 @@ func SetEntries(stk *stack.Stack, optVal []byte) *syserr.Error { // - There are no chains without an unconditional final rule. // - There are no chains without an unconditional underflow rule. - return syserr.TranslateNetstackError(stk.IPTables().ReplaceTable(replace.Name.String(), table)) + return syserr.TranslateNetstackError(stk.IPTables().ReplaceTable(replace.Name.String(), table, ipv6)) + } // parseMatchers parses 0 or more matchers from optVal. optVal should contain @@ -404,7 +279,6 @@ func parseMatchers(filter stack.IPHeaderFilter, optVal []byte) ([]stack.Matcher, // Check some invariants. if match.MatchSize < linux.SizeOfXTEntryMatch { - return nil, fmt.Errorf("match size is too small, must be at least %d", linux.SizeOfXTEntryMatch) } if len(optVal) < int(match.MatchSize) { @@ -429,64 +303,11 @@ func parseMatchers(filter stack.IPHeaderFilter, optVal []byte) ([]stack.Matcher, return matchers, nil } -func filterFromIPTIP(iptip linux.IPTIP) (stack.IPHeaderFilter, error) { - if containsUnsupportedFields(iptip) { - return stack.IPHeaderFilter{}, fmt.Errorf("unsupported fields in struct iptip: %+v", iptip) - } - if len(iptip.Dst) != header.IPv4AddressSize || len(iptip.DstMask) != header.IPv4AddressSize { - return stack.IPHeaderFilter{}, fmt.Errorf("incorrect length of destination (%d) and/or destination mask (%d) fields", len(iptip.Dst), len(iptip.DstMask)) - } - if len(iptip.Src) != header.IPv4AddressSize || len(iptip.SrcMask) != header.IPv4AddressSize { - return stack.IPHeaderFilter{}, fmt.Errorf("incorrect length of source (%d) and/or source mask (%d) fields", len(iptip.Src), len(iptip.SrcMask)) - } - - n := bytes.IndexByte([]byte(iptip.OutputInterface[:]), 0) - if n == -1 { - n = len(iptip.OutputInterface) - } - ifname := string(iptip.OutputInterface[:n]) - - n = bytes.IndexByte([]byte(iptip.OutputInterfaceMask[:]), 0) - if n == -1 { - n = len(iptip.OutputInterfaceMask) - } - ifnameMask := string(iptip.OutputInterfaceMask[:n]) - - return stack.IPHeaderFilter{ - Protocol: tcpip.TransportProtocolNumber(iptip.Protocol), - Dst: tcpip.Address(iptip.Dst[:]), - DstMask: tcpip.Address(iptip.DstMask[:]), - DstInvert: iptip.InverseFlags&linux.IPT_INV_DSTIP != 0, - Src: tcpip.Address(iptip.Src[:]), - SrcMask: tcpip.Address(iptip.SrcMask[:]), - SrcInvert: iptip.InverseFlags&linux.IPT_INV_SRCIP != 0, - OutputInterface: ifname, - OutputInterfaceMask: ifnameMask, - OutputInterfaceInvert: iptip.InverseFlags&linux.IPT_INV_VIA_OUT != 0, - }, nil -} - -func containsUnsupportedFields(iptip linux.IPTIP) bool { - // The following features are supported: - // - Protocol - // - Dst and DstMask - // - Src and SrcMask - // - The inverse destination IP check flag - // - OutputInterface, OutputInterfaceMask and its inverse. - var emptyInterface = [linux.IFNAMSIZ]byte{} - // Disable any supported inverse flags. - inverseMask := uint8(linux.IPT_INV_DSTIP) | uint8(linux.IPT_INV_SRCIP) | uint8(linux.IPT_INV_VIA_OUT) - return iptip.InputInterface != emptyInterface || - iptip.InputInterfaceMask != emptyInterface || - iptip.Flags != 0 || - iptip.InverseFlags&^inverseMask != 0 -} - -func validUnderflow(rule stack.Rule) bool { +func validUnderflow(rule stack.Rule, ipv6 bool) bool { if len(rule.Matchers) != 0 { return false } - if rule.Filter != emptyFilter { + if (ipv6 && rule.Filter != emptyIPv6Filter) || (!ipv6 && rule.Filter != emptyIPv4Filter) { return false } switch rule.Target.(type) { @@ -497,8 +318,8 @@ func validUnderflow(rule stack.Rule) bool { } } -func isUnconditionalAccept(rule stack.Rule) bool { - if !validUnderflow(rule) { +func isUnconditionalAccept(rule stack.Rule, ipv6 bool) bool { + if !validUnderflow(rule, ipv6) { return false } _, ok := rule.Target.(stack.AcceptTarget) diff --git a/pkg/sentry/socket/netfilter/targets.go b/pkg/sentry/socket/netfilter/targets.go index 8ebdaff18..87e41abd8 100644 --- a/pkg/sentry/socket/netfilter/targets.go +++ b/pkg/sentry/socket/netfilter/targets.go @@ -218,8 +218,8 @@ func parseTarget(filter stack.IPHeaderFilter, optVal []byte) (stack.Target, erro return nil, fmt.Errorf("netfilter.SetEntries: optVal has insufficient size for redirect target %d", len(optVal)) } - if filter.Protocol != header.TCPProtocolNumber && filter.Protocol != header.UDPProtocolNumber { - return nil, fmt.Errorf("netfilter.SetEntries: invalid argument") + if p := filter.Protocol; p != header.TCPProtocolNumber && p != header.UDPProtocolNumber { + return nil, fmt.Errorf("netfilter.SetEntries: bad proto %d", p) } var redirectTarget linux.XTRedirectTarget @@ -232,7 +232,7 @@ func parseTarget(filter stack.IPHeaderFilter, optVal []byte) (stack.Target, erro // RangeSize should be 1. if nfRange.RangeSize != 1 { - return nil, fmt.Errorf("netfilter.SetEntries: invalid argument") + return nil, fmt.Errorf("netfilter.SetEntries: bad rangesize %d", nfRange.RangeSize) } // TODO(gvisor.dev/issue/170): Check if the flags are valid. @@ -240,7 +240,7 @@ func parseTarget(filter stack.IPHeaderFilter, optVal []byte) (stack.Target, erro // For now, redirect target only supports destination port change. // Port range and IP range are not supported yet. if nfRange.RangeIPV4.Flags&linux.NF_NAT_RANGE_PROTO_SPECIFIED == 0 { - return nil, fmt.Errorf("netfilter.SetEntries: invalid argument") + return nil, fmt.Errorf("netfilter.SetEntries: invalid range flags %d", nfRange.RangeIPV4.Flags) } target.RangeProtoSpecified = true @@ -249,7 +249,7 @@ func parseTarget(filter stack.IPHeaderFilter, optVal []byte) (stack.Target, erro // TODO(gvisor.dev/issue/170): Port range is not supported yet. if nfRange.RangeIPV4.MinPort != nfRange.RangeIPV4.MaxPort { - return nil, fmt.Errorf("netfilter.SetEntries: invalid argument") + return nil, fmt.Errorf("netfilter.SetEntries: minport != maxport (%d, %d)", nfRange.RangeIPV4.MinPort, nfRange.RangeIPV4.MaxPort) } // Convert port from big endian to little endian. diff --git a/pkg/sentry/socket/netfilter/tcp_matcher.go b/pkg/sentry/socket/netfilter/tcp_matcher.go index 0bfd6c1f4..844acfede 100644 --- a/pkg/sentry/socket/netfilter/tcp_matcher.go +++ b/pkg/sentry/socket/netfilter/tcp_matcher.go @@ -97,17 +97,33 @@ func (*TCPMatcher) Name() string { // Match implements Matcher.Match. func (tm *TCPMatcher) Match(hook stack.Hook, pkt *stack.PacketBuffer, interfaceName string) (bool, bool) { - netHeader := header.IPv4(pkt.NetworkHeader().View()) + // TODO(gvisor.dev/issue/170): Proto checks should ultimately be moved + // into the stack.Check codepath as matchers are added. + switch pkt.NetworkProtocolNumber { + case header.IPv4ProtocolNumber: + netHeader := header.IPv4(pkt.NetworkHeader().View()) + if netHeader.TransportProtocol() != header.TCPProtocolNumber { + return false, false + } - if netHeader.TransportProtocol() != header.TCPProtocolNumber { - return false, false - } + // We don't match fragments. + if frag := netHeader.FragmentOffset(); frag != 0 { + if frag == 1 { + return false, true + } + return false, false + } - // We dont't match fragments. - if frag := netHeader.FragmentOffset(); frag != 0 { - if frag == 1 { - return false, true + case header.IPv6ProtocolNumber: + // As in Linux, we do not perform an IPv6 fragment check. See + // xt_action_param.fragoff in + // include/linux/netfilter/x_tables.h. + if header.IPv6(pkt.NetworkHeader().View()).TransportProtocol() != header.TCPProtocolNumber { + return false, false } + + default: + // We don't know the network protocol. return false, false } diff --git a/pkg/sentry/socket/netfilter/udp_matcher.go b/pkg/sentry/socket/netfilter/udp_matcher.go index 7ed05461d..63201201c 100644 --- a/pkg/sentry/socket/netfilter/udp_matcher.go +++ b/pkg/sentry/socket/netfilter/udp_matcher.go @@ -94,19 +94,33 @@ func (*UDPMatcher) Name() string { // Match implements Matcher.Match. func (um *UDPMatcher) Match(hook stack.Hook, pkt *stack.PacketBuffer, interfaceName string) (bool, bool) { - netHeader := header.IPv4(pkt.NetworkHeader().View()) - // TODO(gvisor.dev/issue/170): Proto checks should ultimately be moved // into the stack.Check codepath as matchers are added. - if netHeader.TransportProtocol() != header.UDPProtocolNumber { - return false, false - } + switch pkt.NetworkProtocolNumber { + case header.IPv4ProtocolNumber: + netHeader := header.IPv4(pkt.NetworkHeader().View()) + if netHeader.TransportProtocol() != header.UDPProtocolNumber { + return false, false + } - // We dont't match fragments. - if frag := netHeader.FragmentOffset(); frag != 0 { - if frag == 1 { - return false, true + // We don't match fragments. + if frag := netHeader.FragmentOffset(); frag != 0 { + if frag == 1 { + return false, true + } + return false, false } + + case header.IPv6ProtocolNumber: + // As in Linux, we do not perform an IPv6 fragment check. See + // xt_action_param.fragoff in + // include/linux/netfilter/x_tables.h. + if header.IPv6(pkt.NetworkHeader().View()).TransportProtocol() != header.UDPProtocolNumber { + return false, false + } + + default: + // We don't know the network protocol. return false, false } diff --git a/pkg/sentry/socket/netlink/BUILD b/pkg/sentry/socket/netlink/BUILD index 0546801bf..1f926aa91 100644 --- a/pkg/sentry/socket/netlink/BUILD +++ b/pkg/sentry/socket/netlink/BUILD @@ -16,6 +16,8 @@ go_library( "//pkg/abi/linux", "//pkg/binary", "//pkg/context", + "//pkg/marshal", + "//pkg/marshal/primitive", "//pkg/sentry/arch", "//pkg/sentry/device", "//pkg/sentry/fs", @@ -36,8 +38,6 @@ go_library( "//pkg/tcpip", "//pkg/usermem", "//pkg/waiter", - "//tools/go_marshal/marshal", - "//tools/go_marshal/primitive", ], ) diff --git a/pkg/sentry/socket/netlink/socket.go b/pkg/sentry/socket/netlink/socket.go index 68a9b9a96..5ddcd4be5 100644 --- a/pkg/sentry/socket/netlink/socket.go +++ b/pkg/sentry/socket/netlink/socket.go @@ -21,6 +21,8 @@ import ( "gvisor.dev/gvisor/pkg/abi/linux" "gvisor.dev/gvisor/pkg/binary" "gvisor.dev/gvisor/pkg/context" + "gvisor.dev/gvisor/pkg/marshal" + "gvisor.dev/gvisor/pkg/marshal/primitive" "gvisor.dev/gvisor/pkg/sentry/arch" "gvisor.dev/gvisor/pkg/sentry/device" "gvisor.dev/gvisor/pkg/sentry/fs" @@ -38,8 +40,6 @@ import ( "gvisor.dev/gvisor/pkg/tcpip" "gvisor.dev/gvisor/pkg/usermem" "gvisor.dev/gvisor/pkg/waiter" - "gvisor.dev/gvisor/tools/go_marshal/marshal" - "gvisor.dev/gvisor/tools/go_marshal/primitive" ) const sizeOfInt32 int = 4 diff --git a/pkg/sentry/socket/netstack/BUILD b/pkg/sentry/socket/netstack/BUILD index 1fb777a6c..fae3b6783 100644 --- a/pkg/sentry/socket/netstack/BUILD +++ b/pkg/sentry/socket/netstack/BUILD @@ -22,6 +22,8 @@ go_library( "//pkg/binary", "//pkg/context", "//pkg/log", + "//pkg/marshal", + "//pkg/marshal/primitive", "//pkg/metric", "//pkg/safemem", "//pkg/sentry/arch", @@ -51,8 +53,6 @@ go_library( "//pkg/tcpip/transport/udp", "//pkg/usermem", "//pkg/waiter", - "//tools/go_marshal/marshal", - "//tools/go_marshal/primitive", "@org_golang_x_sys//unix:go_default_library", ], ) diff --git a/pkg/sentry/socket/netstack/netstack.go b/pkg/sentry/socket/netstack/netstack.go index e4846bc0b..6fede181a 100644 --- a/pkg/sentry/socket/netstack/netstack.go +++ b/pkg/sentry/socket/netstack/netstack.go @@ -40,6 +40,8 @@ import ( "gvisor.dev/gvisor/pkg/binary" "gvisor.dev/gvisor/pkg/context" "gvisor.dev/gvisor/pkg/log" + "gvisor.dev/gvisor/pkg/marshal" + "gvisor.dev/gvisor/pkg/marshal/primitive" "gvisor.dev/gvisor/pkg/metric" "gvisor.dev/gvisor/pkg/safemem" "gvisor.dev/gvisor/pkg/sentry/arch" @@ -62,8 +64,6 @@ import ( "gvisor.dev/gvisor/pkg/tcpip/transport/udp" "gvisor.dev/gvisor/pkg/usermem" "gvisor.dev/gvisor/pkg/waiter" - "gvisor.dev/gvisor/tools/go_marshal/marshal" - "gvisor.dev/gvisor/tools/go_marshal/primitive" ) func mustCreateMetric(name, description string) *tcpip.StatCounter { @@ -158,6 +158,9 @@ var Metrics = tcpip.Stats{ OutgoingPacketErrors: mustCreateMetric("/netstack/ip/outgoing_packet_errors", "Total number of IP packets which failed to write to a link-layer endpoint."), MalformedPacketsReceived: mustCreateMetric("/netstack/ip/malformed_packets_received", "Total number of IP packets which failed IP header validation checks."), MalformedFragmentsReceived: mustCreateMetric("/netstack/ip/malformed_fragments_received", "Total number of IP fragments which failed IP fragment validation checks."), + IPTablesPreroutingDropped: mustCreateMetric("/netstack/ip/iptables/prerouting_dropped", "Total number of IP packets dropped in the Prerouting chain."), + IPTablesInputDropped: mustCreateMetric("/netstack/ip/iptables/input_dropped", "Total number of IP packets dropped in the Input chain."), + IPTablesOutputDropped: mustCreateMetric("/netstack/ip/iptables/output_dropped", "Total number of IP packets dropped in the Output chain."), }, TCP: tcpip.TCPStats{ ActiveConnectionOpenings: mustCreateMetric("/netstack/tcp/active_connection_openings", "Number of connections opened successfully via Connect."), @@ -236,7 +239,7 @@ type commonEndpoint interface { // SetSockOpt implements tcpip.Endpoint.SetSockOpt and // transport.Endpoint.SetSockOpt. - SetSockOpt(interface{}) *tcpip.Error + SetSockOpt(tcpip.SettableSocketOption) *tcpip.Error // SetSockOptBool implements tcpip.Endpoint.SetSockOptBool and // transport.Endpoint.SetSockOptBool. @@ -248,7 +251,7 @@ type commonEndpoint interface { // GetSockOpt implements tcpip.Endpoint.GetSockOpt and // transport.Endpoint.GetSockOpt. - GetSockOpt(interface{}) *tcpip.Error + GetSockOpt(tcpip.GettableSocketOption) *tcpip.Error // GetSockOptBool implements tcpip.Endpoint.GetSockOptBool and // transport.Endpoint.GetSockOpt. @@ -257,6 +260,9 @@ type commonEndpoint interface { // GetSockOptInt implements tcpip.Endpoint.GetSockOptInt and // transport.Endpoint.GetSockOpt. GetSockOptInt(opt tcpip.SockOptInt) (int, *tcpip.Error) + + // LastError implements tcpip.Endpoint.LastError. + LastError() *tcpip.Error } // LINT.IfChange @@ -479,8 +485,35 @@ func (s *socketOpsCommon) fetchReadView() *syserr.Error { } // Release implements fs.FileOperations.Release. -func (s *socketOpsCommon) Release(context.Context) { +func (s *socketOpsCommon) Release(ctx context.Context) { + e, ch := waiter.NewChannelEntry(nil) + s.EventRegister(&e, waiter.EventHUp|waiter.EventErr) + defer s.EventUnregister(&e) + s.Endpoint.Close() + + // SO_LINGER option is valid only for TCP. For other socket types + // return after endpoint close. + if family, skType, _ := s.Type(); skType != linux.SOCK_STREAM || (family != linux.AF_INET && family != linux.AF_INET6) { + return + } + + var v tcpip.LingerOption + if err := s.Endpoint.GetSockOpt(&v); err != nil { + return + } + + // The case for zero timeout is handled in tcp endpoint close function. + // Close is blocked until either: + // 1. The endpoint state is not in any of the states: FIN-WAIT1, + // CLOSING and LAST_ACK. + // 2. Timeout is reached. + if v.Enabled && v.Timeout != 0 { + t := kernel.TaskFromContext(ctx) + start := t.Kernel().MonotonicClock().Now() + deadline := start.Add(v.Timeout) + t.BlockWithDeadline(ch, true, deadline) + } } // Read implements fs.FileOperations.Read. @@ -803,7 +836,20 @@ func (s *socketOpsCommon) Bind(t *kernel.Task, sockaddr []byte) *syserr.Error { } // Issue the bind request to the endpoint. - return syserr.TranslateNetstackError(s.Endpoint.Bind(addr)) + err := s.Endpoint.Bind(addr) + if err == tcpip.ErrNoPortAvailable { + // Bind always returns EADDRINUSE irrespective of if the specified port was + // already bound or if an ephemeral port was requested but none were + // available. + // + // tcpip.ErrNoPortAvailable is mapped to EAGAIN in syserr package because + // UDP connect returns EAGAIN on ephemeral port exhaustion. + // + // TCP connect returns EADDRNOTAVAIL on ephemeral port exhaustion. + err = tcpip.ErrPortInUse + } + + return syserr.TranslateNetstackError(err) } // Listen implements the linux syscall listen(2) for sockets backed by @@ -814,7 +860,7 @@ func (s *socketOpsCommon) Listen(t *kernel.Task, backlog int) *syserr.Error { // blockingAccept implements a blocking version of accept(2), that is, if no // connections are ready to be accept, it will block until one becomes ready. -func (s *socketOpsCommon) blockingAccept(t *kernel.Task) (tcpip.Endpoint, *waiter.Queue, *syserr.Error) { +func (s *socketOpsCommon) blockingAccept(t *kernel.Task, peerAddr *tcpip.FullAddress) (tcpip.Endpoint, *waiter.Queue, *syserr.Error) { // Register for notifications. e, ch := waiter.NewChannelEntry(nil) s.EventRegister(&e, waiter.EventIn) @@ -823,7 +869,7 @@ func (s *socketOpsCommon) blockingAccept(t *kernel.Task) (tcpip.Endpoint, *waite // Try to accept the connection again; if it fails, then wait until we // get a notification. for { - if ep, wq, err := s.Endpoint.Accept(); err != tcpip.ErrWouldBlock { + if ep, wq, err := s.Endpoint.Accept(peerAddr); err != tcpip.ErrWouldBlock { return ep, wq, syserr.TranslateNetstackError(err) } @@ -836,15 +882,18 @@ func (s *socketOpsCommon) blockingAccept(t *kernel.Task) (tcpip.Endpoint, *waite // Accept implements the linux syscall accept(2) for sockets backed by // tcpip.Endpoint. func (s *SocketOperations) Accept(t *kernel.Task, peerRequested bool, flags int, blocking bool) (int32, linux.SockAddr, uint32, *syserr.Error) { - // Issue the accept request to get the new endpoint. - ep, wq, terr := s.Endpoint.Accept() + var peerAddr *tcpip.FullAddress + if peerRequested { + peerAddr = &tcpip.FullAddress{} + } + ep, wq, terr := s.Endpoint.Accept(peerAddr) if terr != nil { if terr != tcpip.ErrWouldBlock || !blocking { return 0, nil, 0, syserr.TranslateNetstackError(terr) } var err *syserr.Error - ep, wq, err = s.blockingAccept(t) + ep, wq, err = s.blockingAccept(t, peerAddr) if err != nil { return 0, nil, 0, err } @@ -864,13 +913,8 @@ func (s *SocketOperations) Accept(t *kernel.Task, peerRequested bool, flags int, var addr linux.SockAddr var addrLen uint32 - if peerRequested { - // Get address of the peer and write it to peer slice. - var err *syserr.Error - addr, addrLen, err = ns.FileOperations.(*SocketOperations).GetPeerName(t) - if err != nil { - return 0, nil, 0, err - } + if peerAddr != nil { + addr, addrLen = ConvertAddress(s.family, *peerAddr) } fd, e := t.NewFDFrom(0, ns, kernel.FDFlags{ @@ -943,47 +987,12 @@ func (s *SocketOperations) GetSockOpt(t *kernel.Task, level, name int, outPtr us return &val, nil } - if s.skType == linux.SOCK_RAW && level == linux.IPPROTO_IP { - switch name { - case linux.IPT_SO_GET_INFO: - if outLen < linux.SizeOfIPTGetinfo { - return nil, syserr.ErrInvalidArgument - } - - stack := inet.StackFromContext(t) - if stack == nil { - return nil, syserr.ErrNoDevice - } - info, err := netfilter.GetInfo(t, stack.(*Stack).Stack, outPtr) - if err != nil { - return nil, err - } - return &info, nil - - case linux.IPT_SO_GET_ENTRIES: - if outLen < linux.SizeOfIPTGetEntries { - return nil, syserr.ErrInvalidArgument - } - - stack := inet.StackFromContext(t) - if stack == nil { - return nil, syserr.ErrNoDevice - } - entries, err := netfilter.GetEntries(t, stack.(*Stack).Stack, outPtr, outLen) - if err != nil { - return nil, err - } - return &entries, nil - - } - } - - return GetSockOpt(t, s, s.Endpoint, s.family, s.skType, level, name, outLen) + return GetSockOpt(t, s, s.Endpoint, s.family, s.skType, level, name, outPtr, outLen) } // GetSockOpt can be used to implement the linux syscall getsockopt(2) for // sockets backed by a commonEndpoint. -func GetSockOpt(t *kernel.Task, s socket.SocketOps, ep commonEndpoint, family int, skType linux.SockType, level, name, outLen int) (marshal.Marshallable, *syserr.Error) { +func GetSockOpt(t *kernel.Task, s socket.SocketOps, ep commonEndpoint, family int, skType linux.SockType, level, name int, outPtr usermem.Addr, outLen int) (marshal.Marshallable, *syserr.Error) { switch level { case linux.SOL_SOCKET: return getSockOptSocket(t, s, ep, family, skType, name, outLen) @@ -992,10 +1001,10 @@ func GetSockOpt(t *kernel.Task, s socket.SocketOps, ep commonEndpoint, family in return getSockOptTCP(t, ep, name, outLen) case linux.SOL_IPV6: - return getSockOptIPv6(t, ep, name, outLen) + return getSockOptIPv6(t, s, ep, name, outPtr, outLen) case linux.SOL_IP: - return getSockOptIP(t, ep, name, outLen, family) + return getSockOptIP(t, s, ep, name, outPtr, outLen, family) case linux.SOL_UDP, linux.SOL_ICMPV6, @@ -1025,7 +1034,7 @@ func getSockOptSocket(t *kernel.Task, s socket.SocketOps, ep commonEndpoint, fam } // Get the last error and convert it. - err := ep.GetSockOpt(tcpip.ErrorOption{}) + err := ep.LastError() if err == nil { optP := primitive.Int32(0) return &optP, nil @@ -1176,7 +1185,16 @@ func getSockOptSocket(t *kernel.Task, s socket.SocketOps, ep commonEndpoint, fam return nil, syserr.ErrInvalidArgument } - linger := linux.Linger{} + var v tcpip.LingerOption + var linger linux.Linger + if err := ep.GetSockOpt(&v); err != nil { + return nil, syserr.TranslateNetstackError(err) + } + + if v.Enabled { + linger.OnOff = 1 + } + linger.Linger = int32(v.Timeout.Seconds()) return &linger, nil case linux.SO_SNDTIMEO: @@ -1390,8 +1408,12 @@ func getSockOptTCP(t *kernel.Task, ep commonEndpoint, name, outLen int) (marshal if err := ep.GetSockOpt(&v); err != nil { return nil, syserr.TranslateNetstackError(err) } - - lingerTimeout := primitive.Int32(time.Duration(v) / time.Second) + var lingerTimeout primitive.Int32 + if v >= 0 { + lingerTimeout = primitive.Int32(time.Duration(v) / time.Second) + } else { + lingerTimeout = -1 + } return &lingerTimeout, nil case linux.TCP_DEFER_ACCEPT: @@ -1437,7 +1459,7 @@ func getSockOptTCP(t *kernel.Task, ep commonEndpoint, name, outLen int) (marshal } // getSockOptIPv6 implements GetSockOpt when level is SOL_IPV6. -func getSockOptIPv6(t *kernel.Task, ep commonEndpoint, name, outLen int) (marshal.Marshallable, *syserr.Error) { +func getSockOptIPv6(t *kernel.Task, s socket.SocketOps, ep commonEndpoint, name int, outPtr usermem.Addr, outLen int) (marshal.Marshallable, *syserr.Error) { switch name { case linux.IPV6_V6ONLY: if outLen < sizeOfInt32 { @@ -1490,10 +1512,50 @@ func getSockOptIPv6(t *kernel.Task, ep commonEndpoint, name, outLen int) (marsha vP := primitive.Int32(boolToInt32(v)) return &vP, nil - case linux.SO_ORIGINAL_DST: + case linux.IP6T_ORIGINAL_DST: // TODO(gvisor.dev/issue/170): ip6tables. return nil, syserr.ErrInvalidArgument + case linux.IP6T_SO_GET_INFO: + if outLen < linux.SizeOfIPTGetinfo { + return nil, syserr.ErrInvalidArgument + } + + // Only valid for raw IPv6 sockets. + if family, skType, _ := s.Type(); family != linux.AF_INET6 || skType != linux.SOCK_RAW { + return nil, syserr.ErrProtocolNotAvailable + } + + stack := inet.StackFromContext(t) + if stack == nil { + return nil, syserr.ErrNoDevice + } + info, err := netfilter.GetInfo(t, stack.(*Stack).Stack, outPtr, true) + if err != nil { + return nil, err + } + return &info, nil + + case linux.IP6T_SO_GET_ENTRIES: + // IPTGetEntries is reused for IPv6. + if outLen < linux.SizeOfIPTGetEntries { + return nil, syserr.ErrInvalidArgument + } + // Only valid for raw IPv6 sockets. + if family, skType, _ := s.Type(); family != linux.AF_INET6 || skType != linux.SOCK_RAW { + return nil, syserr.ErrProtocolNotAvailable + } + + stack := inet.StackFromContext(t) + if stack == nil { + return nil, syserr.ErrNoDevice + } + entries, err := netfilter.GetEntries6(t, stack.(*Stack).Stack, outPtr, outLen) + if err != nil { + return nil, err + } + return &entries, nil + default: emitUnimplementedEventIPv6(t, name) } @@ -1501,7 +1563,7 @@ func getSockOptIPv6(t *kernel.Task, ep commonEndpoint, name, outLen int) (marsha } // getSockOptIP implements GetSockOpt when level is SOL_IP. -func getSockOptIP(t *kernel.Task, ep commonEndpoint, name, outLen int, family int) (marshal.Marshallable, *syserr.Error) { +func getSockOptIP(t *kernel.Task, s socket.SocketOps, ep commonEndpoint, name int, outPtr usermem.Addr, outLen int, family int) (marshal.Marshallable, *syserr.Error) { switch name { case linux.IP_TTL: if outLen < sizeOfInt32 { @@ -1617,6 +1679,46 @@ func getSockOptIP(t *kernel.Task, ep commonEndpoint, name, outLen int, family in a, _ := ConvertAddress(linux.AF_INET, tcpip.FullAddress(v)) return a.(*linux.SockAddrInet), nil + case linux.IPT_SO_GET_INFO: + if outLen < linux.SizeOfIPTGetinfo { + return nil, syserr.ErrInvalidArgument + } + + // Only valid for raw IPv4 sockets. + if family, skType, _ := s.Type(); family != linux.AF_INET || skType != linux.SOCK_RAW { + return nil, syserr.ErrProtocolNotAvailable + } + + stack := inet.StackFromContext(t) + if stack == nil { + return nil, syserr.ErrNoDevice + } + info, err := netfilter.GetInfo(t, stack.(*Stack).Stack, outPtr, false) + if err != nil { + return nil, err + } + return &info, nil + + case linux.IPT_SO_GET_ENTRIES: + if outLen < linux.SizeOfIPTGetEntries { + return nil, syserr.ErrInvalidArgument + } + + // Only valid for raw IPv4 sockets. + if family, skType, _ := s.Type(); family != linux.AF_INET || skType != linux.SOCK_RAW { + return nil, syserr.ErrProtocolNotAvailable + } + + stack := inet.StackFromContext(t) + if stack == nil { + return nil, syserr.ErrNoDevice + } + entries, err := netfilter.GetEntries4(t, stack.(*Stack).Stack, outPtr, outLen) + if err != nil { + return nil, err + } + return &entries, nil + default: emitUnimplementedEventIP(t, name) } @@ -1650,26 +1752,6 @@ func (s *SocketOperations) SetSockOpt(t *kernel.Task, level int, name int, optVa return nil } - if s.skType == linux.SOCK_RAW && level == linux.IPPROTO_IP { - switch name { - case linux.IPT_SO_SET_REPLACE: - if len(optVal) < linux.SizeOfIPTReplace { - return syserr.ErrInvalidArgument - } - - stack := inet.StackFromContext(t) - if stack == nil { - return syserr.ErrNoDevice - } - // Stack must be a netstack stack. - return netfilter.SetEntries(stack.(*Stack).Stack, optVal) - - case linux.IPT_SO_SET_ADD_COUNTERS: - // TODO(gvisor.dev/issue/170): Counter support. - return nil - } - } - return SetSockOpt(t, s, s.Endpoint, level, name, optVal) } @@ -1684,21 +1766,26 @@ func SetSockOpt(t *kernel.Task, s socket.SocketOps, ep commonEndpoint, level int return setSockOptTCP(t, ep, name, optVal) case linux.SOL_IPV6: - return setSockOptIPv6(t, ep, name, optVal) + return setSockOptIPv6(t, s, ep, name, optVal) case linux.SOL_IP: - return setSockOptIP(t, ep, name, optVal) + return setSockOptIP(t, s, ep, name, optVal) + + case linux.SOL_PACKET: + // gVisor doesn't support any SOL_PACKET options just return not + // supported. Returning nil here will result in tcpdump thinking AF_PACKET + // features are supported and proceed to use them and break. + t.Kernel().EmitUnimplementedEvent(t) + return syserr.ErrProtocolNotAvailable case linux.SOL_UDP, linux.SOL_ICMPV6, - linux.SOL_RAW, - linux.SOL_PACKET: + linux.SOL_RAW: t.Kernel().EmitUnimplementedEvent(t) } - // Default to the old behavior; hand off to network stack. - return syserr.TranslateNetstackError(ep.SetSockOpt(struct{}{})) + return nil } // setSockOptSocket implements SetSockOpt when level is SOL_SOCKET. @@ -1743,7 +1830,8 @@ func setSockOptSocket(t *kernel.Task, s socket.SocketOps, ep commonEndpoint, nam } name := string(optVal[:n]) if name == "" { - return syserr.TranslateNetstackError(ep.SetSockOpt(tcpip.BindToDeviceOption(0))) + v := tcpip.BindToDeviceOption(0) + return syserr.TranslateNetstackError(ep.SetSockOpt(&v)) } s := t.NetworkContext() if s == nil { @@ -1751,7 +1839,8 @@ func setSockOptSocket(t *kernel.Task, s socket.SocketOps, ep commonEndpoint, nam } for nicID, nic := range s.Interfaces() { if nic.Name == name { - return syserr.TranslateNetstackError(ep.SetSockOpt(tcpip.BindToDeviceOption(nicID))) + v := tcpip.BindToDeviceOption(nicID) + return syserr.TranslateNetstackError(ep.SetSockOpt(&v)) } } return syserr.ErrUnknownDevice @@ -1817,7 +1906,8 @@ func setSockOptSocket(t *kernel.Task, s socket.SocketOps, ep commonEndpoint, nam socket.SetSockOptEmitUnimplementedEvent(t, name) } - return syserr.TranslateNetstackError(ep.SetSockOpt(tcpip.OutOfBandInlineOption(v))) + opt := tcpip.OutOfBandInlineOption(v) + return syserr.TranslateNetstackError(ep.SetSockOpt(&opt)) case linux.SO_NO_CHECK: if len(optVal) < sizeOfInt32 { @@ -1839,19 +1929,21 @@ func setSockOptSocket(t *kernel.Task, s socket.SocketOps, ep commonEndpoint, nam socket.SetSockOptEmitUnimplementedEvent(t, name) } - return nil + return syserr.TranslateNetstackError( + ep.SetSockOpt(&tcpip.LingerOption{ + Enabled: v.OnOff != 0, + Timeout: time.Second * time.Duration(v.Linger)})) case linux.SO_DETACH_FILTER: // optval is ignored. var v tcpip.SocketDetachFilterOption - return syserr.TranslateNetstackError(ep.SetSockOpt(v)) + return syserr.TranslateNetstackError(ep.SetSockOpt(&v)) default: socket.SetSockOptEmitUnimplementedEvent(t, name) } - // Default to the old behavior; hand off to network stack. - return syserr.TranslateNetstackError(ep.SetSockOpt(struct{}{})) + return nil } // setSockOptTCP implements SetSockOpt when level is SOL_TCP. @@ -1898,7 +1990,8 @@ func setSockOptTCP(t *kernel.Task, ep commonEndpoint, name int, optVal []byte) * if v < 1 || v > linux.MAX_TCP_KEEPIDLE { return syserr.ErrInvalidArgument } - return syserr.TranslateNetstackError(ep.SetSockOpt(tcpip.KeepaliveIdleOption(time.Second * time.Duration(v)))) + opt := tcpip.KeepaliveIdleOption(time.Second * time.Duration(v)) + return syserr.TranslateNetstackError(ep.SetSockOpt(&opt)) case linux.TCP_KEEPINTVL: if len(optVal) < sizeOfInt32 { @@ -1909,7 +2002,8 @@ func setSockOptTCP(t *kernel.Task, ep commonEndpoint, name int, optVal []byte) * if v < 1 || v > linux.MAX_TCP_KEEPINTVL { return syserr.ErrInvalidArgument } - return syserr.TranslateNetstackError(ep.SetSockOpt(tcpip.KeepaliveIntervalOption(time.Second * time.Duration(v)))) + opt := tcpip.KeepaliveIntervalOption(time.Second * time.Duration(v)) + return syserr.TranslateNetstackError(ep.SetSockOpt(&opt)) case linux.TCP_KEEPCNT: if len(optVal) < sizeOfInt32 { @@ -1931,11 +2025,12 @@ func setSockOptTCP(t *kernel.Task, ep commonEndpoint, name int, optVal []byte) * if v < 0 { return syserr.ErrInvalidArgument } - return syserr.TranslateNetstackError(ep.SetSockOpt(tcpip.TCPUserTimeoutOption(time.Millisecond * time.Duration(v)))) + opt := tcpip.TCPUserTimeoutOption(time.Millisecond * time.Duration(v)) + return syserr.TranslateNetstackError(ep.SetSockOpt(&opt)) case linux.TCP_CONGESTION: v := tcpip.CongestionControlOption(optVal) - if err := ep.SetSockOpt(v); err != nil { + if err := ep.SetSockOpt(&v); err != nil { return syserr.TranslateNetstackError(err) } return nil @@ -1945,8 +2040,9 @@ func setSockOptTCP(t *kernel.Task, ep commonEndpoint, name int, optVal []byte) * return syserr.ErrInvalidArgument } - v := usermem.ByteOrder.Uint32(optVal) - return syserr.TranslateNetstackError(ep.SetSockOpt(tcpip.TCPLingerTimeoutOption(time.Second * time.Duration(v)))) + v := int32(usermem.ByteOrder.Uint32(optVal)) + opt := tcpip.TCPLingerTimeoutOption(time.Second * time.Duration(v)) + return syserr.TranslateNetstackError(ep.SetSockOpt(&opt)) case linux.TCP_DEFER_ACCEPT: if len(optVal) < sizeOfInt32 { @@ -1956,7 +2052,8 @@ func setSockOptTCP(t *kernel.Task, ep commonEndpoint, name int, optVal []byte) * if v < 0 { v = 0 } - return syserr.TranslateNetstackError(ep.SetSockOpt(tcpip.TCPDeferAcceptOption(time.Second * time.Duration(v)))) + opt := tcpip.TCPDeferAcceptOption(time.Second * time.Duration(v)) + return syserr.TranslateNetstackError(ep.SetSockOpt(&opt)) case linux.TCP_SYNCNT: if len(optVal) < sizeOfInt32 { @@ -1981,12 +2078,11 @@ func setSockOptTCP(t *kernel.Task, ep commonEndpoint, name int, optVal []byte) * emitUnimplementedEventTCP(t, name) } - // Default to the old behavior; hand off to network stack. - return syserr.TranslateNetstackError(ep.SetSockOpt(struct{}{})) + return nil } // setSockOptIPv6 implements SetSockOpt when level is SOL_IPV6. -func setSockOptIPv6(t *kernel.Task, ep commonEndpoint, name int, optVal []byte) *syserr.Error { +func setSockOptIPv6(t *kernel.Task, s socket.SocketOps, ep commonEndpoint, name int, optVal []byte) *syserr.Error { switch name { case linux.IPV6_V6ONLY: if len(optVal) < sizeOfInt32 { @@ -2035,12 +2131,32 @@ func setSockOptIPv6(t *kernel.Task, ep commonEndpoint, name int, optVal []byte) return syserr.TranslateNetstackError(ep.SetSockOptBool(tcpip.ReceiveTClassOption, v != 0)) + case linux.IP6T_SO_SET_REPLACE: + if len(optVal) < linux.SizeOfIP6TReplace { + return syserr.ErrInvalidArgument + } + + // Only valid for raw IPv6 sockets. + if family, skType, _ := s.Type(); family != linux.AF_INET6 || skType != linux.SOCK_RAW { + return syserr.ErrProtocolNotAvailable + } + + stack := inet.StackFromContext(t) + if stack == nil { + return syserr.ErrNoDevice + } + // Stack must be a netstack stack. + return netfilter.SetEntries(stack.(*Stack).Stack, optVal, true) + + case linux.IP6T_SO_SET_ADD_COUNTERS: + // TODO(gvisor.dev/issue/170): Counter support. + return nil + default: emitUnimplementedEventIPv6(t, name) } - // Default to the old behavior; hand off to network stack. - return syserr.TranslateNetstackError(ep.SetSockOpt(struct{}{})) + return nil } var ( @@ -2095,7 +2211,7 @@ func parseIntOrChar(buf []byte) (int32, *syserr.Error) { } // setSockOptIP implements SetSockOpt when level is SOL_IP. -func setSockOptIP(t *kernel.Task, ep commonEndpoint, name int, optVal []byte) *syserr.Error { +func setSockOptIP(t *kernel.Task, s socket.SocketOps, ep commonEndpoint, name int, optVal []byte) *syserr.Error { switch name { case linux.IP_MULTICAST_TTL: v, err := parseIntOrChar(optVal) @@ -2118,7 +2234,7 @@ func setSockOptIP(t *kernel.Task, ep commonEndpoint, name int, optVal []byte) *s return err } - return syserr.TranslateNetstackError(ep.SetSockOpt(tcpip.AddMembershipOption{ + return syserr.TranslateNetstackError(ep.SetSockOpt(&tcpip.AddMembershipOption{ NIC: tcpip.NICID(req.InterfaceIndex), // TODO(igudger): Change AddMembership to use the standard // any address representation. @@ -2132,7 +2248,7 @@ func setSockOptIP(t *kernel.Task, ep commonEndpoint, name int, optVal []byte) *s return err } - return syserr.TranslateNetstackError(ep.SetSockOpt(tcpip.RemoveMembershipOption{ + return syserr.TranslateNetstackError(ep.SetSockOpt(&tcpip.RemoveMembershipOption{ NIC: tcpip.NICID(req.InterfaceIndex), // TODO(igudger): Change DropMembership to use the standard // any address representation. @@ -2146,7 +2262,7 @@ func setSockOptIP(t *kernel.Task, ep commonEndpoint, name int, optVal []byte) *s return err } - return syserr.TranslateNetstackError(ep.SetSockOpt(tcpip.MulticastInterfaceOption{ + return syserr.TranslateNetstackError(ep.SetSockOpt(&tcpip.MulticastInterfaceOption{ NIC: tcpip.NICID(req.InterfaceIndex), InterfaceAddr: bytesToIPAddress(req.InterfaceAddr[:]), })) @@ -2215,6 +2331,27 @@ func setSockOptIP(t *kernel.Task, ep commonEndpoint, name int, optVal []byte) *s } return syserr.TranslateNetstackError(ep.SetSockOptBool(tcpip.IPHdrIncludedOption, v != 0)) + case linux.IPT_SO_SET_REPLACE: + if len(optVal) < linux.SizeOfIPTReplace { + return syserr.ErrInvalidArgument + } + + // Only valid for raw IPv4 sockets. + if family, skType, _ := s.Type(); family != linux.AF_INET || skType != linux.SOCK_RAW { + return syserr.ErrProtocolNotAvailable + } + + stack := inet.StackFromContext(t) + if stack == nil { + return syserr.ErrNoDevice + } + // Stack must be a netstack stack. + return netfilter.SetEntries(stack.(*Stack).Stack, optVal, false) + + case linux.IPT_SO_SET_ADD_COUNTERS: + // TODO(gvisor.dev/issue/170): Counter support. + return nil + case linux.IP_ADD_SOURCE_MEMBERSHIP, linux.IP_BIND_ADDRESS_NO_PORT, linux.IP_BLOCK_SOURCE, @@ -2249,8 +2386,7 @@ func setSockOptIP(t *kernel.Task, ep commonEndpoint, name int, optVal []byte) *s t.Kernel().EmitUnimplementedEvent(t) } - // Default to the old behavior; hand off to network stack. - return syserr.TranslateNetstackError(ep.SetSockOpt(struct{}{})) + return nil } // emitUnimplementedEventTCP emits unimplemented event if name is valid. This diff --git a/pkg/sentry/socket/netstack/netstack_vfs2.go b/pkg/sentry/socket/netstack/netstack_vfs2.go index 3335e7430..c0212ad76 100644 --- a/pkg/sentry/socket/netstack/netstack_vfs2.go +++ b/pkg/sentry/socket/netstack/netstack_vfs2.go @@ -18,21 +18,19 @@ import ( "gvisor.dev/gvisor/pkg/abi/linux" "gvisor.dev/gvisor/pkg/amutex" "gvisor.dev/gvisor/pkg/context" + "gvisor.dev/gvisor/pkg/marshal" + "gvisor.dev/gvisor/pkg/marshal/primitive" "gvisor.dev/gvisor/pkg/sentry/arch" fslock "gvisor.dev/gvisor/pkg/sentry/fs/lock" "gvisor.dev/gvisor/pkg/sentry/fsimpl/sockfs" - "gvisor.dev/gvisor/pkg/sentry/inet" "gvisor.dev/gvisor/pkg/sentry/kernel" "gvisor.dev/gvisor/pkg/sentry/socket" - "gvisor.dev/gvisor/pkg/sentry/socket/netfilter" "gvisor.dev/gvisor/pkg/sentry/vfs" "gvisor.dev/gvisor/pkg/syserr" "gvisor.dev/gvisor/pkg/syserror" "gvisor.dev/gvisor/pkg/tcpip" "gvisor.dev/gvisor/pkg/usermem" "gvisor.dev/gvisor/pkg/waiter" - "gvisor.dev/gvisor/tools/go_marshal/marshal" - "gvisor.dev/gvisor/tools/go_marshal/primitive" ) // SocketVFS2 encapsulates all the state needed to represent a network stack @@ -58,6 +56,7 @@ func NewVFS2(t *kernel.Task, family int, skType linux.SockType, protocol int, qu mnt := t.Kernel().SocketMount() d := sockfs.NewDentry(t.Credentials(), mnt) + defer d.DecRef(t) s := &SocketVFS2{ socketOpsCommon: socketOpsCommon{ @@ -152,14 +151,18 @@ func (s *SocketVFS2) Write(ctx context.Context, src usermem.IOSequence, opts vfs // tcpip.Endpoint. func (s *SocketVFS2) Accept(t *kernel.Task, peerRequested bool, flags int, blocking bool) (int32, linux.SockAddr, uint32, *syserr.Error) { // Issue the accept request to get the new endpoint. - ep, wq, terr := s.Endpoint.Accept() + var peerAddr *tcpip.FullAddress + if peerRequested { + peerAddr = &tcpip.FullAddress{} + } + ep, wq, terr := s.Endpoint.Accept(peerAddr) if terr != nil { if terr != tcpip.ErrWouldBlock || !blocking { return 0, nil, 0, syserr.TranslateNetstackError(terr) } var err *syserr.Error - ep, wq, err = s.blockingAccept(t) + ep, wq, err = s.blockingAccept(t, peerAddr) if err != nil { return 0, nil, 0, err } @@ -177,13 +180,9 @@ func (s *SocketVFS2) Accept(t *kernel.Task, peerRequested bool, flags int, block var addr linux.SockAddr var addrLen uint32 - if peerRequested { + if peerAddr != nil { // Get address of the peer and write it to peer slice. - var err *syserr.Error - addr, addrLen, err = ns.Impl().(*SocketVFS2).GetPeerName(t) - if err != nil { - return 0, nil, 0, err - } + addr, addrLen = ConvertAddress(s.family, *peerAddr) } fd, e := t.NewFDFromVFS2(0, ns, kernel.FDFlags{ @@ -233,42 +232,7 @@ func (s *SocketVFS2) GetSockOpt(t *kernel.Task, level, name int, outPtr usermem. return &val, nil } - if s.skType == linux.SOCK_RAW && level == linux.IPPROTO_IP { - switch name { - case linux.IPT_SO_GET_INFO: - if outLen < linux.SizeOfIPTGetinfo { - return nil, syserr.ErrInvalidArgument - } - - stack := inet.StackFromContext(t) - if stack == nil { - return nil, syserr.ErrNoDevice - } - info, err := netfilter.GetInfo(t, stack.(*Stack).Stack, outPtr) - if err != nil { - return nil, err - } - return &info, nil - - case linux.IPT_SO_GET_ENTRIES: - if outLen < linux.SizeOfIPTGetEntries { - return nil, syserr.ErrInvalidArgument - } - - stack := inet.StackFromContext(t) - if stack == nil { - return nil, syserr.ErrNoDevice - } - entries, err := netfilter.GetEntries(t, stack.(*Stack).Stack, outPtr, outLen) - if err != nil { - return nil, err - } - return &entries, nil - - } - } - - return GetSockOpt(t, s, s.Endpoint, s.family, s.skType, level, name, outLen) + return GetSockOpt(t, s, s.Endpoint, s.family, s.skType, level, name, outPtr, outLen) } // SetSockOpt implements the linux syscall setsockopt(2) for sockets backed by @@ -298,26 +262,6 @@ func (s *SocketVFS2) SetSockOpt(t *kernel.Task, level int, name int, optVal []by return nil } - if s.skType == linux.SOCK_RAW && level == linux.IPPROTO_IP { - switch name { - case linux.IPT_SO_SET_REPLACE: - if len(optVal) < linux.SizeOfIPTReplace { - return syserr.ErrInvalidArgument - } - - stack := inet.StackFromContext(t) - if stack == nil { - return syserr.ErrNoDevice - } - // Stack must be a netstack stack. - return netfilter.SetEntries(stack.(*Stack).Stack, optVal) - - case linux.IPT_SO_SET_ADD_COUNTERS: - // TODO(gvisor.dev/issue/170): Counter support. - return nil - } - } - return SetSockOpt(t, s, s.Endpoint, level, name, optVal) } diff --git a/pkg/sentry/socket/netstack/stack.go b/pkg/sentry/socket/netstack/stack.go index f9097d6b2..1028d2a6e 100644 --- a/pkg/sentry/socket/netstack/stack.go +++ b/pkg/sentry/socket/netstack/stack.go @@ -155,7 +155,7 @@ func (s *Stack) AddInterfaceAddr(idx int32, addr inet.InterfaceAddr) error { // TCPReceiveBufferSize implements inet.Stack.TCPReceiveBufferSize. func (s *Stack) TCPReceiveBufferSize() (inet.TCPBufferSize, error) { - var rs tcp.ReceiveBufferSizeOption + var rs tcpip.TCPReceiveBufferSizeRangeOption err := s.Stack.TransportProtocolOption(tcp.ProtocolNumber, &rs) return inet.TCPBufferSize{ Min: rs.Min, @@ -166,17 +166,17 @@ func (s *Stack) TCPReceiveBufferSize() (inet.TCPBufferSize, error) { // SetTCPReceiveBufferSize implements inet.Stack.SetTCPReceiveBufferSize. func (s *Stack) SetTCPReceiveBufferSize(size inet.TCPBufferSize) error { - rs := tcp.ReceiveBufferSizeOption{ + rs := tcpip.TCPReceiveBufferSizeRangeOption{ Min: size.Min, Default: size.Default, Max: size.Max, } - return syserr.TranslateNetstackError(s.Stack.SetTransportProtocolOption(tcp.ProtocolNumber, rs)).ToError() + return syserr.TranslateNetstackError(s.Stack.SetTransportProtocolOption(tcp.ProtocolNumber, &rs)).ToError() } // TCPSendBufferSize implements inet.Stack.TCPSendBufferSize. func (s *Stack) TCPSendBufferSize() (inet.TCPBufferSize, error) { - var ss tcp.SendBufferSizeOption + var ss tcpip.TCPSendBufferSizeRangeOption err := s.Stack.TransportProtocolOption(tcp.ProtocolNumber, &ss) return inet.TCPBufferSize{ Min: ss.Min, @@ -187,29 +187,30 @@ func (s *Stack) TCPSendBufferSize() (inet.TCPBufferSize, error) { // SetTCPSendBufferSize implements inet.Stack.SetTCPSendBufferSize. func (s *Stack) SetTCPSendBufferSize(size inet.TCPBufferSize) error { - ss := tcp.SendBufferSizeOption{ + ss := tcpip.TCPSendBufferSizeRangeOption{ Min: size.Min, Default: size.Default, Max: size.Max, } - return syserr.TranslateNetstackError(s.Stack.SetTransportProtocolOption(tcp.ProtocolNumber, ss)).ToError() + return syserr.TranslateNetstackError(s.Stack.SetTransportProtocolOption(tcp.ProtocolNumber, &ss)).ToError() } // TCPSACKEnabled implements inet.Stack.TCPSACKEnabled. func (s *Stack) TCPSACKEnabled() (bool, error) { - var sack tcp.SACKEnabled + var sack tcpip.TCPSACKEnabled err := s.Stack.TransportProtocolOption(tcp.ProtocolNumber, &sack) return bool(sack), syserr.TranslateNetstackError(err).ToError() } // SetTCPSACKEnabled implements inet.Stack.SetTCPSACKEnabled. func (s *Stack) SetTCPSACKEnabled(enabled bool) error { - return syserr.TranslateNetstackError(s.Stack.SetTransportProtocolOption(tcp.ProtocolNumber, tcp.SACKEnabled(enabled))).ToError() + opt := tcpip.TCPSACKEnabled(enabled) + return syserr.TranslateNetstackError(s.Stack.SetTransportProtocolOption(tcp.ProtocolNumber, &opt)).ToError() } // TCPRecovery implements inet.Stack.TCPRecovery. func (s *Stack) TCPRecovery() (inet.TCPLossRecovery, error) { - var recovery tcp.Recovery + var recovery tcpip.TCPRecovery if err := s.Stack.TransportProtocolOption(tcp.ProtocolNumber, &recovery); err != nil { return 0, syserr.TranslateNetstackError(err).ToError() } @@ -218,7 +219,8 @@ func (s *Stack) TCPRecovery() (inet.TCPLossRecovery, error) { // SetTCPRecovery implements inet.Stack.SetTCPRecovery. func (s *Stack) SetTCPRecovery(recovery inet.TCPLossRecovery) error { - return syserr.TranslateNetstackError(s.Stack.SetTransportProtocolOption(tcp.ProtocolNumber, tcp.Recovery(recovery))).ToError() + opt := tcpip.TCPRecovery(recovery) + return syserr.TranslateNetstackError(s.Stack.SetTransportProtocolOption(tcp.ProtocolNumber, &opt)).ToError() } // Statistics implements inet.Stack.Statistics. @@ -417,8 +419,7 @@ func (s *Stack) Forwarding(protocol tcpip.NetworkProtocolNumber) bool { case ipv4.ProtocolNumber, ipv6.ProtocolNumber: return s.Stack.Forwarding(protocol) default: - log.Warningf("Forwarding(%v) failed: unsupported protocol", protocol) - return false + panic(fmt.Sprintf("Forwarding(%v) failed: unsupported protocol", protocol)) } } @@ -428,8 +429,7 @@ func (s *Stack) SetForwarding(protocol tcpip.NetworkProtocolNumber, enable bool) case ipv4.ProtocolNumber, ipv6.ProtocolNumber: s.Stack.SetForwarding(protocol, enable) default: - log.Warningf("SetForwarding(%v) failed: unsupported protocol", protocol) - return syserr.ErrProtocolNotSupported.ToError() + panic(fmt.Sprintf("SetForwarding(%v) failed: unsupported protocol", protocol)) } return nil } diff --git a/pkg/sentry/socket/socket.go b/pkg/sentry/socket/socket.go index 04b259d27..fd31479e5 100644 --- a/pkg/sentry/socket/socket.go +++ b/pkg/sentry/socket/socket.go @@ -25,6 +25,7 @@ import ( "gvisor.dev/gvisor/pkg/abi/linux" "gvisor.dev/gvisor/pkg/binary" "gvisor.dev/gvisor/pkg/context" + "gvisor.dev/gvisor/pkg/marshal" "gvisor.dev/gvisor/pkg/sentry/device" "gvisor.dev/gvisor/pkg/sentry/fs" "gvisor.dev/gvisor/pkg/sentry/fs/fsutil" @@ -35,7 +36,6 @@ import ( "gvisor.dev/gvisor/pkg/syserr" "gvisor.dev/gvisor/pkg/tcpip" "gvisor.dev/gvisor/pkg/usermem" - "gvisor.dev/gvisor/tools/go_marshal/marshal" ) // ControlMessages represents the union of unix control messages and tcpip diff --git a/pkg/sentry/socket/unix/BUILD b/pkg/sentry/socket/unix/BUILD index cb953e4dc..a89583dad 100644 --- a/pkg/sentry/socket/unix/BUILD +++ b/pkg/sentry/socket/unix/BUILD @@ -29,6 +29,7 @@ go_library( "//pkg/context", "//pkg/fspath", "//pkg/log", + "//pkg/marshal", "//pkg/refs", "//pkg/safemem", "//pkg/sentry/arch", @@ -49,6 +50,5 @@ go_library( "//pkg/tcpip", "//pkg/usermem", "//pkg/waiter", - "//tools/go_marshal/marshal", ], ) diff --git a/pkg/sentry/socket/unix/transport/BUILD b/pkg/sentry/socket/unix/transport/BUILD index c708b6030..26c3a51b9 100644 --- a/pkg/sentry/socket/unix/transport/BUILD +++ b/pkg/sentry/socket/unix/transport/BUILD @@ -15,6 +15,17 @@ go_template_instance( }, ) +go_template_instance( + name = "queue_refs", + out = "queue_refs.go", + package = "transport", + prefix = "queue", + template = "//pkg/refs_vfs2:refs_template", + types = { + "T": "queue", + }, +) + go_library( name = "transport", srcs = [ @@ -22,6 +33,7 @@ go_library( "connectioned_state.go", "connectionless.go", "queue.go", + "queue_refs.go", "transport_message_list.go", "unix.go", ], diff --git a/pkg/sentry/socket/unix/transport/connectioned.go b/pkg/sentry/socket/unix/transport/connectioned.go index c67b602f0..aa4f3c04d 100644 --- a/pkg/sentry/socket/unix/transport/connectioned.go +++ b/pkg/sentry/socket/unix/transport/connectioned.go @@ -142,9 +142,9 @@ func NewPair(ctx context.Context, stype linux.SockType, uid UniqueIDProvider) (E } q1 := &queue{ReaderQueue: a.Queue, WriterQueue: b.Queue, limit: initialLimit} - q1.EnableLeakCheck("transport.queue") + q1.EnableLeakCheck() q2 := &queue{ReaderQueue: b.Queue, WriterQueue: a.Queue, limit: initialLimit} - q2.EnableLeakCheck("transport.queue") + q2.EnableLeakCheck() if stype == linux.SOCK_STREAM { a.receiver = &streamQueueReceiver{queueReceiver: queueReceiver{q1}} @@ -300,14 +300,14 @@ func (e *connectionedEndpoint) BidirectionalConnect(ctx context.Context, ce Conn } readQueue := &queue{ReaderQueue: ce.WaiterQueue(), WriterQueue: ne.Queue, limit: initialLimit} - readQueue.EnableLeakCheck("transport.queue") + readQueue.EnableLeakCheck() ne.connected = &connectedEndpoint{ endpoint: ce, writeQueue: readQueue, } writeQueue := &queue{ReaderQueue: ne.Queue, WriterQueue: ce.WaiterQueue(), limit: initialLimit} - writeQueue.EnableLeakCheck("transport.queue") + writeQueue.EnableLeakCheck() if e.stype == linux.SOCK_STREAM { ne.receiver = &streamQueueReceiver{queueReceiver: queueReceiver{readQueue: writeQueue}} } else { @@ -391,7 +391,7 @@ func (e *connectionedEndpoint) Listen(backlog int) *syserr.Error { } // Accept accepts a new connection. -func (e *connectionedEndpoint) Accept() (Endpoint, *syserr.Error) { +func (e *connectionedEndpoint) Accept(peerAddr *tcpip.FullAddress) (Endpoint, *syserr.Error) { e.Lock() defer e.Unlock() @@ -401,6 +401,18 @@ func (e *connectionedEndpoint) Accept() (Endpoint, *syserr.Error) { select { case ne := <-e.acceptedChan: + if peerAddr != nil { + ne.Lock() + c := ne.connected + ne.Unlock() + if c != nil { + addr, err := c.GetLocalAddress() + if err != nil { + return nil, syserr.TranslateNetstackError(err) + } + *peerAddr = addr + } + } return ne, nil default: diff --git a/pkg/sentry/socket/unix/transport/connectionless.go b/pkg/sentry/socket/unix/transport/connectionless.go index 70ee8f9b8..f8aacca13 100644 --- a/pkg/sentry/socket/unix/transport/connectionless.go +++ b/pkg/sentry/socket/unix/transport/connectionless.go @@ -42,7 +42,7 @@ var ( func NewConnectionless(ctx context.Context) Endpoint { ep := &connectionlessEndpoint{baseEndpoint{Queue: &waiter.Queue{}}} q := queue{ReaderQueue: ep.Queue, WriterQueue: &waiter.Queue{}, limit: initialLimit} - q.EnableLeakCheck("transport.queue") + q.EnableLeakCheck() ep.receiver = &queueReceiver{readQueue: &q} return ep } @@ -144,12 +144,12 @@ func (e *connectionlessEndpoint) Connect(ctx context.Context, server BoundEndpoi } // Listen starts listening on the connection. -func (e *connectionlessEndpoint) Listen(int) *syserr.Error { +func (*connectionlessEndpoint) Listen(int) *syserr.Error { return syserr.ErrNotSupported } // Accept accepts a new connection. -func (e *connectionlessEndpoint) Accept() (Endpoint, *syserr.Error) { +func (*connectionlessEndpoint) Accept(*tcpip.FullAddress) (Endpoint, *syserr.Error) { return nil, syserr.ErrNotSupported } diff --git a/pkg/sentry/socket/unix/transport/queue.go b/pkg/sentry/socket/unix/transport/queue.go index ef6043e19..342def28f 100644 --- a/pkg/sentry/socket/unix/transport/queue.go +++ b/pkg/sentry/socket/unix/transport/queue.go @@ -16,7 +16,6 @@ package transport import ( "gvisor.dev/gvisor/pkg/context" - "gvisor.dev/gvisor/pkg/refs" "gvisor.dev/gvisor/pkg/sync" "gvisor.dev/gvisor/pkg/syserr" "gvisor.dev/gvisor/pkg/tcpip" @@ -28,7 +27,7 @@ import ( // // +stateify savable type queue struct { - refs.AtomicRefCount + queueRefs ReaderQueue *waiter.Queue WriterQueue *waiter.Queue @@ -68,11 +67,13 @@ func (q *queue) Reset(ctx context.Context) { q.mu.Unlock() } -// DecRef implements RefCounter.DecRef with destructor q.Reset. +// DecRef implements RefCounter.DecRef. func (q *queue) DecRef(ctx context.Context) { - q.DecRefWithDestructor(ctx, q.Reset) - // We don't need to notify after resetting because no one cares about - // this queue after all references have been dropped. + q.queueRefs.DecRef(func() { + // We don't need to notify after resetting because no one cares about + // this queue after all references have been dropped. + q.Reset(ctx) + }) } // IsReadable determines if q is currently readable. diff --git a/pkg/sentry/socket/unix/transport/unix.go b/pkg/sentry/socket/unix/transport/unix.go index 475d7177e..d6fc03520 100644 --- a/pkg/sentry/socket/unix/transport/unix.go +++ b/pkg/sentry/socket/unix/transport/unix.go @@ -151,7 +151,10 @@ type Endpoint interface { // block if no new connections are available. // // The returned Queue is the wait queue for the newly created endpoint. - Accept() (Endpoint, *syserr.Error) + // + // peerAddr if not nil will be populated with the address of the connected + // peer on a successful accept. + Accept(peerAddr *tcpip.FullAddress) (Endpoint, *syserr.Error) // Bind binds the endpoint to a specific local address and port. // Specifying a NIC is optional. @@ -172,9 +175,8 @@ type Endpoint interface { // connected. GetRemoteAddress() (tcpip.FullAddress, *tcpip.Error) - // SetSockOpt sets a socket option. opt should be one of the tcpip.*Option - // types. - SetSockOpt(opt interface{}) *tcpip.Error + // SetSockOpt sets a socket option. + SetSockOpt(opt tcpip.SettableSocketOption) *tcpip.Error // SetSockOptBool sets a socket option for simple cases when a value has // the int type. @@ -184,9 +186,8 @@ type Endpoint interface { // the int type. SetSockOptInt(opt tcpip.SockOptInt, v int) *tcpip.Error - // GetSockOpt gets a socket option. opt should be a pointer to one of the - // tcpip.*Option types. - GetSockOpt(opt interface{}) *tcpip.Error + // GetSockOpt gets a socket option. + GetSockOpt(opt tcpip.GettableSocketOption) *tcpip.Error // GetSockOptBool gets a socket option for simple cases when a return // value has the int type. @@ -199,6 +200,9 @@ type Endpoint interface { // State returns the current state of the socket, as represented by Linux in // procfs. State() uint32 + + // LastError implements tcpip.Endpoint.LastError. + LastError() *tcpip.Error } // A Credentialer is a socket or endpoint that supports the SO_PASSCRED socket @@ -742,6 +746,9 @@ type baseEndpoint struct { // path is not empty if the endpoint has been bound, // or may be used if the endpoint is connected. path string + + // linger is used for SO_LINGER socket option. + linger tcpip.LingerOption } // EventRegister implements waiter.Waitable.EventRegister. @@ -837,8 +844,14 @@ func (e *baseEndpoint) SendMsg(ctx context.Context, data [][]byte, c ControlMess return n, err } -// SetSockOpt sets a socket option. Currently not supported. -func (e *baseEndpoint) SetSockOpt(opt interface{}) *tcpip.Error { +// SetSockOpt sets a socket option. +func (e *baseEndpoint) SetSockOpt(opt tcpip.SettableSocketOption) *tcpip.Error { + switch v := opt.(type) { + case *tcpip.LingerOption: + e.Lock() + e.linger = *v + e.Unlock() + } return nil } @@ -940,9 +953,12 @@ func (e *baseEndpoint) GetSockOptInt(opt tcpip.SockOptInt) (int, *tcpip.Error) { } // GetSockOpt implements tcpip.Endpoint.GetSockOpt. -func (e *baseEndpoint) GetSockOpt(opt interface{}) *tcpip.Error { - switch opt.(type) { - case tcpip.ErrorOption: +func (e *baseEndpoint) GetSockOpt(opt tcpip.GettableSocketOption) *tcpip.Error { + switch o := opt.(type) { + case *tcpip.LingerOption: + e.Lock() + *o = e.linger + e.Unlock() return nil default: @@ -951,6 +967,11 @@ func (e *baseEndpoint) GetSockOpt(opt interface{}) *tcpip.Error { } } +// LastError implements Endpoint.LastError. +func (*baseEndpoint) LastError() *tcpip.Error { + return nil +} + // Shutdown closes the read and/or write end of the endpoint connection to its // peer. func (e *baseEndpoint) Shutdown(flags tcpip.ShutdownFlags) *syserr.Error { diff --git a/pkg/sentry/socket/unix/unix.go b/pkg/sentry/socket/unix/unix.go index b7e8e4325..917055cea 100644 --- a/pkg/sentry/socket/unix/unix.go +++ b/pkg/sentry/socket/unix/unix.go @@ -24,6 +24,7 @@ import ( "gvisor.dev/gvisor/pkg/abi/linux" "gvisor.dev/gvisor/pkg/context" "gvisor.dev/gvisor/pkg/fspath" + "gvisor.dev/gvisor/pkg/marshal" "gvisor.dev/gvisor/pkg/sentry/arch" "gvisor.dev/gvisor/pkg/sentry/fs" "gvisor.dev/gvisor/pkg/sentry/fs/fsutil" @@ -39,7 +40,6 @@ import ( "gvisor.dev/gvisor/pkg/tcpip" "gvisor.dev/gvisor/pkg/usermem" "gvisor.dev/gvisor/pkg/waiter" - "gvisor.dev/gvisor/tools/go_marshal/marshal" ) // SocketOperations is a Unix socket. It is similar to a netstack socket, @@ -194,7 +194,7 @@ func (s *SocketOperations) Ioctl(ctx context.Context, _ *fs.File, io usermem.IO, // GetSockOpt implements the linux syscall getsockopt(2) for sockets backed by // a transport.Endpoint. func (s *SocketOperations) GetSockOpt(t *kernel.Task, level, name int, outPtr usermem.Addr, outLen int) (marshal.Marshallable, *syserr.Error) { - return netstack.GetSockOpt(t, s, s.ep, linux.AF_UNIX, s.ep.Type(), level, name, outLen) + return netstack.GetSockOpt(t, s, s.ep, linux.AF_UNIX, s.ep.Type(), level, name, outPtr, outLen) } // Listen implements the linux syscall listen(2) for sockets backed by @@ -205,7 +205,7 @@ func (s *socketOpsCommon) Listen(t *kernel.Task, backlog int) *syserr.Error { // blockingAccept implements a blocking version of accept(2), that is, if no // connections are ready to be accept, it will block until one becomes ready. -func (s *SocketOperations) blockingAccept(t *kernel.Task) (transport.Endpoint, *syserr.Error) { +func (s *SocketOperations) blockingAccept(t *kernel.Task, peerAddr *tcpip.FullAddress) (transport.Endpoint, *syserr.Error) { // Register for notifications. e, ch := waiter.NewChannelEntry(nil) s.EventRegister(&e, waiter.EventIn) @@ -214,7 +214,7 @@ func (s *SocketOperations) blockingAccept(t *kernel.Task) (transport.Endpoint, * // Try to accept the connection; if it fails, then wait until we get a // notification. for { - if ep, err := s.ep.Accept(); err != syserr.ErrWouldBlock { + if ep, err := s.ep.Accept(peerAddr); err != syserr.ErrWouldBlock { return ep, err } @@ -227,15 +227,18 @@ func (s *SocketOperations) blockingAccept(t *kernel.Task) (transport.Endpoint, * // Accept implements the linux syscall accept(2) for sockets backed by // a transport.Endpoint. func (s *SocketOperations) Accept(t *kernel.Task, peerRequested bool, flags int, blocking bool) (int32, linux.SockAddr, uint32, *syserr.Error) { - // Issue the accept request to get the new endpoint. - ep, err := s.ep.Accept() + var peerAddr *tcpip.FullAddress + if peerRequested { + peerAddr = &tcpip.FullAddress{} + } + ep, err := s.ep.Accept(peerAddr) if err != nil { if err != syserr.ErrWouldBlock || !blocking { return 0, nil, 0, err } var err *syserr.Error - ep, err = s.blockingAccept(t) + ep, err = s.blockingAccept(t, peerAddr) if err != nil { return 0, nil, 0, err } @@ -252,13 +255,8 @@ func (s *SocketOperations) Accept(t *kernel.Task, peerRequested bool, flags int, var addr linux.SockAddr var addrLen uint32 - if peerRequested { - // Get address of the peer. - var err *syserr.Error - addr, addrLen, err = ns.FileOperations.(*SocketOperations).GetPeerName(t) - if err != nil { - return 0, nil, 0, err - } + if peerAddr != nil { + addr, addrLen = netstack.ConvertAddress(linux.AF_UNIX, *peerAddr) } fd, e := t.NewFDFrom(0, ns, kernel.FDFlags{ diff --git a/pkg/sentry/socket/unix/unix_vfs2.go b/pkg/sentry/socket/unix/unix_vfs2.go index d066ef8ab..3688f22d2 100644 --- a/pkg/sentry/socket/unix/unix_vfs2.go +++ b/pkg/sentry/socket/unix/unix_vfs2.go @@ -18,6 +18,7 @@ import ( "gvisor.dev/gvisor/pkg/abi/linux" "gvisor.dev/gvisor/pkg/context" "gvisor.dev/gvisor/pkg/fspath" + "gvisor.dev/gvisor/pkg/marshal" "gvisor.dev/gvisor/pkg/sentry/arch" fslock "gvisor.dev/gvisor/pkg/sentry/fs/lock" "gvisor.dev/gvisor/pkg/sentry/fsimpl/sockfs" @@ -32,7 +33,6 @@ import ( "gvisor.dev/gvisor/pkg/tcpip" "gvisor.dev/gvisor/pkg/usermem" "gvisor.dev/gvisor/pkg/waiter" - "gvisor.dev/gvisor/tools/go_marshal/marshal" ) // SocketVFS2 implements socket.SocketVFS2 (and by extension, @@ -91,12 +91,12 @@ func NewFileDescription(ep transport.Endpoint, stype linux.SockType, flags uint3 // GetSockOpt implements the linux syscall getsockopt(2) for sockets backed by // a transport.Endpoint. func (s *SocketVFS2) GetSockOpt(t *kernel.Task, level, name int, outPtr usermem.Addr, outLen int) (marshal.Marshallable, *syserr.Error) { - return netstack.GetSockOpt(t, s, s.ep, linux.AF_UNIX, s.ep.Type(), level, name, outLen) + return netstack.GetSockOpt(t, s, s.ep, linux.AF_UNIX, s.ep.Type(), level, name, outPtr, outLen) } // blockingAccept implements a blocking version of accept(2), that is, if no // connections are ready to be accept, it will block until one becomes ready. -func (s *SocketVFS2) blockingAccept(t *kernel.Task) (transport.Endpoint, *syserr.Error) { +func (s *SocketVFS2) blockingAccept(t *kernel.Task, peerAddr *tcpip.FullAddress) (transport.Endpoint, *syserr.Error) { // Register for notifications. e, ch := waiter.NewChannelEntry(nil) s.socketOpsCommon.EventRegister(&e, waiter.EventIn) @@ -105,7 +105,7 @@ func (s *SocketVFS2) blockingAccept(t *kernel.Task) (transport.Endpoint, *syserr // Try to accept the connection; if it fails, then wait until we get a // notification. for { - if ep, err := s.ep.Accept(); err != syserr.ErrWouldBlock { + if ep, err := s.ep.Accept(peerAddr); err != syserr.ErrWouldBlock { return ep, err } @@ -118,15 +118,18 @@ func (s *SocketVFS2) blockingAccept(t *kernel.Task) (transport.Endpoint, *syserr // Accept implements the linux syscall accept(2) for sockets backed by // a transport.Endpoint. func (s *SocketVFS2) Accept(t *kernel.Task, peerRequested bool, flags int, blocking bool) (int32, linux.SockAddr, uint32, *syserr.Error) { - // Issue the accept request to get the new endpoint. - ep, err := s.ep.Accept() + var peerAddr *tcpip.FullAddress + if peerRequested { + peerAddr = &tcpip.FullAddress{} + } + ep, err := s.ep.Accept(peerAddr) if err != nil { if err != syserr.ErrWouldBlock || !blocking { return 0, nil, 0, err } var err *syserr.Error - ep, err = s.blockingAccept(t) + ep, err = s.blockingAccept(t, peerAddr) if err != nil { return 0, nil, 0, err } @@ -144,13 +147,8 @@ func (s *SocketVFS2) Accept(t *kernel.Task, peerRequested bool, flags int, block var addr linux.SockAddr var addrLen uint32 - if peerRequested { - // Get address of the peer. - var err *syserr.Error - addr, addrLen, err = ns.Impl().(*SocketVFS2).GetPeerName(t) - if err != nil { - return 0, nil, 0, err - } + if peerAddr != nil { + addr, addrLen = netstack.ConvertAddress(linux.AF_UNIX, *peerAddr) } fd, e := t.NewFDFromVFS2(0, ns, kernel.FDFlags{ diff --git a/pkg/sentry/strace/BUILD b/pkg/sentry/strace/BUILD index 88d5db9fc..a920180d3 100644 --- a/pkg/sentry/strace/BUILD +++ b/pkg/sentry/strace/BUILD @@ -28,6 +28,7 @@ go_library( "//pkg/binary", "//pkg/bits", "//pkg/eventchannel", + "//pkg/marshal/primitive", "//pkg/seccomp", "//pkg/sentry/arch", "//pkg/sentry/kernel", diff --git a/pkg/sentry/strace/epoll.go b/pkg/sentry/strace/epoll.go index 5d51a7792..ae3b998c8 100644 --- a/pkg/sentry/strace/epoll.go +++ b/pkg/sentry/strace/epoll.go @@ -26,7 +26,7 @@ import ( func epollEvent(t *kernel.Task, eventAddr usermem.Addr) string { var e linux.EpollEvent - if _, err := t.CopyIn(eventAddr, &e); err != nil { + if _, err := e.CopyIn(t, eventAddr); err != nil { return fmt.Sprintf("%#x {error reading event: %v}", eventAddr, err) } var sb strings.Builder @@ -41,7 +41,7 @@ func epollEvents(t *kernel.Task, eventsAddr usermem.Addr, numEvents, maxBytes ui addr := eventsAddr for i := uint64(0); i < numEvents; i++ { var e linux.EpollEvent - if _, err := t.CopyIn(addr, &e); err != nil { + if _, err := e.CopyIn(t, addr); err != nil { fmt.Fprintf(&sb, "{error reading event at %#x: %v}", addr, err) continue } diff --git a/pkg/sentry/strace/socket.go b/pkg/sentry/strace/socket.go index b51c4c941..cc5f70cd4 100644 --- a/pkg/sentry/strace/socket.go +++ b/pkg/sentry/strace/socket.go @@ -21,6 +21,7 @@ import ( "gvisor.dev/gvisor/pkg/abi" "gvisor.dev/gvisor/pkg/abi/linux" "gvisor.dev/gvisor/pkg/binary" + "gvisor.dev/gvisor/pkg/marshal/primitive" "gvisor.dev/gvisor/pkg/sentry/kernel" "gvisor.dev/gvisor/pkg/sentry/socket/netlink" "gvisor.dev/gvisor/pkg/sentry/socket/netstack" @@ -166,7 +167,7 @@ func cmsghdr(t *kernel.Task, addr usermem.Addr, length uint64, maxBytes uint64) } buf := make([]byte, length) - if _, err := t.CopyIn(addr, &buf); err != nil { + if _, err := t.CopyInBytes(addr, buf); err != nil { return fmt.Sprintf("%#x (error decoding control: %v)", addr, err) } @@ -302,7 +303,7 @@ func cmsghdr(t *kernel.Task, addr usermem.Addr, length uint64, maxBytes uint64) func msghdr(t *kernel.Task, addr usermem.Addr, printContent bool, maxBytes uint64) string { var msg slinux.MessageHeader64 - if err := slinux.CopyInMessageHeader64(t, addr, &msg); err != nil { + if _, err := msg.CopyIn(t, addr); err != nil { return fmt.Sprintf("%#x (error decoding msghdr: %v)", addr, err) } s := fmt.Sprintf( @@ -380,9 +381,9 @@ func postSockAddr(t *kernel.Task, addr usermem.Addr, lengthPtr usermem.Addr) str func copySockLen(t *kernel.Task, addr usermem.Addr) (uint32, error) { // socklen_t is 32-bits. - var l uint32 - _, err := t.CopyIn(addr, &l) - return l, err + var l primitive.Uint32 + _, err := l.CopyIn(t, addr) + return uint32(l), err } func sockLenPointer(t *kernel.Task, addr usermem.Addr) string { @@ -436,22 +437,22 @@ func getSockOptVal(t *kernel.Task, level, optname uint64, optVal usermem.Addr, o func sockOptVal(t *kernel.Task, level, optname uint64, optVal usermem.Addr, optLen uint64, maximumBlobSize uint) string { switch optLen { case 1: - var v uint8 - _, err := t.CopyIn(optVal, &v) + var v primitive.Uint8 + _, err := v.CopyIn(t, optVal) if err != nil { return fmt.Sprintf("%#x {error reading optval: %v}", optVal, err) } return fmt.Sprintf("%#x {value=%v}", optVal, v) case 2: - var v uint16 - _, err := t.CopyIn(optVal, &v) + var v primitive.Uint16 + _, err := v.CopyIn(t, optVal) if err != nil { return fmt.Sprintf("%#x {error reading optval: %v}", optVal, err) } return fmt.Sprintf("%#x {value=%v}", optVal, v) case 4: - var v uint32 - _, err := t.CopyIn(optVal, &v) + var v primitive.Uint32 + _, err := v.CopyIn(t, optVal) if err != nil { return fmt.Sprintf("%#x {error reading optval: %v}", optVal, err) } @@ -632,6 +633,8 @@ var sockOptNames = map[uint64]abi.ValueSet{ linux.IPV6_UNICAST_IF: "IPV6_UNICAST_IF", linux.MCAST_MSFILTER: "MCAST_MSFILTER", linux.IPV6_ADDRFORM: "IPV6_ADDRFORM", + linux.IP6T_SO_GET_INFO: "IP6T_SO_GET_INFO", + linux.IP6T_SO_GET_ENTRIES: "IP6T_SO_GET_ENTRIES", }, linux.SOL_NETLINK: { linux.NETLINK_BROADCAST_ERROR: "NETLINK_BROADCAST_ERROR", diff --git a/pkg/sentry/strace/strace.go b/pkg/sentry/strace/strace.go index 87b239730..52281ccc2 100644 --- a/pkg/sentry/strace/strace.go +++ b/pkg/sentry/strace/strace.go @@ -21,13 +21,13 @@ import ( "fmt" "strconv" "strings" - "syscall" "time" "gvisor.dev/gvisor/pkg/abi" "gvisor.dev/gvisor/pkg/abi/linux" "gvisor.dev/gvisor/pkg/bits" "gvisor.dev/gvisor/pkg/eventchannel" + "gvisor.dev/gvisor/pkg/marshal/primitive" "gvisor.dev/gvisor/pkg/seccomp" "gvisor.dev/gvisor/pkg/sentry/arch" "gvisor.dev/gvisor/pkg/sentry/kernel" @@ -91,7 +91,7 @@ func iovecs(t *kernel.Task, addr usermem.Addr, iovcnt int, printContent bool, ma } b := make([]byte, size) - amt, err := t.CopyIn(ar.Start, b) + amt, err := t.CopyInBytes(ar.Start, b) if err != nil { iovs[i] = fmt.Sprintf("{base=%#x, len=%d, %q..., error decoding string: %v}", ar.Start, ar.Length(), b[:amt], err) continue @@ -118,7 +118,7 @@ func dump(t *kernel.Task, addr usermem.Addr, size uint, maximumBlobSize uint) st } b := make([]byte, size) - amt, err := t.CopyIn(addr, b) + amt, err := t.CopyInBytes(addr, b) if err != nil { return fmt.Sprintf("%#x (error decoding string: %s)", addr, err) } @@ -199,7 +199,7 @@ func fdVFS2(t *kernel.Task, fd int32) string { func fdpair(t *kernel.Task, addr usermem.Addr) string { var fds [2]int32 - _, err := t.CopyIn(addr, &fds) + _, err := primitive.CopyInt32SliceIn(t, addr, fds[:]) if err != nil { return fmt.Sprintf("%#x (error decoding fds: %s)", addr, err) } @@ -209,7 +209,7 @@ func fdpair(t *kernel.Task, addr usermem.Addr) string { func uname(t *kernel.Task, addr usermem.Addr) string { var u linux.UtsName - if _, err := t.CopyIn(addr, &u); err != nil { + if _, err := u.CopyIn(t, addr); err != nil { return fmt.Sprintf("%#x (error decoding utsname: %s)", addr, err) } @@ -222,7 +222,7 @@ func utimensTimespec(t *kernel.Task, addr usermem.Addr) string { } var tim linux.Timespec - if _, err := t.CopyIn(addr, &tim); err != nil { + if _, err := tim.CopyIn(t, addr); err != nil { return fmt.Sprintf("%#x (error decoding timespec: %s)", addr, err) } @@ -244,7 +244,7 @@ func timespec(t *kernel.Task, addr usermem.Addr) string { } var tim linux.Timespec - if _, err := t.CopyIn(addr, &tim); err != nil { + if _, err := tim.CopyIn(t, addr); err != nil { return fmt.Sprintf("%#x (error decoding timespec: %s)", addr, err) } return fmt.Sprintf("%#x {sec=%v nsec=%v}", addr, tim.Sec, tim.Nsec) @@ -256,7 +256,7 @@ func timeval(t *kernel.Task, addr usermem.Addr) string { } var tim linux.Timeval - if _, err := t.CopyIn(addr, &tim); err != nil { + if _, err := tim.CopyIn(t, addr); err != nil { return fmt.Sprintf("%#x (error decoding timeval: %s)", addr, err) } @@ -268,8 +268,8 @@ func utimbuf(t *kernel.Task, addr usermem.Addr) string { return "null" } - var utim syscall.Utimbuf - if _, err := t.CopyIn(addr, &utim); err != nil { + var utim linux.Utime + if _, err := utim.CopyIn(t, addr); err != nil { return fmt.Sprintf("%#x (error decoding utimbuf: %s)", addr, err) } @@ -282,7 +282,7 @@ func stat(t *kernel.Task, addr usermem.Addr) string { } var stat linux.Stat - if _, err := t.CopyIn(addr, &stat); err != nil { + if _, err := stat.CopyIn(t, addr); err != nil { return fmt.Sprintf("%#x (error decoding stat: %s)", addr, err) } return fmt.Sprintf("%#x {dev=%d, ino=%d, mode=%s, nlink=%d, uid=%d, gid=%d, rdev=%d, size=%d, blksize=%d, blocks=%d, atime=%s, mtime=%s, ctime=%s}", addr, stat.Dev, stat.Ino, linux.FileMode(stat.Mode), stat.Nlink, stat.UID, stat.GID, stat.Rdev, stat.Size, stat.Blksize, stat.Blocks, time.Unix(stat.ATime.Sec, stat.ATime.Nsec), time.Unix(stat.MTime.Sec, stat.MTime.Nsec), time.Unix(stat.CTime.Sec, stat.CTime.Nsec)) @@ -330,7 +330,7 @@ func rusage(t *kernel.Task, addr usermem.Addr) string { } var ru linux.Rusage - if _, err := t.CopyIn(addr, &ru); err != nil { + if _, err := ru.CopyIn(t, addr); err != nil { return fmt.Sprintf("%#x (error decoding rusage: %s)", addr, err) } return fmt.Sprintf("%#x %+v", addr, ru) @@ -342,7 +342,7 @@ func capHeader(t *kernel.Task, addr usermem.Addr) string { } var hdr linux.CapUserHeader - if _, err := t.CopyIn(addr, &hdr); err != nil { + if _, err := hdr.CopyIn(t, addr); err != nil { return fmt.Sprintf("%#x (error decoding header: %s)", addr, err) } @@ -367,7 +367,7 @@ func capData(t *kernel.Task, hdrAddr, dataAddr usermem.Addr) string { } var hdr linux.CapUserHeader - if _, err := t.CopyIn(hdrAddr, &hdr); err != nil { + if _, err := hdr.CopyIn(t, hdrAddr); err != nil { return fmt.Sprintf("%#x (error decoding header: %v)", dataAddr, err) } @@ -376,7 +376,7 @@ func capData(t *kernel.Task, hdrAddr, dataAddr usermem.Addr) string { switch hdr.Version { case linux.LINUX_CAPABILITY_VERSION_1: var data linux.CapUserData - if _, err := t.CopyIn(dataAddr, &data); err != nil { + if _, err := data.CopyIn(t, dataAddr); err != nil { return fmt.Sprintf("%#x (error decoding data: %v)", dataAddr, err) } p = uint64(data.Permitted) @@ -384,7 +384,7 @@ func capData(t *kernel.Task, hdrAddr, dataAddr usermem.Addr) string { e = uint64(data.Effective) case linux.LINUX_CAPABILITY_VERSION_2, linux.LINUX_CAPABILITY_VERSION_3: var data [2]linux.CapUserData - if _, err := t.CopyIn(dataAddr, &data); err != nil { + if _, err := linux.CopyCapUserDataSliceIn(t, dataAddr, data[:]); err != nil { return fmt.Sprintf("%#x (error decoding data: %v)", dataAddr, err) } p = uint64(data[0].Permitted) | (uint64(data[1].Permitted) << 32) diff --git a/pkg/sentry/syscalls/linux/BUILD b/pkg/sentry/syscalls/linux/BUILD index 4a9b04fd0..75752b2e6 100644 --- a/pkg/sentry/syscalls/linux/BUILD +++ b/pkg/sentry/syscalls/linux/BUILD @@ -56,6 +56,7 @@ go_library( "sys_xattr.go", "timespec.go", ], + marshal = True, visibility = ["//:sandbox"], deps = [ "//pkg/abi", @@ -64,6 +65,8 @@ go_library( "//pkg/bpf", "//pkg/context", "//pkg/log", + "//pkg/marshal", + "//pkg/marshal/primitive", "//pkg/metric", "//pkg/rand", "//pkg/safemem", @@ -99,7 +102,5 @@ go_library( "//pkg/syserror", "//pkg/usermem", "//pkg/waiter", - "//tools/go_marshal/marshal", - "//tools/go_marshal/primitive", ], ) diff --git a/pkg/sentry/syscalls/linux/error.go b/pkg/sentry/syscalls/linux/error.go index 46060f6f5..dab6207c0 100644 --- a/pkg/sentry/syscalls/linux/error.go +++ b/pkg/sentry/syscalls/linux/error.go @@ -147,7 +147,7 @@ func handleIOErrorImpl(t *kernel.Task, partialResult bool, err, intr error, op s } switch err.(type) { - case kernel.SyscallRestartErrno: + case syserror.SyscallRestartErrno: // Identical to the EINTR case. return true, nil } diff --git a/pkg/sentry/syscalls/linux/linux64.go b/pkg/sentry/syscalls/linux/linux64.go index 80c65164a..5f26697d2 100644 --- a/pkg/sentry/syscalls/linux/linux64.go +++ b/pkg/sentry/syscalls/linux/linux64.go @@ -138,7 +138,7 @@ var AMD64 = &kernel.SyscallTable{ 83: syscalls.Supported("mkdir", Mkdir), 84: syscalls.Supported("rmdir", Rmdir), 85: syscalls.Supported("creat", Creat), - 86: syscalls.Supported("link", Link), + 86: syscalls.PartiallySupported("link", Link, "Limited support with Gofer. Link count and linked files may get out of sync because gVisor is not aware of external hardlinks.", nil), 87: syscalls.Supported("unlink", Unlink), 88: syscalls.Supported("symlink", Symlink), 89: syscalls.Supported("readlink", Readlink), @@ -305,9 +305,9 @@ var AMD64 = &kernel.SyscallTable{ 250: syscalls.Error("keyctl", syserror.EACCES, "Not available to user.", nil), 251: syscalls.CapError("ioprio_set", linux.CAP_SYS_ADMIN, "", nil), // requires cap_sys_nice or cap_sys_admin (depending) 252: syscalls.CapError("ioprio_get", linux.CAP_SYS_ADMIN, "", nil), // requires cap_sys_nice or cap_sys_admin (depending) - 253: syscalls.PartiallySupported("inotify_init", InotifyInit, "inotify events are only available inside the sandbox.", nil), - 254: syscalls.PartiallySupported("inotify_add_watch", InotifyAddWatch, "inotify events are only available inside the sandbox.", nil), - 255: syscalls.PartiallySupported("inotify_rm_watch", InotifyRmWatch, "inotify events are only available inside the sandbox.", nil), + 253: syscalls.PartiallySupported("inotify_init", InotifyInit, "Inotify events are only available inside the sandbox. Hard links are treated as different watch targets in gofer fs.", nil), + 254: syscalls.PartiallySupported("inotify_add_watch", InotifyAddWatch, "Inotify events are only available inside the sandbox. Hard links are treated as different watch targets in gofer fs.", nil), + 255: syscalls.PartiallySupported("inotify_rm_watch", InotifyRmWatch, "Inotify events are only available inside the sandbox. Hard links are treated as different watch targets in gofer fs.", nil), 256: syscalls.CapError("migrate_pages", linux.CAP_SYS_NICE, "", nil), 257: syscalls.Supported("openat", Openat), 258: syscalls.Supported("mkdirat", Mkdirat), @@ -317,7 +317,7 @@ var AMD64 = &kernel.SyscallTable{ 262: syscalls.Supported("fstatat", Fstatat), 263: syscalls.Supported("unlinkat", Unlinkat), 264: syscalls.Supported("renameat", Renameat), - 265: syscalls.Supported("linkat", Linkat), + 265: syscalls.PartiallySupported("linkat", Linkat, "See link(2).", nil), 266: syscalls.Supported("symlinkat", Symlinkat), 267: syscalls.Supported("readlinkat", Readlinkat), 268: syscalls.Supported("fchmodat", Fchmodat), @@ -346,7 +346,7 @@ var AMD64 = &kernel.SyscallTable{ 291: syscalls.Supported("epoll_create1", EpollCreate1), 292: syscalls.Supported("dup3", Dup3), 293: syscalls.Supported("pipe2", Pipe2), - 294: syscalls.Supported("inotify_init1", InotifyInit1), + 294: syscalls.PartiallySupported("inotify_init1", InotifyInit1, "Inotify events are only available inside the sandbox. Hard links are treated as different watch targets in gofer fs.", nil), 295: syscalls.Supported("preadv", Preadv), 296: syscalls.Supported("pwritev", Pwritev), 297: syscalls.Supported("rt_tgsigqueueinfo", RtTgsigqueueinfo), @@ -454,9 +454,9 @@ var ARM64 = &kernel.SyscallTable{ 23: syscalls.Supported("dup", Dup), 24: syscalls.Supported("dup3", Dup3), 25: syscalls.PartiallySupported("fcntl", Fcntl, "Not all options are supported.", nil), - 26: syscalls.Supported("inotify_init1", InotifyInit1), - 27: syscalls.PartiallySupported("inotify_add_watch", InotifyAddWatch, "inotify events are only available inside the sandbox.", nil), - 28: syscalls.PartiallySupported("inotify_rm_watch", InotifyRmWatch, "inotify events are only available inside the sandbox.", nil), + 26: syscalls.PartiallySupported("inotify_init1", InotifyInit1, "Inotify events are only available inside the sandbox. Hard links are treated as different watch targets in gofer fs.", nil), + 27: syscalls.PartiallySupported("inotify_add_watch", InotifyAddWatch, "Inotify events are only available inside the sandbox. Hard links are treated as different watch targets in gofer fs.", nil), + 28: syscalls.PartiallySupported("inotify_rm_watch", InotifyRmWatch, "Inotify events are only available inside the sandbox. Hard links are treated as different watch targets in gofer fs.", nil), 29: syscalls.PartiallySupported("ioctl", Ioctl, "Only a few ioctls are implemented for backing devices and file systems.", nil), 30: syscalls.CapError("ioprio_set", linux.CAP_SYS_ADMIN, "", nil), // requires cap_sys_nice or cap_sys_admin (depending) 31: syscalls.CapError("ioprio_get", linux.CAP_SYS_ADMIN, "", nil), // requires cap_sys_nice or cap_sys_admin (depending) diff --git a/pkg/sentry/syscalls/linux/sys_aio.go b/pkg/sentry/syscalls/linux/sys_aio.go index e9d64dec5..0bf313a13 100644 --- a/pkg/sentry/syscalls/linux/sys_aio.go +++ b/pkg/sentry/syscalls/linux/sys_aio.go @@ -17,6 +17,7 @@ package linux import ( "gvisor.dev/gvisor/pkg/abi/linux" "gvisor.dev/gvisor/pkg/context" + "gvisor.dev/gvisor/pkg/marshal/primitive" "gvisor.dev/gvisor/pkg/sentry/arch" "gvisor.dev/gvisor/pkg/sentry/fs" "gvisor.dev/gvisor/pkg/sentry/kernel" @@ -36,7 +37,7 @@ func IoSetup(t *kernel.Task, args arch.SyscallArguments) (uintptr, *kernel.Sysca // // The context pointer _must_ be zero initially. var idIn uint64 - if _, err := t.CopyIn(idAddr, &idIn); err != nil { + if _, err := primitive.CopyUint64In(t, idAddr, &idIn); err != nil { return 0, nil, err } if idIn != 0 { @@ -49,7 +50,7 @@ func IoSetup(t *kernel.Task, args arch.SyscallArguments) (uintptr, *kernel.Sysca } // Copy out the new ID. - if _, err := t.CopyOut(idAddr, &id); err != nil { + if _, err := primitive.CopyUint64Out(t, idAddr, id); err != nil { t.MemoryManager().DestroyAIOContext(t, id) return 0, nil, err } @@ -142,7 +143,7 @@ func IoGetevents(t *kernel.Task, args arch.SyscallArguments) (uintptr, *kernel.S ev := v.(*linux.IOEvent) // Copy out the result. - if _, err := t.CopyOut(eventsAddr, ev); err != nil { + if _, err := ev.CopyOut(t, eventsAddr); err != nil { if count > 0 { return uintptr(count), nil, nil } @@ -338,21 +339,27 @@ func IoSubmit(t *kernel.Task, args arch.SyscallArguments) (uintptr, *kernel.Sysc } for i := int32(0); i < nrEvents; i++ { - // Copy in the address. - cbAddrNative := t.Arch().Native(0) - if _, err := t.CopyIn(addr, cbAddrNative); err != nil { - if i > 0 { - // Some successful. - return uintptr(i), nil, nil + // Copy in the callback address. + var cbAddr usermem.Addr + switch t.Arch().Width() { + case 8: + var cbAddrP primitive.Uint64 + if _, err := cbAddrP.CopyIn(t, addr); err != nil { + if i > 0 { + // Some successful. + return uintptr(i), nil, nil + } + // Nothing done. + return 0, nil, err } - // Nothing done. - return 0, nil, err + cbAddr = usermem.Addr(cbAddrP) + default: + return 0, nil, syserror.ENOSYS } // Copy in this callback. var cb linux.IOCallback - cbAddr := usermem.Addr(t.Arch().Value(cbAddrNative)) - if _, err := t.CopyIn(cbAddr, &cb); err != nil { + if _, err := cb.CopyIn(t, cbAddr); err != nil { if i > 0 { // Some have been successful. diff --git a/pkg/sentry/syscalls/linux/sys_capability.go b/pkg/sentry/syscalls/linux/sys_capability.go index adf5ea5f2..d3b85e11b 100644 --- a/pkg/sentry/syscalls/linux/sys_capability.go +++ b/pkg/sentry/syscalls/linux/sys_capability.go @@ -45,7 +45,7 @@ func Capget(t *kernel.Task, args arch.SyscallArguments) (uintptr, *kernel.Syscal dataAddr := args[1].Pointer() var hdr linux.CapUserHeader - if _, err := t.CopyIn(hdrAddr, &hdr); err != nil { + if _, err := hdr.CopyIn(t, hdrAddr); err != nil { return 0, nil, err } // hdr.Pid doesn't need to be valid if this capget() is a "version probe" @@ -65,7 +65,7 @@ func Capget(t *kernel.Task, args arch.SyscallArguments) (uintptr, *kernel.Syscal Permitted: uint32(p), Inheritable: uint32(i), } - _, err = t.CopyOut(dataAddr, &data) + _, err = data.CopyOut(t, dataAddr) return 0, nil, err case linux.LINUX_CAPABILITY_VERSION_2, linux.LINUX_CAPABILITY_VERSION_3: @@ -88,12 +88,12 @@ func Capget(t *kernel.Task, args arch.SyscallArguments) (uintptr, *kernel.Syscal Inheritable: uint32(i >> 32), }, } - _, err = t.CopyOut(dataAddr, &data) + _, err = linux.CopyCapUserDataSliceOut(t, dataAddr, data[:]) return 0, nil, err default: hdr.Version = linux.HighestCapabilityVersion - if _, err := t.CopyOut(hdrAddr, &hdr); err != nil { + if _, err := hdr.CopyOut(t, hdrAddr); err != nil { return 0, nil, err } if dataAddr != 0 { @@ -109,7 +109,7 @@ func Capset(t *kernel.Task, args arch.SyscallArguments) (uintptr, *kernel.Syscal dataAddr := args[1].Pointer() var hdr linux.CapUserHeader - if _, err := t.CopyIn(hdrAddr, &hdr); err != nil { + if _, err := hdr.CopyIn(t, hdrAddr); err != nil { return 0, nil, err } switch hdr.Version { @@ -118,7 +118,7 @@ func Capset(t *kernel.Task, args arch.SyscallArguments) (uintptr, *kernel.Syscal return 0, nil, syserror.EPERM } var data linux.CapUserData - if _, err := t.CopyIn(dataAddr, &data); err != nil { + if _, err := data.CopyIn(t, dataAddr); err != nil { return 0, nil, err } p := auth.CapabilitySet(data.Permitted) & auth.AllCapabilities @@ -131,7 +131,7 @@ func Capset(t *kernel.Task, args arch.SyscallArguments) (uintptr, *kernel.Syscal return 0, nil, syserror.EPERM } var data [2]linux.CapUserData - if _, err := t.CopyIn(dataAddr, &data); err != nil { + if _, err := linux.CopyCapUserDataSliceIn(t, dataAddr, data[:]); err != nil { return 0, nil, err } p := (auth.CapabilitySet(data[0].Permitted) | (auth.CapabilitySet(data[1].Permitted) << 32)) & auth.AllCapabilities @@ -141,7 +141,7 @@ func Capset(t *kernel.Task, args arch.SyscallArguments) (uintptr, *kernel.Syscal default: hdr.Version = linux.HighestCapabilityVersion - if _, err := t.CopyOut(hdrAddr, &hdr); err != nil { + if _, err := hdr.CopyOut(t, hdrAddr); err != nil { return 0, nil, err } return 0, nil, syserror.EINVAL diff --git a/pkg/sentry/syscalls/linux/sys_file.go b/pkg/sentry/syscalls/linux/sys_file.go index 1bc9b184e..98331eb3c 100644 --- a/pkg/sentry/syscalls/linux/sys_file.go +++ b/pkg/sentry/syscalls/linux/sys_file.go @@ -19,6 +19,7 @@ import ( "gvisor.dev/gvisor/pkg/abi/linux" "gvisor.dev/gvisor/pkg/context" + "gvisor.dev/gvisor/pkg/marshal/primitive" "gvisor.dev/gvisor/pkg/sentry/arch" "gvisor.dev/gvisor/pkg/sentry/fs" "gvisor.dev/gvisor/pkg/sentry/fs/lock" @@ -184,7 +185,7 @@ func openAt(t *kernel.Task, dirFD int32, addr usermem.Addr, flags uint) (fd uint file, err := d.Inode.GetFile(t, d, fileFlags) if err != nil { - return syserror.ConvertIntr(err, kernel.ERESTARTSYS) + return syserror.ConvertIntr(err, syserror.ERESTARTSYS) } defer file.DecRef(t) @@ -414,7 +415,7 @@ func createAt(t *kernel.Task, dirFD int32, addr usermem.Addr, flags uint, mode l // Create a new fs.File. newFile, err = found.Inode.GetFile(t, found, fileFlags) if err != nil { - return syserror.ConvertIntr(err, kernel.ERESTARTSYS) + return syserror.ConvertIntr(err, syserror.ERESTARTSYS) } defer newFile.DecRef(t) case syserror.ENOENT: @@ -601,19 +602,19 @@ func Ioctl(t *kernel.Task, args arch.SyscallArguments) (uintptr, *kernel.Syscall // Shared flags between file and socket. switch request { case linux.FIONCLEX: - t.FDTable().SetFlags(fd, kernel.FDFlags{ + t.FDTable().SetFlags(t, fd, kernel.FDFlags{ CloseOnExec: false, }) return 0, nil, nil case linux.FIOCLEX: - t.FDTable().SetFlags(fd, kernel.FDFlags{ + t.FDTable().SetFlags(t, fd, kernel.FDFlags{ CloseOnExec: true, }) return 0, nil, nil case linux.FIONBIO: var set int32 - if _, err := t.CopyIn(args[2].Pointer(), &set); err != nil { + if _, err := primitive.CopyInt32In(t, args[2].Pointer(), &set); err != nil { return 0, nil, err } flags := file.Flags() @@ -627,7 +628,7 @@ func Ioctl(t *kernel.Task, args arch.SyscallArguments) (uintptr, *kernel.Syscall case linux.FIOASYNC: var set int32 - if _, err := t.CopyIn(args[2].Pointer(), &set); err != nil { + if _, err := primitive.CopyInt32In(t, args[2].Pointer(), &set); err != nil { return 0, nil, err } flags := file.Flags() @@ -641,15 +642,14 @@ func Ioctl(t *kernel.Task, args arch.SyscallArguments) (uintptr, *kernel.Syscall case linux.FIOSETOWN, linux.SIOCSPGRP: var set int32 - if _, err := t.CopyIn(args[2].Pointer(), &set); err != nil { + if _, err := primitive.CopyInt32In(t, args[2].Pointer(), &set); err != nil { return 0, nil, err } fSetOwn(t, file, set) return 0, nil, nil case linux.FIOGETOWN, linux.SIOCGPGRP: - who := fGetOwn(t, file) - _, err := t.CopyOut(args[2].Pointer(), &who) + _, err := primitive.CopyInt32Out(t, args[2].Pointer(), fGetOwn(t, file)) return 0, nil, err default: @@ -694,7 +694,7 @@ func Getcwd(t *kernel.Task, args arch.SyscallArguments) (uintptr, *kernel.Syscal } // Top it off with a terminator. - _, err = t.CopyOut(addr+usermem.Addr(bytes), []byte("\x00")) + _, err = t.CopyOutBytes(addr+usermem.Addr(bytes), []byte("\x00")) return uintptr(bytes + 1), nil, err } @@ -787,7 +787,7 @@ func Close(t *kernel.Task, args arch.SyscallArguments) (uintptr, *kernel.Syscall // Note that Remove provides a reference on the file that we may use to // flush. It is still active until we drop the final reference below // (and other reference-holding operations complete). - file, _ := t.FDTable().Remove(fd) + file, _ := t.FDTable().Remove(t, fd) if file == nil { return 0, nil, syserror.EBADF } @@ -941,7 +941,7 @@ func Fcntl(t *kernel.Task, args arch.SyscallArguments) (uintptr, *kernel.Syscall return uintptr(flags.ToLinuxFDFlags()), nil, nil case linux.F_SETFD: flags := args[2].Uint() - err := t.FDTable().SetFlags(fd, kernel.FDFlags{ + err := t.FDTable().SetFlags(t, fd, kernel.FDFlags{ CloseOnExec: flags&linux.FD_CLOEXEC != 0, }) return 0, nil, err @@ -962,7 +962,7 @@ func Fcntl(t *kernel.Task, args arch.SyscallArguments) (uintptr, *kernel.Syscall // Copy in the lock request. flockAddr := args[2].Pointer() var flock linux.Flock - if _, err := t.CopyIn(flockAddr, &flock); err != nil { + if _, err := flock.CopyIn(t, flockAddr); err != nil { return 0, nil, err } @@ -1052,12 +1052,12 @@ func Fcntl(t *kernel.Task, args arch.SyscallArguments) (uintptr, *kernel.Syscall case linux.F_GETOWN_EX: addr := args[2].Pointer() owner := fGetOwnEx(t, file) - _, err := t.CopyOut(addr, &owner) + _, err := owner.CopyOut(t, addr) return 0, nil, err case linux.F_SETOWN_EX: addr := args[2].Pointer() var owner linux.FOwnerEx - _, err := t.CopyIn(addr, &owner) + _, err := owner.CopyIn(t, addr) if err != nil { return 0, nil, err } @@ -1154,6 +1154,10 @@ func Fadvise64(t *kernel.Task, args arch.SyscallArguments) (uintptr, *kernel.Sys return 0, nil, nil } +// LINT.ThenChange(vfs2/fd.go) + +// LINT.IfChange + func mkdirAt(t *kernel.Task, dirFD int32, addr usermem.Addr, mode linux.FileMode) error { path, _, err := copyInPath(t, addr, false /* allowEmpty */) if err != nil { @@ -1918,7 +1922,7 @@ func Utime(t *kernel.Task, args arch.SyscallArguments) (uintptr, *kernel.Syscall ts := defaultSetToSystemTimeSpec() if timesAddr != 0 { var times linux.Utime - if _, err := t.CopyIn(timesAddr, ×); err != nil { + if _, err := times.CopyIn(t, timesAddr); err != nil { return 0, nil, err } ts = fs.TimeSpec{ @@ -1938,7 +1942,7 @@ func Utimes(t *kernel.Task, args arch.SyscallArguments) (uintptr, *kernel.Syscal ts := defaultSetToSystemTimeSpec() if timesAddr != 0 { var times [2]linux.Timeval - if _, err := t.CopyIn(timesAddr, ×); err != nil { + if _, err := linux.CopyTimevalSliceIn(t, timesAddr, times[:]); err != nil { return 0, nil, err } ts = fs.TimeSpec{ @@ -1966,7 +1970,7 @@ func Utimensat(t *kernel.Task, args arch.SyscallArguments) (uintptr, *kernel.Sys ts := defaultSetToSystemTimeSpec() if timesAddr != 0 { var times [2]linux.Timespec - if _, err := t.CopyIn(timesAddr, ×); err != nil { + if _, err := linux.CopyTimespecSliceIn(t, timesAddr, times[:]); err != nil { return 0, nil, err } if !timespecIsValid(times[0]) || !timespecIsValid(times[1]) { @@ -2000,7 +2004,7 @@ func Futimesat(t *kernel.Task, args arch.SyscallArguments) (uintptr, *kernel.Sys ts := defaultSetToSystemTimeSpec() if timesAddr != 0 { var times [2]linux.Timeval - if _, err := t.CopyIn(timesAddr, ×); err != nil { + if _, err := linux.CopyTimevalSliceIn(t, timesAddr, times[:]); err != nil { return 0, nil, err } if times[0].Usec >= 1e6 || times[0].Usec < 0 || diff --git a/pkg/sentry/syscalls/linux/sys_futex.go b/pkg/sentry/syscalls/linux/sys_futex.go index 9d1b2edb1..f39ce0639 100644 --- a/pkg/sentry/syscalls/linux/sys_futex.go +++ b/pkg/sentry/syscalls/linux/sys_futex.go @@ -74,7 +74,7 @@ func futexWaitAbsolute(t *kernel.Task, clockRealtime bool, ts linux.Timespec, fo } t.Futex().WaitComplete(w, t) - return 0, syserror.ConvertIntr(err, kernel.ERESTARTSYS) + return 0, syserror.ConvertIntr(err, syserror.ERESTARTSYS) } // futexWaitDuration performs a FUTEX_WAIT, blocking until the wait is @@ -110,7 +110,7 @@ func futexWaitDuration(t *kernel.Task, duration time.Duration, forever bool, add // The wait duration was absolute, restart with the original arguments. if forever { - return 0, kernel.ERESTARTSYS + return 0, syserror.ERESTARTSYS } // The wait duration was relative, restart with the remaining duration. @@ -121,7 +121,7 @@ func futexWaitDuration(t *kernel.Task, duration time.Duration, forever bool, add val: val, mask: mask, }) - return 0, kernel.ERESTART_RESTARTBLOCK + return 0, syserror.ERESTART_RESTARTBLOCK } func futexLockPI(t *kernel.Task, ts linux.Timespec, forever bool, addr usermem.Addr, private bool) error { @@ -149,7 +149,7 @@ func futexLockPI(t *kernel.Task, ts linux.Timespec, forever bool, addr usermem.A } t.Futex().WaitComplete(w, t) - return syserror.ConvertIntr(err, kernel.ERESTARTSYS) + return syserror.ConvertIntr(err, syserror.ERESTARTSYS) } func tryLockPI(t *kernel.Task, addr usermem.Addr, private bool) error { @@ -306,8 +306,8 @@ func GetRobustList(t *kernel.Task, args arch.SyscallArguments) (uintptr, *kernel // Despite the syscall using the name 'pid' for this variable, it is // very much a tid. tid := args[0].Int() - head := args[1].Pointer() - size := args[2].Pointer() + headAddr := args[1].Pointer() + sizeAddr := args[2].Pointer() if tid < 0 { return 0, nil, syserror.EINVAL @@ -321,12 +321,16 @@ func GetRobustList(t *kernel.Task, args arch.SyscallArguments) (uintptr, *kernel } // Copy out head pointer. - if _, err := t.CopyOut(head, uint64(ot.GetRobustList())); err != nil { + head := t.Arch().Native(uintptr(ot.GetRobustList())) + if _, err := head.CopyOut(t, headAddr); err != nil { return 0, nil, err } - // Copy out size, which is a constant. - if _, err := t.CopyOut(size, uint64(linux.SizeOfRobustListHead)); err != nil { + // Copy out size, which is a constant. Note that while size isn't + // an address, it is defined as the arch-dependent size_t, so it + // needs to be converted to a native-sized int. + size := t.Arch().Native(uintptr(linux.SizeOfRobustListHead)) + if _, err := size.CopyOut(t, sizeAddr); err != nil { return 0, nil, err } diff --git a/pkg/sentry/syscalls/linux/sys_getdents.go b/pkg/sentry/syscalls/linux/sys_getdents.go index f5699e55d..b25f7d881 100644 --- a/pkg/sentry/syscalls/linux/sys_getdents.go +++ b/pkg/sentry/syscalls/linux/sys_getdents.go @@ -19,7 +19,6 @@ import ( "io" "gvisor.dev/gvisor/pkg/abi/linux" - "gvisor.dev/gvisor/pkg/binary" "gvisor.dev/gvisor/pkg/sentry/arch" "gvisor.dev/gvisor/pkg/sentry/fs" "gvisor.dev/gvisor/pkg/sentry/kernel" @@ -82,7 +81,7 @@ func getdents(t *kernel.Task, fd int32, addr usermem.Addr, size int, f func(*dir ds := newDirentSerializer(f, w, t.Arch(), size) rerr := dir.Readdir(t, ds) - switch err := handleIOError(t, ds.Written() > 0, rerr, kernel.ERESTARTSYS, "getdents", dir); err { + switch err := handleIOError(t, ds.Written() > 0, rerr, syserror.ERESTARTSYS, "getdents", dir); err { case nil: dir.Dirent.InotifyEvent(linux.IN_ACCESS, 0) return uintptr(ds.Written()), nil @@ -93,19 +92,23 @@ func getdents(t *kernel.Task, fd int32, addr usermem.Addr, size int, f func(*dir } } -// oldDirentHdr is a fixed sized header matching the fixed size -// fields found in the old linux dirent struct. +// oldDirentHdr is a fixed sized header matching the fixed size fields found in +// the old linux dirent struct. +// +// +marshal type oldDirentHdr struct { Ino uint64 Off uint64 - Reclen uint16 + Reclen uint16 `marshal:"unaligned"` // Struct ends mid-word. } -// direntHdr is a fixed sized header matching the fixed size -// fields found in the new linux dirent struct. +// direntHdr is a fixed sized header matching the fixed size fields found in the +// new linux dirent struct. +// +// +marshal type direntHdr struct { OldHdr oldDirentHdr - Typ uint8 + Typ uint8 `marshal:"unaligned"` // Struct ends mid-word. } // dirent contains the data pointed to by a new linux dirent struct. @@ -134,20 +137,20 @@ func newDirent(width uint, name string, attr fs.DentAttr, offset uint64) *dirent // the old linux dirent format. func smallestDirent(a arch.Context) uint { d := dirent{} - return uint(binary.Size(d.Hdr.OldHdr)) + a.Width() + 1 + return uint(d.Hdr.OldHdr.SizeBytes()) + a.Width() + 1 } // smallestDirent64 returns the size of the smallest possible dirent using // the new linux dirent format. func smallestDirent64(a arch.Context) uint { d := dirent{} - return uint(binary.Size(d.Hdr)) + a.Width() + return uint(d.Hdr.SizeBytes()) + a.Width() } // padRec pads the name field until the rec length is a multiple of the width, // which must be a power of 2. It returns the padded rec length. func (d *dirent) padRec(width int) uint16 { - a := int(binary.Size(d.Hdr)) + len(d.Name) + a := d.Hdr.SizeBytes() + len(d.Name) r := (a + width) &^ (width - 1) padding := r - a d.Name = append(d.Name, make([]byte, padding)...) @@ -157,7 +160,7 @@ func (d *dirent) padRec(width int) uint16 { // Serialize64 serializes a Dirent struct to a byte slice, keeping the new // linux dirent format. Returns the number of bytes serialized or an error. func (d *dirent) Serialize64(w io.Writer) (int, error) { - n1, err := w.Write(binary.Marshal(nil, usermem.ByteOrder, d.Hdr)) + n1, err := d.Hdr.WriteTo(w) if err != nil { return 0, err } @@ -165,14 +168,14 @@ func (d *dirent) Serialize64(w io.Writer) (int, error) { if err != nil { return 0, err } - return n1 + n2, nil + return int(n1) + n2, nil } // Serialize serializes a Dirent struct to a byte slice, using the old linux // dirent format. // Returns the number of bytes serialized or an error. func (d *dirent) Serialize(w io.Writer) (int, error) { - n1, err := w.Write(binary.Marshal(nil, usermem.ByteOrder, d.Hdr.OldHdr)) + n1, err := d.Hdr.OldHdr.WriteTo(w) if err != nil { return 0, err } @@ -184,7 +187,7 @@ func (d *dirent) Serialize(w io.Writer) (int, error) { if err != nil { return 0, err } - return n1 + n2 + n3, nil + return int(n1) + n2 + n3, nil } // direntSerializer implements fs.InodeOperationsInfoSerializer, serializing dirents to an diff --git a/pkg/sentry/syscalls/linux/sys_identity.go b/pkg/sentry/syscalls/linux/sys_identity.go index 715ac45e6..a29d307e5 100644 --- a/pkg/sentry/syscalls/linux/sys_identity.go +++ b/pkg/sentry/syscalls/linux/sys_identity.go @@ -49,13 +49,13 @@ func Getresuid(t *kernel.Task, args arch.SyscallArguments) (uintptr, *kernel.Sys ruid := c.RealKUID.In(c.UserNamespace).OrOverflow() euid := c.EffectiveKUID.In(c.UserNamespace).OrOverflow() suid := c.SavedKUID.In(c.UserNamespace).OrOverflow() - if _, err := t.CopyOut(ruidAddr, ruid); err != nil { + if _, err := ruid.CopyOut(t, ruidAddr); err != nil { return 0, nil, err } - if _, err := t.CopyOut(euidAddr, euid); err != nil { + if _, err := euid.CopyOut(t, euidAddr); err != nil { return 0, nil, err } - if _, err := t.CopyOut(suidAddr, suid); err != nil { + if _, err := suid.CopyOut(t, suidAddr); err != nil { return 0, nil, err } return 0, nil, nil @@ -84,13 +84,13 @@ func Getresgid(t *kernel.Task, args arch.SyscallArguments) (uintptr, *kernel.Sys rgid := c.RealKGID.In(c.UserNamespace).OrOverflow() egid := c.EffectiveKGID.In(c.UserNamespace).OrOverflow() sgid := c.SavedKGID.In(c.UserNamespace).OrOverflow() - if _, err := t.CopyOut(rgidAddr, rgid); err != nil { + if _, err := rgid.CopyOut(t, rgidAddr); err != nil { return 0, nil, err } - if _, err := t.CopyOut(egidAddr, egid); err != nil { + if _, err := egid.CopyOut(t, egidAddr); err != nil { return 0, nil, err } - if _, err := t.CopyOut(sgidAddr, sgid); err != nil { + if _, err := sgid.CopyOut(t, sgidAddr); err != nil { return 0, nil, err } return 0, nil, nil @@ -157,7 +157,7 @@ func Getgroups(t *kernel.Task, args arch.SyscallArguments) (uintptr, *kernel.Sys for i, kgid := range kgids { gids[i] = kgid.In(t.UserNamespace()).OrOverflow() } - if _, err := t.CopyOut(args[1].Pointer(), gids); err != nil { + if _, err := auth.CopyGIDSliceOut(t, args[1].Pointer(), gids); err != nil { return 0, nil, err } return uintptr(len(gids)), nil, nil @@ -173,7 +173,7 @@ func Setgroups(t *kernel.Task, args arch.SyscallArguments) (uintptr, *kernel.Sys return 0, nil, t.SetExtraGIDs(nil) } gids := make([]auth.GID, size) - if _, err := t.CopyIn(args[1].Pointer(), &gids); err != nil { + if _, err := auth.CopyGIDSliceIn(t, args[1].Pointer(), gids); err != nil { return 0, nil, err } return 0, nil, t.SetExtraGIDs(gids) diff --git a/pkg/sentry/syscalls/linux/sys_lseek.go b/pkg/sentry/syscalls/linux/sys_lseek.go index 1c38f8f4f..0046347cb 100644 --- a/pkg/sentry/syscalls/linux/sys_lseek.go +++ b/pkg/sentry/syscalls/linux/sys_lseek.go @@ -48,7 +48,7 @@ func Lseek(t *kernel.Task, args arch.SyscallArguments) (uintptr, *kernel.Syscall } offset, serr := file.Seek(t, sw, offset) - err := handleIOError(t, false /* partialResult */, serr, kernel.ERESTARTSYS, "lseek", file) + err := handleIOError(t, false /* partialResult */, serr, syserror.ERESTARTSYS, "lseek", file) if err != nil { return 0, nil, err } diff --git a/pkg/sentry/syscalls/linux/sys_mmap.go b/pkg/sentry/syscalls/linux/sys_mmap.go index 72786b032..cd8dfdfa4 100644 --- a/pkg/sentry/syscalls/linux/sys_mmap.go +++ b/pkg/sentry/syscalls/linux/sys_mmap.go @@ -100,6 +100,15 @@ func Mmap(t *kernel.Task, args arch.SyscallArguments) (uintptr, *kernel.SyscallC if err := file.ConfigureMMap(t, &opts); err != nil { return 0, nil, err } + } else if shared { + // Back shared anonymous mappings with a special mappable. + opts.Offset = 0 + m, err := mm.NewSharedAnonMappable(opts.Length, t.Kernel()) + if err != nil { + return 0, nil, err + } + opts.MappingIdentity = m // transfers ownership of m to opts + opts.Mappable = m } rv, err := t.MemoryManager().MMap(t, opts) @@ -239,7 +248,7 @@ func Mincore(t *kernel.Task, args arch.SyscallArguments) (uintptr, *kernel.Sysca return 0, nil, syserror.ENOMEM } resident := bytes.Repeat([]byte{1}, int(mapped/usermem.PageSize)) - _, err := t.CopyOut(vec, resident) + _, err := t.CopyOutBytes(vec, resident) return 0, nil, err } @@ -267,7 +276,7 @@ func Msync(t *kernel.Task, args arch.SyscallArguments) (uintptr, *kernel.Syscall }) // MSync calls fsync, the same interrupt conversion rules apply, see // mm/msync.c, fsync POSIX.1-2008. - return 0, nil, syserror.ConvertIntr(err, kernel.ERESTARTSYS) + return 0, nil, syserror.ConvertIntr(err, syserror.ERESTARTSYS) } // Mlock implements linux syscall mlock(2). diff --git a/pkg/sentry/syscalls/linux/sys_pipe.go b/pkg/sentry/syscalls/linux/sys_pipe.go index 3149e4aad..849a47476 100644 --- a/pkg/sentry/syscalls/linux/sys_pipe.go +++ b/pkg/sentry/syscalls/linux/sys_pipe.go @@ -16,6 +16,7 @@ package linux import ( "gvisor.dev/gvisor/pkg/abi/linux" + "gvisor.dev/gvisor/pkg/marshal/primitive" "gvisor.dev/gvisor/pkg/sentry/arch" "gvisor.dev/gvisor/pkg/sentry/fs" "gvisor.dev/gvisor/pkg/sentry/kernel" @@ -46,9 +47,9 @@ func pipe2(t *kernel.Task, addr usermem.Addr, flags uint) (uintptr, error) { return 0, err } - if _, err := t.CopyOut(addr, fds); err != nil { + if _, err := primitive.CopyInt32SliceOut(t, addr, fds); err != nil { for _, fd := range fds { - if file, _ := t.FDTable().Remove(fd); file != nil { + if file, _ := t.FDTable().Remove(t, fd); file != nil { file.DecRef(t) } } diff --git a/pkg/sentry/syscalls/linux/sys_poll.go b/pkg/sentry/syscalls/linux/sys_poll.go index 3435bdf77..254f4c9f9 100644 --- a/pkg/sentry/syscalls/linux/sys_poll.go +++ b/pkg/sentry/syscalls/linux/sys_poll.go @@ -162,7 +162,7 @@ func CopyInPollFDs(t *kernel.Task, addr usermem.Addr, nfds uint) ([]linux.PollFD pfd := make([]linux.PollFD, nfds) if nfds > 0 { - if _, err := t.CopyIn(addr, &pfd); err != nil { + if _, err := linux.CopyPollFDSliceIn(t, addr, pfd); err != nil { return nil, err } } @@ -189,7 +189,7 @@ func doPoll(t *kernel.Task, addr usermem.Addr, nfds uint, timeout time.Duration) // The poll entries are copied out regardless of whether // any are set or not. This aligns with the Linux behavior. if nfds > 0 && err == nil { - if _, err := t.CopyOut(addr, pfd); err != nil { + if _, err := linux.CopyPollFDSliceOut(t, addr, pfd); err != nil { return remainingTimeout, 0, err } } @@ -202,7 +202,7 @@ func CopyInFDSet(t *kernel.Task, addr usermem.Addr, nBytes, nBitsInLastPartialBy set := make([]byte, nBytes) if addr != 0 { - if _, err := t.CopyIn(addr, &set); err != nil { + if _, err := t.CopyInBytes(addr, set); err != nil { return nil, err } // If we only use part of the last byte, mask out the extraneous bits. @@ -329,19 +329,19 @@ func doSelect(t *kernel.Task, nfds int, readFDs, writeFDs, exceptFDs usermem.Add // Copy updated vectors back. if readFDs != 0 { - if _, err := t.CopyOut(readFDs, r); err != nil { + if _, err := t.CopyOutBytes(readFDs, r); err != nil { return 0, err } } if writeFDs != 0 { - if _, err := t.CopyOut(writeFDs, w); err != nil { + if _, err := t.CopyOutBytes(writeFDs, w); err != nil { return 0, err } } if exceptFDs != 0 { - if _, err := t.CopyOut(exceptFDs, e); err != nil { + if _, err := t.CopyOutBytes(exceptFDs, e); err != nil { return 0, err } } @@ -410,7 +410,7 @@ func poll(t *kernel.Task, pfdAddr usermem.Addr, nfds uint, timeout time.Duration nfds: nfds, timeout: remainingTimeout, }) - return 0, kernel.ERESTART_RESTARTBLOCK + return 0, syserror.ERESTART_RESTARTBLOCK } return n, err } @@ -464,7 +464,7 @@ func Ppoll(t *kernel.Task, args arch.SyscallArguments) (uintptr, *kernel.Syscall // Note that this means that if err is nil but copyErr is not, copyErr is // ignored. This is consistent with Linux. if err == syserror.EINTR && copyErr == nil { - err = kernel.ERESTARTNOHAND + err = syserror.ERESTARTNOHAND } return n, nil, err } @@ -494,7 +494,7 @@ func Select(t *kernel.Task, args arch.SyscallArguments) (uintptr, *kernel.Syscal copyErr := copyOutTimevalRemaining(t, startNs, timeout, timevalAddr) // See comment in Ppoll. if err == syserror.EINTR && copyErr == nil { - err = kernel.ERESTARTNOHAND + err = syserror.ERESTARTNOHAND } return n, nil, err } @@ -539,7 +539,7 @@ func Pselect(t *kernel.Task, args arch.SyscallArguments) (uintptr, *kernel.Sysca copyErr := copyOutTimespecRemaining(t, startNs, timeout, timespecAddr) // See comment in Ppoll. if err == syserror.EINTR && copyErr == nil { - err = kernel.ERESTARTNOHAND + err = syserror.ERESTARTNOHAND } return n, nil, err } diff --git a/pkg/sentry/syscalls/linux/sys_prctl.go b/pkg/sentry/syscalls/linux/sys_prctl.go index 64a725296..a892d2c62 100644 --- a/pkg/sentry/syscalls/linux/sys_prctl.go +++ b/pkg/sentry/syscalls/linux/sys_prctl.go @@ -18,6 +18,7 @@ import ( "fmt" "gvisor.dev/gvisor/pkg/abi/linux" + "gvisor.dev/gvisor/pkg/marshal/primitive" "gvisor.dev/gvisor/pkg/sentry/arch" "gvisor.dev/gvisor/pkg/sentry/fs" "gvisor.dev/gvisor/pkg/sentry/fsbridge" @@ -43,7 +44,7 @@ func Prctl(t *kernel.Task, args arch.SyscallArguments) (uintptr, *kernel.Syscall return 0, nil, nil case linux.PR_GET_PDEATHSIG: - _, err := t.CopyOut(args[1].Pointer(), int32(t.ParentDeathSignal())) + _, err := primitive.CopyInt32Out(t, args[1].Pointer(), int32(t.ParentDeathSignal())) return 0, nil, err case linux.PR_GET_DUMPABLE: @@ -110,7 +111,7 @@ func Prctl(t *kernel.Task, args arch.SyscallArguments) (uintptr, *kernel.Syscall buf[len] = 0 len++ } - _, err := t.CopyOut(addr, buf[:len]) + _, err := t.CopyOutBytes(addr, buf[:len]) if err != nil { return 0, nil, err } diff --git a/pkg/sentry/syscalls/linux/sys_read.go b/pkg/sentry/syscalls/linux/sys_read.go index 3bbc3fa4b..f655d3db1 100644 --- a/pkg/sentry/syscalls/linux/sys_read.go +++ b/pkg/sentry/syscalls/linux/sys_read.go @@ -71,7 +71,7 @@ func Read(t *kernel.Task, args arch.SyscallArguments) (uintptr, *kernel.SyscallC n, err := readv(t, file, dst) t.IOUsage().AccountReadSyscall(n) - return uintptr(n), nil, handleIOError(t, n != 0, err, kernel.ERESTARTSYS, "read", file) + return uintptr(n), nil, handleIOError(t, n != 0, err, syserror.ERESTARTSYS, "read", file) } // Readahead implements readahead(2). @@ -151,7 +151,7 @@ func Pread64(t *kernel.Task, args arch.SyscallArguments) (uintptr, *kernel.Sysca n, err := preadv(t, file, dst, offset) t.IOUsage().AccountReadSyscall(n) - return uintptr(n), nil, handleIOError(t, n != 0, err, kernel.ERESTARTSYS, "pread64", file) + return uintptr(n), nil, handleIOError(t, n != 0, err, syserror.ERESTARTSYS, "pread64", file) } // Readv implements linux syscall readv(2). @@ -181,7 +181,7 @@ func Readv(t *kernel.Task, args arch.SyscallArguments) (uintptr, *kernel.Syscall n, err := readv(t, file, dst) t.IOUsage().AccountReadSyscall(n) - return uintptr(n), nil, handleIOError(t, n != 0, err, kernel.ERESTARTSYS, "readv", file) + return uintptr(n), nil, handleIOError(t, n != 0, err, syserror.ERESTARTSYS, "readv", file) } // Preadv implements linux syscall preadv(2). @@ -222,7 +222,7 @@ func Preadv(t *kernel.Task, args arch.SyscallArguments) (uintptr, *kernel.Syscal n, err := preadv(t, file, dst, offset) t.IOUsage().AccountReadSyscall(n) - return uintptr(n), nil, handleIOError(t, n != 0, err, kernel.ERESTARTSYS, "preadv", file) + return uintptr(n), nil, handleIOError(t, n != 0, err, syserror.ERESTARTSYS, "preadv", file) } // Preadv2 implements linux syscall preadv2(2). @@ -280,12 +280,12 @@ func Preadv2(t *kernel.Task, args arch.SyscallArguments) (uintptr, *kernel.Sysca if offset == -1 { n, err := readv(t, file, dst) t.IOUsage().AccountReadSyscall(n) - return uintptr(n), nil, handleIOError(t, n != 0, err, kernel.ERESTARTSYS, "preadv2", file) + return uintptr(n), nil, handleIOError(t, n != 0, err, syserror.ERESTARTSYS, "preadv2", file) } n, err := preadv(t, file, dst, offset) t.IOUsage().AccountReadSyscall(n) - return uintptr(n), nil, handleIOError(t, n != 0, err, kernel.ERESTARTSYS, "preadv2", file) + return uintptr(n), nil, handleIOError(t, n != 0, err, syserror.ERESTARTSYS, "preadv2", file) } func readv(t *kernel.Task, f *fs.File, dst usermem.IOSequence) (int64, error) { diff --git a/pkg/sentry/syscalls/linux/sys_rlimit.go b/pkg/sentry/syscalls/linux/sys_rlimit.go index d5d5b6959..309c183a3 100644 --- a/pkg/sentry/syscalls/linux/sys_rlimit.go +++ b/pkg/sentry/syscalls/linux/sys_rlimit.go @@ -16,6 +16,7 @@ package linux import ( "gvisor.dev/gvisor/pkg/abi/linux" + "gvisor.dev/gvisor/pkg/marshal" "gvisor.dev/gvisor/pkg/sentry/arch" "gvisor.dev/gvisor/pkg/sentry/kernel" "gvisor.dev/gvisor/pkg/sentry/limits" @@ -26,17 +27,13 @@ import ( // rlimit describes an implementation of 'struct rlimit', which may vary from // system-to-system. type rlimit interface { + marshal.Marshallable + // toLimit converts an rlimit to a limits.Limit. toLimit() *limits.Limit // fromLimit converts a limits.Limit to an rlimit. fromLimit(lim limits.Limit) - - // copyIn copies an rlimit from the untrusted app to the kernel. - copyIn(t *kernel.Task, addr usermem.Addr) error - - // copyOut copies an rlimit from the kernel to the untrusted app. - copyOut(t *kernel.Task, addr usermem.Addr) error } // newRlimit returns the appropriate rlimit type for 'struct rlimit' on this system. @@ -50,6 +47,7 @@ func newRlimit(t *kernel.Task) (rlimit, error) { } } +// +marshal type rlimit64 struct { Cur uint64 Max uint64 @@ -70,12 +68,12 @@ func (r *rlimit64) fromLimit(lim limits.Limit) { } func (r *rlimit64) copyIn(t *kernel.Task, addr usermem.Addr) error { - _, err := t.CopyIn(addr, r) + _, err := r.CopyIn(t, addr) return err } func (r *rlimit64) copyOut(t *kernel.Task, addr usermem.Addr) error { - _, err := t.CopyOut(addr, *r) + _, err := r.CopyOut(t, addr) return err } @@ -140,7 +138,8 @@ func Getrlimit(t *kernel.Task, args arch.SyscallArguments) (uintptr, *kernel.Sys return 0, nil, err } rlim.fromLimit(lim) - return 0, nil, rlim.copyOut(t, addr) + _, err = rlim.CopyOut(t, addr) + return 0, nil, err } // Setrlimit implements linux syscall setrlimit(2). @@ -155,7 +154,7 @@ func Setrlimit(t *kernel.Task, args arch.SyscallArguments) (uintptr, *kernel.Sys if err != nil { return 0, nil, err } - if err := rlim.copyIn(t, addr); err != nil { + if _, err := rlim.CopyIn(t, addr); err != nil { return 0, nil, syserror.EFAULT } _, err = prlimit64(t, resource, rlim.toLimit()) diff --git a/pkg/sentry/syscalls/linux/sys_rusage.go b/pkg/sentry/syscalls/linux/sys_rusage.go index 1674c7445..ac5c98a54 100644 --- a/pkg/sentry/syscalls/linux/sys_rusage.go +++ b/pkg/sentry/syscalls/linux/sys_rusage.go @@ -80,7 +80,7 @@ func Getrusage(t *kernel.Task, args arch.SyscallArguments) (uintptr, *kernel.Sys } ru := getrusage(t, which) - _, err := t.CopyOut(addr, &ru) + _, err := ru.CopyOut(t, addr) return 0, nil, err } @@ -104,7 +104,7 @@ func Times(t *kernel.Task, args arch.SyscallArguments) (uintptr, *kernel.Syscall CUTime: linux.ClockTFromDuration(cs2.UserTime), CSTime: linux.ClockTFromDuration(cs2.SysTime), } - if _, err := t.CopyOut(addr, &r); err != nil { + if _, err := r.CopyOut(t, addr); err != nil { return 0, nil, err } diff --git a/pkg/sentry/syscalls/linux/sys_sched.go b/pkg/sentry/syscalls/linux/sys_sched.go index 99f6993f5..bfcf44b6f 100644 --- a/pkg/sentry/syscalls/linux/sys_sched.go +++ b/pkg/sentry/syscalls/linux/sys_sched.go @@ -27,8 +27,10 @@ const ( ) // SchedParam replicates struct sched_param in sched.h. +// +// +marshal type SchedParam struct { - schedPriority int64 + schedPriority int32 } // SchedGetparam implements linux syscall sched_getparam(2). @@ -45,7 +47,7 @@ func SchedGetparam(t *kernel.Task, args arch.SyscallArguments) (uintptr, *kernel return 0, nil, syserror.ESRCH } r := SchedParam{schedPriority: onlyPriority} - if _, err := t.CopyOut(param, r); err != nil { + if _, err := r.CopyOut(t, param); err != nil { return 0, nil, err } @@ -79,7 +81,7 @@ func SchedSetscheduler(t *kernel.Task, args arch.SyscallArguments) (uintptr, *ke return 0, nil, syserror.ESRCH } var r SchedParam - if _, err := t.CopyIn(param, &r); err != nil { + if _, err := r.CopyIn(t, param); err != nil { return 0, nil, syserror.EINVAL } if r.schedPriority != onlyPriority { diff --git a/pkg/sentry/syscalls/linux/sys_seccomp.go b/pkg/sentry/syscalls/linux/sys_seccomp.go index 5b7a66f4d..4fdb4463c 100644 --- a/pkg/sentry/syscalls/linux/sys_seccomp.go +++ b/pkg/sentry/syscalls/linux/sys_seccomp.go @@ -24,6 +24,8 @@ import ( ) // userSockFprog is equivalent to Linux's struct sock_fprog on amd64. +// +// +marshal type userSockFprog struct { // Len is the length of the filter in BPF instructions. Len uint16 @@ -33,7 +35,7 @@ type userSockFprog struct { // Filter is a user pointer to the struct sock_filter array that makes up // the filter program. Filter is a uint64 rather than a usermem.Addr // because usermem.Addr is actually uintptr, which is not a fixed-size - // type, and encoding/binary.Read objects to this. + // type. Filter uint64 } @@ -54,11 +56,11 @@ func seccomp(t *kernel.Task, mode, flags uint64, addr usermem.Addr) error { } var fprog userSockFprog - if _, err := t.CopyIn(addr, &fprog); err != nil { + if _, err := fprog.CopyIn(t, addr); err != nil { return err } filter := make([]linux.BPFInstruction, int(fprog.Len)) - if _, err := t.CopyIn(usermem.Addr(fprog.Filter), &filter); err != nil { + if _, err := linux.CopyBPFInstructionSliceIn(t, usermem.Addr(fprog.Filter), filter); err != nil { return err } compiledFilter, err := bpf.Compile(filter) diff --git a/pkg/sentry/syscalls/linux/sys_sem.go b/pkg/sentry/syscalls/linux/sys_sem.go index 5f54f2456..47dadb800 100644 --- a/pkg/sentry/syscalls/linux/sys_sem.go +++ b/pkg/sentry/syscalls/linux/sys_sem.go @@ -18,6 +18,7 @@ import ( "math" "gvisor.dev/gvisor/pkg/abi/linux" + "gvisor.dev/gvisor/pkg/marshal/primitive" "gvisor.dev/gvisor/pkg/sentry/arch" "gvisor.dev/gvisor/pkg/sentry/fs" "gvisor.dev/gvisor/pkg/sentry/kernel" @@ -66,7 +67,7 @@ func Semop(t *kernel.Task, args arch.SyscallArguments) (uintptr, *kernel.Syscall } ops := make([]linux.Sembuf, nsops) - if _, err := t.CopyIn(sembufAddr, ops); err != nil { + if _, err := linux.CopySembufSliceIn(t, sembufAddr, ops); err != nil { return 0, nil, err } @@ -116,8 +117,8 @@ func Semctl(t *kernel.Task, args arch.SyscallArguments) (uintptr, *kernel.Syscal case linux.IPC_SET: arg := args[3].Pointer() - s := linux.SemidDS{} - if _, err := t.CopyIn(arg, &s); err != nil { + var s linux.SemidDS + if _, err := s.CopyIn(t, arg); err != nil { return 0, nil, err } @@ -188,7 +189,7 @@ func setValAll(t *kernel.Task, id int32, array usermem.Addr) error { return syserror.EINVAL } vals := make([]uint16, set.Size()) - if _, err := t.CopyIn(array, vals); err != nil { + if _, err := primitive.CopyUint16SliceIn(t, array, vals); err != nil { return err } creds := auth.CredentialsFromContext(t) @@ -217,7 +218,7 @@ func getValAll(t *kernel.Task, id int32, array usermem.Addr) error { if err != nil { return err } - _, err = t.CopyOut(array, vals) + _, err = primitive.CopyUint16SliceOut(t, array, vals) return err } diff --git a/pkg/sentry/syscalls/linux/sys_shm.go b/pkg/sentry/syscalls/linux/sys_shm.go index f0ae8fa8e..584064143 100644 --- a/pkg/sentry/syscalls/linux/sys_shm.go +++ b/pkg/sentry/syscalls/linux/sys_shm.go @@ -112,18 +112,18 @@ func Shmctl(t *kernel.Task, args arch.SyscallArguments) (uintptr, *kernel.Syscal stat, err := segment.IPCStat(t) if err == nil { - _, err = t.CopyOut(buf, stat) + _, err = stat.CopyOut(t, buf) } return 0, nil, err case linux.IPC_INFO: params := r.IPCInfo() - _, err := t.CopyOut(buf, params) + _, err := params.CopyOut(t, buf) return 0, nil, err case linux.SHM_INFO: info := r.ShmInfo() - _, err := t.CopyOut(buf, info) + _, err := info.CopyOut(t, buf) return 0, nil, err } @@ -137,11 +137,10 @@ func Shmctl(t *kernel.Task, args arch.SyscallArguments) (uintptr, *kernel.Syscal switch cmd { case linux.IPC_SET: var ds linux.ShmidDS - _, err = t.CopyIn(buf, &ds) - if err != nil { + if _, err = ds.CopyIn(t, buf); err != nil { return 0, nil, err } - err = segment.Set(t, &ds) + err := segment.Set(t, &ds) return 0, nil, err case linux.IPC_RMID: diff --git a/pkg/sentry/syscalls/linux/sys_signal.go b/pkg/sentry/syscalls/linux/sys_signal.go index 20cb1a5cb..e748d33d8 100644 --- a/pkg/sentry/syscalls/linux/sys_signal.go +++ b/pkg/sentry/syscalls/linux/sys_signal.go @@ -348,7 +348,7 @@ func Sigaltstack(t *kernel.Task, args arch.SyscallArguments) (uintptr, *kernel.S // Pause implements linux syscall pause(2). func Pause(t *kernel.Task, args arch.SyscallArguments) (uintptr, *kernel.SyscallControl, error) { - return 0, nil, syserror.ConvertIntr(t.Block(nil), kernel.ERESTARTNOHAND) + return 0, nil, syserror.ConvertIntr(t.Block(nil), syserror.ERESTARTNOHAND) } // RtSigpending implements linux syscall rt_sigpending(2). @@ -496,7 +496,7 @@ func RtSigsuspend(t *kernel.Task, args arch.SyscallArguments) (uintptr, *kernel. t.SetSavedSignalMask(oldmask) // Perform the wait. - return 0, nil, syserror.ConvertIntr(t.Block(nil), kernel.ERESTARTNOHAND) + return 0, nil, syserror.ConvertIntr(t.Block(nil), syserror.ERESTARTNOHAND) } // RestartSyscall implements the linux syscall restart_syscall(2). diff --git a/pkg/sentry/syscalls/linux/sys_socket.go b/pkg/sentry/syscalls/linux/sys_socket.go index fec1c1974..9feaca0da 100644 --- a/pkg/sentry/syscalls/linux/sys_socket.go +++ b/pkg/sentry/syscalls/linux/sys_socket.go @@ -19,6 +19,8 @@ import ( "gvisor.dev/gvisor/pkg/abi/linux" "gvisor.dev/gvisor/pkg/binary" + "gvisor.dev/gvisor/pkg/marshal" + "gvisor.dev/gvisor/pkg/marshal/primitive" "gvisor.dev/gvisor/pkg/sentry/arch" "gvisor.dev/gvisor/pkg/sentry/fs" "gvisor.dev/gvisor/pkg/sentry/kernel" @@ -29,8 +31,6 @@ import ( "gvisor.dev/gvisor/pkg/syserr" "gvisor.dev/gvisor/pkg/syserror" "gvisor.dev/gvisor/pkg/usermem" - "gvisor.dev/gvisor/tools/go_marshal/marshal" - "gvisor.dev/gvisor/tools/go_marshal/primitive" ) // LINT.IfChange @@ -67,10 +67,10 @@ const flagsOffset = 48 const sizeOfInt32 = 4 // messageHeader64Len is the length of a MessageHeader64 struct. -var messageHeader64Len = uint64(binary.Size(MessageHeader64{})) +var messageHeader64Len = uint64((*MessageHeader64)(nil).SizeBytes()) // multipleMessageHeader64Len is the length of a multipeMessageHeader64 struct. -var multipleMessageHeader64Len = uint64(binary.Size(multipleMessageHeader64{})) +var multipleMessageHeader64Len = uint64((*multipleMessageHeader64)(nil).SizeBytes()) // baseRecvFlags are the flags that are accepted across recvmsg(2), // recvmmsg(2), and recvfrom(2). @@ -78,6 +78,8 @@ const baseRecvFlags = linux.MSG_OOB | linux.MSG_DONTROUTE | linux.MSG_DONTWAIT | // MessageHeader64 is the 64-bit representation of the msghdr struct used in // the recvmsg and sendmsg syscalls. +// +// +marshal type MessageHeader64 struct { // Name is the optional pointer to a network address buffer. Name uint64 @@ -106,30 +108,14 @@ type MessageHeader64 struct { // multipleMessageHeader64 is the 64-bit representation of the mmsghdr struct used in // the recvmmsg and sendmmsg syscalls. +// +// +marshal type multipleMessageHeader64 struct { msgHdr MessageHeader64 msgLen uint32 _ int32 } -// CopyInMessageHeader64 copies a message header from user to kernel memory. -func CopyInMessageHeader64(t *kernel.Task, addr usermem.Addr, msg *MessageHeader64) error { - b := t.CopyScratchBuffer(52) - if _, err := t.CopyInBytes(addr, b); err != nil { - return err - } - - msg.Name = usermem.ByteOrder.Uint64(b[0:]) - msg.NameLen = usermem.ByteOrder.Uint32(b[8:]) - msg.Iov = usermem.ByteOrder.Uint64(b[16:]) - msg.IovLen = usermem.ByteOrder.Uint64(b[24:]) - msg.Control = usermem.ByteOrder.Uint64(b[32:]) - msg.ControlLen = usermem.ByteOrder.Uint64(b[40:]) - msg.Flags = int32(usermem.ByteOrder.Uint32(b[48:])) - - return nil -} - // CaptureAddress allocates memory for and copies a socket address structure // from the untrusted address space range. func CaptureAddress(t *kernel.Task, addr usermem.Addr, addrlen uint32) ([]byte, error) { @@ -148,10 +134,10 @@ func CaptureAddress(t *kernel.Task, addr usermem.Addr, addrlen uint32) ([]byte, // writeAddress writes a sockaddr structure and its length to an output buffer // in the unstrusted address space range. If the address is bigger than the // buffer, it is truncated. -func writeAddress(t *kernel.Task, addr interface{}, addrLen uint32, addrPtr usermem.Addr, addrLenPtr usermem.Addr) error { +func writeAddress(t *kernel.Task, addr linux.SockAddr, addrLen uint32, addrPtr usermem.Addr, addrLenPtr usermem.Addr) error { // Get the buffer length. var bufLen uint32 - if _, err := t.CopyIn(addrLenPtr, &bufLen); err != nil { + if _, err := primitive.CopyUint32In(t, addrLenPtr, &bufLen); err != nil { return err } @@ -160,7 +146,7 @@ func writeAddress(t *kernel.Task, addr interface{}, addrLen uint32, addrPtr user } // Write the length unconditionally. - if _, err := t.CopyOut(addrLenPtr, addrLen); err != nil { + if _, err := primitive.CopyUint32Out(t, addrLenPtr, addrLen); err != nil { return err } @@ -173,7 +159,8 @@ func writeAddress(t *kernel.Task, addr interface{}, addrLen uint32, addrPtr user } // Copy as much of the address as will fit in the buffer. - encodedAddr := binary.Marshal(nil, usermem.ByteOrder, addr) + encodedAddr := t.CopyScratchBuffer(addr.SizeBytes()) + addr.MarshalUnsafe(encodedAddr) if bufLen > uint32(len(encodedAddr)) { bufLen = uint32(len(encodedAddr)) } @@ -247,9 +234,9 @@ func SocketPair(t *kernel.Task, args arch.SyscallArguments) (uintptr, *kernel.Sy } // Copy the file descriptors out. - if _, err := t.CopyOut(socks, fds); err != nil { + if _, err := primitive.CopyInt32SliceOut(t, socks, fds); err != nil { for _, fd := range fds { - if file, _ := t.FDTable().Remove(fd); file != nil { + if file, _ := t.FDTable().Remove(t, fd); file != nil { file.DecRef(t) } } @@ -285,7 +272,7 @@ func Connect(t *kernel.Task, args arch.SyscallArguments) (uintptr, *kernel.Sysca } blocking := !file.Flags().NonBlocking - return 0, nil, syserror.ConvertIntr(s.Connect(t, a, blocking).ToError(), kernel.ERESTARTSYS) + return 0, nil, syserror.ConvertIntr(s.Connect(t, a, blocking).ToError(), syserror.ERESTARTSYS) } // accept is the implementation of the accept syscall. It is called by accept @@ -316,7 +303,7 @@ func accept(t *kernel.Task, fd int32, addr usermem.Addr, addrLen usermem.Addr, f peerRequested := addrLen != 0 nfd, peer, peerLen, e := s.Accept(t, peerRequested, flags, blocking) if e != nil { - return 0, syserror.ConvertIntr(e.ToError(), kernel.ERESTARTSYS) + return 0, syserror.ConvertIntr(e.ToError(), syserror.ERESTARTSYS) } if peerRequested { // NOTE(magi): Linux does not give you an error if it can't @@ -456,8 +443,8 @@ func GetSockOpt(t *kernel.Task, args arch.SyscallArguments) (uintptr, *kernel.Sy } // Read the length. Reject negative values. - optLen := int32(0) - if _, err := t.CopyIn(optLenAddr, &optLen); err != nil { + var optLen int32 + if _, err := primitive.CopyInt32In(t, optLenAddr, &optLen); err != nil { return 0, nil, err } if optLen < 0 { @@ -471,7 +458,7 @@ func GetSockOpt(t *kernel.Task, args arch.SyscallArguments) (uintptr, *kernel.Sy } vLen := int32(binary.Size(v)) - if _, err := t.CopyOut(optLenAddr, vLen); err != nil { + if _, err := primitive.CopyInt32Out(t, optLenAddr, vLen); err != nil { return 0, nil, err } @@ -733,7 +720,7 @@ func RecvMMsg(t *kernel.Task, args arch.SyscallArguments) (uintptr, *kernel.Sysc if !ok { return 0, nil, syserror.EFAULT } - if _, err = t.CopyOut(lp, uint32(n)); err != nil { + if _, err = primitive.CopyUint32Out(t, lp, uint32(n)); err != nil { break } count++ @@ -748,7 +735,7 @@ func RecvMMsg(t *kernel.Task, args arch.SyscallArguments) (uintptr, *kernel.Sysc func recvSingleMsg(t *kernel.Task, s socket.Socket, msgPtr usermem.Addr, flags int32, haveDeadline bool, deadline ktime.Time) (uintptr, error) { // Capture the message header and io vectors. var msg MessageHeader64 - if err := CopyInMessageHeader64(t, msgPtr, &msg); err != nil { + if _, err := msg.CopyIn(t, msgPtr); err != nil { return 0, err } @@ -771,7 +758,7 @@ func recvSingleMsg(t *kernel.Task, s socket.Socket, msgPtr usermem.Addr, flags i if msg.ControlLen == 0 && msg.NameLen == 0 { n, mflags, _, _, cms, err := s.RecvMsg(t, dst, int(flags), haveDeadline, deadline, false, 0) if err != nil { - return 0, syserror.ConvertIntr(err.ToError(), kernel.ERESTARTSYS) + return 0, syserror.ConvertIntr(err.ToError(), syserror.ERESTARTSYS) } if !cms.Unix.Empty() { mflags |= linux.MSG_CTRUNC @@ -780,7 +767,7 @@ func recvSingleMsg(t *kernel.Task, s socket.Socket, msgPtr usermem.Addr, flags i if int(msg.Flags) != mflags { // Copy out the flags to the caller. - if _, err := t.CopyOut(msgPtr+flagsOffset, int32(mflags)); err != nil { + if _, err := primitive.CopyInt32Out(t, msgPtr+flagsOffset, int32(mflags)); err != nil { return 0, err } } @@ -793,7 +780,7 @@ func recvSingleMsg(t *kernel.Task, s socket.Socket, msgPtr usermem.Addr, flags i } n, mflags, sender, senderLen, cms, e := s.RecvMsg(t, dst, int(flags), haveDeadline, deadline, msg.NameLen != 0, msg.ControlLen) if e != nil { - return 0, syserror.ConvertIntr(e.ToError(), kernel.ERESTARTSYS) + return 0, syserror.ConvertIntr(e.ToError(), syserror.ERESTARTSYS) } defer cms.Release(t) @@ -817,17 +804,17 @@ func recvSingleMsg(t *kernel.Task, s socket.Socket, msgPtr usermem.Addr, flags i } // Copy the control data to the caller. - if _, err := t.CopyOut(msgPtr+controlLenOffset, uint64(len(controlData))); err != nil { + if _, err := primitive.CopyUint64Out(t, msgPtr+controlLenOffset, uint64(len(controlData))); err != nil { return 0, err } if len(controlData) > 0 { - if _, err := t.CopyOut(usermem.Addr(msg.Control), controlData); err != nil { + if _, err := t.CopyOutBytes(usermem.Addr(msg.Control), controlData); err != nil { return 0, err } } // Copy out the flags to the caller. - if _, err := t.CopyOut(msgPtr+flagsOffset, int32(mflags)); err != nil { + if _, err := primitive.CopyInt32Out(t, msgPtr+flagsOffset, int32(mflags)); err != nil { return 0, err } @@ -882,7 +869,7 @@ func recvFrom(t *kernel.Task, fd int32, bufPtr usermem.Addr, bufLen uint64, flag n, _, sender, senderLen, cm, e := s.RecvMsg(t, dst, int(flags), haveDeadline, deadline, nameLenPtr != 0, 0) cm.Release(t) if e != nil { - return 0, syserror.ConvertIntr(e.ToError(), kernel.ERESTARTSYS) + return 0, syserror.ConvertIntr(e.ToError(), syserror.ERESTARTSYS) } // Copy the address to the caller. @@ -996,7 +983,7 @@ func SendMMsg(t *kernel.Task, args arch.SyscallArguments) (uintptr, *kernel.Sysc if !ok { return 0, nil, syserror.EFAULT } - if _, err = t.CopyOut(lp, uint32(n)); err != nil { + if _, err = primitive.CopyUint32Out(t, lp, uint32(n)); err != nil { break } count++ @@ -1011,7 +998,7 @@ func SendMMsg(t *kernel.Task, args arch.SyscallArguments) (uintptr, *kernel.Sysc func sendSingleMsg(t *kernel.Task, s socket.Socket, file *fs.File, msgPtr usermem.Addr, flags int32) (uintptr, error) { // Capture the message header. var msg MessageHeader64 - if err := CopyInMessageHeader64(t, msgPtr, &msg); err != nil { + if _, err := msg.CopyIn(t, msgPtr); err != nil { return 0, err } @@ -1022,7 +1009,7 @@ func sendSingleMsg(t *kernel.Task, s socket.Socket, file *fs.File, msgPtr userme return 0, syserror.ENOBUFS } controlData = make([]byte, msg.ControlLen) - if _, err := t.CopyIn(usermem.Addr(msg.Control), &controlData); err != nil { + if _, err := t.CopyInBytes(usermem.Addr(msg.Control), controlData); err != nil { return 0, err } } @@ -1064,7 +1051,7 @@ func sendSingleMsg(t *kernel.Task, s socket.Socket, file *fs.File, msgPtr userme // Call the syscall implementation. n, e := s.SendMsg(t, src, to, int(flags), haveDeadline, deadline, controlMessages) - err = handleIOError(t, n != 0, e.ToError(), kernel.ERESTARTSYS, "sendmsg", file) + err = handleIOError(t, n != 0, e.ToError(), syserror.ERESTARTSYS, "sendmsg", file) if err != nil { controlMessages.Release(t) } @@ -1124,7 +1111,7 @@ func sendTo(t *kernel.Task, fd int32, bufPtr usermem.Addr, bufLen uint64, flags // Call the syscall implementation. n, e := s.SendMsg(t, src, to, int(flags), haveDeadline, deadline, socket.ControlMessages{Unix: control.New(t, s, nil)}) - return uintptr(n), handleIOError(t, n != 0, e.ToError(), kernel.ERESTARTSYS, "sendto", file) + return uintptr(n), handleIOError(t, n != 0, e.ToError(), syserror.ERESTARTSYS, "sendto", file) } // SendTo implements the linux syscall sendto(2). diff --git a/pkg/sentry/syscalls/linux/sys_splice.go b/pkg/sentry/syscalls/linux/sys_splice.go index b8846a10a..46616c961 100644 --- a/pkg/sentry/syscalls/linux/sys_splice.go +++ b/pkg/sentry/syscalls/linux/sys_splice.go @@ -16,6 +16,7 @@ package linux import ( "gvisor.dev/gvisor/pkg/abi/linux" + "gvisor.dev/gvisor/pkg/marshal/primitive" "gvisor.dev/gvisor/pkg/sentry/arch" "gvisor.dev/gvisor/pkg/sentry/fs" "gvisor.dev/gvisor/pkg/sentry/kernel" @@ -141,7 +142,7 @@ func Sendfile(t *kernel.Task, args arch.SyscallArguments) (uintptr, *kernel.Sysc // Copy in the offset. var offset int64 - if _, err := t.CopyIn(offsetAddr, &offset); err != nil { + if _, err := primitive.CopyInt64In(t, offsetAddr, &offset); err != nil { return 0, nil, err } @@ -149,11 +150,11 @@ func Sendfile(t *kernel.Task, args arch.SyscallArguments) (uintptr, *kernel.Sysc n, err = doSplice(t, outFile, inFile, fs.SpliceOpts{ Length: count, SrcOffset: true, - SrcStart: offset, + SrcStart: int64(offset), }, outFile.Flags().NonBlocking) // Copy out the new offset. - if _, err := t.CopyOut(offsetAddr, n+offset); err != nil { + if _, err := primitive.CopyInt64Out(t, offsetAddr, offset+n); err != nil { return 0, nil, err } } else { @@ -170,7 +171,7 @@ func Sendfile(t *kernel.Task, args arch.SyscallArguments) (uintptr, *kernel.Sysc // We can only pass a single file to handleIOError, so pick inFile // arbitrarily. This is used only for debugging purposes. - return uintptr(n), nil, handleIOError(t, false, err, kernel.ERESTARTSYS, "sendfile", inFile) + return uintptr(n), nil, handleIOError(t, false, err, syserror.ERESTARTSYS, "sendfile", inFile) } // Splice implements splice(2). @@ -228,7 +229,7 @@ func Splice(t *kernel.Task, args arch.SyscallArguments) (uintptr, *kernel.Syscal } var offset int64 - if _, err := t.CopyIn(outOffset, &offset); err != nil { + if _, err := primitive.CopyInt64In(t, outOffset, &offset); err != nil { return 0, nil, err } @@ -246,7 +247,7 @@ func Splice(t *kernel.Task, args arch.SyscallArguments) (uintptr, *kernel.Syscal } var offset int64 - if _, err := t.CopyIn(inOffset, &offset); err != nil { + if _, err := primitive.CopyInt64In(t, inOffset, &offset); err != nil { return 0, nil, err } @@ -280,7 +281,7 @@ func Splice(t *kernel.Task, args arch.SyscallArguments) (uintptr, *kernel.Syscal } // See above; inFile is chosen arbitrarily here. - return uintptr(n), nil, handleIOError(t, n != 0, err, kernel.ERESTARTSYS, "splice", inFile) + return uintptr(n), nil, handleIOError(t, n != 0, err, syserror.ERESTARTSYS, "splice", inFile) } // Tee imlements tee(2). @@ -333,5 +334,5 @@ func Tee(t *kernel.Task, args arch.SyscallArguments) (uintptr, *kernel.SyscallCo } // See above; inFile is chosen arbitrarily here. - return uintptr(n), nil, handleIOError(t, false, err, kernel.ERESTARTSYS, "tee", inFile) + return uintptr(n), nil, handleIOError(t, false, err, syserror.ERESTARTSYS, "tee", inFile) } diff --git a/pkg/sentry/syscalls/linux/sys_stat.go b/pkg/sentry/syscalls/linux/sys_stat.go index a5826f2dd..cda29a8b5 100644 --- a/pkg/sentry/syscalls/linux/sys_stat.go +++ b/pkg/sentry/syscalls/linux/sys_stat.go @@ -221,7 +221,7 @@ func statx(t *kernel.Task, sattr fs.StableAttr, uattr fs.UnstableAttr, statxAddr DevMajor: uint32(devMajor), DevMinor: devMinor, } - _, err := t.CopyOut(statxAddr, &s) + _, err := s.CopyOut(t, statxAddr) return err } @@ -283,7 +283,7 @@ func statfsImpl(t *kernel.Task, d *fs.Dirent, addr usermem.Addr) error { FragmentSize: d.Inode.StableAttr.BlockSize, // Leave other fields 0 like simple_statfs does. } - _, err = t.CopyOut(addr, &statfs) + _, err = statfs.CopyOut(t, addr) return err } diff --git a/pkg/sentry/syscalls/linux/sys_sync.go b/pkg/sentry/syscalls/linux/sys_sync.go index f2c0e5069..048a21c6e 100644 --- a/pkg/sentry/syscalls/linux/sys_sync.go +++ b/pkg/sentry/syscalls/linux/sys_sync.go @@ -57,7 +57,7 @@ func Fsync(t *kernel.Task, args arch.SyscallArguments) (uintptr, *kernel.Syscall defer file.DecRef(t) err := file.Fsync(t, 0, fs.FileMaxOffset, fs.SyncAll) - return 0, nil, syserror.ConvertIntr(err, kernel.ERESTARTSYS) + return 0, nil, syserror.ConvertIntr(err, syserror.ERESTARTSYS) } // Fdatasync implements linux syscall fdatasync(2). @@ -73,7 +73,7 @@ func Fdatasync(t *kernel.Task, args arch.SyscallArguments) (uintptr, *kernel.Sys defer file.DecRef(t) err := file.Fsync(t, 0, fs.FileMaxOffset, fs.SyncData) - return 0, nil, syserror.ConvertIntr(err, kernel.ERESTARTSYS) + return 0, nil, syserror.ConvertIntr(err, syserror.ERESTARTSYS) } // SyncFileRange implements linux syscall sync_file_rage(2) @@ -135,7 +135,7 @@ func SyncFileRange(t *kernel.Task, args arch.SyscallArguments) (uintptr, *kernel err = file.Fsync(t, offset, fs.FileMaxOffset, fs.SyncData) } - return 0, nil, syserror.ConvertIntr(err, kernel.ERESTARTSYS) + return 0, nil, syserror.ConvertIntr(err, syserror.ERESTARTSYS) } // LINT.ThenChange(vfs2/sync.go) diff --git a/pkg/sentry/syscalls/linux/sys_sysinfo.go b/pkg/sentry/syscalls/linux/sys_sysinfo.go index 297de052a..674d341b6 100644 --- a/pkg/sentry/syscalls/linux/sys_sysinfo.go +++ b/pkg/sentry/syscalls/linux/sys_sysinfo.go @@ -43,6 +43,6 @@ func Sysinfo(t *kernel.Task, args arch.SyscallArguments) (uintptr, *kernel.Sysca FreeRAM: memFree, Unit: 1, } - _, err := t.CopyOut(addr, si) + _, err := si.CopyOut(t, addr) return 0, nil, err } diff --git a/pkg/sentry/syscalls/linux/sys_thread.go b/pkg/sentry/syscalls/linux/sys_thread.go index 2d16e4933..39ca9ea97 100644 --- a/pkg/sentry/syscalls/linux/sys_thread.go +++ b/pkg/sentry/syscalls/linux/sys_thread.go @@ -19,6 +19,7 @@ import ( "syscall" "gvisor.dev/gvisor/pkg/abi/linux" + "gvisor.dev/gvisor/pkg/marshal/primitive" "gvisor.dev/gvisor/pkg/sentry/arch" "gvisor.dev/gvisor/pkg/sentry/fs" "gvisor.dev/gvisor/pkg/sentry/fsbridge" @@ -262,7 +263,7 @@ func parseCommonWaitOptions(wopts *kernel.WaitOptions, options int) error { wopts.Events |= kernel.EventGroupContinue } if options&linux.WNOHANG == 0 { - wopts.BlockInterruptErr = kernel.ERESTARTSYS + wopts.BlockInterruptErr = syserror.ERESTARTSYS } if options&linux.WNOTHREAD == 0 { wopts.SiblingChildren = true @@ -311,13 +312,13 @@ func wait4(t *kernel.Task, pid int, statusAddr usermem.Addr, options int, rusage return 0, err } if statusAddr != 0 { - if _, err := t.CopyOut(statusAddr, wr.Status); err != nil { + if _, err := primitive.CopyUint32Out(t, statusAddr, wr.Status); err != nil { return 0, err } } if rusageAddr != 0 { ru := getrusage(wr.Task, linux.RUSAGE_BOTH) - if _, err := t.CopyOut(rusageAddr, &ru); err != nil { + if _, err := ru.CopyOut(t, rusageAddr); err != nil { return 0, err } } @@ -395,14 +396,14 @@ func Waitid(t *kernel.Task, args arch.SyscallArguments) (uintptr, *kernel.Syscal // as well. if infop != 0 { var si arch.SignalInfo - _, err = t.CopyOut(infop, &si) + _, err = si.CopyOut(t, infop) } } return 0, nil, err } if rusageAddr != 0 { ru := getrusage(wr.Task, linux.RUSAGE_BOTH) - if _, err := t.CopyOut(rusageAddr, &ru); err != nil { + if _, err := ru.CopyOut(t, rusageAddr); err != nil { return 0, nil, err } } @@ -441,7 +442,7 @@ func Waitid(t *kernel.Task, args arch.SyscallArguments) (uintptr, *kernel.Syscal default: t.Warningf("waitid got incomprehensible wait status %d", s) } - _, err = t.CopyOut(infop, &si) + _, err = si.CopyOut(t, infop) return 0, nil, err } @@ -558,9 +559,7 @@ func Getcpu(t *kernel.Task, args arch.SyscallArguments) (uintptr, *kernel.Syscal // third argument to this system call is nowadays unused. if cpu != 0 { - buf := t.CopyScratchBuffer(4) - usermem.ByteOrder.PutUint32(buf, uint32(t.CPU())) - if _, err := t.CopyOutBytes(cpu, buf); err != nil { + if _, err := primitive.CopyInt32Out(t, cpu, t.CPU()); err != nil { return 0, nil, err } } diff --git a/pkg/sentry/syscalls/linux/sys_time.go b/pkg/sentry/syscalls/linux/sys_time.go index 2d2aa0819..c5054d2f1 100644 --- a/pkg/sentry/syscalls/linux/sys_time.go +++ b/pkg/sentry/syscalls/linux/sys_time.go @@ -19,6 +19,7 @@ import ( "time" "gvisor.dev/gvisor/pkg/abi/linux" + "gvisor.dev/gvisor/pkg/marshal/primitive" "gvisor.dev/gvisor/pkg/sentry/arch" "gvisor.dev/gvisor/pkg/sentry/kernel" ktime "gvisor.dev/gvisor/pkg/sentry/kernel/time" @@ -168,7 +169,7 @@ func Time(t *kernel.Task, args arch.SyscallArguments) (uintptr, *kernel.SyscallC return uintptr(r), nil, nil } - if _, err := t.CopyOut(addr, r); err != nil { + if _, err := r.CopyOut(t, addr); err != nil { return 0, nil, err } return uintptr(r), nil, nil @@ -213,7 +214,7 @@ func clockNanosleepUntil(t *kernel.Task, c ktime.Clock, ts linux.Timespec) error return nil } - return syserror.ConvertIntr(err, kernel.ERESTARTNOHAND) + return syserror.ConvertIntr(err, syserror.ERESTARTNOHAND) } // clockNanosleepFor blocks for a specified duration. @@ -254,7 +255,7 @@ func clockNanosleepFor(t *kernel.Task, c ktime.Clock, dur time.Duration, rem use duration: remaining, rem: rem, }) - return kernel.ERESTART_RESTARTBLOCK + return syserror.ERESTART_RESTARTBLOCK default: panic(fmt.Sprintf("Impossible BlockWithTimer error %v", err)) } @@ -334,8 +335,8 @@ func Gettimeofday(t *kernel.Task, args arch.SyscallArguments) (uintptr, *kernel. // Ask the time package for the timezone. _, offset := time.Now().Zone() // This int32 array mimics linux's struct timezone. - timezone := [2]int32{-int32(offset) / 60, 0} - _, err := t.CopyOut(tz, timezone) + timezone := []int32{-int32(offset) / 60, 0} + _, err := primitive.CopyInt32SliceOut(t, tz, timezone) return 0, nil, err } return 0, nil, nil diff --git a/pkg/sentry/syscalls/linux/sys_timer.go b/pkg/sentry/syscalls/linux/sys_timer.go index a4c400f87..45eef4feb 100644 --- a/pkg/sentry/syscalls/linux/sys_timer.go +++ b/pkg/sentry/syscalls/linux/sys_timer.go @@ -21,81 +21,63 @@ import ( "gvisor.dev/gvisor/pkg/sentry/arch" "gvisor.dev/gvisor/pkg/sentry/kernel" "gvisor.dev/gvisor/pkg/syserror" - "gvisor.dev/gvisor/pkg/usermem" ) const nsecPerSec = int64(time.Second) -// copyItimerValIn copies an ItimerVal from the untrusted app range to the -// kernel. The ItimerVal may be either 32 or 64 bits. -// A NULL address is allowed because because Linux allows -// setitimer(which, NULL, &old_value) which disables the timer. -// There is a KERN_WARN message saying this misfeature will be removed. -// However, that hasn't happened as of 3.19, so we continue to support it. -func copyItimerValIn(t *kernel.Task, addr usermem.Addr) (linux.ItimerVal, error) { - if addr == usermem.Addr(0) { - return linux.ItimerVal{}, nil - } - - switch t.Arch().Width() { - case 8: - // Native size, just copy directly. - var itv linux.ItimerVal - if _, err := t.CopyIn(addr, &itv); err != nil { - return linux.ItimerVal{}, err - } - - return itv, nil - default: - return linux.ItimerVal{}, syserror.ENOSYS - } -} - -// copyItimerValOut copies an ItimerVal to the untrusted app range. -// The ItimerVal may be either 32 or 64 bits. -// A NULL address is allowed, in which case no copy takes place -func copyItimerValOut(t *kernel.Task, addr usermem.Addr, itv *linux.ItimerVal) error { - if addr == usermem.Addr(0) { - return nil - } - - switch t.Arch().Width() { - case 8: - // Native size, just copy directly. - _, err := t.CopyOut(addr, itv) - return err - default: - return syserror.ENOSYS - } -} - // Getitimer implements linux syscall getitimer(2). func Getitimer(t *kernel.Task, args arch.SyscallArguments) (uintptr, *kernel.SyscallControl, error) { + if t.Arch().Width() != 8 { + // Definition of linux.ItimerVal assumes 64-bit architecture. + return 0, nil, syserror.ENOSYS + } + timerID := args[0].Int() - val := args[1].Pointer() + addr := args[1].Pointer() olditv, err := t.Getitimer(timerID) if err != nil { return 0, nil, err } - return 0, nil, copyItimerValOut(t, val, &olditv) + // A NULL address is allowed, in which case no copy out takes place. + if addr == 0 { + return 0, nil, nil + } + _, err = olditv.CopyOut(t, addr) + return 0, nil, err } // Setitimer implements linux syscall setitimer(2). func Setitimer(t *kernel.Task, args arch.SyscallArguments) (uintptr, *kernel.SyscallControl, error) { - timerID := args[0].Int() - newVal := args[1].Pointer() - oldVal := args[2].Pointer() + if t.Arch().Width() != 8 { + // Definition of linux.ItimerVal assumes 64-bit architecture. + return 0, nil, syserror.ENOSYS + } - newitv, err := copyItimerValIn(t, newVal) - if err != nil { - return 0, nil, err + timerID := args[0].Int() + newAddr := args[1].Pointer() + oldAddr := args[2].Pointer() + + var newitv linux.ItimerVal + // A NULL address is allowed because because Linux allows + // setitimer(which, NULL, &old_value) which disables the timer. There is a + // KERN_WARN message saying this misfeature will be removed. However, that + // hasn't happened as of 3.19, so we continue to support it. + if newAddr != 0 { + if _, err := newitv.CopyIn(t, newAddr); err != nil { + return 0, nil, err + } } olditv, err := t.Setitimer(timerID, newitv) if err != nil { return 0, nil, err } - return 0, nil, copyItimerValOut(t, oldVal, &olditv) + // A NULL address is allowed, in which case no copy out takes place. + if oldAddr == 0 { + return 0, nil, nil + } + _, err = olditv.CopyOut(t, oldAddr) + return 0, nil, err } // Alarm implements linux syscall alarm(2). @@ -131,7 +113,7 @@ func TimerCreate(t *kernel.Task, args arch.SyscallArguments) (uintptr, *kernel.S var sev *linux.Sigevent if sevp != 0 { sev = &linux.Sigevent{} - if _, err = t.CopyIn(sevp, sev); err != nil { + if _, err = sev.CopyIn(t, sevp); err != nil { return 0, nil, err } } @@ -141,7 +123,7 @@ func TimerCreate(t *kernel.Task, args arch.SyscallArguments) (uintptr, *kernel.S return 0, nil, err } - if _, err := t.CopyOut(timerIDp, &id); err != nil { + if _, err := id.CopyOut(t, timerIDp); err != nil { t.IntervalTimerDelete(id) return 0, nil, err } @@ -157,7 +139,7 @@ func TimerSettime(t *kernel.Task, args arch.SyscallArguments) (uintptr, *kernel. oldValAddr := args[3].Pointer() var newVal linux.Itimerspec - if _, err := t.CopyIn(newValAddr, &newVal); err != nil { + if _, err := newVal.CopyIn(t, newValAddr); err != nil { return 0, nil, err } oldVal, err := t.IntervalTimerSettime(timerID, newVal, flags&linux.TIMER_ABSTIME != 0) @@ -165,9 +147,8 @@ func TimerSettime(t *kernel.Task, args arch.SyscallArguments) (uintptr, *kernel. return 0, nil, err } if oldValAddr != 0 { - if _, err := t.CopyOut(oldValAddr, &oldVal); err != nil { - return 0, nil, err - } + _, err = oldVal.CopyOut(t, oldValAddr) + return 0, nil, err } return 0, nil, nil } @@ -181,7 +162,7 @@ func TimerGettime(t *kernel.Task, args arch.SyscallArguments) (uintptr, *kernel. if err != nil { return 0, nil, err } - _, err = t.CopyOut(curValAddr, &curVal) + _, err = curVal.CopyOut(t, curValAddr) return 0, nil, err } diff --git a/pkg/sentry/syscalls/linux/sys_timerfd.go b/pkg/sentry/syscalls/linux/sys_timerfd.go index 34b03e4ee..cadd9d348 100644 --- a/pkg/sentry/syscalls/linux/sys_timerfd.go +++ b/pkg/sentry/syscalls/linux/sys_timerfd.go @@ -81,7 +81,7 @@ func TimerfdSettime(t *kernel.Task, args arch.SyscallArguments) (uintptr, *kerne } var newVal linux.Itimerspec - if _, err := t.CopyIn(newValAddr, &newVal); err != nil { + if _, err := newVal.CopyIn(t, newValAddr); err != nil { return 0, nil, err } newS, err := ktime.SettingFromItimerspec(newVal, flags&linux.TFD_TIMER_ABSTIME != 0, tf.Clock()) @@ -91,7 +91,7 @@ func TimerfdSettime(t *kernel.Task, args arch.SyscallArguments) (uintptr, *kerne tm, oldS := tf.SetTime(newS) if oldValAddr != 0 { oldVal := ktime.ItimerspecFromSetting(tm, oldS) - if _, err := t.CopyOut(oldValAddr, &oldVal); err != nil { + if _, err := oldVal.CopyOut(t, oldValAddr); err != nil { return 0, nil, err } } @@ -116,6 +116,6 @@ func TimerfdGettime(t *kernel.Task, args arch.SyscallArguments) (uintptr, *kerne tm, s := tf.GetTime() curVal := ktime.ItimerspecFromSetting(tm, s) - _, err := t.CopyOut(curValAddr, &curVal) + _, err := curVal.CopyOut(t, curValAddr) return 0, nil, err } diff --git a/pkg/sentry/syscalls/linux/sys_tls_amd64.go b/pkg/sentry/syscalls/linux/sys_tls_amd64.go index b3eb96a1c..6ddd30d5c 100644 --- a/pkg/sentry/syscalls/linux/sys_tls_amd64.go +++ b/pkg/sentry/syscalls/linux/sys_tls_amd64.go @@ -18,6 +18,7 @@ package linux import ( "gvisor.dev/gvisor/pkg/abi/linux" + "gvisor.dev/gvisor/pkg/marshal/primitive" "gvisor.dev/gvisor/pkg/sentry/arch" "gvisor.dev/gvisor/pkg/sentry/kernel" "gvisor.dev/gvisor/pkg/syserror" @@ -30,17 +31,19 @@ func ArchPrctl(t *kernel.Task, args arch.SyscallArguments) (uintptr, *kernel.Sys case linux.ARCH_GET_FS: addr := args[1].Pointer() fsbase := t.Arch().TLS() - _, err := t.CopyOut(addr, uint64(fsbase)) - if err != nil { - return 0, nil, err + switch t.Arch().Width() { + case 8: + if _, err := primitive.CopyUint64Out(t, addr, uint64(fsbase)); err != nil { + return 0, nil, err + } + default: + return 0, nil, syserror.ENOSYS } - case linux.ARCH_SET_FS: fsbase := args[1].Uint64() if !t.Arch().SetTLS(uintptr(fsbase)) { return 0, nil, syserror.EPERM } - case linux.ARCH_GET_GS, linux.ARCH_SET_GS: t.Kernel().EmitUnimplementedEvent(t) fallthrough diff --git a/pkg/sentry/syscalls/linux/sys_utsname.go b/pkg/sentry/syscalls/linux/sys_utsname.go index e9d702e8e..66c5974f5 100644 --- a/pkg/sentry/syscalls/linux/sys_utsname.go +++ b/pkg/sentry/syscalls/linux/sys_utsname.go @@ -46,7 +46,7 @@ func Uname(t *kernel.Task, args arch.SyscallArguments) (uintptr, *kernel.Syscall // Copy out the result. va := args[0].Pointer() - _, err := t.CopyOut(va, u) + _, err := u.CopyOut(t, va) return 0, nil, err } diff --git a/pkg/sentry/syscalls/linux/sys_write.go b/pkg/sentry/syscalls/linux/sys_write.go index 485526e28..95bfe6606 100644 --- a/pkg/sentry/syscalls/linux/sys_write.go +++ b/pkg/sentry/syscalls/linux/sys_write.go @@ -71,7 +71,7 @@ func Write(t *kernel.Task, args arch.SyscallArguments) (uintptr, *kernel.Syscall n, err := writev(t, file, src) t.IOUsage().AccountWriteSyscall(n) - return uintptr(n), nil, handleIOError(t, n != 0, err, kernel.ERESTARTSYS, "write", file) + return uintptr(n), nil, handleIOError(t, n != 0, err, syserror.ERESTARTSYS, "write", file) } // Pwrite64 implements linux syscall pwrite64(2). @@ -118,7 +118,7 @@ func Pwrite64(t *kernel.Task, args arch.SyscallArguments) (uintptr, *kernel.Sysc n, err := pwritev(t, file, src, offset) t.IOUsage().AccountWriteSyscall(n) - return uintptr(n), nil, handleIOError(t, n != 0, err, kernel.ERESTARTSYS, "pwrite64", file) + return uintptr(n), nil, handleIOError(t, n != 0, err, syserror.ERESTARTSYS, "pwrite64", file) } // Writev implements linux syscall writev(2). @@ -148,7 +148,7 @@ func Writev(t *kernel.Task, args arch.SyscallArguments) (uintptr, *kernel.Syscal n, err := writev(t, file, src) t.IOUsage().AccountWriteSyscall(n) - return uintptr(n), nil, handleIOError(t, n != 0, err, kernel.ERESTARTSYS, "writev", file) + return uintptr(n), nil, handleIOError(t, n != 0, err, syserror.ERESTARTSYS, "writev", file) } // Pwritev implements linux syscall pwritev(2). @@ -189,7 +189,7 @@ func Pwritev(t *kernel.Task, args arch.SyscallArguments) (uintptr, *kernel.Sysca n, err := pwritev(t, file, src, offset) t.IOUsage().AccountWriteSyscall(n) - return uintptr(n), nil, handleIOError(t, n != 0, err, kernel.ERESTARTSYS, "pwritev", file) + return uintptr(n), nil, handleIOError(t, n != 0, err, syserror.ERESTARTSYS, "pwritev", file) } // Pwritev2 implements linux syscall pwritev2(2). @@ -250,12 +250,12 @@ func Pwritev2(t *kernel.Task, args arch.SyscallArguments) (uintptr, *kernel.Sysc if offset == -1 { n, err := writev(t, file, src) t.IOUsage().AccountWriteSyscall(n) - return uintptr(n), nil, handleIOError(t, n != 0, err, kernel.ERESTARTSYS, "pwritev2", file) + return uintptr(n), nil, handleIOError(t, n != 0, err, syserror.ERESTARTSYS, "pwritev2", file) } n, err := pwritev(t, file, src, offset) t.IOUsage().AccountWriteSyscall(n) - return uintptr(n), nil, handleIOError(t, n != 0, err, kernel.ERESTARTSYS, "pwritev2", file) + return uintptr(n), nil, handleIOError(t, n != 0, err, syserror.ERESTARTSYS, "pwritev2", file) } func writev(t *kernel.Task, f *fs.File, src usermem.IOSequence) (int64, error) { diff --git a/pkg/sentry/syscalls/linux/vfs2/BUILD b/pkg/sentry/syscalls/linux/vfs2/BUILD index 64696b438..9ee766552 100644 --- a/pkg/sentry/syscalls/linux/vfs2/BUILD +++ b/pkg/sentry/syscalls/linux/vfs2/BUILD @@ -44,6 +44,9 @@ go_library( "//pkg/context", "//pkg/fspath", "//pkg/gohacks", + "//pkg/log", + "//pkg/marshal", + "//pkg/marshal/primitive", "//pkg/sentry/arch", "//pkg/sentry/fs/lock", "//pkg/sentry/fsbridge", @@ -72,7 +75,5 @@ go_library( "//pkg/syserror", "//pkg/usermem", "//pkg/waiter", - "//tools/go_marshal/marshal", - "//tools/go_marshal/primitive", ], ) diff --git a/pkg/sentry/syscalls/linux/vfs2/aio.go b/pkg/sentry/syscalls/linux/vfs2/aio.go index 42559bf69..6d0a38330 100644 --- a/pkg/sentry/syscalls/linux/vfs2/aio.go +++ b/pkg/sentry/syscalls/linux/vfs2/aio.go @@ -17,6 +17,7 @@ package vfs2 import ( "gvisor.dev/gvisor/pkg/abi/linux" "gvisor.dev/gvisor/pkg/context" + "gvisor.dev/gvisor/pkg/marshal/primitive" "gvisor.dev/gvisor/pkg/sentry/arch" "gvisor.dev/gvisor/pkg/sentry/fsimpl/eventfd" "gvisor.dev/gvisor/pkg/sentry/kernel" @@ -38,21 +39,27 @@ func IoSubmit(t *kernel.Task, args arch.SyscallArguments) (uintptr, *kernel.Sysc } for i := int32(0); i < nrEvents; i++ { - // Copy in the address. - cbAddrNative := t.Arch().Native(0) - if _, err := t.CopyIn(addr, cbAddrNative); err != nil { - if i > 0 { - // Some successful. - return uintptr(i), nil, nil + // Copy in the callback address. + var cbAddr usermem.Addr + switch t.Arch().Width() { + case 8: + var cbAddrP primitive.Uint64 + if _, err := cbAddrP.CopyIn(t, addr); err != nil { + if i > 0 { + // Some successful. + return uintptr(i), nil, nil + } + // Nothing done. + return 0, nil, err } - // Nothing done. - return 0, nil, err + cbAddr = usermem.Addr(cbAddrP) + default: + return 0, nil, syserror.ENOSYS } // Copy in this callback. var cb linux.IOCallback - cbAddr := usermem.Addr(t.Arch().Value(cbAddrNative)) - if _, err := t.CopyIn(cbAddr, &cb); err != nil { + if _, err := cb.CopyIn(t, cbAddr); err != nil { if i > 0 { // Some have been successful. return uintptr(i), nil, nil diff --git a/pkg/sentry/syscalls/linux/vfs2/epoll.go b/pkg/sentry/syscalls/linux/vfs2/epoll.go index c62f03509..d0cbb77eb 100644 --- a/pkg/sentry/syscalls/linux/vfs2/epoll.go +++ b/pkg/sentry/syscalls/linux/vfs2/epoll.go @@ -24,7 +24,6 @@ import ( ktime "gvisor.dev/gvisor/pkg/sentry/kernel/time" "gvisor.dev/gvisor/pkg/sentry/vfs" "gvisor.dev/gvisor/pkg/syserror" - "gvisor.dev/gvisor/pkg/usermem" "gvisor.dev/gvisor/pkg/waiter" ) @@ -141,50 +140,26 @@ func EpollWait(t *kernel.Task, args arch.SyscallArguments) (uintptr, *kernel.Sys return 0, nil, syserror.EINVAL } - // Use a fixed-size buffer in a loop, instead of make([]linux.EpollEvent, - // maxEvents), so that the buffer can be allocated on the stack. + // Allocate space for a few events on the stack for the common case in + // which we don't have too many events. var ( - events [16]linux.EpollEvent - total int + eventsArr [16]linux.EpollEvent ch chan struct{} haveDeadline bool deadline ktime.Time ) for { - batchEvents := len(events) - if batchEvents > maxEvents { - batchEvents = maxEvents - } - n := ep.ReadEvents(events[:batchEvents]) - maxEvents -= n - if n != 0 { - // Copy what we read out. - copiedBytes, err := linux.CopyEpollEventSliceOut(t, eventsAddr, events[:n]) + events := ep.ReadEvents(eventsArr[:0], maxEvents) + if len(events) != 0 { + copiedBytes, err := linux.CopyEpollEventSliceOut(t, eventsAddr, events) copiedEvents := copiedBytes / sizeofEpollEvent // rounded down - eventsAddr += usermem.Addr(copiedEvents * sizeofEpollEvent) - total += copiedEvents - if err != nil { - if total != 0 { - return uintptr(total), nil, nil - } - return 0, nil, err - } - // If we've filled the application's event buffer, we're done. - if maxEvents == 0 { - return uintptr(total), nil, nil - } - // Loop if we read a full batch, under the expectation that there - // may be more events to read. - if n == batchEvents { - continue + if copiedEvents != 0 { + return uintptr(copiedEvents), nil, nil } + return 0, nil, err } - // We get here if n != batchEvents. If we read any number of events - // (just now, or in a previous iteration of this loop), or if timeout - // is 0 (such that epoll_wait should be non-blocking), return the - // events we've read so far to the application. - if total != 0 || timeout == 0 { - return uintptr(total), nil, nil + if timeout == 0 { + return 0, nil, nil } // In the first iteration of this loop, register with the epoll // instance for readability events, but then immediately continue the @@ -207,8 +182,6 @@ func EpollWait(t *kernel.Task, args arch.SyscallArguments) (uintptr, *kernel.Sys if err == syserror.ETIMEDOUT { err = nil } - // total must be 0 since otherwise we would have returned - // above. return 0, nil, err } } diff --git a/pkg/sentry/syscalls/linux/vfs2/fd.go b/pkg/sentry/syscalls/linux/vfs2/fd.go index 4856554fe..d8b8d9783 100644 --- a/pkg/sentry/syscalls/linux/vfs2/fd.go +++ b/pkg/sentry/syscalls/linux/vfs2/fd.go @@ -34,7 +34,7 @@ func Close(t *kernel.Task, args arch.SyscallArguments) (uintptr, *kernel.Syscall // Note that Remove provides a reference on the file that we may use to // flush. It is still active until we drop the final reference below // (and other reference-holding operations complete). - _, file := t.FDTable().Remove(fd) + _, file := t.FDTable().Remove(t, fd) if file == nil { return 0, nil, syserror.EBADF } @@ -137,7 +137,7 @@ func Fcntl(t *kernel.Task, args arch.SyscallArguments) (uintptr, *kernel.Syscall return uintptr(flags.ToLinuxFDFlags()), nil, nil case linux.F_SETFD: flags := args[2].Uint() - err := t.FDTable().SetFlagsVFS2(fd, kernel.FDFlags{ + err := t.FDTable().SetFlagsVFS2(t, fd, kernel.FDFlags{ CloseOnExec: flags&linux.FD_CLOEXEC != 0, }) return 0, nil, err @@ -181,11 +181,11 @@ func Fcntl(t *kernel.Task, args arch.SyscallArguments) (uintptr, *kernel.Syscall if !hasOwner { return 0, nil, nil } - _, err := t.CopyOut(args[2].Pointer(), &owner) + _, err := owner.CopyOut(t, args[2].Pointer()) return 0, nil, err case linux.F_SETOWN_EX: var owner linux.FOwnerEx - _, err := t.CopyIn(args[2].Pointer(), &owner) + _, err := owner.CopyIn(t, args[2].Pointer()) if err != nil { return 0, nil, err } @@ -286,7 +286,7 @@ func posixLock(t *kernel.Task, args arch.SyscallArguments, file *vfs.FileDescrip // Copy in the lock request. flockAddr := args[2].Pointer() var flock linux.Flock - if _, err := t.CopyIn(flockAddr, &flock); err != nil { + if _, err := flock.CopyIn(t, flockAddr); err != nil { return err } diff --git a/pkg/sentry/syscalls/linux/vfs2/ioctl.go b/pkg/sentry/syscalls/linux/vfs2/ioctl.go index 38778a388..2806c3f6f 100644 --- a/pkg/sentry/syscalls/linux/vfs2/ioctl.go +++ b/pkg/sentry/syscalls/linux/vfs2/ioctl.go @@ -16,6 +16,7 @@ package vfs2 import ( "gvisor.dev/gvisor/pkg/abi/linux" + "gvisor.dev/gvisor/pkg/marshal/primitive" "gvisor.dev/gvisor/pkg/sentry/arch" "gvisor.dev/gvisor/pkg/sentry/kernel" "gvisor.dev/gvisor/pkg/syserror" @@ -34,20 +35,20 @@ func Ioctl(t *kernel.Task, args arch.SyscallArguments) (uintptr, *kernel.Syscall // Handle ioctls that apply to all FDs. switch args[1].Int() { case linux.FIONCLEX: - t.FDTable().SetFlagsVFS2(fd, kernel.FDFlags{ + t.FDTable().SetFlagsVFS2(t, fd, kernel.FDFlags{ CloseOnExec: false, }) return 0, nil, nil case linux.FIOCLEX: - t.FDTable().SetFlagsVFS2(fd, kernel.FDFlags{ + t.FDTable().SetFlagsVFS2(t, fd, kernel.FDFlags{ CloseOnExec: true, }) return 0, nil, nil case linux.FIONBIO: var set int32 - if _, err := t.CopyIn(args[2].Pointer(), &set); err != nil { + if _, err := primitive.CopyInt32In(t, args[2].Pointer(), &set); err != nil { return 0, nil, err } flags := file.StatusFlags() @@ -60,7 +61,7 @@ func Ioctl(t *kernel.Task, args arch.SyscallArguments) (uintptr, *kernel.Syscall case linux.FIOASYNC: var set int32 - if _, err := t.CopyIn(args[2].Pointer(), &set); err != nil { + if _, err := primitive.CopyInt32In(t, args[2].Pointer(), &set); err != nil { return 0, nil, err } flags := file.StatusFlags() @@ -82,12 +83,12 @@ func Ioctl(t *kernel.Task, args arch.SyscallArguments) (uintptr, *kernel.Syscall who = owner.PID } } - _, err := t.CopyOut(args[2].Pointer(), &who) + _, err := primitive.CopyInt32Out(t, args[2].Pointer(), who) return 0, nil, err case linux.FIOSETOWN, linux.SIOCSPGRP: var who int32 - if _, err := t.CopyIn(args[2].Pointer(), &who); err != nil { + if _, err := primitive.CopyInt32In(t, args[2].Pointer(), &who); err != nil { return 0, nil, err } ownerType := int32(linux.F_OWNER_PID) diff --git a/pkg/sentry/syscalls/linux/vfs2/mmap.go b/pkg/sentry/syscalls/linux/vfs2/mmap.go index dc05c2994..9d9dbf775 100644 --- a/pkg/sentry/syscalls/linux/vfs2/mmap.go +++ b/pkg/sentry/syscalls/linux/vfs2/mmap.go @@ -17,6 +17,7 @@ package vfs2 import ( "gvisor.dev/gvisor/pkg/abi/linux" "gvisor.dev/gvisor/pkg/sentry/arch" + "gvisor.dev/gvisor/pkg/sentry/fsimpl/tmpfs" "gvisor.dev/gvisor/pkg/sentry/kernel" "gvisor.dev/gvisor/pkg/sentry/memmap" "gvisor.dev/gvisor/pkg/syserror" @@ -85,6 +86,17 @@ func Mmap(t *kernel.Task, args arch.SyscallArguments) (uintptr, *kernel.SyscallC if err := file.ConfigureMMap(t, &opts); err != nil { return 0, nil, err } + } else if shared { + // Back shared anonymous mappings with an anonymous tmpfs file. + opts.Offset = 0 + file, err := tmpfs.NewZeroFile(t, t.Credentials(), t.Kernel().ShmMount(), opts.Length) + if err != nil { + return 0, nil, err + } + defer file.DecRef(t) + if err := file.ConfigureMMap(t, &opts); err != nil { + return 0, nil, err + } } rv, err := t.MemoryManager().MMap(t, opts) diff --git a/pkg/sentry/syscalls/linux/vfs2/mount.go b/pkg/sentry/syscalls/linux/vfs2/mount.go index 4bd5c7ca2..769c9b92f 100644 --- a/pkg/sentry/syscalls/linux/vfs2/mount.go +++ b/pkg/sentry/syscalls/linux/vfs2/mount.go @@ -109,8 +109,8 @@ func Mount(t *kernel.Task, args arch.SyscallArguments) (uintptr, *kernel.Syscall return 0, nil, err } defer target.Release(t) - - return 0, nil, t.Kernel().VFS().MountAt(t, creds, source, &target.pop, fsType, &opts) + _, err = t.Kernel().VFS().MountAt(t, creds, source, &target.pop, fsType, &opts) + return 0, nil, err } // Umount2 implements Linux syscall umount2(2). diff --git a/pkg/sentry/syscalls/linux/vfs2/pipe.go b/pkg/sentry/syscalls/linux/vfs2/pipe.go index 9b4848d9e..ee38fdca0 100644 --- a/pkg/sentry/syscalls/linux/vfs2/pipe.go +++ b/pkg/sentry/syscalls/linux/vfs2/pipe.go @@ -16,6 +16,7 @@ package vfs2 import ( "gvisor.dev/gvisor/pkg/abi/linux" + "gvisor.dev/gvisor/pkg/marshal/primitive" "gvisor.dev/gvisor/pkg/sentry/arch" "gvisor.dev/gvisor/pkg/sentry/fsimpl/pipefs" "gvisor.dev/gvisor/pkg/sentry/kernel" @@ -51,9 +52,9 @@ func pipe2(t *kernel.Task, addr usermem.Addr, flags int32) error { if err != nil { return err } - if _, err := t.CopyOut(addr, fds); err != nil { + if _, err := primitive.CopyInt32SliceOut(t, addr, fds); err != nil { for _, fd := range fds { - if _, file := t.FDTable().Remove(fd); file != nil { + if _, file := t.FDTable().Remove(t, fd); file != nil { file.DecRef(t) } } diff --git a/pkg/sentry/syscalls/linux/vfs2/poll.go b/pkg/sentry/syscalls/linux/vfs2/poll.go index 7b9d5e18a..c22e4ce54 100644 --- a/pkg/sentry/syscalls/linux/vfs2/poll.go +++ b/pkg/sentry/syscalls/linux/vfs2/poll.go @@ -165,7 +165,7 @@ func copyInPollFDs(t *kernel.Task, addr usermem.Addr, nfds uint) ([]linux.PollFD pfd := make([]linux.PollFD, nfds) if nfds > 0 { - if _, err := t.CopyIn(addr, &pfd); err != nil { + if _, err := linux.CopyPollFDSliceIn(t, addr, pfd); err != nil { return nil, err } } @@ -192,7 +192,7 @@ func doPoll(t *kernel.Task, addr usermem.Addr, nfds uint, timeout time.Duration) // The poll entries are copied out regardless of whether // any are set or not. This aligns with the Linux behavior. if nfds > 0 && err == nil { - if _, err := t.CopyOut(addr, pfd); err != nil { + if _, err := linux.CopyPollFDSliceOut(t, addr, pfd); err != nil { return remainingTimeout, 0, err } } @@ -205,7 +205,7 @@ func CopyInFDSet(t *kernel.Task, addr usermem.Addr, nBytes, nBitsInLastPartialBy set := make([]byte, nBytes) if addr != 0 { - if _, err := t.CopyIn(addr, &set); err != nil { + if _, err := t.CopyInBytes(addr, set); err != nil { return nil, err } // If we only use part of the last byte, mask out the extraneous bits. @@ -332,19 +332,19 @@ func doSelect(t *kernel.Task, nfds int, readFDs, writeFDs, exceptFDs usermem.Add // Copy updated vectors back. if readFDs != 0 { - if _, err := t.CopyOut(readFDs, r); err != nil { + if _, err := t.CopyOutBytes(readFDs, r); err != nil { return 0, err } } if writeFDs != 0 { - if _, err := t.CopyOut(writeFDs, w); err != nil { + if _, err := t.CopyOutBytes(writeFDs, w); err != nil { return 0, err } } if exceptFDs != 0 { - if _, err := t.CopyOut(exceptFDs, e); err != nil { + if _, err := t.CopyOutBytes(exceptFDs, e); err != nil { return 0, err } } @@ -415,7 +415,7 @@ func poll(t *kernel.Task, pfdAddr usermem.Addr, nfds uint, timeout time.Duration nfds: nfds, timeout: remainingTimeout, }) - return 0, kernel.ERESTART_RESTARTBLOCK + return 0, syserror.ERESTART_RESTARTBLOCK } return n, err } @@ -462,7 +462,7 @@ func Ppoll(t *kernel.Task, args arch.SyscallArguments) (uintptr, *kernel.Syscall // Note that this means that if err is nil but copyErr is not, copyErr is // ignored. This is consistent with Linux. if err == syserror.EINTR && copyErr == nil { - err = kernel.ERESTARTNOHAND + err = syserror.ERESTARTNOHAND } return n, nil, err } @@ -492,11 +492,17 @@ func Select(t *kernel.Task, args arch.SyscallArguments) (uintptr, *kernel.Syscal copyErr := copyOutTimevalRemaining(t, startNs, timeout, timevalAddr) // See comment in Ppoll. if err == syserror.EINTR && copyErr == nil { - err = kernel.ERESTARTNOHAND + err = syserror.ERESTARTNOHAND } return n, nil, err } +// +marshal +type sigSetWithSize struct { + sigsetAddr uint64 + sizeofSigset uint64 +} + // Pselect implements linux syscall pselect(2). func Pselect(t *kernel.Task, args arch.SyscallArguments) (uintptr, *kernel.SyscallControl, error) { nfds := int(args[0].Int()) // select(2) uses an int. @@ -533,17 +539,11 @@ func Pselect(t *kernel.Task, args arch.SyscallArguments) (uintptr, *kernel.Sysca copyErr := copyOutTimespecRemaining(t, startNs, timeout, timespecAddr) // See comment in Ppoll. if err == syserror.EINTR && copyErr == nil { - err = kernel.ERESTARTNOHAND + err = syserror.ERESTARTNOHAND } return n, nil, err } -// +marshal -type sigSetWithSize struct { - sigsetAddr uint64 - sizeofSigset uint64 -} - // copyTimespecInToDuration copies a Timespec from the untrusted app range, // validates it and converts it to a Duration. // diff --git a/pkg/sentry/syscalls/linux/vfs2/read_write.go b/pkg/sentry/syscalls/linux/vfs2/read_write.go index a905dae0a..b77b29dcc 100644 --- a/pkg/sentry/syscalls/linux/vfs2/read_write.go +++ b/pkg/sentry/syscalls/linux/vfs2/read_write.go @@ -62,7 +62,7 @@ func Read(t *kernel.Task, args arch.SyscallArguments) (uintptr, *kernel.SyscallC n, err := read(t, file, dst, vfs.ReadOptions{}) t.IOUsage().AccountReadSyscall(n) - return uintptr(n), nil, slinux.HandleIOErrorVFS2(t, n != 0, err, kernel.ERESTARTSYS, "read", file) + return uintptr(n), nil, slinux.HandleIOErrorVFS2(t, n != 0, err, syserror.ERESTARTSYS, "read", file) } // Readv implements Linux syscall readv(2). @@ -87,7 +87,7 @@ func Readv(t *kernel.Task, args arch.SyscallArguments) (uintptr, *kernel.Syscall n, err := read(t, file, dst, vfs.ReadOptions{}) t.IOUsage().AccountReadSyscall(n) - return uintptr(n), nil, slinux.HandleIOErrorVFS2(t, n != 0, err, kernel.ERESTARTSYS, "readv", file) + return uintptr(n), nil, slinux.HandleIOErrorVFS2(t, n != 0, err, syserror.ERESTARTSYS, "readv", file) } func read(t *kernel.Task, file *vfs.FileDescription, dst usermem.IOSequence, opts vfs.ReadOptions) (int64, error) { @@ -174,7 +174,7 @@ func Pread64(t *kernel.Task, args arch.SyscallArguments) (uintptr, *kernel.Sysca n, err := pread(t, file, dst, offset, vfs.ReadOptions{}) t.IOUsage().AccountReadSyscall(n) - return uintptr(n), nil, slinux.HandleIOErrorVFS2(t, n != 0, err, kernel.ERESTARTSYS, "pread64", file) + return uintptr(n), nil, slinux.HandleIOErrorVFS2(t, n != 0, err, syserror.ERESTARTSYS, "pread64", file) } // Preadv implements Linux syscall preadv(2). @@ -205,7 +205,7 @@ func Preadv(t *kernel.Task, args arch.SyscallArguments) (uintptr, *kernel.Syscal n, err := pread(t, file, dst, offset, vfs.ReadOptions{}) t.IOUsage().AccountReadSyscall(n) - return uintptr(n), nil, slinux.HandleIOErrorVFS2(t, n != 0, err, kernel.ERESTARTSYS, "preadv", file) + return uintptr(n), nil, slinux.HandleIOErrorVFS2(t, n != 0, err, syserror.ERESTARTSYS, "preadv", file) } // Preadv2 implements Linux syscall preadv2(2). @@ -251,7 +251,7 @@ func Preadv2(t *kernel.Task, args arch.SyscallArguments) (uintptr, *kernel.Sysca n, err = pread(t, file, dst, offset, opts) } t.IOUsage().AccountReadSyscall(n) - return uintptr(n), nil, slinux.HandleIOErrorVFS2(t, n != 0, err, kernel.ERESTARTSYS, "preadv2", file) + return uintptr(n), nil, slinux.HandleIOErrorVFS2(t, n != 0, err, syserror.ERESTARTSYS, "preadv2", file) } func pread(t *kernel.Task, file *vfs.FileDescription, dst usermem.IOSequence, offset int64, opts vfs.ReadOptions) (int64, error) { @@ -332,7 +332,7 @@ func Write(t *kernel.Task, args arch.SyscallArguments) (uintptr, *kernel.Syscall n, err := write(t, file, src, vfs.WriteOptions{}) t.IOUsage().AccountWriteSyscall(n) - return uintptr(n), nil, slinux.HandleIOErrorVFS2(t, n != 0, err, kernel.ERESTARTSYS, "write", file) + return uintptr(n), nil, slinux.HandleIOErrorVFS2(t, n != 0, err, syserror.ERESTARTSYS, "write", file) } // Writev implements Linux syscall writev(2). @@ -357,7 +357,7 @@ func Writev(t *kernel.Task, args arch.SyscallArguments) (uintptr, *kernel.Syscal n, err := write(t, file, src, vfs.WriteOptions{}) t.IOUsage().AccountWriteSyscall(n) - return uintptr(n), nil, slinux.HandleIOErrorVFS2(t, n != 0, err, kernel.ERESTARTSYS, "writev", file) + return uintptr(n), nil, slinux.HandleIOErrorVFS2(t, n != 0, err, syserror.ERESTARTSYS, "writev", file) } func write(t *kernel.Task, file *vfs.FileDescription, src usermem.IOSequence, opts vfs.WriteOptions) (int64, error) { @@ -444,7 +444,7 @@ func Pwrite64(t *kernel.Task, args arch.SyscallArguments) (uintptr, *kernel.Sysc n, err := pwrite(t, file, src, offset, vfs.WriteOptions{}) t.IOUsage().AccountWriteSyscall(n) - return uintptr(n), nil, slinux.HandleIOErrorVFS2(t, n != 0, err, kernel.ERESTARTSYS, "pwrite64", file) + return uintptr(n), nil, slinux.HandleIOErrorVFS2(t, n != 0, err, syserror.ERESTARTSYS, "pwrite64", file) } // Pwritev implements Linux syscall pwritev(2). @@ -475,7 +475,7 @@ func Pwritev(t *kernel.Task, args arch.SyscallArguments) (uintptr, *kernel.Sysca n, err := pwrite(t, file, src, offset, vfs.WriteOptions{}) t.IOUsage().AccountReadSyscall(n) - return uintptr(n), nil, slinux.HandleIOErrorVFS2(t, n != 0, err, kernel.ERESTARTSYS, "pwritev", file) + return uintptr(n), nil, slinux.HandleIOErrorVFS2(t, n != 0, err, syserror.ERESTARTSYS, "pwritev", file) } // Pwritev2 implements Linux syscall pwritev2(2). @@ -521,7 +521,7 @@ func Pwritev2(t *kernel.Task, args arch.SyscallArguments) (uintptr, *kernel.Sysc n, err = pwrite(t, file, src, offset, opts) } t.IOUsage().AccountWriteSyscall(n) - return uintptr(n), nil, slinux.HandleIOErrorVFS2(t, n != 0, err, kernel.ERESTARTSYS, "pwritev2", file) + return uintptr(n), nil, slinux.HandleIOErrorVFS2(t, n != 0, err, syserror.ERESTARTSYS, "pwritev2", file) } func pwrite(t *kernel.Task, file *vfs.FileDescription, src usermem.IOSequence, offset int64, opts vfs.WriteOptions) (int64, error) { diff --git a/pkg/sentry/syscalls/linux/vfs2/setstat.go b/pkg/sentry/syscalls/linux/vfs2/setstat.go index 5e6eb13ba..1ee37e5a8 100644 --- a/pkg/sentry/syscalls/linux/vfs2/setstat.go +++ b/pkg/sentry/syscalls/linux/vfs2/setstat.go @@ -346,7 +346,7 @@ func populateSetStatOptionsForUtimes(t *kernel.Task, timesAddr usermem.Addr, opt return nil } var times [2]linux.Timeval - if _, err := t.CopyIn(timesAddr, ×); err != nil { + if _, err := linux.CopyTimevalSliceIn(t, timesAddr, times[:]); err != nil { return err } if times[0].Usec < 0 || times[0].Usec > 999999 || times[1].Usec < 0 || times[1].Usec > 999999 { @@ -410,7 +410,7 @@ func populateSetStatOptionsForUtimens(t *kernel.Task, timesAddr usermem.Addr, op return nil } var times [2]linux.Timespec - if _, err := t.CopyIn(timesAddr, ×); err != nil { + if _, err := linux.CopyTimespecSliceIn(t, timesAddr, times[:]); err != nil { return err } if times[0].Nsec != linux.UTIME_OMIT { diff --git a/pkg/sentry/syscalls/linux/vfs2/socket.go b/pkg/sentry/syscalls/linux/vfs2/socket.go index 4a68c64f3..bfae6b7e9 100644 --- a/pkg/sentry/syscalls/linux/vfs2/socket.go +++ b/pkg/sentry/syscalls/linux/vfs2/socket.go @@ -19,6 +19,8 @@ import ( "gvisor.dev/gvisor/pkg/abi/linux" "gvisor.dev/gvisor/pkg/binary" + "gvisor.dev/gvisor/pkg/marshal" + "gvisor.dev/gvisor/pkg/marshal/primitive" "gvisor.dev/gvisor/pkg/sentry/arch" "gvisor.dev/gvisor/pkg/sentry/kernel" ktime "gvisor.dev/gvisor/pkg/sentry/kernel/time" @@ -30,8 +32,6 @@ import ( "gvisor.dev/gvisor/pkg/syserr" "gvisor.dev/gvisor/pkg/syserror" "gvisor.dev/gvisor/pkg/usermem" - "gvisor.dev/gvisor/tools/go_marshal/marshal" - "gvisor.dev/gvisor/tools/go_marshal/primitive" ) // minListenBacklog is the minimum reasonable backlog for listening sockets. @@ -66,10 +66,10 @@ const flagsOffset = 48 const sizeOfInt32 = 4 // messageHeader64Len is the length of a MessageHeader64 struct. -var messageHeader64Len = uint64(binary.Size(MessageHeader64{})) +var messageHeader64Len = uint64((*MessageHeader64)(nil).SizeBytes()) // multipleMessageHeader64Len is the length of a multipeMessageHeader64 struct. -var multipleMessageHeader64Len = uint64(binary.Size(multipleMessageHeader64{})) +var multipleMessageHeader64Len = uint64((*multipleMessageHeader64)(nil).SizeBytes()) // baseRecvFlags are the flags that are accepted across recvmsg(2), // recvmmsg(2), and recvfrom(2). @@ -77,6 +77,8 @@ const baseRecvFlags = linux.MSG_OOB | linux.MSG_DONTROUTE | linux.MSG_DONTWAIT | // MessageHeader64 is the 64-bit representation of the msghdr struct used in // the recvmsg and sendmsg syscalls. +// +// +marshal type MessageHeader64 struct { // Name is the optional pointer to a network address buffer. Name uint64 @@ -105,30 +107,14 @@ type MessageHeader64 struct { // multipleMessageHeader64 is the 64-bit representation of the mmsghdr struct used in // the recvmmsg and sendmmsg syscalls. +// +// +marshal type multipleMessageHeader64 struct { msgHdr MessageHeader64 msgLen uint32 _ int32 } -// CopyInMessageHeader64 copies a message header from user to kernel memory. -func CopyInMessageHeader64(t *kernel.Task, addr usermem.Addr, msg *MessageHeader64) error { - b := t.CopyScratchBuffer(52) - if _, err := t.CopyInBytes(addr, b); err != nil { - return err - } - - msg.Name = usermem.ByteOrder.Uint64(b[0:]) - msg.NameLen = usermem.ByteOrder.Uint32(b[8:]) - msg.Iov = usermem.ByteOrder.Uint64(b[16:]) - msg.IovLen = usermem.ByteOrder.Uint64(b[24:]) - msg.Control = usermem.ByteOrder.Uint64(b[32:]) - msg.ControlLen = usermem.ByteOrder.Uint64(b[40:]) - msg.Flags = int32(usermem.ByteOrder.Uint32(b[48:])) - - return nil -} - // CaptureAddress allocates memory for and copies a socket address structure // from the untrusted address space range. func CaptureAddress(t *kernel.Task, addr usermem.Addr, addrlen uint32) ([]byte, error) { @@ -147,10 +133,10 @@ func CaptureAddress(t *kernel.Task, addr usermem.Addr, addrlen uint32) ([]byte, // writeAddress writes a sockaddr structure and its length to an output buffer // in the unstrusted address space range. If the address is bigger than the // buffer, it is truncated. -func writeAddress(t *kernel.Task, addr interface{}, addrLen uint32, addrPtr usermem.Addr, addrLenPtr usermem.Addr) error { +func writeAddress(t *kernel.Task, addr linux.SockAddr, addrLen uint32, addrPtr usermem.Addr, addrLenPtr usermem.Addr) error { // Get the buffer length. var bufLen uint32 - if _, err := t.CopyIn(addrLenPtr, &bufLen); err != nil { + if _, err := primitive.CopyUint32In(t, addrLenPtr, &bufLen); err != nil { return err } @@ -159,7 +145,7 @@ func writeAddress(t *kernel.Task, addr interface{}, addrLen uint32, addrPtr user } // Write the length unconditionally. - if _, err := t.CopyOut(addrLenPtr, addrLen); err != nil { + if _, err := primitive.CopyUint32Out(t, addrLenPtr, addrLen); err != nil { return err } @@ -172,7 +158,8 @@ func writeAddress(t *kernel.Task, addr interface{}, addrLen uint32, addrPtr user } // Copy as much of the address as will fit in the buffer. - encodedAddr := binary.Marshal(nil, usermem.ByteOrder, addr) + encodedAddr := t.CopyScratchBuffer(addr.SizeBytes()) + addr.MarshalUnsafe(encodedAddr) if bufLen > uint32(len(encodedAddr)) { bufLen = uint32(len(encodedAddr)) } @@ -250,9 +237,9 @@ func SocketPair(t *kernel.Task, args arch.SyscallArguments) (uintptr, *kernel.Sy return 0, nil, err } - if _, err := t.CopyOut(addr, fds); err != nil { + if _, err := primitive.CopyInt32SliceOut(t, addr, fds); err != nil { for _, fd := range fds { - if _, file := t.FDTable().Remove(fd); file != nil { + if _, file := t.FDTable().Remove(t, fd); file != nil { file.DecRef(t) } } @@ -288,7 +275,7 @@ func Connect(t *kernel.Task, args arch.SyscallArguments) (uintptr, *kernel.Sysca } blocking := (file.StatusFlags() & linux.SOCK_NONBLOCK) == 0 - return 0, nil, syserror.ConvertIntr(s.Connect(t, a, blocking).ToError(), kernel.ERESTARTSYS) + return 0, nil, syserror.ConvertIntr(s.Connect(t, a, blocking).ToError(), syserror.ERESTARTSYS) } // accept is the implementation of the accept syscall. It is called by accept @@ -319,7 +306,7 @@ func accept(t *kernel.Task, fd int32, addr usermem.Addr, addrLen usermem.Addr, f peerRequested := addrLen != 0 nfd, peer, peerLen, e := s.Accept(t, peerRequested, flags, blocking) if e != nil { - return 0, syserror.ConvertIntr(e.ToError(), kernel.ERESTARTSYS) + return 0, syserror.ConvertIntr(e.ToError(), syserror.ERESTARTSYS) } if peerRequested { // NOTE(magi): Linux does not give you an error if it can't @@ -459,8 +446,8 @@ func GetSockOpt(t *kernel.Task, args arch.SyscallArguments) (uintptr, *kernel.Sy } // Read the length. Reject negative values. - optLen := int32(0) - if _, err := t.CopyIn(optLenAddr, &optLen); err != nil { + var optLen int32 + if _, err := primitive.CopyInt32In(t, optLenAddr, &optLen); err != nil { return 0, nil, err } if optLen < 0 { @@ -474,7 +461,7 @@ func GetSockOpt(t *kernel.Task, args arch.SyscallArguments) (uintptr, *kernel.Sy } vLen := int32(binary.Size(v)) - if _, err := t.CopyOut(optLenAddr, vLen); err != nil { + if _, err := primitive.CopyInt32Out(t, optLenAddr, vLen); err != nil { return 0, nil, err } @@ -736,7 +723,7 @@ func RecvMMsg(t *kernel.Task, args arch.SyscallArguments) (uintptr, *kernel.Sysc if !ok { return 0, nil, syserror.EFAULT } - if _, err = t.CopyOut(lp, uint32(n)); err != nil { + if _, err = primitive.CopyUint32Out(t, lp, uint32(n)); err != nil { break } count++ @@ -751,7 +738,7 @@ func RecvMMsg(t *kernel.Task, args arch.SyscallArguments) (uintptr, *kernel.Sysc func recvSingleMsg(t *kernel.Task, s socket.SocketVFS2, msgPtr usermem.Addr, flags int32, haveDeadline bool, deadline ktime.Time) (uintptr, error) { // Capture the message header and io vectors. var msg MessageHeader64 - if err := CopyInMessageHeader64(t, msgPtr, &msg); err != nil { + if _, err := msg.CopyIn(t, msgPtr); err != nil { return 0, err } @@ -774,7 +761,7 @@ func recvSingleMsg(t *kernel.Task, s socket.SocketVFS2, msgPtr usermem.Addr, fla if msg.ControlLen == 0 && msg.NameLen == 0 { n, mflags, _, _, cms, err := s.RecvMsg(t, dst, int(flags), haveDeadline, deadline, false, 0) if err != nil { - return 0, syserror.ConvertIntr(err.ToError(), kernel.ERESTARTSYS) + return 0, syserror.ConvertIntr(err.ToError(), syserror.ERESTARTSYS) } if !cms.Unix.Empty() { mflags |= linux.MSG_CTRUNC @@ -783,7 +770,7 @@ func recvSingleMsg(t *kernel.Task, s socket.SocketVFS2, msgPtr usermem.Addr, fla if int(msg.Flags) != mflags { // Copy out the flags to the caller. - if _, err := t.CopyOut(msgPtr+flagsOffset, int32(mflags)); err != nil { + if _, err := primitive.CopyInt32Out(t, msgPtr+flagsOffset, int32(mflags)); err != nil { return 0, err } } @@ -796,7 +783,7 @@ func recvSingleMsg(t *kernel.Task, s socket.SocketVFS2, msgPtr usermem.Addr, fla } n, mflags, sender, senderLen, cms, e := s.RecvMsg(t, dst, int(flags), haveDeadline, deadline, msg.NameLen != 0, msg.ControlLen) if e != nil { - return 0, syserror.ConvertIntr(e.ToError(), kernel.ERESTARTSYS) + return 0, syserror.ConvertIntr(e.ToError(), syserror.ERESTARTSYS) } defer cms.Release(t) @@ -820,17 +807,17 @@ func recvSingleMsg(t *kernel.Task, s socket.SocketVFS2, msgPtr usermem.Addr, fla } // Copy the control data to the caller. - if _, err := t.CopyOut(msgPtr+controlLenOffset, uint64(len(controlData))); err != nil { + if _, err := primitive.CopyUint64Out(t, msgPtr+controlLenOffset, uint64(len(controlData))); err != nil { return 0, err } if len(controlData) > 0 { - if _, err := t.CopyOut(usermem.Addr(msg.Control), controlData); err != nil { + if _, err := t.CopyOutBytes(usermem.Addr(msg.Control), controlData); err != nil { return 0, err } } // Copy out the flags to the caller. - if _, err := t.CopyOut(msgPtr+flagsOffset, int32(mflags)); err != nil { + if _, err := primitive.CopyInt32Out(t, msgPtr+flagsOffset, int32(mflags)); err != nil { return 0, err } @@ -885,7 +872,7 @@ func recvFrom(t *kernel.Task, fd int32, bufPtr usermem.Addr, bufLen uint64, flag n, _, sender, senderLen, cm, e := s.RecvMsg(t, dst, int(flags), haveDeadline, deadline, nameLenPtr != 0, 0) cm.Release(t) if e != nil { - return 0, syserror.ConvertIntr(e.ToError(), kernel.ERESTARTSYS) + return 0, syserror.ConvertIntr(e.ToError(), syserror.ERESTARTSYS) } // Copy the address to the caller. @@ -999,7 +986,7 @@ func SendMMsg(t *kernel.Task, args arch.SyscallArguments) (uintptr, *kernel.Sysc if !ok { return 0, nil, syserror.EFAULT } - if _, err = t.CopyOut(lp, uint32(n)); err != nil { + if _, err = primitive.CopyUint32Out(t, lp, uint32(n)); err != nil { break } count++ @@ -1014,7 +1001,7 @@ func SendMMsg(t *kernel.Task, args arch.SyscallArguments) (uintptr, *kernel.Sysc func sendSingleMsg(t *kernel.Task, s socket.SocketVFS2, file *vfs.FileDescription, msgPtr usermem.Addr, flags int32) (uintptr, error) { // Capture the message header. var msg MessageHeader64 - if err := CopyInMessageHeader64(t, msgPtr, &msg); err != nil { + if _, err := msg.CopyIn(t, msgPtr); err != nil { return 0, err } @@ -1025,7 +1012,7 @@ func sendSingleMsg(t *kernel.Task, s socket.SocketVFS2, file *vfs.FileDescriptio return 0, syserror.ENOBUFS } controlData = make([]byte, msg.ControlLen) - if _, err := t.CopyIn(usermem.Addr(msg.Control), &controlData); err != nil { + if _, err := t.CopyInBytes(usermem.Addr(msg.Control), controlData); err != nil { return 0, err } } @@ -1067,7 +1054,7 @@ func sendSingleMsg(t *kernel.Task, s socket.SocketVFS2, file *vfs.FileDescriptio // Call the syscall implementation. n, e := s.SendMsg(t, src, to, int(flags), haveDeadline, deadline, controlMessages) - err = slinux.HandleIOErrorVFS2(t, n != 0, e.ToError(), kernel.ERESTARTSYS, "sendmsg", file) + err = slinux.HandleIOErrorVFS2(t, n != 0, e.ToError(), syserror.ERESTARTSYS, "sendmsg", file) if err != nil { controlMessages.Release(t) } @@ -1127,7 +1114,7 @@ func sendTo(t *kernel.Task, fd int32, bufPtr usermem.Addr, bufLen uint64, flags // Call the syscall implementation. n, e := s.SendMsg(t, src, to, int(flags), haveDeadline, deadline, socket.ControlMessages{Unix: control.New(t, s, nil)}) - return uintptr(n), slinux.HandleIOErrorVFS2(t, n != 0, e.ToError(), kernel.ERESTARTSYS, "sendto", file) + return uintptr(n), slinux.HandleIOErrorVFS2(t, n != 0, e.ToError(), syserror.ERESTARTSYS, "sendto", file) } // SendTo implements the linux syscall sendto(2). diff --git a/pkg/sentry/syscalls/linux/vfs2/splice.go b/pkg/sentry/syscalls/linux/vfs2/splice.go index 75bfa2c79..f55d74cd2 100644 --- a/pkg/sentry/syscalls/linux/vfs2/splice.go +++ b/pkg/sentry/syscalls/linux/vfs2/splice.go @@ -18,6 +18,8 @@ import ( "io" "gvisor.dev/gvisor/pkg/abi/linux" + "gvisor.dev/gvisor/pkg/log" + "gvisor.dev/gvisor/pkg/marshal/primitive" "gvisor.dev/gvisor/pkg/sentry/arch" "gvisor.dev/gvisor/pkg/sentry/kernel" "gvisor.dev/gvisor/pkg/sentry/kernel/pipe" @@ -88,7 +90,7 @@ func Splice(t *kernel.Task, args arch.SyscallArguments) (uintptr, *kernel.Syscal if inFile.Options().DenyPRead { return 0, nil, syserror.EINVAL } - if _, err := t.CopyIn(inOffsetPtr, &inOffset); err != nil { + if _, err := primitive.CopyInt64In(t, inOffsetPtr, &inOffset); err != nil { return 0, nil, err } if inOffset < 0 { @@ -103,7 +105,7 @@ func Splice(t *kernel.Task, args arch.SyscallArguments) (uintptr, *kernel.Syscal if outFile.Options().DenyPWrite { return 0, nil, syserror.EINVAL } - if _, err := t.CopyIn(outOffsetPtr, &outOffset); err != nil { + if _, err := primitive.CopyInt64In(t, outOffsetPtr, &outOffset); err != nil { return 0, nil, err } if outOffset < 0 { @@ -131,23 +133,24 @@ func Splice(t *kernel.Task, args arch.SyscallArguments) (uintptr, *kernel.Syscal case inIsPipe && outIsPipe: n, err = pipe.Splice(t, outPipeFD, inPipeFD, count) case inIsPipe: + n, err = inPipeFD.SpliceToNonPipe(t, outFile, outOffset, count) if outOffset != -1 { - n, err = outFile.PWrite(t, inPipeFD.IOSequence(count), outOffset, vfs.WriteOptions{}) outOffset += n - } else { - n, err = outFile.Write(t, inPipeFD.IOSequence(count), vfs.WriteOptions{}) } case outIsPipe: + n, err = outPipeFD.SpliceFromNonPipe(t, inFile, inOffset, count) if inOffset != -1 { - n, err = inFile.PRead(t, outPipeFD.IOSequence(count), inOffset, vfs.ReadOptions{}) inOffset += n - } else { - n, err = inFile.Read(t, outPipeFD.IOSequence(count), vfs.ReadOptions{}) } default: - panic("not possible") + panic("at least one end of splice must be a pipe") } + if n == 0 && err == io.EOF { + // We reached the end of the file. Eat the error and exit the loop. + err = nil + break + } if n != 0 || err != syserror.ErrWouldBlock || nonBlock { break } @@ -158,12 +161,12 @@ func Splice(t *kernel.Task, args arch.SyscallArguments) (uintptr, *kernel.Syscal // Copy updated offsets out. if inOffsetPtr != 0 { - if _, err := t.CopyOut(inOffsetPtr, &inOffset); err != nil { + if _, err := primitive.CopyInt64Out(t, inOffsetPtr, inOffset); err != nil { return 0, nil, err } } if outOffsetPtr != 0 { - if _, err := t.CopyOut(outOffsetPtr, &outOffset); err != nil { + if _, err := primitive.CopyInt64Out(t, outOffsetPtr, outOffset); err != nil { return 0, nil, err } } @@ -301,9 +304,12 @@ func Sendfile(t *kernel.Task, args arch.SyscallArguments) (uintptr, *kernel.Sysc if inFile.Options().DenyPRead { return 0, nil, syserror.ESPIPE } - if _, err := t.CopyIn(offsetAddr, &offset); err != nil { + var offsetP primitive.Int64 + if _, err := offsetP.CopyIn(t, offsetAddr); err != nil { return 0, nil, err } + offset = int64(offsetP) + if offset < 0 { return 0, nil, syserror.EINVAL } @@ -341,17 +347,15 @@ func Sendfile(t *kernel.Task, args arch.SyscallArguments) (uintptr, *kernel.Sysc if outIsPipe { for n < count { var spliceN int64 - if offset != -1 { - spliceN, err = inFile.PRead(t, outPipeFD.IOSequence(count), offset, vfs.ReadOptions{}) - offset += spliceN - } else { - spliceN, err = inFile.Read(t, outPipeFD.IOSequence(count), vfs.ReadOptions{}) - } + spliceN, err = outPipeFD.SpliceFromNonPipe(t, inFile, offset, count) if spliceN == 0 && err == io.EOF { // We reached the end of the file. Eat the error and exit the loop. err = nil break } + if offset != -1 { + offset += spliceN + } n += spliceN if err == syserror.ErrWouldBlock && !nonBlock { err = dw.waitForBoth(t) @@ -371,19 +375,18 @@ func Sendfile(t *kernel.Task, args arch.SyscallArguments) (uintptr, *kernel.Sysc } else { readN, err = inFile.Read(t, usermem.BytesIOSequence(buf), vfs.ReadOptions{}) } - if readN == 0 && err == io.EOF { - // We reached the end of the file. Eat the error and exit the loop. - err = nil + if readN == 0 && err != nil { + if err == io.EOF { + // We reached the end of the file. Eat the error before exiting the loop. + err = nil + } break } n += readN - if err != nil { - break - } // Write all of the bytes that we read. This may need // multiple write calls to complete. - wbuf := buf[:n] + wbuf := buf[:readN] for len(wbuf) > 0 { var writeN int64 writeN, err = outFile.Write(t, usermem.BytesIOSequence(wbuf), vfs.WriteOptions{}) @@ -392,12 +395,21 @@ func Sendfile(t *kernel.Task, args arch.SyscallArguments) (uintptr, *kernel.Sysc err = dw.waitForOut(t) } if err != nil { - // We didn't complete the write. Only - // report the bytes that were actually - // written, and rewind the offset. + // We didn't complete the write. Only report the bytes that were actually + // written, and rewind offsets as needed. notWritten := int64(len(wbuf)) n -= notWritten - if offset != -1 { + if offset == -1 { + // We modified the offset of the input file itself during the read + // operation. Rewind it. + if _, seekErr := inFile.Seek(t, -notWritten, linux.SEEK_CUR); seekErr != nil { + // Log the error but don't return it, since the write has already + // completed successfully. + log.Warningf("failed to roll back input file offset: %v", seekErr) + } + } else { + // The sendfile call was provided an offset parameter that should be + // adjusted to reflect the number of bytes sent. Rewind it. offset -= notWritten } break @@ -414,7 +426,8 @@ func Sendfile(t *kernel.Task, args arch.SyscallArguments) (uintptr, *kernel.Sysc if offsetAddr != 0 { // Copy out the new offset. - if _, err := t.CopyOut(offsetAddr, offset); err != nil { + offsetP := primitive.Uint64(offset) + if _, err := offsetP.CopyOut(t, offsetAddr); err != nil { return 0, nil, err } } diff --git a/pkg/sentry/syscalls/linux/vfs2/sync.go b/pkg/sentry/syscalls/linux/vfs2/sync.go index a6491ac37..6e9b599e2 100644 --- a/pkg/sentry/syscalls/linux/vfs2/sync.go +++ b/pkg/sentry/syscalls/linux/vfs2/sync.go @@ -108,7 +108,7 @@ func SyncFileRange(t *kernel.Task, args arch.SyscallArguments) (uintptr, *kernel if flags&linux.SYNC_FILE_RANGE_WAIT_AFTER != 0 { if err := file.Sync(t); err != nil { - return 0, nil, syserror.ConvertIntr(err, kernel.ERESTARTSYS) + return 0, nil, syserror.ConvertIntr(err, syserror.ERESTARTSYS) } } return 0, nil, nil diff --git a/pkg/sentry/syscalls/linux/vfs2/timerfd.go b/pkg/sentry/syscalls/linux/vfs2/timerfd.go index 7a26890ef..250870c03 100644 --- a/pkg/sentry/syscalls/linux/vfs2/timerfd.go +++ b/pkg/sentry/syscalls/linux/vfs2/timerfd.go @@ -87,7 +87,7 @@ func TimerfdSettime(t *kernel.Task, args arch.SyscallArguments) (uintptr, *kerne } var newVal linux.Itimerspec - if _, err := t.CopyIn(newValAddr, &newVal); err != nil { + if _, err := newVal.CopyIn(t, newValAddr); err != nil { return 0, nil, err } newS, err := ktime.SettingFromItimerspec(newVal, flags&linux.TFD_TIMER_ABSTIME != 0, tfd.Clock()) @@ -97,7 +97,7 @@ func TimerfdSettime(t *kernel.Task, args arch.SyscallArguments) (uintptr, *kerne tm, oldS := tfd.SetTime(newS) if oldValAddr != 0 { oldVal := ktime.ItimerspecFromSetting(tm, oldS) - if _, err := t.CopyOut(oldValAddr, &oldVal); err != nil { + if _, err := oldVal.CopyOut(t, oldValAddr); err != nil { return 0, nil, err } } @@ -122,6 +122,6 @@ func TimerfdGettime(t *kernel.Task, args arch.SyscallArguments) (uintptr, *kerne tm, s := tfd.GetTime() curVal := ktime.ItimerspecFromSetting(tm, s) - _, err := t.CopyOut(curValAddr, &curVal) + _, err := curVal.CopyOut(t, curValAddr) return 0, nil, err } diff --git a/pkg/sentry/syscalls/linux/vfs2/vfs2.go b/pkg/sentry/syscalls/linux/vfs2/vfs2.go index c576d9475..0df3bd449 100644 --- a/pkg/sentry/syscalls/linux/vfs2/vfs2.go +++ b/pkg/sentry/syscalls/linux/vfs2/vfs2.go @@ -93,16 +93,16 @@ func Override() { s.Table[165] = syscalls.Supported("mount", Mount) s.Table[166] = syscalls.Supported("umount2", Umount2) s.Table[187] = syscalls.Supported("readahead", Readahead) - s.Table[188] = syscalls.Supported("setxattr", Setxattr) + s.Table[188] = syscalls.Supported("setxattr", SetXattr) s.Table[189] = syscalls.Supported("lsetxattr", Lsetxattr) s.Table[190] = syscalls.Supported("fsetxattr", Fsetxattr) - s.Table[191] = syscalls.Supported("getxattr", Getxattr) + s.Table[191] = syscalls.Supported("getxattr", GetXattr) s.Table[192] = syscalls.Supported("lgetxattr", Lgetxattr) s.Table[193] = syscalls.Supported("fgetxattr", Fgetxattr) - s.Table[194] = syscalls.Supported("listxattr", Listxattr) + s.Table[194] = syscalls.Supported("listxattr", ListXattr) s.Table[195] = syscalls.Supported("llistxattr", Llistxattr) s.Table[196] = syscalls.Supported("flistxattr", Flistxattr) - s.Table[197] = syscalls.Supported("removexattr", Removexattr) + s.Table[197] = syscalls.Supported("removexattr", RemoveXattr) s.Table[198] = syscalls.Supported("lremovexattr", Lremovexattr) s.Table[199] = syscalls.Supported("fremovexattr", Fremovexattr) s.Table[209] = syscalls.PartiallySupported("io_submit", IoSubmit, "Generally supported with exceptions. User ring optimizations are not implemented.", []string{"gvisor.dev/issue/204"}) @@ -163,16 +163,16 @@ func Override() { // Override ARM64. s = linux.ARM64 - s.Table[5] = syscalls.Supported("setxattr", Setxattr) + s.Table[5] = syscalls.Supported("setxattr", SetXattr) s.Table[6] = syscalls.Supported("lsetxattr", Lsetxattr) s.Table[7] = syscalls.Supported("fsetxattr", Fsetxattr) - s.Table[8] = syscalls.Supported("getxattr", Getxattr) + s.Table[8] = syscalls.Supported("getxattr", GetXattr) s.Table[9] = syscalls.Supported("lgetxattr", Lgetxattr) s.Table[10] = syscalls.Supported("fgetxattr", Fgetxattr) - s.Table[11] = syscalls.Supported("listxattr", Listxattr) + s.Table[11] = syscalls.Supported("listxattr", ListXattr) s.Table[12] = syscalls.Supported("llistxattr", Llistxattr) s.Table[13] = syscalls.Supported("flistxattr", Flistxattr) - s.Table[14] = syscalls.Supported("removexattr", Removexattr) + s.Table[14] = syscalls.Supported("removexattr", RemoveXattr) s.Table[15] = syscalls.Supported("lremovexattr", Lremovexattr) s.Table[16] = syscalls.Supported("fremovexattr", Fremovexattr) s.Table[17] = syscalls.Supported("getcwd", Getcwd) diff --git a/pkg/sentry/syscalls/linux/vfs2/xattr.go b/pkg/sentry/syscalls/linux/vfs2/xattr.go index ef99246ed..e05723ef9 100644 --- a/pkg/sentry/syscalls/linux/vfs2/xattr.go +++ b/pkg/sentry/syscalls/linux/vfs2/xattr.go @@ -26,8 +26,8 @@ import ( "gvisor.dev/gvisor/pkg/usermem" ) -// Listxattr implements Linux syscall listxattr(2). -func Listxattr(t *kernel.Task, args arch.SyscallArguments) (uintptr, *kernel.SyscallControl, error) { +// ListXattr implements Linux syscall listxattr(2). +func ListXattr(t *kernel.Task, args arch.SyscallArguments) (uintptr, *kernel.SyscallControl, error) { return listxattr(t, args, followFinalSymlink) } @@ -51,7 +51,7 @@ func listxattr(t *kernel.Task, args arch.SyscallArguments, shouldFollowFinalSyml } defer tpop.Release(t) - names, err := t.Kernel().VFS().ListxattrAt(t, t.Credentials(), &tpop.pop, uint64(size)) + names, err := t.Kernel().VFS().ListXattrAt(t, t.Credentials(), &tpop.pop, uint64(size)) if err != nil { return 0, nil, err } @@ -74,7 +74,7 @@ func Flistxattr(t *kernel.Task, args arch.SyscallArguments) (uintptr, *kernel.Sy } defer file.DecRef(t) - names, err := file.Listxattr(t, uint64(size)) + names, err := file.ListXattr(t, uint64(size)) if err != nil { return 0, nil, err } @@ -85,8 +85,8 @@ func Flistxattr(t *kernel.Task, args arch.SyscallArguments) (uintptr, *kernel.Sy return uintptr(n), nil, nil } -// Getxattr implements Linux syscall getxattr(2). -func Getxattr(t *kernel.Task, args arch.SyscallArguments) (uintptr, *kernel.SyscallControl, error) { +// GetXattr implements Linux syscall getxattr(2). +func GetXattr(t *kernel.Task, args arch.SyscallArguments) (uintptr, *kernel.SyscallControl, error) { return getxattr(t, args, followFinalSymlink) } @@ -116,7 +116,7 @@ func getxattr(t *kernel.Task, args arch.SyscallArguments, shouldFollowFinalSymli return 0, nil, err } - value, err := t.Kernel().VFS().GetxattrAt(t, t.Credentials(), &tpop.pop, &vfs.GetxattrOptions{ + value, err := t.Kernel().VFS().GetXattrAt(t, t.Credentials(), &tpop.pop, &vfs.GetXattrOptions{ Name: name, Size: uint64(size), }) @@ -148,7 +148,7 @@ func Fgetxattr(t *kernel.Task, args arch.SyscallArguments) (uintptr, *kernel.Sys return 0, nil, err } - value, err := file.Getxattr(t, &vfs.GetxattrOptions{Name: name, Size: uint64(size)}) + value, err := file.GetXattr(t, &vfs.GetXattrOptions{Name: name, Size: uint64(size)}) if err != nil { return 0, nil, err } @@ -159,8 +159,8 @@ func Fgetxattr(t *kernel.Task, args arch.SyscallArguments) (uintptr, *kernel.Sys return uintptr(n), nil, nil } -// Setxattr implements Linux syscall setxattr(2). -func Setxattr(t *kernel.Task, args arch.SyscallArguments) (uintptr, *kernel.SyscallControl, error) { +// SetXattr implements Linux syscall setxattr(2). +func SetXattr(t *kernel.Task, args arch.SyscallArguments) (uintptr, *kernel.SyscallControl, error) { return 0, nil, setxattr(t, args, followFinalSymlink) } @@ -199,7 +199,7 @@ func setxattr(t *kernel.Task, args arch.SyscallArguments, shouldFollowFinalSymli return err } - return t.Kernel().VFS().SetxattrAt(t, t.Credentials(), &tpop.pop, &vfs.SetxattrOptions{ + return t.Kernel().VFS().SetXattrAt(t, t.Credentials(), &tpop.pop, &vfs.SetXattrOptions{ Name: name, Value: value, Flags: uint32(flags), @@ -233,15 +233,15 @@ func Fsetxattr(t *kernel.Task, args arch.SyscallArguments) (uintptr, *kernel.Sys return 0, nil, err } - return 0, nil, file.Setxattr(t, &vfs.SetxattrOptions{ + return 0, nil, file.SetXattr(t, &vfs.SetXattrOptions{ Name: name, Value: value, Flags: uint32(flags), }) } -// Removexattr implements Linux syscall removexattr(2). -func Removexattr(t *kernel.Task, args arch.SyscallArguments) (uintptr, *kernel.SyscallControl, error) { +// RemoveXattr implements Linux syscall removexattr(2). +func RemoveXattr(t *kernel.Task, args arch.SyscallArguments) (uintptr, *kernel.SyscallControl, error) { return 0, nil, removexattr(t, args, followFinalSymlink) } @@ -269,7 +269,7 @@ func removexattr(t *kernel.Task, args arch.SyscallArguments, shouldFollowFinalSy return err } - return t.Kernel().VFS().RemovexattrAt(t, t.Credentials(), &tpop.pop, name) + return t.Kernel().VFS().RemoveXattrAt(t, t.Credentials(), &tpop.pop, name) } // Fremovexattr implements Linux syscall fremovexattr(2). @@ -288,7 +288,7 @@ func Fremovexattr(t *kernel.Task, args arch.SyscallArguments) (uintptr, *kernel. return 0, nil, err } - return 0, nil, file.Removexattr(t, name) + return 0, nil, file.RemoveXattr(t, name) } func copyInXattrName(t *kernel.Task, nameAddr usermem.Addr) (string, error) { diff --git a/pkg/sentry/vfs/BUILD b/pkg/sentry/vfs/BUILD index 642769e7c..8093ca55c 100644 --- a/pkg/sentry/vfs/BUILD +++ b/pkg/sentry/vfs/BUILD @@ -27,6 +27,39 @@ go_template_instance( }, ) +go_template_instance( + name = "file_description_refs", + out = "file_description_refs.go", + package = "vfs", + prefix = "FileDescription", + template = "//pkg/refs_vfs2:refs_template", + types = { + "T": "FileDescription", + }, +) + +go_template_instance( + name = "mount_namespace_refs", + out = "mount_namespace_refs.go", + package = "vfs", + prefix = "MountNamespace", + template = "//pkg/refs_vfs2:refs_template", + types = { + "T": "MountNamespace", + }, +) + +go_template_instance( + name = "filesystem_refs", + out = "filesystem_refs.go", + package = "vfs", + prefix = "Filesystem", + template = "//pkg/refs_vfs2:refs_template", + types = { + "T": "Filesystem", + }, +) + go_library( name = "vfs", srcs = [ @@ -40,12 +73,15 @@ go_library( "event_list.go", "file_description.go", "file_description_impl_util.go", + "file_description_refs.go", "filesystem.go", "filesystem_impl_util.go", + "filesystem_refs.go", "filesystem_type.go", "inotify.go", "lock.go", "mount.go", + "mount_namespace_refs.go", "mount_unsafe.go", "options.go", "pathname.go", @@ -63,6 +99,7 @@ go_library( "//pkg/fspath", "//pkg/gohacks", "//pkg/log", + "//pkg/refs", "//pkg/safemem", "//pkg/sentry/arch", "//pkg/sentry/fs", diff --git a/pkg/sentry/vfs/README.md b/pkg/sentry/vfs/README.md index 4b9faf2ea..5aad31b78 100644 --- a/pkg/sentry/vfs/README.md +++ b/pkg/sentry/vfs/README.md @@ -184,12 +184,3 @@ This construction, which is essentially a type-safe analogue to Linux's - File locking - `O_ASYNC` - -- Reference counts in the `vfs` package do not use the `refs` package since - `refs.AtomicRefCount` adds 64 bytes of overhead to each 8-byte reference - count, resulting in considerable cache bloat. 24 bytes of this overhead is - for weak reference support, which have poor performance and will not be used - by VFS2. The remaining 40 bytes is to store a descriptive string and stack - trace for reference leak checking; we can support reference leak checking - without incurring this space overhead by including the applicable - information directly in finalizers for applicable types. diff --git a/pkg/sentry/vfs/anonfs.go b/pkg/sentry/vfs/anonfs.go index 5a0e3e6b5..9c4db3047 100644 --- a/pkg/sentry/vfs/anonfs.go +++ b/pkg/sentry/vfs/anonfs.go @@ -245,32 +245,32 @@ func (fs *anonFilesystem) BoundEndpointAt(ctx context.Context, rp *ResolvingPath return nil, syserror.ECONNREFUSED } -// ListxattrAt implements FilesystemImpl.ListxattrAt. -func (fs *anonFilesystem) ListxattrAt(ctx context.Context, rp *ResolvingPath, size uint64) ([]string, error) { +// ListXattrAt implements FilesystemImpl.ListXattrAt. +func (fs *anonFilesystem) ListXattrAt(ctx context.Context, rp *ResolvingPath, size uint64) ([]string, error) { if !rp.Done() { return nil, syserror.ENOTDIR } return nil, nil } -// GetxattrAt implements FilesystemImpl.GetxattrAt. -func (fs *anonFilesystem) GetxattrAt(ctx context.Context, rp *ResolvingPath, opts GetxattrOptions) (string, error) { +// GetXattrAt implements FilesystemImpl.GetXattrAt. +func (fs *anonFilesystem) GetXattrAt(ctx context.Context, rp *ResolvingPath, opts GetXattrOptions) (string, error) { if !rp.Done() { return "", syserror.ENOTDIR } return "", syserror.ENOTSUP } -// SetxattrAt implements FilesystemImpl.SetxattrAt. -func (fs *anonFilesystem) SetxattrAt(ctx context.Context, rp *ResolvingPath, opts SetxattrOptions) error { +// SetXattrAt implements FilesystemImpl.SetXattrAt. +func (fs *anonFilesystem) SetXattrAt(ctx context.Context, rp *ResolvingPath, opts SetXattrOptions) error { if !rp.Done() { return syserror.ENOTDIR } return syserror.EPERM } -// RemovexattrAt implements FilesystemImpl.RemovexattrAt. -func (fs *anonFilesystem) RemovexattrAt(ctx context.Context, rp *ResolvingPath, name string) error { +// RemoveXattrAt implements FilesystemImpl.RemoveXattrAt. +func (fs *anonFilesystem) RemoveXattrAt(ctx context.Context, rp *ResolvingPath, name string) error { if !rp.Done() { return syserror.ENOTDIR } diff --git a/pkg/sentry/vfs/context.go b/pkg/sentry/vfs/context.go index c9e724fef..97018651f 100644 --- a/pkg/sentry/vfs/context.go +++ b/pkg/sentry/vfs/context.go @@ -40,6 +40,30 @@ func MountNamespaceFromContext(ctx context.Context) *MountNamespace { return nil } +type mountNamespaceContext struct { + context.Context + mntns *MountNamespace +} + +// Value implements Context.Value. +func (mc mountNamespaceContext) Value(key interface{}) interface{} { + switch key { + case CtxMountNamespace: + mc.mntns.IncRef() + return mc.mntns + default: + return mc.Context.Value(key) + } +} + +// WithMountNamespace returns a copy of ctx with the given MountNamespace. +func WithMountNamespace(ctx context.Context, mntns *MountNamespace) context.Context { + return &mountNamespaceContext{ + Context: ctx, + mntns: mntns, + } +} + // RootFromContext returns the VFS root used by ctx. It takes a reference on // the returned VirtualDentry. If ctx does not have a specific VFS root, // RootFromContext returns a zero-value VirtualDentry. diff --git a/pkg/sentry/vfs/dentry.go b/pkg/sentry/vfs/dentry.go index bc7ea93ea..a69a5b2f1 100644 --- a/pkg/sentry/vfs/dentry.go +++ b/pkg/sentry/vfs/dentry.go @@ -242,8 +242,9 @@ func (vfs *VirtualFilesystem) InvalidateDentry(ctx context.Context, d *Dentry) { // caller must call AbortRenameDentry, CommitRenameReplaceDentry, or // CommitRenameExchangeDentry depending on the rename's outcome. // -// Preconditions: If to is not nil, it must be a child Dentry from the same -// Filesystem. from != to. +// Preconditions: +// * If to is not nil, it must be a child Dentry from the same Filesystem. +// * from != to. func (vfs *VirtualFilesystem) PrepareRenameDentry(mntns *MountNamespace, from, to *Dentry) error { vfs.mountMu.Lock() if mntns.mountpoints[from] != 0 { diff --git a/pkg/sentry/vfs/epoll.go b/pkg/sentry/vfs/epoll.go index 1b5af9f73..754e76aec 100644 --- a/pkg/sentry/vfs/epoll.go +++ b/pkg/sentry/vfs/epoll.go @@ -331,11 +331,9 @@ func (ep *EpollInstance) removeLocked(epi *epollInterest) { ep.mu.Unlock() } -// ReadEvents reads up to len(events) ready events into events and returns the -// number of events read. -// -// Preconditions: len(events) != 0. -func (ep *EpollInstance) ReadEvents(events []linux.EpollEvent) int { +// ReadEvents appends up to maxReady events to events and returns the updated +// slice of events. +func (ep *EpollInstance) ReadEvents(events []linux.EpollEvent, maxEvents int) []linux.EpollEvent { i := 0 // Hot path: avoid defer. ep.mu.Lock() @@ -368,16 +366,16 @@ func (ep *EpollInstance) ReadEvents(events []linux.EpollEvent) int { requeue.PushBack(epi) } // Report ievents. - events[i] = linux.EpollEvent{ + events = append(events, linux.EpollEvent{ Events: ievents.ToLinux(), Data: epi.userData, - } + }) i++ - if i == len(events) { + if i == maxEvents { break } } ep.ready.PushBackList(&requeue) ep.mu.Unlock() - return i + return events } diff --git a/pkg/sentry/vfs/file_description.go b/pkg/sentry/vfs/file_description.go index dcafffe57..73bb36d3e 100644 --- a/pkg/sentry/vfs/file_description.go +++ b/pkg/sentry/vfs/file_description.go @@ -38,9 +38,7 @@ import ( // // FileDescription is analogous to Linux's struct file. type FileDescription struct { - // refs is the reference count. refs is accessed using atomic memory - // operations. - refs int64 + FileDescriptionRefs // flagsMu protects statusFlags and asyncHandler below. flagsMu sync.Mutex @@ -103,7 +101,7 @@ type FileDescriptionOptions struct { // If UseDentryMetadata is true, calls to FileDescription methods that // interact with file and filesystem metadata (Stat, SetStat, StatFS, - // Listxattr, Getxattr, Setxattr, Removexattr) are implemented by calling + // ListXattr, GetXattr, SetXattr, RemoveXattr) are implemented by calling // the corresponding FilesystemImpl methods instead of the corresponding // FileDescriptionImpl methods. // @@ -131,7 +129,7 @@ func (fd *FileDescription) Init(impl FileDescriptionImpl, flags uint32, mnt *Mou } } - fd.refs = 1 + fd.EnableLeakCheck() // Remove "file creation flags" to mirror the behavior from file.f_flags in // fs/open.c:do_dentry_open. @@ -149,30 +147,9 @@ func (fd *FileDescription) Init(impl FileDescriptionImpl, flags uint32, mnt *Mou return nil } -// IncRef increments fd's reference count. -func (fd *FileDescription) IncRef() { - atomic.AddInt64(&fd.refs, 1) -} - -// TryIncRef increments fd's reference count and returns true. If fd's -// reference count is already zero, TryIncRef does nothing and returns false. -// -// TryIncRef does not require that a reference is held on fd. -func (fd *FileDescription) TryIncRef() bool { - for { - refs := atomic.LoadInt64(&fd.refs) - if refs <= 0 { - return false - } - if atomic.CompareAndSwapInt64(&fd.refs, refs, refs+1) { - return true - } - } -} - // DecRef decrements fd's reference count. func (fd *FileDescription) DecRef(ctx context.Context) { - if refs := atomic.AddInt64(&fd.refs, -1); refs == 0 { + fd.FileDescriptionRefs.DecRef(func() { // Unregister fd from all epoll instances. fd.epollMu.Lock() epolls := fd.epolls @@ -208,15 +185,7 @@ func (fd *FileDescription) DecRef(ctx context.Context) { } fd.asyncHandler = nil fd.flagsMu.Unlock() - } else if refs < 0 { - panic("FileDescription.DecRef() called without holding a reference") - } -} - -// Refs returns the current number of references. The returned count -// is inherently racy and is unsafe to use without external synchronization. -func (fd *FileDescription) Refs() int64 { - return atomic.LoadInt64(&fd.refs) + }) } // Mount returns the mount on which fd was opened. It does not take a reference @@ -357,6 +326,9 @@ type FileDescriptionImpl interface { // Allocate grows the file to offset + length bytes. // Only mode == 0 is supported currently. // + // Allocate should return EISDIR on directories, ESPIPE on pipes, and ENODEV on + // other files where it is not supported. + // // Preconditions: The FileDescription was opened for writing. Allocate(ctx context.Context, mode, offset, length uint64) error @@ -371,8 +343,9 @@ type FileDescriptionImpl interface { // // - If opts.Flags specifies unsupported options, PRead returns EOPNOTSUPP. // - // Preconditions: The FileDescription was opened for reading. - // FileDescriptionOptions.DenyPRead == false. + // Preconditions: + // * The FileDescription was opened for reading. + // * FileDescriptionOptions.DenyPRead == false. PRead(ctx context.Context, dst usermem.IOSequence, offset int64, opts ReadOptions) (int64, error) // Read is similar to PRead, but does not specify an offset. @@ -403,8 +376,9 @@ type FileDescriptionImpl interface { // - If opts.Flags specifies unsupported options, PWrite returns // EOPNOTSUPP. // - // Preconditions: The FileDescription was opened for writing. - // FileDescriptionOptions.DenyPWrite == false. + // Preconditions: + // * The FileDescription was opened for writing. + // * FileDescriptionOptions.DenyPWrite == false. PWrite(ctx context.Context, src usermem.IOSequence, offset int64, opts WriteOptions) (int64, error) // Write is similar to PWrite, but does not specify an offset, which is @@ -449,19 +423,19 @@ type FileDescriptionImpl interface { // Ioctl implements the ioctl(2) syscall. Ioctl(ctx context.Context, uio usermem.IO, args arch.SyscallArguments) (uintptr, error) - // Listxattr returns all extended attribute names for the file. - Listxattr(ctx context.Context, size uint64) ([]string, error) + // ListXattr returns all extended attribute names for the file. + ListXattr(ctx context.Context, size uint64) ([]string, error) - // Getxattr returns the value associated with the given extended attribute + // GetXattr returns the value associated with the given extended attribute // for the file. - Getxattr(ctx context.Context, opts GetxattrOptions) (string, error) + GetXattr(ctx context.Context, opts GetXattrOptions) (string, error) - // Setxattr changes the value associated with the given extended attribute + // SetXattr changes the value associated with the given extended attribute // for the file. - Setxattr(ctx context.Context, opts SetxattrOptions) error + SetXattr(ctx context.Context, opts SetXattrOptions) error - // Removexattr removes the given extended attribute from the file. - Removexattr(ctx context.Context, name string) error + // RemoveXattr removes the given extended attribute from the file. + RemoveXattr(ctx context.Context, name string) error // LockBSD tries to acquire a BSD-style advisory file lock. LockBSD(ctx context.Context, uid lock.UniqueID, t lock.LockType, block lock.Blocker) error @@ -664,25 +638,25 @@ func (fd *FileDescription) Ioctl(ctx context.Context, uio usermem.IO, args arch. return fd.impl.Ioctl(ctx, uio, args) } -// Listxattr returns all extended attribute names for the file represented by +// ListXattr returns all extended attribute names for the file represented by // fd. // // If the size of the list (including a NUL terminating byte after every entry) // would exceed size, ERANGE may be returned. Note that implementations // are free to ignore size entirely and return without error). In all cases, // if size is 0, the list should be returned without error, regardless of size. -func (fd *FileDescription) Listxattr(ctx context.Context, size uint64) ([]string, error) { +func (fd *FileDescription) ListXattr(ctx context.Context, size uint64) ([]string, error) { if fd.opts.UseDentryMetadata { vfsObj := fd.vd.mount.vfs rp := vfsObj.getResolvingPath(auth.CredentialsFromContext(ctx), &PathOperation{ Root: fd.vd, Start: fd.vd, }) - names, err := fd.vd.mount.fs.impl.ListxattrAt(ctx, rp, size) + names, err := fd.vd.mount.fs.impl.ListXattrAt(ctx, rp, size) vfsObj.putResolvingPath(ctx, rp) return names, err } - names, err := fd.impl.Listxattr(ctx, size) + names, err := fd.impl.ListXattr(ctx, size) if err == syserror.ENOTSUP { // Linux doesn't actually return ENOTSUP in this case; instead, // fs/xattr.c:vfs_listxattr() falls back to allowing the security @@ -693,57 +667,57 @@ func (fd *FileDescription) Listxattr(ctx context.Context, size uint64) ([]string return names, err } -// Getxattr returns the value associated with the given extended attribute for +// GetXattr returns the value associated with the given extended attribute for // the file represented by fd. // // If the size of the return value exceeds opts.Size, ERANGE may be returned // (note that implementations are free to ignore opts.Size entirely and return // without error). In all cases, if opts.Size is 0, the value should be // returned without error, regardless of size. -func (fd *FileDescription) Getxattr(ctx context.Context, opts *GetxattrOptions) (string, error) { +func (fd *FileDescription) GetXattr(ctx context.Context, opts *GetXattrOptions) (string, error) { if fd.opts.UseDentryMetadata { vfsObj := fd.vd.mount.vfs rp := vfsObj.getResolvingPath(auth.CredentialsFromContext(ctx), &PathOperation{ Root: fd.vd, Start: fd.vd, }) - val, err := fd.vd.mount.fs.impl.GetxattrAt(ctx, rp, *opts) + val, err := fd.vd.mount.fs.impl.GetXattrAt(ctx, rp, *opts) vfsObj.putResolvingPath(ctx, rp) return val, err } - return fd.impl.Getxattr(ctx, *opts) + return fd.impl.GetXattr(ctx, *opts) } -// Setxattr changes the value associated with the given extended attribute for +// SetXattr changes the value associated with the given extended attribute for // the file represented by fd. -func (fd *FileDescription) Setxattr(ctx context.Context, opts *SetxattrOptions) error { +func (fd *FileDescription) SetXattr(ctx context.Context, opts *SetXattrOptions) error { if fd.opts.UseDentryMetadata { vfsObj := fd.vd.mount.vfs rp := vfsObj.getResolvingPath(auth.CredentialsFromContext(ctx), &PathOperation{ Root: fd.vd, Start: fd.vd, }) - err := fd.vd.mount.fs.impl.SetxattrAt(ctx, rp, *opts) + err := fd.vd.mount.fs.impl.SetXattrAt(ctx, rp, *opts) vfsObj.putResolvingPath(ctx, rp) return err } - return fd.impl.Setxattr(ctx, *opts) + return fd.impl.SetXattr(ctx, *opts) } -// Removexattr removes the given extended attribute from the file represented +// RemoveXattr removes the given extended attribute from the file represented // by fd. -func (fd *FileDescription) Removexattr(ctx context.Context, name string) error { +func (fd *FileDescription) RemoveXattr(ctx context.Context, name string) error { if fd.opts.UseDentryMetadata { vfsObj := fd.vd.mount.vfs rp := vfsObj.getResolvingPath(auth.CredentialsFromContext(ctx), &PathOperation{ Root: fd.vd, Start: fd.vd, }) - err := fd.vd.mount.fs.impl.RemovexattrAt(ctx, rp, name) + err := fd.vd.mount.fs.impl.RemoveXattrAt(ctx, rp, name) vfsObj.putResolvingPath(ctx, rp) return err } - return fd.impl.Removexattr(ctx, name) + return fd.impl.RemoveXattr(ctx, name) } // SyncFS instructs the filesystem containing fd to execute the semantics of @@ -845,3 +819,31 @@ func (fd *FileDescription) SetAsyncHandler(newHandler func() FileAsync) FileAsyn } return fd.asyncHandler } + +// FileReadWriteSeeker is a helper struct to pass a FileDescription as +// io.Reader/io.Writer/io.ReadSeeker/etc. +type FileReadWriteSeeker struct { + FD *FileDescription + Ctx context.Context + ROpts ReadOptions + WOpts WriteOptions +} + +// Read implements io.ReadWriteSeeker.Read. +func (f *FileReadWriteSeeker) Read(p []byte) (int, error) { + dst := usermem.BytesIOSequence(p) + ret, err := f.FD.Read(f.Ctx, dst, f.ROpts) + return int(ret), err +} + +// Seek implements io.ReadWriteSeeker.Seek. +func (f *FileReadWriteSeeker) Seek(offset int64, whence int) (int64, error) { + return f.FD.Seek(f.Ctx, offset, int32(whence)) +} + +// Write implements io.ReadWriteSeeker.Write. +func (f *FileReadWriteSeeker) Write(p []byte) (int, error) { + buf := usermem.BytesIOSequence(p) + ret, err := f.FD.Write(f.Ctx, buf, f.WOpts) + return int(ret), err +} diff --git a/pkg/sentry/vfs/file_description_impl_util.go b/pkg/sentry/vfs/file_description_impl_util.go index 6b8b4ad49..78da16bac 100644 --- a/pkg/sentry/vfs/file_description_impl_util.go +++ b/pkg/sentry/vfs/file_description_impl_util.go @@ -57,7 +57,11 @@ func (FileDescriptionDefaultImpl) StatFS(ctx context.Context) (linux.Statfs, err } // Allocate implements FileDescriptionImpl.Allocate analogously to -// fallocate called on regular file, directory or FIFO in Linux. +// fallocate called on an invalid type of file in Linux. +// +// Note that directories can rely on this implementation even though they +// should technically return EISDIR. Allocate should never be called for a +// directory, because it requires a writable fd. func (FileDescriptionDefaultImpl) Allocate(ctx context.Context, mode, offset, length uint64) error { return syserror.ENODEV } @@ -134,28 +138,28 @@ func (FileDescriptionDefaultImpl) Ioctl(ctx context.Context, uio usermem.IO, arg return 0, syserror.ENOTTY } -// Listxattr implements FileDescriptionImpl.Listxattr analogously to +// ListXattr implements FileDescriptionImpl.ListXattr analogously to // inode_operations::listxattr == NULL in Linux. -func (FileDescriptionDefaultImpl) Listxattr(ctx context.Context, size uint64) ([]string, error) { - // This isn't exactly accurate; see FileDescription.Listxattr. +func (FileDescriptionDefaultImpl) ListXattr(ctx context.Context, size uint64) ([]string, error) { + // This isn't exactly accurate; see FileDescription.ListXattr. return nil, syserror.ENOTSUP } -// Getxattr implements FileDescriptionImpl.Getxattr analogously to +// GetXattr implements FileDescriptionImpl.GetXattr analogously to // inode::i_opflags & IOP_XATTR == 0 in Linux. -func (FileDescriptionDefaultImpl) Getxattr(ctx context.Context, opts GetxattrOptions) (string, error) { +func (FileDescriptionDefaultImpl) GetXattr(ctx context.Context, opts GetXattrOptions) (string, error) { return "", syserror.ENOTSUP } -// Setxattr implements FileDescriptionImpl.Setxattr analogously to +// SetXattr implements FileDescriptionImpl.SetXattr analogously to // inode::i_opflags & IOP_XATTR == 0 in Linux. -func (FileDescriptionDefaultImpl) Setxattr(ctx context.Context, opts SetxattrOptions) error { +func (FileDescriptionDefaultImpl) SetXattr(ctx context.Context, opts SetXattrOptions) error { return syserror.ENOTSUP } -// Removexattr implements FileDescriptionImpl.Removexattr analogously to +// RemoveXattr implements FileDescriptionImpl.RemoveXattr analogously to // inode::i_opflags & IOP_XATTR == 0 in Linux. -func (FileDescriptionDefaultImpl) Removexattr(ctx context.Context, name string) error { +func (FileDescriptionDefaultImpl) RemoveXattr(ctx context.Context, name string) error { return syserror.ENOTSUP } diff --git a/pkg/sentry/vfs/filesystem.go b/pkg/sentry/vfs/filesystem.go index df3758fd1..7dae4e7e8 100644 --- a/pkg/sentry/vfs/filesystem.go +++ b/pkg/sentry/vfs/filesystem.go @@ -15,8 +15,6 @@ package vfs import ( - "sync/atomic" - "gvisor.dev/gvisor/pkg/abi/linux" "gvisor.dev/gvisor/pkg/context" "gvisor.dev/gvisor/pkg/fspath" @@ -34,9 +32,7 @@ import ( // // +stateify savable type Filesystem struct { - // refs is the reference count. refs is accessed using atomic memory - // operations. - refs int64 + FilesystemRefs // vfs is the VirtualFilesystem that uses this Filesystem. vfs is // immutable. @@ -52,7 +48,7 @@ type Filesystem struct { // Init must be called before first use of fs. func (fs *Filesystem) Init(vfsObj *VirtualFilesystem, fsType FilesystemType, impl FilesystemImpl) { - fs.refs = 1 + fs.EnableLeakCheck() fs.vfs = vfsObj fs.fsType = fsType fs.impl = impl @@ -76,39 +72,14 @@ func (fs *Filesystem) Impl() FilesystemImpl { return fs.impl } -// IncRef increments fs' reference count. -func (fs *Filesystem) IncRef() { - if atomic.AddInt64(&fs.refs, 1) <= 1 { - panic("Filesystem.IncRef() called without holding a reference") - } -} - -// TryIncRef increments fs' reference count and returns true. If fs' reference -// count is zero, TryIncRef does nothing and returns false. -// -// TryIncRef does not require that a reference is held on fs. -func (fs *Filesystem) TryIncRef() bool { - for { - refs := atomic.LoadInt64(&fs.refs) - if refs <= 0 { - return false - } - if atomic.CompareAndSwapInt64(&fs.refs, refs, refs+1) { - return true - } - } -} - // DecRef decrements fs' reference count. func (fs *Filesystem) DecRef(ctx context.Context) { - if refs := atomic.AddInt64(&fs.refs, -1); refs == 0 { + fs.FilesystemRefs.DecRef(func() { fs.vfs.filesystemsMu.Lock() delete(fs.vfs.filesystems, fs) fs.vfs.filesystemsMu.Unlock() fs.impl.Release(ctx) - } else if refs < 0 { - panic("Filesystem.decRef() called without holding a reference") - } + }) } // FilesystemImpl contains implementation details for a Filesystem. @@ -212,8 +183,9 @@ type FilesystemImpl interface { // ENOENT. Equivalently, if vd represents a file with a link count of 0 not // created by open(O_TMPFILE) without O_EXCL, LinkAt returns ENOENT. // - // Preconditions: !rp.Done(). For the final path component in rp, - // !rp.ShouldFollowSymlink(). + // Preconditions: + // * !rp.Done(). + // * For the final path component in rp, !rp.ShouldFollowSymlink(). // // Postconditions: If LinkAt returns an error returned by // ResolvingPath.Resolve*(), then !rp.Done(). @@ -231,8 +203,9 @@ type FilesystemImpl interface { // - If the directory in which the new directory would be created has been // removed by RmdirAt or RenameAt, MkdirAt returns ENOENT. // - // Preconditions: !rp.Done(). For the final path component in rp, - // !rp.ShouldFollowSymlink(). + // Preconditions: + // * !rp.Done(). + // * For the final path component in rp, !rp.ShouldFollowSymlink(). // // Postconditions: If MkdirAt returns an error returned by // ResolvingPath.Resolve*(), then !rp.Done(). @@ -253,8 +226,9 @@ type FilesystemImpl interface { // - If the directory in which the file would be created has been removed // by RmdirAt or RenameAt, MknodAt returns ENOENT. // - // Preconditions: !rp.Done(). For the final path component in rp, - // !rp.ShouldFollowSymlink(). + // Preconditions: + // * !rp.Done(). + // * For the final path component in rp, !rp.ShouldFollowSymlink(). // // Postconditions: If MknodAt returns an error returned by // ResolvingPath.Resolve*(), then !rp.Done(). @@ -345,11 +319,12 @@ type FilesystemImpl interface { // - If renaming would replace a non-empty directory, RenameAt returns // ENOTEMPTY. // - // Preconditions: !rp.Done(). For the final path component in rp, - // !rp.ShouldFollowSymlink(). oldParentVD.Dentry() was obtained from a - // previous call to - // oldParentVD.Mount().Filesystem().Impl().GetParentDentryAt(). oldName is - // not "." or "..". + // Preconditions: + // * !rp.Done(). + // * For the final path component in rp, !rp.ShouldFollowSymlink(). + // * oldParentVD.Dentry() was obtained from a previous call to + // oldParentVD.Mount().Filesystem().Impl().GetParentDentryAt(). + // * oldName is not "." or "..". // // Postconditions: If RenameAt returns an error returned by // ResolvingPath.Resolve*(), then !rp.Done(). @@ -372,8 +347,9 @@ type FilesystemImpl interface { // - If the file at rp exists but is not a directory, RmdirAt returns // ENOTDIR. // - // Preconditions: !rp.Done(). For the final path component in rp, - // !rp.ShouldFollowSymlink(). + // Preconditions: + // * !rp.Done(). + // * For the final path component in rp, !rp.ShouldFollowSymlink(). // // Postconditions: If RmdirAt returns an error returned by // ResolvingPath.Resolve*(), then !rp.Done(). @@ -410,8 +386,9 @@ type FilesystemImpl interface { // - If the directory in which the symbolic link would be created has been // removed by RmdirAt or RenameAt, SymlinkAt returns ENOENT. // - // Preconditions: !rp.Done(). For the final path component in rp, - // !rp.ShouldFollowSymlink(). + // Preconditions: + // * !rp.Done(). + // * For the final path component in rp, !rp.ShouldFollowSymlink(). // // Postconditions: If SymlinkAt returns an error returned by // ResolvingPath.Resolve*(), then !rp.Done(). @@ -431,33 +408,34 @@ type FilesystemImpl interface { // // - If the file at rp exists but is a directory, UnlinkAt returns EISDIR. // - // Preconditions: !rp.Done(). For the final path component in rp, - // !rp.ShouldFollowSymlink(). + // Preconditions: + // * !rp.Done(). + // * For the final path component in rp, !rp.ShouldFollowSymlink(). // // Postconditions: If UnlinkAt returns an error returned by // ResolvingPath.Resolve*(), then !rp.Done(). UnlinkAt(ctx context.Context, rp *ResolvingPath) error - // ListxattrAt returns all extended attribute names for the file at rp. + // ListXattrAt returns all extended attribute names for the file at rp. // // Errors: // // - If extended attributes are not supported by the filesystem, - // ListxattrAt returns ENOTSUP. + // ListXattrAt returns ENOTSUP. // // - If the size of the list (including a NUL terminating byte after every // entry) would exceed size, ERANGE may be returned. Note that // implementations are free to ignore size entirely and return without // error). In all cases, if size is 0, the list should be returned without // error, regardless of size. - ListxattrAt(ctx context.Context, rp *ResolvingPath, size uint64) ([]string, error) + ListXattrAt(ctx context.Context, rp *ResolvingPath, size uint64) ([]string, error) - // GetxattrAt returns the value associated with the given extended + // GetXattrAt returns the value associated with the given extended // attribute for the file at rp. // // Errors: // - // - If extended attributes are not supported by the filesystem, GetxattrAt + // - If extended attributes are not supported by the filesystem, GetXattrAt // returns ENOTSUP. // // - If an extended attribute named opts.Name does not exist, ENODATA is @@ -467,30 +445,30 @@ type FilesystemImpl interface { // returned (note that implementations are free to ignore opts.Size entirely // and return without error). In all cases, if opts.Size is 0, the value // should be returned without error, regardless of size. - GetxattrAt(ctx context.Context, rp *ResolvingPath, opts GetxattrOptions) (string, error) + GetXattrAt(ctx context.Context, rp *ResolvingPath, opts GetXattrOptions) (string, error) - // SetxattrAt changes the value associated with the given extended + // SetXattrAt changes the value associated with the given extended // attribute for the file at rp. // // Errors: // - // - If extended attributes are not supported by the filesystem, SetxattrAt + // - If extended attributes are not supported by the filesystem, SetXattrAt // returns ENOTSUP. // // - If XATTR_CREATE is set in opts.Flag and opts.Name already exists, // EEXIST is returned. If XATTR_REPLACE is set and opts.Name does not exist, // ENODATA is returned. - SetxattrAt(ctx context.Context, rp *ResolvingPath, opts SetxattrOptions) error + SetXattrAt(ctx context.Context, rp *ResolvingPath, opts SetXattrOptions) error - // RemovexattrAt removes the given extended attribute from the file at rp. + // RemoveXattrAt removes the given extended attribute from the file at rp. // // Errors: // // - If extended attributes are not supported by the filesystem, - // RemovexattrAt returns ENOTSUP. + // RemoveXattrAt returns ENOTSUP. // // - If name does not exist, ENODATA is returned. - RemovexattrAt(ctx context.Context, rp *ResolvingPath, name string) error + RemoveXattrAt(ctx context.Context, rp *ResolvingPath, name string) error // BoundEndpointAt returns the Unix socket endpoint bound at the path rp. // diff --git a/pkg/sentry/vfs/filesystem_impl_util.go b/pkg/sentry/vfs/filesystem_impl_util.go index 465e610e0..2620cf975 100644 --- a/pkg/sentry/vfs/filesystem_impl_util.go +++ b/pkg/sentry/vfs/filesystem_impl_util.go @@ -16,6 +16,9 @@ package vfs import ( "strings" + + "gvisor.dev/gvisor/pkg/abi/linux" + "gvisor.dev/gvisor/pkg/usermem" ) // GenericParseMountOptions parses a comma-separated list of options of the @@ -41,3 +44,13 @@ func GenericParseMountOptions(str string) map[string]string { } return m } + +// GenericStatFS returns a statfs struct filled with the common fields for a +// general filesystem. This is analogous to Linux's fs/libfs.cs:simple_statfs(). +func GenericStatFS(fsMagic uint64) linux.Statfs { + return linux.Statfs{ + Type: fsMagic, + BlockSize: usermem.PageSize, + NameLength: linux.NAME_MAX, + } +} diff --git a/pkg/sentry/vfs/g3doc/inotify.md b/pkg/sentry/vfs/g3doc/inotify.md index e7da49faa..833db213f 100644 --- a/pkg/sentry/vfs/g3doc/inotify.md +++ b/pkg/sentry/vfs/g3doc/inotify.md @@ -28,9 +28,9 @@ The set of all watches held on a single file (i.e., the watch target) is stored in vfs.Watches. Each watch will belong to a different inotify instance (an instance can only have one watch on any watch target). The watches are stored in a map indexed by their vfs.Inotify owner’s id. Hard links and file descriptions -to a single file will all share the same vfs.Watches. Activity on the target -causes its vfs.Watches to generate notifications on its watches’ inotify -instances. +to a single file will all share the same vfs.Watches (with the exception of the +gofer filesystem, described in a later section). Activity on the target causes +its vfs.Watches to generate notifications on its watches’ inotify instances. ### vfs.Watch @@ -103,12 +103,12 @@ inotify: unopened p9 file (and possibly an open FID), through which the Sentry interacts with the gofer. * *Solution:* Because there is no inode structure stored in the sandbox, - inotify watches must be held on the dentry. This would be an issue in - the presence of hard links, where multiple dentries would need to share - the same set of watches, but in VFS2, we do not support the internal - creation of hard links on gofer fs. As a result, we make the assumption - that every dentry corresponds to a unique inode. However, the next point - raises an issue with this assumption: + inotify watches must be held on the dentry. For the purposes of inotify, + we assume that every dentry corresponds to a unique inode, which may + cause unexpected behavior in the presence of hard links, where multiple + dentries should share the same set of watches. Indeed, it is impossible + for us to be absolutely sure whether dentries correspond to the same + file or not, due to the following point: * **The Sentry cannot always be aware of hard links on the remote filesystem.** There is no way for us to confirm whether two files on the remote filesystem are actually links to the same inode. QIDs and inodes are diff --git a/pkg/sentry/vfs/lock.go b/pkg/sentry/vfs/lock.go index 6c7583a81..42666eebf 100644 --- a/pkg/sentry/vfs/lock.go +++ b/pkg/sentry/vfs/lock.go @@ -46,7 +46,13 @@ func (fl *FileLocks) LockBSD(uid fslock.UniqueID, t fslock.LockType, block fsloc if fl.bsd.LockRegion(uid, t, fslock.LockRange{0, fslock.LockEOF}, block) { return nil } - return syserror.ErrWouldBlock + + // Return an appropriate error for the unsuccessful lock attempt, depending on + // whether this is a blocking or non-blocking operation. + if block == nil { + return syserror.ErrWouldBlock + } + return syserror.ERESTARTSYS } // UnlockBSD releases a BSD-style lock on the entire file. @@ -66,7 +72,13 @@ func (fl *FileLocks) LockPOSIX(ctx context.Context, fd *FileDescription, uid fsl if fl.posix.LockRegion(uid, t, rng, block) { return nil } - return syserror.ErrWouldBlock + + // Return an appropriate error for the unsuccessful lock attempt, depending on + // whether this is a blocking or non-blocking operation. + if block == nil { + return syserror.ErrWouldBlock + } + return syserror.ERESTARTSYS } // UnlockPOSIX releases a POSIX-style lock on a file region. diff --git a/pkg/sentry/vfs/memxattr/xattr.go b/pkg/sentry/vfs/memxattr/xattr.go index cc1e7d764..638b5d830 100644 --- a/pkg/sentry/vfs/memxattr/xattr.go +++ b/pkg/sentry/vfs/memxattr/xattr.go @@ -33,8 +33,8 @@ type SimpleExtendedAttributes struct { xattrs map[string]string } -// Getxattr returns the value at 'name'. -func (x *SimpleExtendedAttributes) Getxattr(opts *vfs.GetxattrOptions) (string, error) { +// GetXattr returns the value at 'name'. +func (x *SimpleExtendedAttributes) GetXattr(opts *vfs.GetXattrOptions) (string, error) { x.mu.RLock() value, ok := x.xattrs[opts.Name] x.mu.RUnlock() @@ -49,8 +49,8 @@ func (x *SimpleExtendedAttributes) Getxattr(opts *vfs.GetxattrOptions) (string, return value, nil } -// Setxattr sets 'value' at 'name'. -func (x *SimpleExtendedAttributes) Setxattr(opts *vfs.SetxattrOptions) error { +// SetXattr sets 'value' at 'name'. +func (x *SimpleExtendedAttributes) SetXattr(opts *vfs.SetXattrOptions) error { x.mu.Lock() defer x.mu.Unlock() if x.xattrs == nil { @@ -72,8 +72,8 @@ func (x *SimpleExtendedAttributes) Setxattr(opts *vfs.SetxattrOptions) error { return nil } -// Listxattr returns all names in xattrs. -func (x *SimpleExtendedAttributes) Listxattr(size uint64) ([]string, error) { +// ListXattr returns all names in xattrs. +func (x *SimpleExtendedAttributes) ListXattr(size uint64) ([]string, error) { // Keep track of the size of the buffer needed in listxattr(2) for the list. listSize := 0 x.mu.RLock() @@ -90,8 +90,8 @@ func (x *SimpleExtendedAttributes) Listxattr(size uint64) ([]string, error) { return names, nil } -// Removexattr removes the xattr at 'name'. -func (x *SimpleExtendedAttributes) Removexattr(name string) error { +// RemoveXattr removes the xattr at 'name'. +func (x *SimpleExtendedAttributes) RemoveXattr(name string) error { x.mu.Lock() defer x.mu.Unlock() if _, ok := x.xattrs[name]; !ok { diff --git a/pkg/sentry/vfs/mount.go b/pkg/sentry/vfs/mount.go index d1d29d0cd..9da09d4c1 100644 --- a/pkg/sentry/vfs/mount.go +++ b/pkg/sentry/vfs/mount.go @@ -114,7 +114,7 @@ func (mnt *Mount) Options() MountOptions { defer mnt.vfs.mountMu.Unlock() return MountOptions{ Flags: mnt.Flags, - ReadOnly: mnt.readOnly(), + ReadOnly: mnt.ReadOnly(), } } @@ -126,16 +126,14 @@ func (mnt *Mount) Options() MountOptions { // // +stateify savable type MountNamespace struct { + MountNamespaceRefs + // Owner is the usernamespace that owns this mount namespace. Owner *auth.UserNamespace // root is the MountNamespace's root mount. root is immutable. root *Mount - // refs is the reference count. refs is accessed using atomic memory - // operations. - refs int64 - // mountpoints maps all Dentries which are mount points in this namespace // to the number of Mounts for which they are mount points. mountpoints is // protected by VirtualFilesystem.mountMu. @@ -154,22 +152,22 @@ type MountNamespace struct { // NewMountNamespace returns a new mount namespace with a root filesystem // configured by the given arguments. A reference is taken on the returned // MountNamespace. -func (vfs *VirtualFilesystem) NewMountNamespace(ctx context.Context, creds *auth.Credentials, source, fsTypeName string, opts *GetFilesystemOptions) (*MountNamespace, error) { +func (vfs *VirtualFilesystem) NewMountNamespace(ctx context.Context, creds *auth.Credentials, source, fsTypeName string, opts *MountOptions) (*MountNamespace, error) { rft := vfs.getFilesystemType(fsTypeName) if rft == nil { ctx.Warningf("Unknown filesystem type: %s", fsTypeName) return nil, syserror.ENODEV } - fs, root, err := rft.fsType.GetFilesystem(ctx, vfs, creds, source, *opts) + fs, root, err := rft.fsType.GetFilesystem(ctx, vfs, creds, source, opts.GetFilesystemOptions) if err != nil { return nil, err } mntns := &MountNamespace{ Owner: creds.UserNamespace, - refs: 1, mountpoints: make(map[*Dentry]uint32), } - mntns.root = newMount(vfs, fs, root, mntns, &MountOptions{}) + mntns.EnableLeakCheck() + mntns.root = newMount(vfs, fs, root, mntns, opts) return mntns, nil } @@ -263,16 +261,20 @@ func (vfs *VirtualFilesystem) ConnectMountAt(ctx context.Context, creds *auth.Cr } // MountAt creates and mounts a Filesystem configured by the given arguments. -func (vfs *VirtualFilesystem) MountAt(ctx context.Context, creds *auth.Credentials, source string, target *PathOperation, fsTypeName string, opts *MountOptions) error { +// The VirtualFilesystem will hold a reference to the Mount until it is unmounted. +// +// This method returns the mounted Mount without a reference, for convenience +// during VFS setup when there is no chance of racing with unmount. +func (vfs *VirtualFilesystem) MountAt(ctx context.Context, creds *auth.Credentials, source string, target *PathOperation, fsTypeName string, opts *MountOptions) (*Mount, error) { mnt, err := vfs.MountDisconnected(ctx, creds, source, fsTypeName, opts) if err != nil { - return err + return nil, err } defer mnt.DecRef(ctx) if err := vfs.ConnectMountAt(ctx, creds, mnt, target); err != nil { - return err + return nil, err } - return nil + return mnt, nil } // UmountAt removes the Mount at the given path. @@ -369,8 +371,9 @@ type umountRecursiveOptions struct { // // umountRecursiveLocked is analogous to Linux's fs/namespace.c:umount_tree(). // -// Preconditions: vfs.mountMu must be locked. vfs.mounts.seq must be in a -// writer critical section. +// Preconditions: +// * vfs.mountMu must be locked. +// * vfs.mounts.seq must be in a writer critical section. func (vfs *VirtualFilesystem) umountRecursiveLocked(mnt *Mount, opts *umountRecursiveOptions, vdsToDecRef []VirtualDentry, mountsToDecRef []*Mount) ([]VirtualDentry, []*Mount) { if !mnt.umounted { mnt.umounted = true @@ -399,9 +402,11 @@ func (vfs *VirtualFilesystem) umountRecursiveLocked(mnt *Mount, opts *umountRecu // connectLocked makes vd the mount parent/point for mnt. It consumes // references held by vd. // -// Preconditions: vfs.mountMu must be locked. vfs.mounts.seq must be in a -// writer critical section. d.mu must be locked. mnt.parent() == nil, i.e. mnt -// must not already be connected. +// Preconditions: +// * vfs.mountMu must be locked. +// * vfs.mounts.seq must be in a writer critical section. +// * d.mu must be locked. +// * mnt.parent() == nil, i.e. mnt must not already be connected. func (vfs *VirtualFilesystem) connectLocked(mnt *Mount, vd VirtualDentry, mntns *MountNamespace) { if checkInvariants { if mnt.parent() != nil { @@ -429,8 +434,10 @@ func (vfs *VirtualFilesystem) connectLocked(mnt *Mount, vd VirtualDentry, mntns // disconnectLocked makes vd have no mount parent/point and returns its old // mount parent/point with a reference held. // -// Preconditions: vfs.mountMu must be locked. vfs.mounts.seq must be in a -// writer critical section. mnt.parent() != nil. +// Preconditions: +// * vfs.mountMu must be locked. +// * vfs.mounts.seq must be in a writer critical section. +// * mnt.parent() != nil. func (vfs *VirtualFilesystem) disconnectLocked(mnt *Mount) VirtualDentry { vd := mnt.loadKey() if checkInvariants { @@ -498,17 +505,10 @@ func (mnt *Mount) DecRef(ctx context.Context) { } } -// IncRef increments mntns' reference count. -func (mntns *MountNamespace) IncRef() { - if atomic.AddInt64(&mntns.refs, 1) <= 1 { - panic("MountNamespace.IncRef() called without holding a reference") - } -} - // DecRef decrements mntns' reference count. func (mntns *MountNamespace) DecRef(ctx context.Context) { vfs := mntns.root.fs.VirtualFilesystem() - if refs := atomic.AddInt64(&mntns.refs, -1); refs == 0 { + mntns.MountNamespaceRefs.DecRef(func() { vfs.mountMu.Lock() vfs.mounts.seq.BeginWrite() vdsToDecRef, mountsToDecRef := vfs.umountRecursiveLocked(mntns.root, &umountRecursiveOptions{ @@ -522,9 +522,7 @@ func (mntns *MountNamespace) DecRef(ctx context.Context) { for _, mnt := range mountsToDecRef { mnt.DecRef(ctx) } - } else if refs < 0 { - panic("MountNamespace.DecRef() called without holding a reference") - } + }) } // getMountAt returns the last Mount in the stack mounted at (mnt, d). It takes @@ -576,8 +574,9 @@ retryFirst: // mnt. It takes a reference on the returned VirtualDentry. If no such mount // point exists (i.e. mnt is a root mount), getMountpointAt returns (nil, nil). // -// Preconditions: References are held on mnt and root. vfsroot is not (mnt, -// mnt.root). +// Preconditions: +// * References are held on mnt and root. +// * vfsroot is not (mnt, mnt.root). func (vfs *VirtualFilesystem) getMountpointAt(ctx context.Context, mnt *Mount, vfsroot VirtualDentry) VirtualDentry { // The first mount is special-cased: // @@ -651,6 +650,13 @@ retryFirst: return VirtualDentry{mnt, d} } +// SetMountReadOnly sets the mount as ReadOnly. +func (vfs *VirtualFilesystem) SetMountReadOnly(mnt *Mount, ro bool) error { + vfs.mountMu.Lock() + defer vfs.mountMu.Unlock() + return mnt.setReadOnlyLocked(ro) +} + // CheckBeginWrite increments the counter of in-progress write operations on // mnt. If mnt is mounted MS_RDONLY, CheckBeginWrite does nothing and returns // EROFS. @@ -688,7 +694,8 @@ func (mnt *Mount) setReadOnlyLocked(ro bool) error { return nil } -func (mnt *Mount) readOnly() bool { +// ReadOnly returns true if mount is readonly. +func (mnt *Mount) ReadOnly() bool { return atomic.LoadInt64(&mnt.writers) < 0 } @@ -731,11 +738,23 @@ func (mntns *MountNamespace) Root() VirtualDentry { // // Preconditions: taskRootDir.Ok(). func (vfs *VirtualFilesystem) GenerateProcMounts(ctx context.Context, taskRootDir VirtualDentry, buf *bytes.Buffer) { - vfs.mountMu.Lock() - defer vfs.mountMu.Unlock() rootMnt := taskRootDir.mount + + vfs.mountMu.Lock() mounts := rootMnt.submountsLocked() + // Take a reference on mounts since we need to drop vfs.mountMu before + // calling vfs.PathnameReachable() (=> FilesystemImpl.PrependPath()). + for _, mnt := range mounts { + mnt.IncRef() + } + vfs.mountMu.Unlock() + defer func() { + for _, mnt := range mounts { + mnt.DecRef(ctx) + } + }() sort.Slice(mounts, func(i, j int) bool { return mounts[i].ID < mounts[j].ID }) + for _, mnt := range mounts { // Get the path to this mount relative to task root. mntRootVD := VirtualDentry{ @@ -746,7 +765,7 @@ func (vfs *VirtualFilesystem) GenerateProcMounts(ctx context.Context, taskRootDi if err != nil { // For some reason we didn't get a path. Log a warning // and run with empty path. - ctx.Warningf("Error getting pathname for mount root %+v: %v", mnt.root, err) + ctx.Warningf("VFS.GenerateProcMounts: error getting pathname for mount root %+v: %v", mnt.root, err) path = "" } if path == "" { @@ -756,7 +775,7 @@ func (vfs *VirtualFilesystem) GenerateProcMounts(ctx context.Context, taskRootDi } opts := "rw" - if mnt.readOnly() { + if mnt.ReadOnly() { opts = "ro" } if mnt.Flags.NoATime { @@ -780,11 +799,25 @@ func (vfs *VirtualFilesystem) GenerateProcMounts(ctx context.Context, taskRootDi // // Preconditions: taskRootDir.Ok(). func (vfs *VirtualFilesystem) GenerateProcMountInfo(ctx context.Context, taskRootDir VirtualDentry, buf *bytes.Buffer) { - vfs.mountMu.Lock() - defer vfs.mountMu.Unlock() rootMnt := taskRootDir.mount + + vfs.mountMu.Lock() mounts := rootMnt.submountsLocked() + // Take a reference on mounts since we need to drop vfs.mountMu before + // calling vfs.PathnameReachable() (=> FilesystemImpl.PrependPath()) or + // vfs.StatAt() (=> FilesystemImpl.StatAt()). + for _, mnt := range mounts { + mnt.IncRef() + } + vfs.mountMu.Unlock() + defer func() { + for _, mnt := range mounts { + mnt.DecRef(ctx) + } + }() sort.Slice(mounts, func(i, j int) bool { return mounts[i].ID < mounts[j].ID }) + + creds := auth.CredentialsFromContext(ctx) for _, mnt := range mounts { // Get the path to this mount relative to task root. mntRootVD := VirtualDentry{ @@ -795,7 +828,7 @@ func (vfs *VirtualFilesystem) GenerateProcMountInfo(ctx context.Context, taskRoo if err != nil { // For some reason we didn't get a path. Log a warning // and run with empty path. - ctx.Warningf("Error getting pathname for mount root %+v: %v", mnt.root, err) + ctx.Warningf("VFS.GenerateProcMountInfo: error getting pathname for mount root %+v: %v", mnt.root, err) path = "" } if path == "" { @@ -808,9 +841,10 @@ func (vfs *VirtualFilesystem) GenerateProcMountInfo(ctx context.Context, taskRoo Root: mntRootVD, Start: mntRootVD, } - statx, err := vfs.StatAt(ctx, auth.NewAnonymousCredentials(), pop, &StatOptions{}) + statx, err := vfs.StatAt(ctx, creds, pop, &StatOptions{}) if err != nil { // Well that's not good. Ignore this mount. + ctx.Warningf("VFS.GenerateProcMountInfo: failed to stat mount root %+v: %v", mnt.root, err) break } @@ -822,6 +856,9 @@ func (vfs *VirtualFilesystem) GenerateProcMountInfo(ctx context.Context, taskRoo fmt.Fprintf(buf, "%d ", mnt.ID) // (2) Parent ID (or this ID if there is no parent). + // Note that even if the call to mnt.parent() races with Mount + // destruction (which is possible since we're not holding vfs.mountMu), + // its Mount.ID will still be valid. pID := mnt.ID if p := mnt.parent(); p != nil { pID = p.ID @@ -844,7 +881,7 @@ func (vfs *VirtualFilesystem) GenerateProcMountInfo(ctx context.Context, taskRoo // (6) Mount options. opts := "rw" - if mnt.readOnly() { + if mnt.ReadOnly() { opts = "ro" } if mnt.Flags.NoATime { @@ -883,7 +920,7 @@ func superBlockOpts(mountPath string, mnt *Mount) string { // gVisor doesn't (yet) have a concept of super block options, so we // use the ro/rw bit from the mount flag. opts := "rw" - if mnt.readOnly() { + if mnt.ReadOnly() { opts = "ro" } diff --git a/pkg/sentry/vfs/mount_unsafe.go b/pkg/sentry/vfs/mount_unsafe.go index 70f850ca4..da2a2e9c4 100644 --- a/pkg/sentry/vfs/mount_unsafe.go +++ b/pkg/sentry/vfs/mount_unsafe.go @@ -13,7 +13,7 @@ // limitations under the License. // +build go1.12 -// +build !go1.16 +// +build !go1.17 // Check go:linkname function signatures when updating Go version. @@ -217,8 +217,9 @@ func (mt *mountTable) Insert(mount *Mount) { // insertSeqed inserts the given mount into mt. // -// Preconditions: mt.seq must be in a writer critical section. mt must not -// already contain a Mount with the same mount point and parent. +// Preconditions: +// * mt.seq must be in a writer critical section. +// * mt must not already contain a Mount with the same mount point and parent. func (mt *mountTable) insertSeqed(mount *Mount) { hash := memhash(unsafe.Pointer(&mount.key), uintptr(mt.seed), mountKeyBytes) @@ -269,9 +270,11 @@ func (mt *mountTable) insertSeqed(mount *Mount) { atomic.StorePointer(&mt.slots, newSlots) } -// Preconditions: There are no concurrent mutators of the table (slots, cap). -// If the table is visible to readers, then mt.seq must be in a writer critical -// section. cap must be a power of 2. +// Preconditions: +// * There are no concurrent mutators of the table (slots, cap). +// * If the table is visible to readers, then mt.seq must be in a writer +// critical section. +// * cap must be a power of 2. func mtInsertLocked(slots unsafe.Pointer, cap uintptr, value unsafe.Pointer, hash uintptr) { mask := cap - 1 off := (hash & mask) * mountSlotBytes @@ -313,8 +316,9 @@ func (mt *mountTable) Remove(mount *Mount) { // removeSeqed removes the given mount from mt. // -// Preconditions: mt.seq must be in a writer critical section. mt must contain -// mount. +// Preconditions: +// * mt.seq must be in a writer critical section. +// * mt must contain mount. func (mt *mountTable) removeSeqed(mount *Mount) { hash := memhash(unsafe.Pointer(&mount.key), uintptr(mt.seed), mountKeyBytes) tcap := uintptr(1) << (mt.size & mtSizeOrderMask) diff --git a/pkg/sentry/vfs/options.go b/pkg/sentry/vfs/options.go index dfc8573fd..b33d36cb1 100644 --- a/pkg/sentry/vfs/options.go +++ b/pkg/sentry/vfs/options.go @@ -190,10 +190,10 @@ type BoundEndpointOptions struct { Addr string } -// GetxattrOptions contains options to VirtualFilesystem.GetxattrAt(), -// FilesystemImpl.GetxattrAt(), FileDescription.Getxattr(), and -// FileDescriptionImpl.Getxattr(). -type GetxattrOptions struct { +// GetXattrOptions contains options to VirtualFilesystem.GetXattrAt(), +// FilesystemImpl.GetXattrAt(), FileDescription.GetXattr(), and +// FileDescriptionImpl.GetXattr(). +type GetXattrOptions struct { // Name is the name of the extended attribute to retrieve. Name string @@ -204,10 +204,10 @@ type GetxattrOptions struct { Size uint64 } -// SetxattrOptions contains options to VirtualFilesystem.SetxattrAt(), -// FilesystemImpl.SetxattrAt(), FileDescription.Setxattr(), and -// FileDescriptionImpl.Setxattr(). -type SetxattrOptions struct { +// SetXattrOptions contains options to VirtualFilesystem.SetXattrAt(), +// FilesystemImpl.SetXattrAt(), FileDescription.SetXattr(), and +// FileDescriptionImpl.SetXattr(). +type SetXattrOptions struct { // Name is the name of the extended attribute being mutated. Name string diff --git a/pkg/sentry/vfs/permissions.go b/pkg/sentry/vfs/permissions.go index 33389c1df..00eeb8842 100644 --- a/pkg/sentry/vfs/permissions.go +++ b/pkg/sentry/vfs/permissions.go @@ -16,6 +16,7 @@ package vfs import ( "math" + "strings" "gvisor.dev/gvisor/pkg/abi/linux" "gvisor.dev/gvisor/pkg/context" @@ -271,7 +272,7 @@ func HasCapabilityOnFile(creds *auth.Credentials, cp linux.Capability, kuid auth // operation must not proceed. Otherwise it returns the max length allowed to // without violating the limit. func CheckLimit(ctx context.Context, offset, size int64) (int64, error) { - fileSizeLimit := limits.FromContext(ctx).Get(limits.FileSize).Cur + fileSizeLimit := limits.FromContextOrDie(ctx).Get(limits.FileSize).Cur if fileSizeLimit > math.MaxInt64 { return size, nil } @@ -284,3 +285,40 @@ func CheckLimit(ctx context.Context, offset, size int64) (int64, error) { } return size, nil } + +// CheckXattrPermissions checks permissions for extended attribute access. +// This is analogous to fs/xattr.c:xattr_permission(). Some key differences: +// * Does not check for read-only filesystem property. +// * Does not check inode immutability or append only mode. In both cases EPERM +// must be returned by filesystem implementations. +// * Does not do inode permission checks. Filesystem implementations should +// handle inode permission checks as they may differ across implementations. +func CheckXattrPermissions(creds *auth.Credentials, ats AccessTypes, mode linux.FileMode, kuid auth.KUID, name string) error { + switch { + case strings.HasPrefix(name, linux.XATTR_TRUSTED_PREFIX): + // The trusted.* namespace can only be accessed by privileged + // users. + if creds.HasCapability(linux.CAP_SYS_ADMIN) { + return nil + } + if ats.MayWrite() { + return syserror.EPERM + } + return syserror.ENODATA + case strings.HasPrefix(name, linux.XATTR_USER_PREFIX): + // In the user.* namespace, only regular files and directories can have + // extended attributes. For sticky directories, only the owner and + // privileged users can write attributes. + filetype := mode.FileType() + if filetype != linux.ModeRegular && filetype != linux.ModeDirectory { + if ats.MayWrite() { + return syserror.EPERM + } + return syserror.ENODATA + } + if filetype == linux.ModeDirectory && mode&linux.ModeSticky != 0 && ats.MayWrite() && !CanActAsOwner(creds, kuid) { + return syserror.EPERM + } + } + return nil +} diff --git a/pkg/sentry/vfs/vfs.go b/pkg/sentry/vfs/vfs.go index 9c2420683..1ebf355ef 100644 --- a/pkg/sentry/vfs/vfs.go +++ b/pkg/sentry/vfs/vfs.go @@ -24,9 +24,9 @@ // Locks acquired by FilesystemImpls between Prepare{Delete,Rename}Dentry and Commit{Delete,Rename*}Dentry // VirtualFilesystem.filesystemsMu // EpollInstance.mu -// Inotify.mu -// Watches.mu -// Inotify.evMu +// Inotify.mu +// Watches.mu +// Inotify.evMu // VirtualFilesystem.fsTypesMu // // Locking Dentry.mu in multiple Dentries requires holding @@ -36,6 +36,7 @@ package vfs import ( "fmt" + "path" "gvisor.dev/gvisor/pkg/abi/linux" "gvisor.dev/gvisor/pkg/context" @@ -296,6 +297,8 @@ func (vfs *VirtualFilesystem) LinkAt(ctx context.Context, creds *auth.Credential // MkdirAt creates a directory at the given path. func (vfs *VirtualFilesystem) MkdirAt(ctx context.Context, creds *auth.Credentials, pop *PathOperation, opts *MkdirOptions) error { if !pop.Path.Begin.Ok() { + // pop.Path should not be empty in operations that create/delete files. + // This is consistent with mkdirat(dirfd, "", mode). if pop.Path.Absolute { return syserror.EEXIST } @@ -332,6 +335,8 @@ func (vfs *VirtualFilesystem) MkdirAt(ctx context.Context, creds *auth.Credentia // error from the syserror package. func (vfs *VirtualFilesystem) MknodAt(ctx context.Context, creds *auth.Credentials, pop *PathOperation, opts *MknodOptions) error { if !pop.Path.Begin.Ok() { + // pop.Path should not be empty in operations that create/delete files. + // This is consistent with mknodat(dirfd, "", mode, dev). if pop.Path.Absolute { return syserror.EEXIST } @@ -517,6 +522,8 @@ func (vfs *VirtualFilesystem) RenameAt(ctx context.Context, creds *auth.Credenti // RmdirAt removes the directory at the given path. func (vfs *VirtualFilesystem) RmdirAt(ctx context.Context, creds *auth.Credentials, pop *PathOperation) error { if !pop.Path.Begin.Ok() { + // pop.Path should not be empty in operations that create/delete files. + // This is consistent with unlinkat(dirfd, "", AT_REMOVEDIR). if pop.Path.Absolute { return syserror.EBUSY } @@ -598,6 +605,8 @@ func (vfs *VirtualFilesystem) StatFSAt(ctx context.Context, creds *auth.Credenti // SymlinkAt creates a symbolic link at the given path with the given target. func (vfs *VirtualFilesystem) SymlinkAt(ctx context.Context, creds *auth.Credentials, pop *PathOperation, target string) error { if !pop.Path.Begin.Ok() { + // pop.Path should not be empty in operations that create/delete files. + // This is consistent with symlinkat(oldpath, newdirfd, ""). if pop.Path.Absolute { return syserror.EEXIST } @@ -630,6 +639,8 @@ func (vfs *VirtualFilesystem) SymlinkAt(ctx context.Context, creds *auth.Credent // UnlinkAt deletes the non-directory file at the given path. func (vfs *VirtualFilesystem) UnlinkAt(ctx context.Context, creds *auth.Credentials, pop *PathOperation) error { if !pop.Path.Begin.Ok() { + // pop.Path should not be empty in operations that create/delete files. + // This is consistent with unlinkat(dirfd, "", 0). if pop.Path.Absolute { return syserror.EBUSY } @@ -661,12 +672,6 @@ func (vfs *VirtualFilesystem) UnlinkAt(ctx context.Context, creds *auth.Credenti // BoundEndpointAt gets the bound endpoint at the given path, if one exists. func (vfs *VirtualFilesystem) BoundEndpointAt(ctx context.Context, creds *auth.Credentials, pop *PathOperation, opts *BoundEndpointOptions) (transport.BoundEndpoint, error) { - if !pop.Path.Begin.Ok() { - if pop.Path.Absolute { - return nil, syserror.ECONNREFUSED - } - return nil, syserror.ENOENT - } rp := vfs.getResolvingPath(creds, pop) for { bep, err := rp.mount.fs.impl.BoundEndpointAt(ctx, rp, *opts) @@ -686,12 +691,12 @@ func (vfs *VirtualFilesystem) BoundEndpointAt(ctx context.Context, creds *auth.C } } -// ListxattrAt returns all extended attribute names for the file at the given +// ListXattrAt returns all extended attribute names for the file at the given // path. -func (vfs *VirtualFilesystem) ListxattrAt(ctx context.Context, creds *auth.Credentials, pop *PathOperation, size uint64) ([]string, error) { +func (vfs *VirtualFilesystem) ListXattrAt(ctx context.Context, creds *auth.Credentials, pop *PathOperation, size uint64) ([]string, error) { rp := vfs.getResolvingPath(creds, pop) for { - names, err := rp.mount.fs.impl.ListxattrAt(ctx, rp, size) + names, err := rp.mount.fs.impl.ListXattrAt(ctx, rp, size) if err == nil { vfs.putResolvingPath(ctx, rp) return names, nil @@ -711,12 +716,12 @@ func (vfs *VirtualFilesystem) ListxattrAt(ctx context.Context, creds *auth.Crede } } -// GetxattrAt returns the value associated with the given extended attribute +// GetXattrAt returns the value associated with the given extended attribute // for the file at the given path. -func (vfs *VirtualFilesystem) GetxattrAt(ctx context.Context, creds *auth.Credentials, pop *PathOperation, opts *GetxattrOptions) (string, error) { +func (vfs *VirtualFilesystem) GetXattrAt(ctx context.Context, creds *auth.Credentials, pop *PathOperation, opts *GetXattrOptions) (string, error) { rp := vfs.getResolvingPath(creds, pop) for { - val, err := rp.mount.fs.impl.GetxattrAt(ctx, rp, *opts) + val, err := rp.mount.fs.impl.GetXattrAt(ctx, rp, *opts) if err == nil { vfs.putResolvingPath(ctx, rp) return val, nil @@ -728,12 +733,12 @@ func (vfs *VirtualFilesystem) GetxattrAt(ctx context.Context, creds *auth.Creden } } -// SetxattrAt changes the value associated with the given extended attribute +// SetXattrAt changes the value associated with the given extended attribute // for the file at the given path. -func (vfs *VirtualFilesystem) SetxattrAt(ctx context.Context, creds *auth.Credentials, pop *PathOperation, opts *SetxattrOptions) error { +func (vfs *VirtualFilesystem) SetXattrAt(ctx context.Context, creds *auth.Credentials, pop *PathOperation, opts *SetXattrOptions) error { rp := vfs.getResolvingPath(creds, pop) for { - err := rp.mount.fs.impl.SetxattrAt(ctx, rp, *opts) + err := rp.mount.fs.impl.SetXattrAt(ctx, rp, *opts) if err == nil { vfs.putResolvingPath(ctx, rp) return nil @@ -745,11 +750,11 @@ func (vfs *VirtualFilesystem) SetxattrAt(ctx context.Context, creds *auth.Creden } } -// RemovexattrAt removes the given extended attribute from the file at rp. -func (vfs *VirtualFilesystem) RemovexattrAt(ctx context.Context, creds *auth.Credentials, pop *PathOperation, name string) error { +// RemoveXattrAt removes the given extended attribute from the file at rp. +func (vfs *VirtualFilesystem) RemoveXattrAt(ctx context.Context, creds *auth.Credentials, pop *PathOperation, name string) error { rp := vfs.getResolvingPath(creds, pop) for { - err := rp.mount.fs.impl.RemovexattrAt(ctx, rp, name) + err := rp.mount.fs.impl.RemoveXattrAt(ctx, rp, name) if err == nil { vfs.putResolvingPath(ctx, rp) return nil @@ -782,6 +787,62 @@ func (vfs *VirtualFilesystem) SyncAllFilesystems(ctx context.Context) error { return retErr } +// MkdirAllAt recursively creates non-existent directories on the given path +// (including the last component). +func (vfs *VirtualFilesystem) MkdirAllAt(ctx context.Context, currentPath string, root VirtualDentry, creds *auth.Credentials, mkdirOpts *MkdirOptions) error { + pop := &PathOperation{ + Root: root, + Start: root, + Path: fspath.Parse(currentPath), + } + stat, err := vfs.StatAt(ctx, creds, pop, &StatOptions{Mask: linux.STATX_TYPE}) + switch err { + case nil: + if stat.Mask&linux.STATX_TYPE == 0 || stat.Mode&linux.FileTypeMask != linux.ModeDirectory { + return syserror.ENOTDIR + } + // Directory already exists. + return nil + case syserror.ENOENT: + // Expected, we will create the dir. + default: + return fmt.Errorf("stat failed for %q during directory creation: %w", currentPath, err) + } + + // Recurse to ensure parent is created and then create the final directory. + if err := vfs.MkdirAllAt(ctx, path.Dir(currentPath), root, creds, mkdirOpts); err != nil { + return err + } + if err := vfs.MkdirAt(ctx, creds, pop, mkdirOpts); err != nil { + return fmt.Errorf("failed to create directory %q: %w", currentPath, err) + } + return nil +} + +// MakeSyntheticMountpoint creates parent directories of target if they do not +// exist and attempts to create a directory for the mountpoint. If a +// non-directory file already exists there then we allow it. +func (vfs *VirtualFilesystem) MakeSyntheticMountpoint(ctx context.Context, target string, root VirtualDentry, creds *auth.Credentials) error { + mkdirOpts := &MkdirOptions{Mode: 0777, ForSyntheticMountpoint: true} + + // Make sure the parent directory of target exists. + if err := vfs.MkdirAllAt(ctx, path.Dir(target), root, creds, mkdirOpts); err != nil { + return fmt.Errorf("failed to create parent directory of mountpoint %q: %w", target, err) + } + + // Attempt to mkdir the final component. If a file (of any type) exists + // then we let allow mounting on top of that because we do not require the + // target to be an existing directory, unlike Linux mount(2). + if err := vfs.MkdirAt(ctx, creds, &PathOperation{ + Root: root, + Start: root, + Path: fspath.Parse(target), + }, mkdirOpts); err != nil && err != syserror.EEXIST { + return fmt.Errorf("failed to create mountpoint %q: %w", target, err) + } + return nil +} + // A VirtualDentry represents a node in a VFS tree, by combining a Dentry // (which represents a node in a Filesystem's tree) and a Mount (which // represents the Filesystem's position in a VFS mount tree). diff --git a/pkg/sentry/watchdog/watchdog.go b/pkg/sentry/watchdog/watchdog.go index 748273366..bbafb8b7f 100644 --- a/pkg/sentry/watchdog/watchdog.go +++ b/pkg/sentry/watchdog/watchdog.go @@ -96,15 +96,33 @@ const ( Panic ) +// Set implements flag.Value. +func (a *Action) Set(v string) error { + switch v { + case "log", "logwarning": + *a = LogWarning + case "panic": + *a = Panic + default: + return fmt.Errorf("invalid watchdog action %q", v) + } + return nil +} + +// Get implements flag.Value. +func (a *Action) Get() interface{} { + return *a +} + // String returns Action's string representation. -func (a Action) String() string { - switch a { +func (a *Action) String() string { + switch *a { case LogWarning: - return "LogWarning" + return "logWarning" case Panic: - return "Panic" + return "panic" default: - panic(fmt.Sprintf("Invalid action: %d", a)) + panic(fmt.Sprintf("Invalid watchdog action: %d", *a)) } } diff --git a/pkg/shim/v2/runtimeoptions/BUILD b/pkg/shim/v2/runtimeoptions/BUILD index 01716034c..ba2ed1ea7 100644 --- a/pkg/shim/v2/runtimeoptions/BUILD +++ b/pkg/shim/v2/runtimeoptions/BUILD @@ -1,4 +1,4 @@ -load("//tools:defs.bzl", "go_library", "proto_library") +load("//tools:defs.bzl", "go_library", "go_test", "proto_library") package(licenses = ["notice"]) @@ -14,7 +14,19 @@ go_library( srcs = ["runtimeoptions.go"], visibility = ["//pkg/shim/v2:__pkg__"], deps = [ - "//pkg/shim/v2/runtimeoptions:api_go_proto", + ":api_go_proto", "@com_github_gogo_protobuf//proto:go_default_library", ], ) + +go_test( + name = "runtimeoptions_test", + size = "small", + srcs = ["runtimeoptions_test.go"], + library = ":runtimeoptions", + deps = [ + "@com_github_containerd_containerd//runtime/v1/shim/v1:go_default_library", + "@com_github_containerd_typeurl//:go_default_library", + "@com_github_golang_protobuf//proto:go_default_library", + ], +) diff --git a/pkg/shim/v2/runtimeoptions/runtimeoptions.go b/pkg/shim/v2/runtimeoptions/runtimeoptions.go index 1c1a0c5d1..aaf17b87a 100644 --- a/pkg/shim/v2/runtimeoptions/runtimeoptions.go +++ b/pkg/shim/v2/runtimeoptions/runtimeoptions.go @@ -23,5 +23,8 @@ import ( type Options = pb.Options func init() { + // The generated proto file auto registers with "golang/protobuf/proto" + // package. However, typeurl uses "golang/gogo/protobuf/proto". So registers + // the type there too. proto.RegisterType((*Options)(nil), "cri.runtimeoptions.v1.Options") } diff --git a/pkg/shim/v2/runtimeoptions/runtimeoptions.proto b/pkg/shim/v2/runtimeoptions/runtimeoptions.proto index edb19020a..057032e34 100644 --- a/pkg/shim/v2/runtimeoptions/runtimeoptions.proto +++ b/pkg/shim/v2/runtimeoptions/runtimeoptions.proto @@ -14,11 +14,11 @@ syntax = "proto3"; -package runtimeoptions; +package cri.runtimeoptions.v1; // This is a version of the runtimeoptions CRI API that is vendored. // -// Imported the full CRI package is a nightmare. +// Importing the full CRI package is a nightmare. message Options { string type_url = 1; string config_path = 2; diff --git a/pkg/shim/v2/runtimeoptions/runtimeoptions_test.go b/pkg/shim/v2/runtimeoptions/runtimeoptions_test.go new file mode 100644 index 000000000..f4c238a00 --- /dev/null +++ b/pkg/shim/v2/runtimeoptions/runtimeoptions_test.go @@ -0,0 +1,52 @@ +// Copyright 2020 The gVisor Authors. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// https://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package runtimeoptions + +import ( + "testing" + + shim "github.com/containerd/containerd/runtime/v1/shim/v1" + "github.com/containerd/typeurl" + "github.com/golang/protobuf/proto" +) + +func TestCreateTaskRequest(t *testing.T) { + // Serialize the top-level message. + const encodedText = `options: < + type_url: "cri.runtimeoptions.v1.Options" + value: "\n\010type_url\022\013config_path" +>` + got := &shim.CreateTaskRequest{} // Should have raw options. + if err := proto.UnmarshalText(encodedText, got); err != nil { + t.Fatalf("unable to unmarshal text: %v", err) + } + t.Logf("got: %s", proto.MarshalTextString(got)) + + // Check the options. + wantOptions := &Options{} + wantOptions.TypeUrl = "type_url" + wantOptions.ConfigPath = "config_path" + gotMessage, err := typeurl.UnmarshalAny(got.Options) + if err != nil { + t.Fatalf("unable to unmarshal any: %v", err) + } + gotOptions, ok := gotMessage.(*Options) + if !ok { + t.Fatalf("got %v, want %v", gotMessage, wantOptions) + } + if !proto.Equal(gotOptions, wantOptions) { + t.Fatalf("got %v, want %v", gotOptions, wantOptions) + } +} diff --git a/pkg/sleep/sleep_unsafe.go b/pkg/sleep/sleep_unsafe.go index 118805492..19bce2afb 100644 --- a/pkg/sleep/sleep_unsafe.go +++ b/pkg/sleep/sleep_unsafe.go @@ -13,7 +13,7 @@ // limitations under the License. // +build go1.11 -// +build !go1.16 +// +build !go1.17 // Check go:linkname function signatures when updating Go version. diff --git a/pkg/state/decode.go b/pkg/state/decode.go index c9971cdf6..89467ca8e 100644 --- a/pkg/state/decode.go +++ b/pkg/state/decode.go @@ -584,10 +584,12 @@ func (ds *decodeState) Load(obj reflect.Value) { }) // Create the root object. - ds.objectsByID = append(ds.objectsByID, &objectDecodeState{ + rootOds := &objectDecodeState{ id: 1, obj: obj, - }) + } + ds.objectsByID = append(ds.objectsByID, rootOds) + ds.pending.PushBack(rootOds) // Read the number of objects. lastID, object, err := ReadHeader(ds.r) diff --git a/pkg/state/pretty/pretty.go b/pkg/state/pretty/pretty.go index cf37aaa49..887f453a9 100644 --- a/pkg/state/pretty/pretty.go +++ b/pkg/state/pretty/pretty.go @@ -26,12 +26,17 @@ import ( "gvisor.dev/gvisor/pkg/state/wire" ) -func formatRef(x *wire.Ref, graph uint64, html bool) string { +type printer struct { + html bool + typeSpecs map[string]*wire.Type +} + +func (p *printer) formatRef(x *wire.Ref, graph uint64) string { baseRef := fmt.Sprintf("g%dr%d", graph, x.Root) fullRef := baseRef if len(x.Dots) > 0 { // See wire.Ref; Type valid if Dots non-zero. - typ, _ := formatType(x.Type, graph, html) + typ, _ := p.formatType(x.Type, graph) var buf strings.Builder buf.WriteString("(*") buf.WriteString(typ) @@ -51,34 +56,40 @@ func formatRef(x *wire.Ref, graph uint64, html bool) string { buf.WriteString(")") fullRef = buf.String() } - if html { + if p.html { return fmt.Sprintf("<a href=\"#%s\">%s</a>", baseRef, fullRef) } return fullRef } -func formatType(t wire.TypeSpec, graph uint64, html bool) (string, bool) { +func (p *printer) formatType(t wire.TypeSpec, graph uint64) (string, bool) { switch x := t.(type) { case wire.TypeID: - base := fmt.Sprintf("g%dt%d", graph, x) - if html { - return fmt.Sprintf("<a href=\"#%s\">%s</a>", base, base), true + tag := fmt.Sprintf("g%dt%d", graph, x) + desc := tag + if spec, ok := p.typeSpecs[tag]; ok { + desc += fmt.Sprintf("=%s", spec.Name) + } else { + desc += "!missing-type-spec" + } + if p.html { + return fmt.Sprintf("<a href=\"#%s\">%s</a>", tag, desc), true } - return fmt.Sprintf("%s", base), true + return desc, true case wire.TypeSpecNil: return "", false // Only nil type. case *wire.TypeSpecPointer: - element, _ := formatType(x.Type, graph, html) + element, _ := p.formatType(x.Type, graph) return fmt.Sprintf("(*%s)", element), true case *wire.TypeSpecArray: - element, _ := formatType(x.Type, graph, html) + element, _ := p.formatType(x.Type, graph) return fmt.Sprintf("[%d](%s)", x.Count, element), true case *wire.TypeSpecSlice: - element, _ := formatType(x.Type, graph, html) + element, _ := p.formatType(x.Type, graph) return fmt.Sprintf("([]%s)", element), true case *wire.TypeSpecMap: - key, _ := formatType(x.Key, graph, html) - value, _ := formatType(x.Value, graph, html) + key, _ := p.formatType(x.Key, graph) + value, _ := p.formatType(x.Value, graph) return fmt.Sprintf("(map[%s]%s)", key, value), true default: panic(fmt.Sprintf("unreachable: unknown type %T", t)) @@ -87,7 +98,7 @@ func formatType(t wire.TypeSpec, graph uint64, html bool) (string, bool) { // format formats a single object, for pretty-printing. It also returns whether // the value is a non-zero value. -func format(graph uint64, depth int, encoded wire.Object, html bool) (string, bool) { +func (p *printer) format(graph uint64, depth int, encoded wire.Object) (string, bool) { switch x := encoded.(type) { case wire.Nil: return "nil", false @@ -98,7 +109,7 @@ func format(graph uint64, depth int, encoded wire.Object, html bool) (string, bo case *wire.Complex128: return fmt.Sprintf("%f+%fi", real(*x), imag(*x)), *x != 0.0 case *wire.Ref: - return formatRef(x, graph, html), x.Root != 0 + return p.formatRef(x, graph), x.Root != 0 case *wire.Type: tabs := "\n" + strings.Repeat("\t", depth) items := make([]string, 0, len(x.Fields)+2) @@ -109,7 +120,7 @@ func format(graph uint64, depth int, encoded wire.Object, html bool) (string, bo items = append(items, "}") return strings.Join(items, tabs), true // No zero value. case *wire.Slice: - return fmt.Sprintf("%s{len:%d,cap:%d}", formatRef(&x.Ref, graph, html), x.Length, x.Capacity), x.Capacity != 0 + return fmt.Sprintf("%s{len:%d,cap:%d}", p.formatRef(&x.Ref, graph), x.Length, x.Capacity), x.Capacity != 0 case *wire.Array: if len(x.Contents) == 0 { return "[]", false @@ -119,7 +130,7 @@ func format(graph uint64, depth int, encoded wire.Object, html bool) (string, bo items = append(items, "[") tabs := "\n" + strings.Repeat("\t", depth) for i := 0; i < len(x.Contents); i++ { - item, ok := format(graph, depth+1, x.Contents[i], html) + item, ok := p.format(graph, depth+1, x.Contents[i]) if !ok { zeros = append(zeros, fmt.Sprintf("\t%s,", item)) continue @@ -136,7 +147,9 @@ func format(graph uint64, depth int, encoded wire.Object, html bool) (string, bo items = append(items, "]") return strings.Join(items, tabs), len(zeros) < len(x.Contents) case *wire.Struct: - typ, _ := formatType(x.TypeID, graph, html) + tag := fmt.Sprintf("g%dt%d", graph, x.TypeID) + spec, _ := p.typeSpecs[tag] + typ, _ := p.formatType(x.TypeID, graph) if x.Fields() == 0 { return fmt.Sprintf("struct[%s]{}", typ), false } @@ -145,10 +158,15 @@ func format(graph uint64, depth int, encoded wire.Object, html bool) (string, bo tabs := "\n" + strings.Repeat("\t", depth) allZero := true for i := 0; i < x.Fields(); i++ { - element, ok := format(graph, depth+1, *x.Field(i), html) + var name string + if spec != nil && i < len(spec.Fields) { + name = spec.Fields[i] + } else { + name = fmt.Sprintf("%d", i) + } + element, ok := p.format(graph, depth+1, *x.Field(i)) allZero = allZero && !ok - items = append(items, fmt.Sprintf("\t%d: %s,", i, element)) - i++ + items = append(items, fmt.Sprintf("\t%s: %s,", name, element)) } items = append(items, "}") return strings.Join(items, tabs), !allZero @@ -160,15 +178,15 @@ func format(graph uint64, depth int, encoded wire.Object, html bool) (string, bo items = append(items, "map{") tabs := "\n" + strings.Repeat("\t", depth) for i := 0; i < len(x.Keys); i++ { - key, _ := format(graph, depth+1, x.Keys[i], html) - value, _ := format(graph, depth+1, x.Values[i], html) + key, _ := p.format(graph, depth+1, x.Keys[i]) + value, _ := p.format(graph, depth+1, x.Values[i]) items = append(items, fmt.Sprintf("\t%s: %s,", key, value)) } items = append(items, "}") return strings.Join(items, tabs), true case *wire.Interface: - typ, typOk := formatType(x.Type, graph, html) - element, elementOk := format(graph, depth+1, x.Value, html) + typ, typOk := p.formatType(x.Type, graph) + element, elementOk := p.format(graph, depth+1, x.Value) return fmt.Sprintf("interface[%s]{%s}", typ, element), typOk || elementOk default: // Must be a primitive; use reflection. @@ -177,11 +195,11 @@ func format(graph uint64, depth int, encoded wire.Object, html bool) (string, bo } // printStream is the basic print implementation. -func printStream(w io.Writer, r wire.Reader, html bool) (err error) { +func (p *printer) printStream(w io.Writer, r wire.Reader) (err error) { // current graph ID. var graph uint64 - if html { + if p.html { fmt.Fprintf(w, "<pre>") defer fmt.Fprintf(w, "</pre>") } @@ -196,6 +214,8 @@ func printStream(w io.Writer, r wire.Reader, html bool) (err error) { } }() + p.typeSpecs = make(map[string]*wire.Type) + for { // Find the first object to begin generation. length, object, err := state.ReadHeader(r) @@ -223,18 +243,19 @@ func printStream(w io.Writer, r wire.Reader, html bool) (err error) { // loop in decode.go. But we don't register type information, // etc. and just print the raw structures. var ( - oid uint64 = 1 - tid uint64 = 1 + tid uint64 = 1 + objects []wire.Object ) - for oid <= length { + for oid := uint64(1); oid <= length; { // Unmarshal the object. encoded := wire.Load(r) // Is this a type? - if _, ok := encoded.(*wire.Type); ok { - str, _ := format(graph, 0, encoded, html) + if typ, ok := encoded.(*wire.Type); ok { + str, _ := p.format(graph, 0, encoded) tag := fmt.Sprintf("g%dt%d", graph, tid) - if html { + p.typeSpecs[tag] = typ + if p.html { // See below. tag = fmt.Sprintf("<a name=\"%s\">%s</a><a href=\"#%s\">⚓</a>", tag, tag, tag) } @@ -245,17 +266,24 @@ func printStream(w io.Writer, r wire.Reader, html bool) (err error) { continue } + // Otherwise, it is a node. + objects = append(objects, encoded) + oid++ + } + + for i, encoded := range objects { + // oid starts at 1. + oid := i + 1 // Format the node. - str, _ := format(graph, 0, encoded, html) + str, _ := p.format(graph, 0, encoded) tag := fmt.Sprintf("g%dr%d", graph, oid) - if html { + if p.html { // Create a little tag with an anchor next to it for linking. tag = fmt.Sprintf("<a name=\"%s\">%s</a><a href=\"#%s\">⚓</a>", tag, tag, tag) } if _, err := fmt.Fprintf(w, "%s = %s\n", tag, str); err != nil { return err } - oid++ } } @@ -264,10 +292,10 @@ func printStream(w io.Writer, r wire.Reader, html bool) (err error) { // PrintText reads the stream from r and prints text to w. func PrintText(w io.Writer, r wire.Reader) error { - return printStream(w, r, false /* html */) + return (&printer{}).printStream(w, r) } // PrintHTML reads the stream from r and prints html to w. func PrintHTML(w io.Writer, r wire.Reader) error { - return printStream(w, r, true /* html */) + return (&printer{html: true}).printStream(w, r) } diff --git a/pkg/state/tests/load_test.go b/pkg/state/tests/load_test.go index 1e9794296..3c73ac391 100644 --- a/pkg/state/tests/load_test.go +++ b/pkg/state/tests/load_test.go @@ -20,6 +20,14 @@ import ( func TestLoadHooks(t *testing.T) { runTestCases(t, false, "load-hooks", []interface{}{ + // Root object being a struct. + afterLoadStruct{v: 1}, + valueLoadStruct{v: 1}, + genericContainer{v: &afterLoadStruct{v: 1}}, + genericContainer{v: &valueLoadStruct{v: 1}}, + sliceContainer{v: []interface{}{&afterLoadStruct{v: 1}}}, + sliceContainer{v: []interface{}{&valueLoadStruct{v: 1}}}, + // Root object being a pointer. &afterLoadStruct{v: 1}, &valueLoadStruct{v: 1}, &genericContainer{v: &afterLoadStruct{v: 1}}, diff --git a/pkg/sync/BUILD b/pkg/sync/BUILD index 4d47207f7..68535c3b1 100644 --- a/pkg/sync/BUILD +++ b/pkg/sync/BUILD @@ -38,6 +38,7 @@ go_library( "race_unsafe.go", "rwmutex_unsafe.go", "seqcount.go", + "spin_unsafe.go", "sync.go", ], marshal = False, diff --git a/pkg/sync/memmove_unsafe.go b/pkg/sync/memmove_unsafe.go index 1d7780695..f5e630009 100644 --- a/pkg/sync/memmove_unsafe.go +++ b/pkg/sync/memmove_unsafe.go @@ -4,7 +4,7 @@ // license that can be found in the LICENSE file. // +build go1.12 -// +build !go1.16 +// +build !go1.17 // Check go:linkname function signatures when updating Go version. diff --git a/pkg/sync/mutex_unsafe.go b/pkg/sync/mutex_unsafe.go index dc034d561..f4c2e9642 100644 --- a/pkg/sync/mutex_unsafe.go +++ b/pkg/sync/mutex_unsafe.go @@ -4,7 +4,7 @@ // license that can be found in the LICENSE file. // +build go1.13 -// +build !go1.16 +// +build !go1.17 // When updating the build constraint (above), check that syncMutex matches the // standard library sync.Mutex definition. diff --git a/pkg/sync/rwmutex_unsafe.go b/pkg/sync/rwmutex_unsafe.go index 995c0346e..b3b4dee78 100644 --- a/pkg/sync/rwmutex_unsafe.go +++ b/pkg/sync/rwmutex_unsafe.go @@ -4,7 +4,7 @@ // license that can be found in the LICENSE file. // +build go1.13 -// +build !go1.16 +// +build !go1.17 // Check go:linkname function signatures when updating Go version. diff --git a/pkg/sync/seqatomic_unsafe.go b/pkg/sync/seqatomic_unsafe.go index eda6fb131..2184cb5ab 100644 --- a/pkg/sync/seqatomic_unsafe.go +++ b/pkg/sync/seqatomic_unsafe.go @@ -25,41 +25,35 @@ import ( type Value struct{} // SeqAtomicLoad returns a copy of *ptr, ensuring that the read does not race -// with any writer critical sections in sc. -func SeqAtomicLoad(sc *sync.SeqCount, ptr *Value) Value { - // This function doesn't use SeqAtomicTryLoad because doing so is - // measurably, significantly (~20%) slower; Go is awful at inlining. - var val Value +// with any writer critical sections in seq. +// +//go:nosplit +func SeqAtomicLoad(seq *sync.SeqCount, ptr *Value) Value { for { - epoch := sc.BeginRead() - if sync.RaceEnabled { - // runtime.RaceDisable() doesn't actually stop the race detector, - // so it can't help us here. Instead, call runtime.memmove - // directly, which is not instrumented by the race detector. - sync.Memmove(unsafe.Pointer(&val), unsafe.Pointer(ptr), unsafe.Sizeof(val)) - } else { - // This is ~40% faster for short reads than going through memmove. - val = *ptr - } - if sc.ReadOk(epoch) { - break + if val, ok := SeqAtomicTryLoad(seq, seq.BeginRead(), ptr); ok { + return val } } - return val } // SeqAtomicTryLoad returns a copy of *ptr while in a reader critical section -// in sc initiated by a call to sc.BeginRead() that returned epoch. If the read -// would race with a writer critical section, SeqAtomicTryLoad returns +// in seq initiated by a call to seq.BeginRead() that returned epoch. If the +// read would race with a writer critical section, SeqAtomicTryLoad returns // (unspecified, false). -func SeqAtomicTryLoad(sc *sync.SeqCount, epoch sync.SeqCountEpoch, ptr *Value) (Value, bool) { - var val Value +// +//go:nosplit +func SeqAtomicTryLoad(seq *sync.SeqCount, epoch sync.SeqCountEpoch, ptr *Value) (val Value, ok bool) { if sync.RaceEnabled { + // runtime.RaceDisable() doesn't actually stop the race detector, so it + // can't help us here. Instead, call runtime.memmove directly, which is + // not instrumented by the race detector. sync.Memmove(unsafe.Pointer(&val), unsafe.Pointer(ptr), unsafe.Sizeof(val)) } else { + // This is ~40% faster for short reads than going through memmove. val = *ptr } - return val, sc.ReadOk(epoch) + ok = seq.ReadOk(epoch) + return } func init() { diff --git a/pkg/sync/seqcount.go b/pkg/sync/seqcount.go index a1e895352..2c5d3df99 100644 --- a/pkg/sync/seqcount.go +++ b/pkg/sync/seqcount.go @@ -8,7 +8,6 @@ package sync import ( "fmt" "reflect" - "runtime" "sync/atomic" ) @@ -43,9 +42,7 @@ type SeqCount struct { } // SeqCountEpoch tracks writer critical sections in a SeqCount. -type SeqCountEpoch struct { - val uint32 -} +type SeqCountEpoch uint32 // We assume that: // @@ -83,12 +80,25 @@ type SeqCountEpoch struct { // using this pattern. Most users of SeqCount will need to use the // SeqAtomicLoad function template in seqatomic.go. func (s *SeqCount) BeginRead() SeqCountEpoch { - epoch := atomic.LoadUint32(&s.epoch) - for epoch&1 != 0 { - runtime.Gosched() - epoch = atomic.LoadUint32(&s.epoch) + if epoch := atomic.LoadUint32(&s.epoch); epoch&1 == 0 { + return SeqCountEpoch(epoch) + } + return s.beginReadSlow() +} + +func (s *SeqCount) beginReadSlow() SeqCountEpoch { + i := 0 + for { + if canSpin(i) { + i++ + doSpin() + } else { + goyield() + } + if epoch := atomic.LoadUint32(&s.epoch); epoch&1 == 0 { + return SeqCountEpoch(epoch) + } } - return SeqCountEpoch{epoch} } // ReadOk returns true if the reader critical section initiated by a previous @@ -99,7 +109,7 @@ func (s *SeqCount) BeginRead() SeqCountEpoch { // Reader critical sections do not need to be explicitly terminated; the last // call to ReadOk is implicitly the end of the reader critical section. func (s *SeqCount) ReadOk(epoch SeqCountEpoch) bool { - return atomic.LoadUint32(&s.epoch) == epoch.val + return atomic.LoadUint32(&s.epoch) == uint32(epoch) } // BeginWrite indicates the beginning of a writer critical section. diff --git a/pkg/sync/spin_unsafe.go b/pkg/sync/spin_unsafe.go new file mode 100644 index 000000000..cafb2d065 --- /dev/null +++ b/pkg/sync/spin_unsafe.go @@ -0,0 +1,24 @@ +// Copyright 2020 The gVisor Authors. +// +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +// +build go1.13 +// +build !go1.17 + +// Check go:linkname function signatures when updating Go version. + +package sync + +import ( + _ "unsafe" // for go:linkname +) + +//go:linkname canSpin sync.runtime_canSpin +func canSpin(i int) bool + +//go:linkname doSpin sync.runtime_doSpin +func doSpin() + +//go:linkname goyield runtime.goyield +func goyield() diff --git a/pkg/syncevent/broadcaster.go b/pkg/syncevent/broadcaster.go index 4bff59e7d..dabf08895 100644 --- a/pkg/syncevent/broadcaster.go +++ b/pkg/syncevent/broadcaster.go @@ -111,7 +111,9 @@ func (b *Broadcaster) SubscribeEvents(r *Receiver, filter Set) SubscriptionID { return id } -// Preconditions: table must not be full. len(table) is a power of 2. +// Preconditions: +// * table must not be full. +// * len(table) is a power of 2. func broadcasterTableInsert(table []broadcasterSlot, id SubscriptionID, r *Receiver, filter Set) { entry := broadcasterSlot{ receiver: r, diff --git a/pkg/syncevent/source.go b/pkg/syncevent/source.go index ddffb171a..d3d0f34c5 100644 --- a/pkg/syncevent/source.go +++ b/pkg/syncevent/source.go @@ -19,9 +19,11 @@ type Source interface { // SubscribeEvents causes the Source to notify the given Receiver of the // given subset of events. // - // Preconditions: r != nil. The ReceiverCallback for r must not take locks - // that are ordered prior to the Source; for example, it cannot call any - // Source methods. + // Preconditions: + // * r != nil. + // * The ReceiverCallback for r must not take locks that are ordered + // prior to the Source; for example, it cannot call any Source + // methods. SubscribeEvents(r *Receiver, filter Set) SubscriptionID // UnsubscribeEvents causes the Source to stop notifying the Receiver diff --git a/pkg/syncevent/waiter_unsafe.go b/pkg/syncevent/waiter_unsafe.go index ad271e1a0..518f18479 100644 --- a/pkg/syncevent/waiter_unsafe.go +++ b/pkg/syncevent/waiter_unsafe.go @@ -13,7 +13,7 @@ // limitations under the License. // +build go1.11 -// +build !go1.16 +// +build !go1.17 // Check go:linkname function signatures when updating Go version. diff --git a/pkg/syserror/syserror.go b/pkg/syserror/syserror.go index 798e07b01..f516c8e46 100644 --- a/pkg/syserror/syserror.go +++ b/pkg/syserror/syserror.go @@ -33,6 +33,7 @@ var ( EBADFD = error(syscall.EBADFD) EBUSY = error(syscall.EBUSY) ECHILD = error(syscall.ECHILD) + ECONNABORTED = error(syscall.ECONNABORTED) ECONNREFUSED = error(syscall.ECONNREFUSED) ECONNRESET = error(syscall.ECONNRESET) EDEADLK = error(syscall.EDEADLK) @@ -153,6 +154,73 @@ func ConvertIntr(err, intr error) error { return err } +// SyscallRestartErrno represents a ERESTART* errno defined in the Linux's kernel +// include/linux/errno.h. These errnos are never returned to userspace +// directly, but are used to communicate the expected behavior of an +// interrupted syscall from the syscall to signal handling. +type SyscallRestartErrno int + +// These numeric values are significant because ptrace syscall exit tracing can +// observe them. +// +// For all of the following errnos, if the syscall is not interrupted by a +// signal delivered to a user handler, the syscall is restarted. +const ( + // ERESTARTSYS is returned by an interrupted syscall to indicate that it + // should be converted to EINTR if interrupted by a signal delivered to a + // user handler without SA_RESTART set, and restarted otherwise. + ERESTARTSYS = SyscallRestartErrno(512) + + // ERESTARTNOINTR is returned by an interrupted syscall to indicate that it + // should always be restarted. + ERESTARTNOINTR = SyscallRestartErrno(513) + + // ERESTARTNOHAND is returned by an interrupted syscall to indicate that it + // should be converted to EINTR if interrupted by a signal delivered to a + // user handler, and restarted otherwise. + ERESTARTNOHAND = SyscallRestartErrno(514) + + // ERESTART_RESTARTBLOCK is returned by an interrupted syscall to indicate + // that it should be restarted using a custom function. The interrupted + // syscall must register a custom restart function by calling + // Task.SetRestartSyscallFn. + ERESTART_RESTARTBLOCK = SyscallRestartErrno(516) +) + +// Error implements error.Error. +func (e SyscallRestartErrno) Error() string { + // Descriptions are borrowed from strace. + switch e { + case ERESTARTSYS: + return "to be restarted if SA_RESTART is set" + case ERESTARTNOINTR: + return "to be restarted" + case ERESTARTNOHAND: + return "to be restarted if no handler" + case ERESTART_RESTARTBLOCK: + return "interrupted by signal" + default: + return "(unknown interrupt error)" + } +} + +// SyscallRestartErrnoFromReturn returns the SyscallRestartErrno represented by +// rv, the value in a syscall return register. +func SyscallRestartErrnoFromReturn(rv uintptr) (SyscallRestartErrno, bool) { + switch int(rv) { + case -int(ERESTARTSYS): + return ERESTARTSYS, true + case -int(ERESTARTNOINTR): + return ERESTARTNOINTR, true + case -int(ERESTARTNOHAND): + return ERESTARTNOHAND, true + case -int(ERESTART_RESTARTBLOCK): + return ERESTART_RESTARTBLOCK, true + default: + return 0, false + } +} + func init() { AddErrorTranslation(ErrWouldBlock, syscall.EWOULDBLOCK) AddErrorTranslation(ErrInterrupted, syscall.EINTR) diff --git a/pkg/syserror/syserror_test.go b/pkg/syserror/syserror_test.go index 29719752e..7036467c4 100644 --- a/pkg/syserror/syserror_test.go +++ b/pkg/syserror/syserror_test.go @@ -24,27 +24,20 @@ import ( var globalError error -func returnErrnoAsError() error { - return syscall.EINVAL -} - -func returnError() error { - return syserror.EINVAL -} - -func BenchmarkReturnErrnoAsError(b *testing.B) { +func BenchmarkAssignErrno(b *testing.B) { for i := b.N; i > 0; i-- { - returnErrnoAsError() + globalError = syscall.EINVAL } } -func BenchmarkReturnError(b *testing.B) { +func BenchmarkAssignError(b *testing.B) { for i := b.N; i > 0; i-- { - returnError() + globalError = syserror.EINVAL } } func BenchmarkCompareErrno(b *testing.B) { + globalError = syscall.EAGAIN j := 0 for i := b.N; i > 0; i-- { if globalError == syscall.EINVAL { @@ -54,6 +47,7 @@ func BenchmarkCompareErrno(b *testing.B) { } func BenchmarkCompareError(b *testing.B) { + globalError = syserror.EAGAIN j := 0 for i := b.N; i > 0; i-- { if globalError == syserror.EINVAL { @@ -63,6 +57,7 @@ func BenchmarkCompareError(b *testing.B) { } func BenchmarkSwitchErrno(b *testing.B) { + globalError = syscall.EPERM j := 0 for i := b.N; i > 0; i-- { switch globalError { @@ -77,6 +72,7 @@ func BenchmarkSwitchErrno(b *testing.B) { } func BenchmarkSwitchError(b *testing.B) { + globalError = syserror.EPERM j := 0 for i := b.N; i > 0; i-- { switch globalError { diff --git a/pkg/tcpip/adapters/gonet/gonet.go b/pkg/tcpip/adapters/gonet/gonet.go index d82ed5205..4f551cd92 100644 --- a/pkg/tcpip/adapters/gonet/gonet.go +++ b/pkg/tcpip/adapters/gonet/gonet.go @@ -245,7 +245,7 @@ func NewTCPConn(wq *waiter.Queue, ep tcpip.Endpoint) *TCPConn { // Accept implements net.Conn.Accept. func (l *TCPListener) Accept() (net.Conn, error) { - n, wq, err := l.ep.Accept() + n, wq, err := l.ep.Accept(nil) if err == tcpip.ErrWouldBlock { // Create wait queue entry that notifies a channel. @@ -254,7 +254,7 @@ func (l *TCPListener) Accept() (net.Conn, error) { defer l.wq.EventUnregister(&waitEntry) for { - n, wq, err = l.ep.Accept() + n, wq, err = l.ep.Accept(nil) if err != tcpip.ErrWouldBlock { break @@ -541,7 +541,7 @@ func DialContextTCP(ctx context.Context, s *stack.Stack, addr tcpip.FullAddress, case <-notifyCh: } - err = ep.GetSockOpt(tcpip.ErrorOption{}) + err = ep.LastError() } if err != nil { ep.Close() diff --git a/pkg/tcpip/adapters/gonet/gonet_test.go b/pkg/tcpip/adapters/gonet/gonet_test.go index 3c552988a..c975ad9cf 100644 --- a/pkg/tcpip/adapters/gonet/gonet_test.go +++ b/pkg/tcpip/adapters/gonet/gonet_test.go @@ -104,7 +104,7 @@ func connect(s *stack.Stack, addr tcpip.FullAddress) (*testConnection, *tcpip.Er err = ep.Connect(addr) if err == tcpip.ErrConnectStarted { <-ch - err = ep.GetSockOpt(tcpip.ErrorOption{}) + err = ep.LastError() } if err != nil { return nil, err diff --git a/pkg/tcpip/buffer/BUILD b/pkg/tcpip/buffer/BUILD index 5e135c50d..c326fab54 100644 --- a/pkg/tcpip/buffer/BUILD +++ b/pkg/tcpip/buffer/BUILD @@ -15,7 +15,6 @@ go_test( name = "buffer_test", size = "small", srcs = [ - "prependable_test.go", "view_test.go", ], library = ":buffer", diff --git a/pkg/tcpip/buffer/prependable.go b/pkg/tcpip/buffer/prependable.go index 57d1922ab..ba21f4eca 100644 --- a/pkg/tcpip/buffer/prependable.go +++ b/pkg/tcpip/buffer/prependable.go @@ -32,19 +32,13 @@ func NewPrependable(size int) Prependable { return Prependable{buf: NewView(size), usedIdx: size} } -// NewPrependableFromView creates a Prependable from a View and allocates -// additional space if needed. +// NewPrependableFromView creates an entirely-used Prependable from a View. // -// NewPrependableFromView takes ownership of v. Note that if the entire -// prependable is used, further attempts to call Prepend will note that -// size > p.usedIdx and return nil. -func NewPrependableFromView(v View, extraCap int) Prependable { - if extraCap == 0 { - return Prependable{buf: v, usedIdx: 0} - } - buf := make([]byte, extraCap, extraCap+len(v)) - buf = append(buf, v...) - return Prependable{buf: buf, usedIdx: extraCap} +// NewPrependableFromView takes ownership of v. Note that since the entire +// prependable is used, further attempts to call Prepend will note that size > +// p.usedIdx and return nil. +func NewPrependableFromView(v View) Prependable { + return Prependable{buf: v, usedIdx: 0} } // NewEmptyPrependableFromView creates a new prependable buffer from a View. diff --git a/pkg/tcpip/buffer/prependable_test.go b/pkg/tcpip/buffer/prependable_test.go deleted file mode 100644 index 435a94a61..000000000 --- a/pkg/tcpip/buffer/prependable_test.go +++ /dev/null @@ -1,50 +0,0 @@ -// Copyright 2019 The gVisor Authors. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. -package buffer - -import ( - "reflect" - "testing" -) - -func TestNewPrependableFromView(t *testing.T) { - tests := []struct { - comment string - view View - extraSize int - want Prependable - }{ - { - comment: "Reserve extra space", - view: View("abc"), - extraSize: 2, - want: Prependable{buf: View("\x00\x00abc"), usedIdx: 2}, - }, - { - comment: "Don't reserve extra space", - view: View("abc"), - extraSize: 0, - want: Prependable{buf: View("abc"), usedIdx: 0}, - }, - } - - for _, testCase := range tests { - t.Run(testCase.comment, func(t *testing.T) { - prep := NewPrependableFromView(testCase.view, testCase.extraSize) - if !reflect.DeepEqual(prep, testCase.want) { - t.Errorf("NewPrependableFromView(%#v, %d) = %#v; want %#v", testCase.view, testCase.extraSize, prep, testCase.want) - } - }) - } -} diff --git a/pkg/tcpip/header/ipv4.go b/pkg/tcpip/header/ipv4.go index 680eafd16..e8816c3f4 100644 --- a/pkg/tcpip/header/ipv4.go +++ b/pkg/tcpip/header/ipv4.go @@ -88,6 +88,16 @@ const ( // units, the header cannot exceed 15*4 = 60 bytes. IPv4MaximumHeaderSize = 60 + // IPv4MaximumPayloadSize is the maximum size of a valid IPv4 payload. + // + // Linux limits this to 65,515 octets (the max IP datagram size - the IPv4 + // header size). But RFC 791 section 3.2 discusses the design of the IPv4 + // fragment "allows 2**13 = 8192 fragments of 8 octets each for a total of + // 65,536 octets. Note that this is consistent with the the datagram total + // length field (of course, the header is counted in the total length and not + // in the fragments)." + IPv4MaximumPayloadSize = 65536 + // MinIPFragmentPayloadSize is the minimum number of payload bytes that // the first fragment must carry when an IPv4 packet is fragmented. MinIPFragmentPayloadSize = 8 diff --git a/pkg/tcpip/header/ipv6.go b/pkg/tcpip/header/ipv6.go index ea3823898..0761a1807 100644 --- a/pkg/tcpip/header/ipv6.go +++ b/pkg/tcpip/header/ipv6.go @@ -74,6 +74,10 @@ const ( // IPv6AddressSize is the size, in bytes, of an IPv6 address. IPv6AddressSize = 16 + // IPv6MaximumPayloadSize is the maximum size of a valid IPv6 payload per + // RFC 8200 Section 4.5. + IPv6MaximumPayloadSize = 65535 + // IPv6ProtocolNumber is IPv6's network protocol number. IPv6ProtocolNumber tcpip.NetworkProtocolNumber = 0x86dd diff --git a/pkg/tcpip/header/parse/BUILD b/pkg/tcpip/header/parse/BUILD new file mode 100644 index 000000000..2adee9288 --- /dev/null +++ b/pkg/tcpip/header/parse/BUILD @@ -0,0 +1,15 @@ +load("//tools:defs.bzl", "go_library") + +package(licenses = ["notice"]) + +go_library( + name = "parse", + srcs = ["parse.go"], + visibility = ["//visibility:public"], + deps = [ + "//pkg/tcpip", + "//pkg/tcpip/buffer", + "//pkg/tcpip/header", + "//pkg/tcpip/stack", + ], +) diff --git a/pkg/tcpip/header/parse/parse.go b/pkg/tcpip/header/parse/parse.go new file mode 100644 index 000000000..522135557 --- /dev/null +++ b/pkg/tcpip/header/parse/parse.go @@ -0,0 +1,166 @@ +// Copyright 2020 The gVisor Authors. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +// Package parse provides utilities to parse packets. +package parse + +import ( + "fmt" + + "gvisor.dev/gvisor/pkg/tcpip" + "gvisor.dev/gvisor/pkg/tcpip/buffer" + "gvisor.dev/gvisor/pkg/tcpip/header" + "gvisor.dev/gvisor/pkg/tcpip/stack" +) + +// ARP populates pkt's network header with an ARP header found in +// pkt.Data. +// +// Returns true if the header was successfully parsed. +func ARP(pkt *stack.PacketBuffer) bool { + _, ok := pkt.NetworkHeader().Consume(header.ARPSize) + if ok { + pkt.NetworkProtocolNumber = header.ARPProtocolNumber + } + return ok +} + +// IPv4 parses an IPv4 packet found in pkt.Data and populates pkt's network +// header with the IPv4 header. +// +// Returns true if the header was successfully parsed. +func IPv4(pkt *stack.PacketBuffer) bool { + hdr, ok := pkt.Data.PullUp(header.IPv4MinimumSize) + if !ok { + return false + } + ipHdr := header.IPv4(hdr) + + // Header may have options, determine the true header length. + headerLen := int(ipHdr.HeaderLength()) + if headerLen < header.IPv4MinimumSize { + // TODO(gvisor.dev/issue/2404): Per RFC 791, IHL needs to be at least 5 in + // order for the packet to be valid. Figure out if we want to reject this + // case. + headerLen = header.IPv4MinimumSize + } + hdr, ok = pkt.NetworkHeader().Consume(headerLen) + if !ok { + return false + } + ipHdr = header.IPv4(hdr) + + pkt.NetworkProtocolNumber = header.IPv4ProtocolNumber + pkt.Data.CapLength(int(ipHdr.TotalLength()) - len(hdr)) + return true +} + +// IPv6 parses an IPv6 packet found in pkt.Data and populates pkt's network +// header with the IPv6 header. +func IPv6(pkt *stack.PacketBuffer) (proto tcpip.TransportProtocolNumber, fragID uint32, fragOffset uint16, fragMore bool, ok bool) { + hdr, ok := pkt.Data.PullUp(header.IPv6MinimumSize) + if !ok { + return 0, 0, 0, false, false + } + ipHdr := header.IPv6(hdr) + + // dataClone consists of: + // - Any IPv6 header bytes after the first 40 (i.e. extensions). + // - The transport header, if present. + // - Any other payload data. + views := [8]buffer.View{} + dataClone := pkt.Data.Clone(views[:]) + dataClone.TrimFront(header.IPv6MinimumSize) + it := header.MakeIPv6PayloadIterator(header.IPv6ExtensionHeaderIdentifier(ipHdr.NextHeader()), dataClone) + + // Iterate over the IPv6 extensions to find their length. + var nextHdr tcpip.TransportProtocolNumber + var extensionsSize int + +traverseExtensions: + for { + extHdr, done, err := it.Next() + if err != nil { + break + } + + // If we exhaust the extension list, the entire packet is the IPv6 header + // and (possibly) extensions. + if done { + extensionsSize = dataClone.Size() + break + } + + switch extHdr := extHdr.(type) { + case header.IPv6FragmentExtHdr: + if fragID == 0 && fragOffset == 0 && !fragMore { + fragID = extHdr.ID() + fragOffset = extHdr.FragmentOffset() + fragMore = extHdr.More() + } + + case header.IPv6RawPayloadHeader: + // We've found the payload after any extensions. + extensionsSize = dataClone.Size() - extHdr.Buf.Size() + nextHdr = tcpip.TransportProtocolNumber(extHdr.Identifier) + break traverseExtensions + + default: + // Any other extension is a no-op, keep looping until we find the payload. + } + } + + // Put the IPv6 header with extensions in pkt.NetworkHeader(). + hdr, ok = pkt.NetworkHeader().Consume(header.IPv6MinimumSize + extensionsSize) + if !ok { + panic(fmt.Sprintf("pkt.Data should have at least %d bytes, but only has %d.", header.IPv6MinimumSize+extensionsSize, pkt.Data.Size())) + } + ipHdr = header.IPv6(hdr) + pkt.Data.CapLength(int(ipHdr.PayloadLength())) + pkt.NetworkProtocolNumber = header.IPv6ProtocolNumber + + return nextHdr, fragID, fragOffset, fragMore, true +} + +// UDP parses a UDP packet found in pkt.Data and populates pkt's transport +// header with the UDP header. +// +// Returns true if the header was successfully parsed. +func UDP(pkt *stack.PacketBuffer) bool { + _, ok := pkt.TransportHeader().Consume(header.UDPMinimumSize) + return ok +} + +// TCP parses a TCP packet found in pkt.Data and populates pkt's transport +// header with the TCP header. +// +// Returns true if the header was successfully parsed. +func TCP(pkt *stack.PacketBuffer) bool { + // TCP header is variable length, peek at it first. + hdrLen := header.TCPMinimumSize + hdr, ok := pkt.Data.PullUp(hdrLen) + if !ok { + return false + } + + // If the header has options, pull those up as well. + if offset := int(header.TCP(hdr).DataOffset()); offset > header.TCPMinimumSize && offset <= pkt.Data.Size() { + // TODO(gvisor.dev/issue/2404): Figure out whether to reject this kind of + // packets. + hdrLen = offset + } + + _, ok = pkt.TransportHeader().Consume(hdrLen) + return ok +} diff --git a/pkg/tcpip/header/udp.go b/pkg/tcpip/header/udp.go index 9339d637f..98bdd29db 100644 --- a/pkg/tcpip/header/udp.go +++ b/pkg/tcpip/header/udp.go @@ -16,6 +16,7 @@ package header import ( "encoding/binary" + "math" "gvisor.dev/gvisor/pkg/tcpip" ) @@ -55,6 +56,10 @@ const ( // UDPMinimumSize is the minimum size of a valid UDP packet. UDPMinimumSize = 8 + // UDPMaximumSize is the maximum size of a valid UDP packet. The length field + // in the UDP header is 16 bits as per RFC 768. + UDPMaximumSize = math.MaxUint16 + // UDPProtocolNumber is UDP's transport protocol number. UDPProtocolNumber tcpip.TransportProtocolNumber = 17 ) diff --git a/pkg/tcpip/link/rawfile/BUILD b/pkg/tcpip/link/rawfile/BUILD index 14b527bc2..6c410c5a6 100644 --- a/pkg/tcpip/link/rawfile/BUILD +++ b/pkg/tcpip/link/rawfile/BUILD @@ -1,4 +1,4 @@ -load("//tools:defs.bzl", "go_library") +load("//tools:defs.bzl", "go_library", "go_test") package(licenses = ["notice"]) @@ -18,3 +18,14 @@ go_library( "@org_golang_x_sys//unix:go_default_library", ], ) + +go_test( + name = "rawfile_test", + srcs = [ + "errors_test.go", + ], + library = "rawfile", + deps = [ + "//pkg/tcpip", + ], +) diff --git a/pkg/tcpip/link/rawfile/blockingpoll_yield_unsafe.go b/pkg/tcpip/link/rawfile/blockingpoll_yield_unsafe.go index 99313ee25..5db4bf12b 100644 --- a/pkg/tcpip/link/rawfile/blockingpoll_yield_unsafe.go +++ b/pkg/tcpip/link/rawfile/blockingpoll_yield_unsafe.go @@ -14,7 +14,7 @@ // +build linux,amd64 linux,arm64 // +build go1.12 -// +build !go1.16 +// +build !go1.17 // Check go:linkname function signatures when updating Go version. diff --git a/pkg/tcpip/link/rawfile/errors.go b/pkg/tcpip/link/rawfile/errors.go index a0a873c84..604868fd8 100644 --- a/pkg/tcpip/link/rawfile/errors.go +++ b/pkg/tcpip/link/rawfile/errors.go @@ -31,10 +31,12 @@ var translations [maxErrno]*tcpip.Error // *tcpip.Error. // // Valid, but unrecognized errnos will be translated to -// tcpip.ErrInvalidEndpointState (EINVAL). Panics on invalid errnos. +// tcpip.ErrInvalidEndpointState (EINVAL). func TranslateErrno(e syscall.Errno) *tcpip.Error { - if err := translations[e]; err != nil { - return err + if e > 0 && e < syscall.Errno(len(translations)) { + if err := translations[e]; err != nil { + return err + } } return tcpip.ErrInvalidEndpointState } diff --git a/pkg/tcpip/link/rawfile/errors_test.go b/pkg/tcpip/link/rawfile/errors_test.go new file mode 100644 index 000000000..e4cdc66bd --- /dev/null +++ b/pkg/tcpip/link/rawfile/errors_test.go @@ -0,0 +1,53 @@ +// Copyright 2020 The gVisor Authors. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +// +build linux + +package rawfile + +import ( + "syscall" + "testing" + + "gvisor.dev/gvisor/pkg/tcpip" +) + +func TestTranslateErrno(t *testing.T) { + for _, test := range []struct { + errno syscall.Errno + translated *tcpip.Error + }{ + { + errno: syscall.Errno(0), + translated: tcpip.ErrInvalidEndpointState, + }, + { + errno: syscall.Errno(maxErrno), + translated: tcpip.ErrInvalidEndpointState, + }, + { + errno: syscall.Errno(514), + translated: tcpip.ErrInvalidEndpointState, + }, + { + errno: syscall.EEXIST, + translated: tcpip.ErrDuplicateAddress, + }, + } { + got := TranslateErrno(test.errno) + if got != test.translated { + t.Errorf("TranslateErrno(%q) = %q, want %q", test.errno, got, test.translated) + } + } +} diff --git a/pkg/tcpip/link/sniffer/BUILD b/pkg/tcpip/link/sniffer/BUILD index 7cbc305e7..4aac12a8c 100644 --- a/pkg/tcpip/link/sniffer/BUILD +++ b/pkg/tcpip/link/sniffer/BUILD @@ -14,6 +14,7 @@ go_library( "//pkg/tcpip", "//pkg/tcpip/buffer", "//pkg/tcpip/header", + "//pkg/tcpip/header/parse", "//pkg/tcpip/link/nested", "//pkg/tcpip/stack", ], diff --git a/pkg/tcpip/link/sniffer/sniffer.go b/pkg/tcpip/link/sniffer/sniffer.go index 4fb127978..560477926 100644 --- a/pkg/tcpip/link/sniffer/sniffer.go +++ b/pkg/tcpip/link/sniffer/sniffer.go @@ -31,6 +31,7 @@ import ( "gvisor.dev/gvisor/pkg/tcpip" "gvisor.dev/gvisor/pkg/tcpip/buffer" "gvisor.dev/gvisor/pkg/tcpip/header" + "gvisor.dev/gvisor/pkg/tcpip/header/parse" "gvisor.dev/gvisor/pkg/tcpip/link/nested" "gvisor.dev/gvisor/pkg/tcpip/stack" ) @@ -195,49 +196,52 @@ func logPacket(prefix string, protocol tcpip.NetworkProtocolNumber, pkt *stack.P var transProto uint8 src := tcpip.Address("unknown") dst := tcpip.Address("unknown") - id := 0 - size := uint16(0) + var size uint16 + var id uint32 var fragmentOffset uint16 var moreFragments bool - // Examine the packet using a new VV. Backing storage must not be written. - vv := buffer.NewVectorisedView(pkt.Size(), pkt.Views()) - + // Clone the packet buffer to not modify the original. + // + // We don't clone the original packet buffer so that the new packet buffer + // does not have any of its headers set. + pkt = stack.NewPacketBuffer(stack.PacketBufferOptions{Data: buffer.NewVectorisedView(pkt.Size(), pkt.Views())}) switch protocol { case header.IPv4ProtocolNumber: - hdr, ok := vv.PullUp(header.IPv4MinimumSize) - if !ok { + if ok := parse.IPv4(pkt); !ok { return } - ipv4 := header.IPv4(hdr) + + ipv4 := header.IPv4(pkt.NetworkHeader().View()) fragmentOffset = ipv4.FragmentOffset() moreFragments = ipv4.Flags()&header.IPv4FlagMoreFragments == header.IPv4FlagMoreFragments src = ipv4.SourceAddress() dst = ipv4.DestinationAddress() transProto = ipv4.Protocol() size = ipv4.TotalLength() - uint16(ipv4.HeaderLength()) - vv.TrimFront(int(ipv4.HeaderLength())) - id = int(ipv4.ID()) + id = uint32(ipv4.ID()) case header.IPv6ProtocolNumber: - hdr, ok := vv.PullUp(header.IPv6MinimumSize) + proto, fragID, fragOffset, fragMore, ok := parse.IPv6(pkt) if !ok { return } - ipv6 := header.IPv6(hdr) + + ipv6 := header.IPv6(pkt.NetworkHeader().View()) src = ipv6.SourceAddress() dst = ipv6.DestinationAddress() - transProto = ipv6.NextHeader() + transProto = uint8(proto) size = ipv6.PayloadLength() - vv.TrimFront(header.IPv6MinimumSize) + id = fragID + moreFragments = fragMore + fragmentOffset = fragOffset case header.ARPProtocolNumber: - hdr, ok := vv.PullUp(header.ARPSize) - if !ok { + if parse.ARP(pkt) { return } - vv.TrimFront(header.ARPSize) - arp := header.ARP(hdr) + + arp := header.ARP(pkt.NetworkHeader().View()) log.Infof( "%s arp %s (%s) -> %s (%s) valid:%t", prefix, @@ -259,7 +263,7 @@ func logPacket(prefix string, protocol tcpip.NetworkProtocolNumber, pkt *stack.P switch tcpip.TransportProtocolNumber(transProto) { case header.ICMPv4ProtocolNumber: transName = "icmp" - hdr, ok := vv.PullUp(header.ICMPv4MinimumSize) + hdr, ok := pkt.Data.PullUp(header.ICMPv4MinimumSize) if !ok { break } @@ -296,7 +300,7 @@ func logPacket(prefix string, protocol tcpip.NetworkProtocolNumber, pkt *stack.P case header.ICMPv6ProtocolNumber: transName = "icmp" - hdr, ok := vv.PullUp(header.ICMPv6MinimumSize) + hdr, ok := pkt.Data.PullUp(header.ICMPv6MinimumSize) if !ok { break } @@ -331,11 +335,11 @@ func logPacket(prefix string, protocol tcpip.NetworkProtocolNumber, pkt *stack.P case header.UDPProtocolNumber: transName = "udp" - hdr, ok := vv.PullUp(header.UDPMinimumSize) - if !ok { + if ok := parse.UDP(pkt); !ok { break } - udp := header.UDP(hdr) + + udp := header.UDP(pkt.TransportHeader().View()) if fragmentOffset == 0 { srcPort = udp.SourcePort() dstPort = udp.DestinationPort() @@ -345,19 +349,19 @@ func logPacket(prefix string, protocol tcpip.NetworkProtocolNumber, pkt *stack.P case header.TCPProtocolNumber: transName = "tcp" - hdr, ok := vv.PullUp(header.TCPMinimumSize) - if !ok { + if ok := parse.TCP(pkt); !ok { break } - tcp := header.TCP(hdr) + + tcp := header.TCP(pkt.TransportHeader().View()) if fragmentOffset == 0 { offset := int(tcp.DataOffset()) if offset < header.TCPMinimumSize { details += fmt.Sprintf("invalid packet: tcp data offset too small %d", offset) break } - if offset > vv.Size() && !moreFragments { - details += fmt.Sprintf("invalid packet: tcp data offset %d larger than packet buffer length %d", offset, vv.Size()) + if size := pkt.Data.Size() + len(tcp); offset > size && !moreFragments { + details += fmt.Sprintf("invalid packet: tcp data offset %d larger than tcp packet length %d", offset, size) break } diff --git a/pkg/tcpip/link/tun/BUILD b/pkg/tcpip/link/tun/BUILD index 6c137f693..0243424f6 100644 --- a/pkg/tcpip/link/tun/BUILD +++ b/pkg/tcpip/link/tun/BUILD @@ -1,18 +1,32 @@ load("//tools:defs.bzl", "go_library") +load("//tools/go_generics:defs.bzl", "go_template_instance") package(licenses = ["notice"]) +go_template_instance( + name = "tun_endpoint_refs", + out = "tun_endpoint_refs.go", + package = "tun", + prefix = "tunEndpoint", + template = "//pkg/refs_vfs2:refs_template", + types = { + "T": "tunEndpoint", + }, +) + go_library( name = "tun", srcs = [ "device.go", "protocol.go", + "tun_endpoint_refs.go", "tun_unsafe.go", ], visibility = ["//visibility:public"], deps = [ "//pkg/abi/linux", "//pkg/context", + "//pkg/log", "//pkg/refs", "//pkg/sync", "//pkg/syserror", diff --git a/pkg/tcpip/link/tun/device.go b/pkg/tcpip/link/tun/device.go index 3b1510a33..b6ddbe81e 100644 --- a/pkg/tcpip/link/tun/device.go +++ b/pkg/tcpip/link/tun/device.go @@ -19,7 +19,6 @@ import ( "gvisor.dev/gvisor/pkg/abi/linux" "gvisor.dev/gvisor/pkg/context" - "gvisor.dev/gvisor/pkg/refs" "gvisor.dev/gvisor/pkg/sync" "gvisor.dev/gvisor/pkg/syserror" "gvisor.dev/gvisor/pkg/tcpip" @@ -135,6 +134,7 @@ func attachOrCreateNIC(s *stack.Stack, name, prefix string, linkCaps stack.LinkE // 2. Creating a new NIC. id := tcpip.NICID(s.UniqueID()) + // TODO(gvisor.dev/1486): enable leak check for tunEndpoint. endpoint := &tunEndpoint{ Endpoint: channel.New(defaultDevOutQueueLen, defaultDevMtu, ""), stack: s, @@ -331,19 +331,18 @@ func (d *Device) WriteNotify() { // It is ref-counted as multiple opening files can attach to the same NIC. // The last owner is responsible for deleting the NIC. type tunEndpoint struct { + tunEndpointRefs *channel.Endpoint - refs.AtomicRefCount - stack *stack.Stack nicID tcpip.NICID name string isTap bool } -// DecRef decrements refcount of e, removes NIC if refcount goes to 0. +// DecRef decrements refcount of e, removing NIC if it reaches 0. func (e *tunEndpoint) DecRef(ctx context.Context) { - e.DecRefWithDestructor(ctx, func(context.Context) { + e.tunEndpointRefs.DecRef(func() { e.stack.RemoveNIC(e.nicID) }) } diff --git a/pkg/tcpip/network/arp/BUILD b/pkg/tcpip/network/arp/BUILD index eddf7b725..b40dde96b 100644 --- a/pkg/tcpip/network/arp/BUILD +++ b/pkg/tcpip/network/arp/BUILD @@ -10,6 +10,7 @@ go_library( "//pkg/tcpip", "//pkg/tcpip/buffer", "//pkg/tcpip/header", + "//pkg/tcpip/header/parse", "//pkg/tcpip/stack", ], ) @@ -28,5 +29,6 @@ go_test( "//pkg/tcpip/network/ipv4", "//pkg/tcpip/stack", "//pkg/tcpip/transport/icmp", + "@com_github_google_go_cmp//cmp:go_default_library", ], ) diff --git a/pkg/tcpip/network/arp/arp.go b/pkg/tcpip/network/arp/arp.go index 920872c3f..cb9225bd7 100644 --- a/pkg/tcpip/network/arp/arp.go +++ b/pkg/tcpip/network/arp/arp.go @@ -29,6 +29,7 @@ import ( "gvisor.dev/gvisor/pkg/tcpip" "gvisor.dev/gvisor/pkg/tcpip/buffer" "gvisor.dev/gvisor/pkg/tcpip/header" + "gvisor.dev/gvisor/pkg/tcpip/header/parse" "gvisor.dev/gvisor/pkg/tcpip/stack" ) @@ -46,6 +47,7 @@ type endpoint struct { nicID tcpip.NICID linkEP stack.LinkEndpoint linkAddrCache stack.LinkAddressCache + nud stack.NUDHandler } // DefaultTTL is unused for ARP. It implements stack.NetworkEndpoint. @@ -78,7 +80,7 @@ func (e *endpoint) WritePacket(*stack.Route, *stack.GSO, stack.NetworkHeaderPara // NetworkProtocolNumber implements stack.NetworkEndpoint.NetworkProtocolNumber. func (e *endpoint) NetworkProtocolNumber() tcpip.NetworkProtocolNumber { - return e.protocol.Number() + return ProtocolNumber } // WritePackets implements stack.NetworkEndpoint.WritePackets. @@ -99,9 +101,25 @@ func (e *endpoint) HandlePacket(r *stack.Route, pkt *stack.PacketBuffer) { switch h.Op() { case header.ARPRequest: localAddr := tcpip.Address(h.ProtocolAddressTarget()) - if e.linkAddrCache.CheckLocalAddress(e.nicID, header.IPv4ProtocolNumber, localAddr) == 0 { - return // we have no useful answer, ignore the request + + if e.nud == nil { + if e.linkAddrCache.CheckLocalAddress(e.nicID, header.IPv4ProtocolNumber, localAddr) == 0 { + return // we have no useful answer, ignore the request + } + + addr := tcpip.Address(h.ProtocolAddressSender()) + linkAddr := tcpip.LinkAddress(h.HardwareAddressSender()) + e.linkAddrCache.AddLinkAddress(e.nicID, addr, linkAddr) + } else { + if r.Stack().CheckLocalAddress(e.nicID, header.IPv4ProtocolNumber, localAddr) == 0 { + return // we have no useful answer, ignore the request + } + + remoteAddr := tcpip.Address(h.ProtocolAddressSender()) + remoteLinkAddr := tcpip.LinkAddress(h.HardwareAddressSender()) + e.nud.HandleProbe(remoteAddr, localAddr, ProtocolNumber, remoteLinkAddr, e.protocol) } + pkt := stack.NewPacketBuffer(stack.PacketBufferOptions{ ReserveHeaderBytes: int(e.linkEP.MaxHeaderLength()) + header.ARPSize, }) @@ -113,11 +131,28 @@ func (e *endpoint) HandlePacket(r *stack.Route, pkt *stack.PacketBuffer) { copy(packet.HardwareAddressTarget(), h.HardwareAddressSender()) copy(packet.ProtocolAddressTarget(), h.ProtocolAddressSender()) _ = e.linkEP.WritePacket(r, nil /* gso */, ProtocolNumber, pkt) - fallthrough // also fill the cache from requests + case header.ARPReply: addr := tcpip.Address(h.ProtocolAddressSender()) linkAddr := tcpip.LinkAddress(h.HardwareAddressSender()) - e.linkAddrCache.AddLinkAddress(e.nicID, addr, linkAddr) + + if e.nud == nil { + e.linkAddrCache.AddLinkAddress(e.nicID, addr, linkAddr) + return + } + + // The solicited, override, and isRouter flags are not available for ARP; + // they are only available for IPv6 Neighbor Advertisements. + e.nud.HandleConfirmation(addr, linkAddr, stack.ReachabilityConfirmationFlags{ + // Solicited and unsolicited (also referred to as gratuitous) ARP Replies + // are handled equivalently to a solicited Neighbor Advertisement. + Solicited: true, + // If a different link address is received than the one cached, the entry + // should always go to Stale. + Override: false, + // ARP does not distinguish between router and non-router hosts. + IsRouter: false, + }) } } @@ -134,12 +169,13 @@ func (*protocol) ParseAddresses(v buffer.View) (src, dst tcpip.Address) { return tcpip.Address(h.ProtocolAddressSender()), ProtocolAddress } -func (p *protocol) NewEndpoint(nicID tcpip.NICID, linkAddrCache stack.LinkAddressCache, dispatcher stack.TransportDispatcher, sender stack.LinkEndpoint, st *stack.Stack) stack.NetworkEndpoint { +func (p *protocol) NewEndpoint(nicID tcpip.NICID, linkAddrCache stack.LinkAddressCache, nud stack.NUDHandler, dispatcher stack.TransportDispatcher, sender stack.LinkEndpoint, st *stack.Stack) stack.NetworkEndpoint { return &endpoint{ protocol: p, nicID: nicID, linkEP: sender, linkAddrCache: linkAddrCache, + nud: nud, } } @@ -182,12 +218,12 @@ func (*protocol) ResolveStaticAddress(addr tcpip.Address) (tcpip.LinkAddress, bo } // SetOption implements stack.NetworkProtocol.SetOption. -func (*protocol) SetOption(option interface{}) *tcpip.Error { +func (*protocol) SetOption(tcpip.SettableNetworkProtocolOption) *tcpip.Error { return tcpip.ErrUnknownProtocolOption } // Option implements stack.NetworkProtocol.Option. -func (*protocol) Option(option interface{}) *tcpip.Error { +func (*protocol) Option(tcpip.GettableNetworkProtocolOption) *tcpip.Error { return tcpip.ErrUnknownProtocolOption } @@ -199,11 +235,7 @@ func (*protocol) Wait() {} // Parse implements stack.NetworkProtocol.Parse. func (*protocol) Parse(pkt *stack.PacketBuffer) (proto tcpip.TransportProtocolNumber, hasTransportHdr bool, ok bool) { - _, ok = pkt.NetworkHeader().Consume(header.ARPSize) - if !ok { - return 0, false, false - } - return 0, false, true + return 0, false, parse.ARP(pkt) } // NewProtocol returns an ARP network protocol. diff --git a/pkg/tcpip/network/arp/arp_test.go b/pkg/tcpip/network/arp/arp_test.go index c2c3e6891..9c9a859e3 100644 --- a/pkg/tcpip/network/arp/arp_test.go +++ b/pkg/tcpip/network/arp/arp_test.go @@ -16,10 +16,12 @@ package arp_test import ( "context" + "fmt" "strconv" "testing" "time" + "github.com/google/go-cmp/cmp" "gvisor.dev/gvisor/pkg/tcpip" "gvisor.dev/gvisor/pkg/tcpip/buffer" "gvisor.dev/gvisor/pkg/tcpip/header" @@ -32,57 +34,192 @@ import ( ) const ( - stackLinkAddr1 = tcpip.LinkAddress("\x0a\x0a\x0b\x0b\x0c\x0c") - stackLinkAddr2 = tcpip.LinkAddress("\x0b\x0b\x0c\x0c\x0d\x0d") - stackAddr1 = tcpip.Address("\x0a\x00\x00\x01") - stackAddr2 = tcpip.Address("\x0a\x00\x00\x02") - stackAddrBad = tcpip.Address("\x0a\x00\x00\x03") + nicID = 1 + + stackAddr = tcpip.Address("\x0a\x00\x00\x01") + stackLinkAddr = tcpip.LinkAddress("\x0a\x0a\x0b\x0b\x0c\x0c") + + remoteAddr = tcpip.Address("\x0a\x00\x00\x02") + remoteLinkAddr = tcpip.LinkAddress("\x01\x02\x03\x04\x05\x06") + + unknownAddr = tcpip.Address("\x0a\x00\x00\x03") defaultChannelSize = 1 defaultMTU = 65536 + + // eventChanSize defines the size of event channels used by the neighbor + // cache's event dispatcher. The size chosen here needs to be sufficient to + // queue all the events received during tests before consumption. + // If eventChanSize is too small, the tests may deadlock. + eventChanSize = 32 +) + +type eventType uint8 + +const ( + entryAdded eventType = iota + entryChanged + entryRemoved ) +func (t eventType) String() string { + switch t { + case entryAdded: + return "add" + case entryChanged: + return "change" + case entryRemoved: + return "remove" + default: + return fmt.Sprintf("unknown (%d)", t) + } +} + +type eventInfo struct { + eventType eventType + nicID tcpip.NICID + addr tcpip.Address + linkAddr tcpip.LinkAddress + state stack.NeighborState +} + +func (e eventInfo) String() string { + return fmt.Sprintf("%s event for NIC #%d, addr=%q, linkAddr=%q, state=%q", e.eventType, e.nicID, e.addr, e.linkAddr, e.state) +} + +// arpDispatcher implements NUDDispatcher to validate the dispatching of +// events upon certain NUD state machine events. +type arpDispatcher struct { + // C is where events are queued + C chan eventInfo +} + +var _ stack.NUDDispatcher = (*arpDispatcher)(nil) + +func (d *arpDispatcher) OnNeighborAdded(nicID tcpip.NICID, addr tcpip.Address, linkAddr tcpip.LinkAddress, state stack.NeighborState, updatedAt time.Time) { + e := eventInfo{ + eventType: entryAdded, + nicID: nicID, + addr: addr, + linkAddr: linkAddr, + state: state, + } + d.C <- e +} + +func (d *arpDispatcher) OnNeighborChanged(nicID tcpip.NICID, addr tcpip.Address, linkAddr tcpip.LinkAddress, state stack.NeighborState, updatedAt time.Time) { + e := eventInfo{ + eventType: entryChanged, + nicID: nicID, + addr: addr, + linkAddr: linkAddr, + state: state, + } + d.C <- e +} + +func (d *arpDispatcher) OnNeighborRemoved(nicID tcpip.NICID, addr tcpip.Address, linkAddr tcpip.LinkAddress, state stack.NeighborState, updatedAt time.Time) { + e := eventInfo{ + eventType: entryRemoved, + nicID: nicID, + addr: addr, + linkAddr: linkAddr, + state: state, + } + d.C <- e +} + +func (d *arpDispatcher) waitForEvent(ctx context.Context, want eventInfo) error { + select { + case got := <-d.C: + if diff := cmp.Diff(got, want, cmp.AllowUnexported(got)); diff != "" { + return fmt.Errorf("got invalid event (-got +want):\n%s", diff) + } + case <-ctx.Done(): + return fmt.Errorf("%s for %s", ctx.Err(), want) + } + return nil +} + +func (d *arpDispatcher) waitForEventWithTimeout(want eventInfo, timeout time.Duration) error { + ctx, cancel := context.WithTimeout(context.Background(), timeout) + defer cancel() + return d.waitForEvent(ctx, want) +} + +func (d *arpDispatcher) nextEvent() (eventInfo, bool) { + select { + case event := <-d.C: + return event, true + default: + return eventInfo{}, false + } +} + type testContext struct { - t *testing.T - linkEP *channel.Endpoint - s *stack.Stack + s *stack.Stack + linkEP *channel.Endpoint + nudDisp *arpDispatcher } -func newTestContext(t *testing.T) *testContext { +func newTestContext(t *testing.T, useNeighborCache bool) *testContext { + c := stack.DefaultNUDConfigurations() + // Transition from Reachable to Stale almost immediately to test if receiving + // probes refreshes positive reachability. + c.BaseReachableTime = time.Microsecond + + d := arpDispatcher{ + // Create an event channel large enough so the neighbor cache doesn't block + // while dispatching events. Blocking could interfere with the timing of + // NUD transitions. + C: make(chan eventInfo, eventChanSize), + } + s := stack.New(stack.Options{ NetworkProtocols: []stack.NetworkProtocol{ipv4.NewProtocol(), arp.NewProtocol()}, TransportProtocols: []stack.TransportProtocol{icmp.NewProtocol4()}, + NUDConfigs: c, + NUDDisp: &d, + UseNeighborCache: useNeighborCache, }) - ep := channel.New(defaultChannelSize, defaultMTU, stackLinkAddr1) + ep := channel.New(defaultChannelSize, defaultMTU, stackLinkAddr) + ep.LinkEPCapabilities |= stack.CapabilityResolutionRequired + wep := stack.LinkEndpoint(ep) if testing.Verbose() { wep = sniffer.New(ep) } - if err := s.CreateNIC(1, wep); err != nil { + if err := s.CreateNIC(nicID, wep); err != nil { t.Fatalf("CreateNIC failed: %v", err) } - if err := s.AddAddress(1, ipv4.ProtocolNumber, stackAddr1); err != nil { + if err := s.AddAddress(nicID, ipv4.ProtocolNumber, stackAddr); err != nil { t.Fatalf("AddAddress for ipv4 failed: %v", err) } - if err := s.AddAddress(1, ipv4.ProtocolNumber, stackAddr2); err != nil { - t.Fatalf("AddAddress for ipv4 failed: %v", err) + if !useNeighborCache { + // The remote address needs to be assigned to the NIC so we can receive and + // verify outgoing ARP packets. The neighbor cache isn't concerned with + // this; the tests that use linkAddrCache expect the ARP responses to be + // received by the same NIC. + if err := s.AddAddress(nicID, ipv4.ProtocolNumber, remoteAddr); err != nil { + t.Fatalf("AddAddress for ipv4 failed: %v", err) + } } - if err := s.AddAddress(1, arp.ProtocolNumber, arp.ProtocolAddress); err != nil { + if err := s.AddAddress(nicID, arp.ProtocolNumber, arp.ProtocolAddress); err != nil { t.Fatalf("AddAddress for arp failed: %v", err) } s.SetRouteTable([]tcpip.Route{{ Destination: header.IPv4EmptySubnet, - NIC: 1, + NIC: nicID, }}) return &testContext{ - t: t, - s: s, - linkEP: ep, + s: s, + linkEP: ep, + nudDisp: &d, } } @@ -91,7 +228,7 @@ func (c *testContext) cleanup() { } func TestDirectRequest(t *testing.T) { - c := newTestContext(t) + c := newTestContext(t, false /* useNeighborCache */) defer c.cleanup() const senderMAC = "\x01\x02\x03\x04\x05\x06" @@ -111,7 +248,7 @@ func TestDirectRequest(t *testing.T) { })) } - for i, address := range []tcpip.Address{stackAddr1, stackAddr2} { + for i, address := range []tcpip.Address{stackAddr, remoteAddr} { t.Run(strconv.Itoa(i), func(t *testing.T) { inject(address) pi, _ := c.linkEP.ReadContext(context.Background()) @@ -122,7 +259,7 @@ func TestDirectRequest(t *testing.T) { if !rep.IsValid() { t.Fatalf("invalid ARP response: len = %d; response = %x", len(rep), rep) } - if got, want := tcpip.LinkAddress(rep.HardwareAddressSender()), stackLinkAddr1; got != want { + if got, want := tcpip.LinkAddress(rep.HardwareAddressSender()), stackLinkAddr; got != want { t.Errorf("got HardwareAddressSender = %s, want = %s", got, want) } if got, want := tcpip.Address(rep.ProtocolAddressSender()), tcpip.Address(h.ProtocolAddressTarget()); got != want { @@ -137,7 +274,7 @@ func TestDirectRequest(t *testing.T) { }) } - inject(stackAddrBad) + inject(unknownAddr) // Sleep tests are gross, but this will only potentially flake // if there's a bug. If there is no bug this will reliably // succeed. @@ -148,6 +285,144 @@ func TestDirectRequest(t *testing.T) { } } +func TestDirectRequestWithNeighborCache(t *testing.T) { + c := newTestContext(t, true /* useNeighborCache */) + defer c.cleanup() + + tests := []struct { + name string + senderAddr tcpip.Address + senderLinkAddr tcpip.LinkAddress + targetAddr tcpip.Address + isValid bool + }{ + { + name: "Loopback", + senderAddr: stackAddr, + senderLinkAddr: stackLinkAddr, + targetAddr: stackAddr, + isValid: true, + }, + { + name: "Remote", + senderAddr: remoteAddr, + senderLinkAddr: remoteLinkAddr, + targetAddr: stackAddr, + isValid: true, + }, + { + name: "RemoteInvalidTarget", + senderAddr: remoteAddr, + senderLinkAddr: remoteLinkAddr, + targetAddr: unknownAddr, + isValid: false, + }, + } + + for _, test := range tests { + t.Run(test.name, func(t *testing.T) { + // Inject an incoming ARP request. + v := make(buffer.View, header.ARPSize) + h := header.ARP(v) + h.SetIPv4OverEthernet() + h.SetOp(header.ARPRequest) + copy(h.HardwareAddressSender(), test.senderLinkAddr) + copy(h.ProtocolAddressSender(), test.senderAddr) + copy(h.ProtocolAddressTarget(), test.targetAddr) + c.linkEP.InjectInbound(arp.ProtocolNumber, &stack.PacketBuffer{ + Data: v.ToVectorisedView(), + }) + + if !test.isValid { + // No packets should be sent after receiving an invalid ARP request. + // There is no need to perform a blocking read here, since packets are + // sent in the same function that handles ARP requests. + if pkt, ok := c.linkEP.Read(); ok { + t.Errorf("unexpected packet sent with network protocol number %d", pkt.Proto) + } + return + } + + // Verify an ARP response was sent. + pi, ok := c.linkEP.Read() + if !ok { + t.Fatal("expected ARP response to be sent, got none") + } + + if pi.Proto != arp.ProtocolNumber { + t.Fatalf("expected ARP response, got network protocol number %d", pi.Proto) + } + rep := header.ARP(pi.Pkt.NetworkHeader().View()) + if !rep.IsValid() { + t.Fatalf("invalid ARP response: len = %d; response = %x", len(rep), rep) + } + if got, want := tcpip.LinkAddress(rep.HardwareAddressSender()), stackLinkAddr; got != want { + t.Errorf("got HardwareAddressSender() = %s, want = %s", got, want) + } + if got, want := tcpip.Address(rep.ProtocolAddressSender()), tcpip.Address(h.ProtocolAddressTarget()); got != want { + t.Errorf("got ProtocolAddressSender() = %s, want = %s", got, want) + } + if got, want := tcpip.LinkAddress(rep.HardwareAddressTarget()), tcpip.LinkAddress(h.HardwareAddressSender()); got != want { + t.Errorf("got HardwareAddressTarget() = %s, want = %s", got, want) + } + if got, want := tcpip.Address(rep.ProtocolAddressTarget()), tcpip.Address(h.ProtocolAddressSender()); got != want { + t.Errorf("got ProtocolAddressTarget() = %s, want = %s", got, want) + } + + // Verify the sender was saved in the neighbor cache. + wantEvent := eventInfo{ + eventType: entryAdded, + nicID: nicID, + addr: test.senderAddr, + linkAddr: tcpip.LinkAddress(test.senderLinkAddr), + state: stack.Stale, + } + if err := c.nudDisp.waitForEventWithTimeout(wantEvent, time.Second); err != nil { + t.Fatal(err) + } + + neighbors, err := c.s.Neighbors(nicID) + if err != nil { + t.Fatalf("c.s.Neighbors(%d): %s", nicID, err) + } + + neighborByAddr := make(map[tcpip.Address]stack.NeighborEntry) + for _, n := range neighbors { + if existing, ok := neighborByAddr[n.Addr]; ok { + if diff := cmp.Diff(existing, n); diff != "" { + t.Fatalf("duplicate neighbor entry found (-existing +got):\n%s", diff) + } + t.Fatalf("exact neighbor entry duplicate found for addr=%s", n.Addr) + } + neighborByAddr[n.Addr] = n + } + + neigh, ok := neighborByAddr[test.senderAddr] + if !ok { + t.Fatalf("expected neighbor entry with Addr = %s", test.senderAddr) + } + if got, want := neigh.LinkAddr, test.senderLinkAddr; got != want { + t.Errorf("got neighbor LinkAddr = %s, want = %s", got, want) + } + if got, want := neigh.LocalAddr, stackAddr; got != want { + t.Errorf("got neighbor LocalAddr = %s, want = %s", got, want) + } + if got, want := neigh.State, stack.Stale; got != want { + t.Errorf("got neighbor State = %s, want = %s", got, want) + } + + // No more events should be dispatched + for { + event, ok := c.nudDisp.nextEvent() + if !ok { + break + } + t.Errorf("unexpected %s", event) + } + }) + } +} + func TestLinkAddressRequest(t *testing.T) { tests := []struct { name string @@ -156,8 +431,8 @@ func TestLinkAddressRequest(t *testing.T) { }{ { name: "Unicast", - remoteLinkAddr: stackLinkAddr2, - expectLinkAddr: stackLinkAddr2, + remoteLinkAddr: remoteLinkAddr, + expectLinkAddr: remoteLinkAddr, }, { name: "Multicast", @@ -173,9 +448,9 @@ func TestLinkAddressRequest(t *testing.T) { t.Fatal("expected ARP protocol to implement stack.LinkAddressResolver") } - linkEP := channel.New(defaultChannelSize, defaultMTU, stackLinkAddr1) - if err := linkRes.LinkAddressRequest(stackAddr1, stackAddr2, test.remoteLinkAddr, linkEP); err != nil { - t.Errorf("got p.LinkAddressRequest(%s, %s, %s, _) = %s", stackAddr1, stackAddr2, test.remoteLinkAddr, err) + linkEP := channel.New(defaultChannelSize, defaultMTU, stackLinkAddr) + if err := linkRes.LinkAddressRequest(stackAddr, remoteAddr, test.remoteLinkAddr, linkEP); err != nil { + t.Errorf("got p.LinkAddressRequest(%s, %s, %s, _) = %s", stackAddr, remoteAddr, test.remoteLinkAddr, err) } pkt, ok := linkEP.Read() diff --git a/pkg/tcpip/network/fragmentation/BUILD b/pkg/tcpip/network/fragmentation/BUILD index d1c728ccf..96c5f42f8 100644 --- a/pkg/tcpip/network/fragmentation/BUILD +++ b/pkg/tcpip/network/fragmentation/BUILD @@ -41,5 +41,7 @@ go_test( "reassembler_test.go", ], library = ":fragmentation", - deps = ["//pkg/tcpip/buffer"], + deps = [ + "//pkg/tcpip/buffer", + ], ) diff --git a/pkg/tcpip/network/fragmentation/fragmentation.go b/pkg/tcpip/network/fragmentation/fragmentation.go index 1827666c5..6a4843f92 100644 --- a/pkg/tcpip/network/fragmentation/fragmentation.go +++ b/pkg/tcpip/network/fragmentation/fragmentation.go @@ -120,29 +120,36 @@ func NewFragmentation(blockSize uint16, highMemoryLimit, lowMemoryLimit int, rea } // Process processes an incoming fragment belonging to an ID and returns a -// complete packet when all the packets belonging to that ID have been received. +// complete packet and its protocol number when all the packets belonging to +// that ID have been received. // // [first, last] is the range of the fragment bytes. // // first must be a multiple of the block size f is configured with. The size // of the fragment data must be a multiple of the block size, unless there are // no fragments following this fragment (more set to false). -func (f *Fragmentation) Process(id FragmentID, first, last uint16, more bool, vv buffer.VectorisedView) (buffer.VectorisedView, bool, error) { +// +// proto is the protocol number marked in the fragment being processed. It has +// to be given here outside of the FragmentID struct because IPv6 should not use +// the protocol to identify a fragment. +func (f *Fragmentation) Process( + id FragmentID, first, last uint16, more bool, proto uint8, vv buffer.VectorisedView) ( + buffer.VectorisedView, uint8, bool, error) { if first > last { - return buffer.VectorisedView{}, false, fmt.Errorf("first=%d is greater than last=%d: %w", first, last, ErrInvalidArgs) + return buffer.VectorisedView{}, 0, false, fmt.Errorf("first=%d is greater than last=%d: %w", first, last, ErrInvalidArgs) } if first%f.blockSize != 0 { - return buffer.VectorisedView{}, false, fmt.Errorf("first=%d is not a multiple of block size=%d: %w", first, f.blockSize, ErrInvalidArgs) + return buffer.VectorisedView{}, 0, false, fmt.Errorf("first=%d is not a multiple of block size=%d: %w", first, f.blockSize, ErrInvalidArgs) } fragmentSize := last - first + 1 if more && fragmentSize%f.blockSize != 0 { - return buffer.VectorisedView{}, false, fmt.Errorf("fragment size=%d bytes is not a multiple of block size=%d on non-final fragment: %w", fragmentSize, f.blockSize, ErrInvalidArgs) + return buffer.VectorisedView{}, 0, false, fmt.Errorf("fragment size=%d bytes is not a multiple of block size=%d on non-final fragment: %w", fragmentSize, f.blockSize, ErrInvalidArgs) } if l := vv.Size(); l < int(fragmentSize) { - return buffer.VectorisedView{}, false, fmt.Errorf("got fragment size=%d bytes less than the expected fragment size=%d bytes (first=%d last=%d): %w", l, fragmentSize, first, last, ErrInvalidArgs) + return buffer.VectorisedView{}, 0, false, fmt.Errorf("got fragment size=%d bytes less than the expected fragment size=%d bytes (first=%d last=%d): %w", l, fragmentSize, first, last, ErrInvalidArgs) } vv.CapLength(int(fragmentSize)) @@ -160,14 +167,14 @@ func (f *Fragmentation) Process(id FragmentID, first, last uint16, more bool, vv } f.mu.Unlock() - res, done, consumed, err := r.process(first, last, more, vv) + res, firstFragmentProto, done, consumed, err := r.process(first, last, more, proto, vv) if err != nil { // We probably got an invalid sequence of fragments. Just // discard the reassembler and move on. f.mu.Lock() f.release(r) f.mu.Unlock() - return buffer.VectorisedView{}, false, fmt.Errorf("fragmentation processing error: %v", err) + return buffer.VectorisedView{}, 0, false, fmt.Errorf("fragmentation processing error: %w", err) } f.mu.Lock() f.size += consumed @@ -186,7 +193,7 @@ func (f *Fragmentation) Process(id FragmentID, first, last uint16, more bool, vv } } f.mu.Unlock() - return res, done, nil + return res, firstFragmentProto, done, nil } func (f *Fragmentation) release(r *reassembler) { diff --git a/pkg/tcpip/network/fragmentation/fragmentation_test.go b/pkg/tcpip/network/fragmentation/fragmentation_test.go index 9eedd33c4..416604659 100644 --- a/pkg/tcpip/network/fragmentation/fragmentation_test.go +++ b/pkg/tcpip/network/fragmentation/fragmentation_test.go @@ -38,12 +38,14 @@ type processInput struct { first uint16 last uint16 more bool + proto uint8 vv buffer.VectorisedView } type processOutput struct { - vv buffer.VectorisedView - done bool + vv buffer.VectorisedView + proto uint8 + done bool } var processTestCases = []struct { @@ -63,6 +65,17 @@ var processTestCases = []struct { }, }, { + comment: "Next Header protocol mismatch", + in: []processInput{ + {id: FragmentID{ID: 0}, first: 0, last: 1, more: true, proto: 6, vv: vv(2, "01")}, + {id: FragmentID{ID: 0}, first: 2, last: 3, more: false, proto: 17, vv: vv(2, "23")}, + }, + out: []processOutput{ + {vv: buffer.VectorisedView{}, done: false}, + {vv: vv(4, "01", "23"), proto: 6, done: true}, + }, + }, + { comment: "Two IDs", in: []processInput{ {id: FragmentID{ID: 0}, first: 0, last: 1, more: true, vv: vv(2, "01")}, @@ -83,18 +96,26 @@ func TestFragmentationProcess(t *testing.T) { for _, c := range processTestCases { t.Run(c.comment, func(t *testing.T) { f := NewFragmentation(minBlockSize, 1024, 512, DefaultReassembleTimeout) + firstFragmentProto := c.in[0].proto for i, in := range c.in { - vv, done, err := f.Process(in.id, in.first, in.last, in.more, in.vv) + vv, proto, done, err := f.Process(in.id, in.first, in.last, in.more, in.proto, in.vv) if err != nil { - t.Fatalf("f.Process(%+v, %+d, %+d, %t, %+v) failed: %v", in.id, in.first, in.last, in.more, in.vv, err) + t.Fatalf("f.Process(%+v, %d, %d, %t, %d, %X) failed: %s", + in.id, in.first, in.last, in.more, in.proto, in.vv.ToView(), err) } if !reflect.DeepEqual(vv, c.out[i].vv) { - t.Errorf("got Process(%d) = %+v, want = %+v", i, vv, c.out[i].vv) + t.Errorf("got Process(%+v, %d, %d, %t, %d, %X) = (%X, _, _, _), want = (%X, _, _, _)", + in.id, in.first, in.last, in.more, in.proto, in.vv.ToView(), vv.ToView(), c.out[i].vv.ToView()) } if done != c.out[i].done { - t.Errorf("got Process(%d) = %+v, want = %+v", i, done, c.out[i].done) + t.Errorf("got Process(%+v, %d, %d, %t, %d, _) = (_, _, %t, _), want = (_, _, %t, _)", + in.id, in.first, in.last, in.more, in.proto, done, c.out[i].done) } if c.out[i].done { + if firstFragmentProto != proto { + t.Errorf("got Process(%+v, %d, %d, %t, %d, _) = (_, %d, _, _), want = (_, %d, _, _)", + in.id, in.first, in.last, in.more, in.proto, proto, firstFragmentProto) + } if _, ok := f.reassemblers[in.id]; ok { t.Errorf("Process(%d) did not remove buffer from reassemblers", i) } @@ -113,14 +134,14 @@ func TestReassemblingTimeout(t *testing.T) { timeout := time.Millisecond f := NewFragmentation(minBlockSize, 1024, 512, timeout) // Send first fragment with id = 0, first = 0, last = 0, and more = true. - f.Process(FragmentID{}, 0, 0, true, vv(1, "0")) + f.Process(FragmentID{}, 0, 0, true, 0xFF, vv(1, "0")) // Sleep more than the timeout. time.Sleep(2 * timeout) // Send another fragment that completes a packet. // However, no packet should be reassembled because the fragment arrived after the timeout. - _, done, err := f.Process(FragmentID{}, 1, 1, false, vv(1, "1")) + _, _, done, err := f.Process(FragmentID{}, 1, 1, false, 0xFF, vv(1, "1")) if err != nil { - t.Fatalf("f.Process(0, 1, 1, false, vv(1, \"1\")) failed: %v", err) + t.Fatalf("f.Process(0, 1, 1, false, 0xFF, vv(1, \"1\")) failed: %v", err) } if done { t.Errorf("Fragmentation does not respect the reassembling timeout.") @@ -130,15 +151,15 @@ func TestReassemblingTimeout(t *testing.T) { func TestMemoryLimits(t *testing.T) { f := NewFragmentation(minBlockSize, 3, 1, DefaultReassembleTimeout) // Send first fragment with id = 0. - f.Process(FragmentID{ID: 0}, 0, 0, true, vv(1, "0")) + f.Process(FragmentID{ID: 0}, 0, 0, true, 0xFF, vv(1, "0")) // Send first fragment with id = 1. - f.Process(FragmentID{ID: 1}, 0, 0, true, vv(1, "1")) + f.Process(FragmentID{ID: 1}, 0, 0, true, 0xFF, vv(1, "1")) // Send first fragment with id = 2. - f.Process(FragmentID{ID: 2}, 0, 0, true, vv(1, "2")) + f.Process(FragmentID{ID: 2}, 0, 0, true, 0xFF, vv(1, "2")) // Send first fragment with id = 3. This should caused id = 0 and id = 1 to be // evicted. - f.Process(FragmentID{ID: 3}, 0, 0, true, vv(1, "3")) + f.Process(FragmentID{ID: 3}, 0, 0, true, 0xFF, vv(1, "3")) if _, ok := f.reassemblers[FragmentID{ID: 0}]; ok { t.Errorf("Memory limits are not respected: id=0 has not been evicted.") @@ -154,9 +175,9 @@ func TestMemoryLimits(t *testing.T) { func TestMemoryLimitsIgnoresDuplicates(t *testing.T) { f := NewFragmentation(minBlockSize, 1, 0, DefaultReassembleTimeout) // Send first fragment with id = 0. - f.Process(FragmentID{}, 0, 0, true, vv(1, "0")) + f.Process(FragmentID{}, 0, 0, true, 0xFF, vv(1, "0")) // Send the same packet again. - f.Process(FragmentID{}, 0, 0, true, vv(1, "0")) + f.Process(FragmentID{}, 0, 0, true, 0xFF, vv(1, "0")) got := f.size want := 1 @@ -248,12 +269,12 @@ func TestErrors(t *testing.T) { for _, test := range tests { t.Run(test.name, func(t *testing.T) { f := NewFragmentation(test.blockSize, HighFragThreshold, LowFragThreshold, DefaultReassembleTimeout) - _, done, err := f.Process(FragmentID{}, test.first, test.last, test.more, vv(len(test.data), test.data)) + _, _, done, err := f.Process(FragmentID{}, test.first, test.last, test.more, 0, vv(len(test.data), test.data)) if !errors.Is(err, test.err) { - t.Errorf("got Proceess(_, %d, %d, %t, %q) = (_, _, %v), want = (_, _, %v)", test.first, test.last, test.more, test.data, err, test.err) + t.Errorf("got Process(_, %d, %d, %t, _, %q) = (_, _, _, %v), want = (_, _, _, %v)", test.first, test.last, test.more, test.data, err, test.err) } if done { - t.Errorf("got Proceess(_, %d, %d, %t, %q) = (_, true, _), want = (_, false, _)", test.first, test.last, test.more, test.data) + t.Errorf("got Process(_, %d, %d, %t, _, %q) = (_, _, true, _), want = (_, _, false, _)", test.first, test.last, test.more, test.data) } }) } diff --git a/pkg/tcpip/network/fragmentation/reassembler.go b/pkg/tcpip/network/fragmentation/reassembler.go index 50d30bbf0..f044867dc 100644 --- a/pkg/tcpip/network/fragmentation/reassembler.go +++ b/pkg/tcpip/network/fragmentation/reassembler.go @@ -34,6 +34,7 @@ type reassembler struct { reassemblerEntry id FragmentID size int + proto uint8 mu sync.Mutex holes []hole deleted int @@ -46,7 +47,6 @@ func newReassembler(id FragmentID) *reassembler { r := &reassembler{ id: id, holes: make([]hole, 0, 16), - deleted: 0, heap: make(fragHeap, 0, 8), creationTime: time.Now(), } @@ -78,7 +78,7 @@ func (r *reassembler) updateHoles(first, last uint16, more bool) bool { return used } -func (r *reassembler) process(first, last uint16, more bool, vv buffer.VectorisedView) (buffer.VectorisedView, bool, int, error) { +func (r *reassembler) process(first, last uint16, more bool, proto uint8, vv buffer.VectorisedView) (buffer.VectorisedView, uint8, bool, int, error) { r.mu.Lock() defer r.mu.Unlock() consumed := 0 @@ -86,7 +86,18 @@ func (r *reassembler) process(first, last uint16, more bool, vv buffer.Vectorise // A concurrent goroutine might have already reassembled // the packet and emptied the heap while this goroutine // was waiting on the mutex. We don't have to do anything in this case. - return buffer.VectorisedView{}, false, consumed, nil + return buffer.VectorisedView{}, 0, false, consumed, nil + } + // For IPv6, it is possible to have different Protocol values between + // fragments of a packet (because, unlike IPv4, the Protocol is not used to + // identify a fragment). In this case, only the Protocol of the first + // fragment must be used as per RFC 8200 Section 4.5. + // + // TODO(gvisor.dev/issue/3648): The entire first IP header should be recorded + // here (instead of just the protocol) because most IP options should be + // derived from the first fragment. + if first == 0 { + r.proto = proto } if r.updateHoles(first, last, more) { // We store the incoming packet only if it filled some holes. @@ -96,13 +107,13 @@ func (r *reassembler) process(first, last uint16, more bool, vv buffer.Vectorise } // Check if all the holes have been deleted and we are ready to reassamble. if r.deleted < len(r.holes) { - return buffer.VectorisedView{}, false, consumed, nil + return buffer.VectorisedView{}, 0, false, consumed, nil } res, err := r.heap.reassemble() if err != nil { - return buffer.VectorisedView{}, false, consumed, fmt.Errorf("fragment reassembly failed: %v", err) + return buffer.VectorisedView{}, 0, false, consumed, fmt.Errorf("fragment reassembly failed: %w", err) } - return res, true, consumed, nil + return res, r.proto, true, consumed, nil } func (r *reassembler) tooOld(timeout time.Duration) bool { diff --git a/pkg/tcpip/network/ip_test.go b/pkg/tcpip/network/ip_test.go index 9007346fe..e45dd17f8 100644 --- a/pkg/tcpip/network/ip_test.go +++ b/pkg/tcpip/network/ip_test.go @@ -250,7 +250,7 @@ func buildDummyStack(t *testing.T) *stack.Stack { func TestIPv4Send(t *testing.T) { o := testObject{t: t, v4: true} proto := ipv4.NewProtocol() - ep := proto.NewEndpoint(nicID, nil, nil, &o, buildDummyStack(t)) + ep := proto.NewEndpoint(nicID, nil, nil, nil, &o, buildDummyStack(t)) defer ep.Close() // Allocate and initialize the payload view. @@ -287,7 +287,7 @@ func TestIPv4Send(t *testing.T) { func TestIPv4Receive(t *testing.T) { o := testObject{t: t, v4: true} proto := ipv4.NewProtocol() - ep := proto.NewEndpoint(nicID, nil, &o, nil, buildDummyStack(t)) + ep := proto.NewEndpoint(nicID, nil, nil, &o, nil, buildDummyStack(t)) defer ep.Close() totalLen := header.IPv4MinimumSize + 30 @@ -357,7 +357,7 @@ func TestIPv4ReceiveControl(t *testing.T) { t.Run(c.name, func(t *testing.T) { o := testObject{t: t} proto := ipv4.NewProtocol() - ep := proto.NewEndpoint(nicID, nil, &o, nil, buildDummyStack(t)) + ep := proto.NewEndpoint(nicID, nil, nil, &o, nil, buildDummyStack(t)) defer ep.Close() const dataOffset = header.IPv4MinimumSize*2 + header.ICMPv4MinimumSize @@ -418,7 +418,7 @@ func TestIPv4ReceiveControl(t *testing.T) { func TestIPv4FragmentationReceive(t *testing.T) { o := testObject{t: t, v4: true} proto := ipv4.NewProtocol() - ep := proto.NewEndpoint(nicID, nil, &o, nil, buildDummyStack(t)) + ep := proto.NewEndpoint(nicID, nil, nil, &o, nil, buildDummyStack(t)) defer ep.Close() totalLen := header.IPv4MinimumSize + 24 @@ -495,7 +495,7 @@ func TestIPv4FragmentationReceive(t *testing.T) { func TestIPv6Send(t *testing.T) { o := testObject{t: t} proto := ipv6.NewProtocol() - ep := proto.NewEndpoint(nicID, nil, &o, channel.New(0, 1280, ""), buildDummyStack(t)) + ep := proto.NewEndpoint(nicID, nil, nil, &o, channel.New(0, 1280, ""), buildDummyStack(t)) defer ep.Close() // Allocate and initialize the payload view. @@ -532,7 +532,7 @@ func TestIPv6Send(t *testing.T) { func TestIPv6Receive(t *testing.T) { o := testObject{t: t} proto := ipv6.NewProtocol() - ep := proto.NewEndpoint(nicID, nil, &o, nil, buildDummyStack(t)) + ep := proto.NewEndpoint(nicID, nil, nil, &o, nil, buildDummyStack(t)) defer ep.Close() totalLen := header.IPv6MinimumSize + 30 @@ -611,7 +611,7 @@ func TestIPv6ReceiveControl(t *testing.T) { t.Run(c.name, func(t *testing.T) { o := testObject{t: t} proto := ipv6.NewProtocol() - ep := proto.NewEndpoint(nicID, nil, &o, nil, buildDummyStack(t)) + ep := proto.NewEndpoint(nicID, nil, nil, &o, nil, buildDummyStack(t)) defer ep.Close() dataOffset := header.IPv6MinimumSize*2 + header.ICMPv6MinimumSize diff --git a/pkg/tcpip/network/ipv4/BUILD b/pkg/tcpip/network/ipv4/BUILD index d142b4ffa..f9c2aa980 100644 --- a/pkg/tcpip/network/ipv4/BUILD +++ b/pkg/tcpip/network/ipv4/BUILD @@ -13,6 +13,7 @@ go_library( "//pkg/tcpip", "//pkg/tcpip/buffer", "//pkg/tcpip/header", + "//pkg/tcpip/header/parse", "//pkg/tcpip/network/fragmentation", "//pkg/tcpip/network/hash", "//pkg/tcpip/stack", @@ -30,6 +31,7 @@ go_test( "//pkg/tcpip/link/channel", "//pkg/tcpip/link/sniffer", "//pkg/tcpip/network/ipv4", + "//pkg/tcpip/network/testutil", "//pkg/tcpip/stack", "//pkg/tcpip/transport/tcp", "//pkg/tcpip/transport/udp", diff --git a/pkg/tcpip/network/ipv4/ipv4.go b/pkg/tcpip/network/ipv4/ipv4.go index 79872ec9a..b14b356d6 100644 --- a/pkg/tcpip/network/ipv4/ipv4.go +++ b/pkg/tcpip/network/ipv4/ipv4.go @@ -26,6 +26,7 @@ import ( "gvisor.dev/gvisor/pkg/tcpip" "gvisor.dev/gvisor/pkg/tcpip/buffer" "gvisor.dev/gvisor/pkg/tcpip/header" + "gvisor.dev/gvisor/pkg/tcpip/header/parse" "gvisor.dev/gvisor/pkg/tcpip/network/fragmentation" "gvisor.dev/gvisor/pkg/tcpip/network/hash" "gvisor.dev/gvisor/pkg/tcpip/stack" @@ -59,7 +60,7 @@ type endpoint struct { } // NewEndpoint creates a new ipv4 endpoint. -func (p *protocol) NewEndpoint(nicID tcpip.NICID, linkAddrCache stack.LinkAddressCache, dispatcher stack.TransportDispatcher, linkEP stack.LinkEndpoint, st *stack.Stack) stack.NetworkEndpoint { +func (p *protocol) NewEndpoint(nicID tcpip.NICID, _ stack.LinkAddressCache, _ stack.NUDHandler, dispatcher stack.TransportDispatcher, linkEP stack.LinkEndpoint, st *stack.Stack) stack.NetworkEndpoint { return &endpoint{ nicID: nicID, linkEP: linkEP, @@ -235,14 +236,17 @@ func (e *endpoint) WritePacket(r *stack.Route, gso *stack.GSO, params stack.Netw ipt := e.stack.IPTables() if ok := ipt.Check(stack.Output, pkt, gso, r, "", nicName); !ok { // iptables is telling us to drop the packet. + r.Stats().IP.IPTablesOutputDropped.Increment() return nil } - // If the packet is manipulated as per NAT Ouput rules, handle packet - // based on destination address and do not send the packet to link layer. - // TODO(gvisor.dev/issue/170): We should do this for every packet, rather than - // only NATted packets, but removing this check short circuits broadcasts - // before they are sent out to other hosts. + // If the packet is manipulated as per NAT Output rules, handle packet + // based on destination address and do not send the packet to link + // layer. + // + // TODO(gvisor.dev/issue/170): We should do this for every + // packet, rather than only NATted packets, but removing this check + // short circuits broadcasts before they are sent out to other hosts. if pkt.NatDone { netHeader := header.IPv4(pkt.NetworkHeader().View()) ep, err := e.stack.FindNetworkEndpoint(header.IPv4ProtocolNumber, netHeader.DestinationAddress()) @@ -297,8 +301,9 @@ func (e *endpoint) WritePackets(r *stack.Route, gso *stack.GSO, pkts stack.Packe r.Stats().IP.PacketsSent.IncrementBy(uint64(n)) return n, err } + r.Stats().IP.IPTablesOutputDropped.IncrementBy(uint64(len(dropped))) - // Slow Path as we are dropping some packets in the batch degrade to + // Slow path as we are dropping some packets in the batch degrade to // emitting one packet at a time. n := 0 for pkt := pkts.Front(); pkt != nil; pkt = pkt.Next() { @@ -318,12 +323,15 @@ func (e *endpoint) WritePackets(r *stack.Route, gso *stack.GSO, pkts stack.Packe } if err := e.linkEP.WritePacket(r, gso, ProtocolNumber, pkt); err != nil { r.Stats().IP.PacketsSent.IncrementBy(uint64(n)) - return n, err + // Dropped packets aren't errors, so include them in + // the return value. + return n + len(dropped), err } n++ } r.Stats().IP.PacketsSent.IncrementBy(uint64(n)) - return n, nil + // Dropped packets aren't errors, so include them in the return value. + return n + len(dropped), nil } // WriteHeaderIncludedPacket writes a packet already containing a network @@ -392,6 +400,7 @@ func (e *endpoint) HandlePacket(r *stack.Route, pkt *stack.PacketBuffer) { ipt := e.stack.IPTables() if ok := ipt.Check(stack.Input, pkt, nil, nil, "", ""); !ok { // iptables is telling us to drop the packet. + r.Stats().IP.IPTablesInputDropped.Increment() return } @@ -404,29 +413,35 @@ func (e *endpoint) HandlePacket(r *stack.Route, pkt *stack.PacketBuffer) { return } // The packet is a fragment, let's try to reassemble it. - last := h.FragmentOffset() + uint16(pkt.Data.Size()) - 1 - // Drop the packet if the fragmentOffset is incorrect. i.e the - // combination of fragmentOffset and pkt.Data.size() causes a - // wrap around resulting in last being less than the offset. - if last < h.FragmentOffset() { + start := h.FragmentOffset() + // Drop the fragment if the size of the reassembled payload would exceed the + // maximum payload size. + // + // Note that this addition doesn't overflow even on 32bit architecture + // because pkt.Data.Size() should not exceed 65535 (the max IP datagram + // size). Otherwise the packet would've been rejected as invalid before + // reaching here. + if int(start)+pkt.Data.Size() > header.IPv4MaximumPayloadSize { r.Stats().IP.MalformedPacketsReceived.Increment() r.Stats().IP.MalformedFragmentsReceived.Increment() return } var ready bool var err error - pkt.Data, ready, err = e.protocol.fragmentation.Process( + proto := h.Protocol() + pkt.Data, _, ready, err = e.protocol.fragmentation.Process( // As per RFC 791 section 2.3, the identification value is unique // for a source-destination pair and protocol. fragmentation.FragmentID{ Source: h.SourceAddress(), Destination: h.DestinationAddress(), ID: uint32(h.ID()), - Protocol: h.Protocol(), + Protocol: proto, }, - h.FragmentOffset(), - last, + start, + start+uint16(pkt.Data.Size())-1, h.More(), + proto, pkt.Data, ) if err != nil { @@ -484,10 +499,10 @@ func (*protocol) ParseAddresses(v buffer.View) (src, dst tcpip.Address) { } // SetOption implements NetworkProtocol.SetOption. -func (p *protocol) SetOption(option interface{}) *tcpip.Error { +func (p *protocol) SetOption(option tcpip.SettableNetworkProtocolOption) *tcpip.Error { switch v := option.(type) { - case tcpip.DefaultTTLOption: - p.SetDefaultTTL(uint8(v)) + case *tcpip.DefaultTTLOption: + p.SetDefaultTTL(uint8(*v)) return nil default: return tcpip.ErrUnknownProtocolOption @@ -495,7 +510,7 @@ func (p *protocol) SetOption(option interface{}) *tcpip.Error { } // Option implements NetworkProtocol.Option. -func (p *protocol) Option(option interface{}) *tcpip.Error { +func (p *protocol) Option(option tcpip.GettableNetworkProtocolOption) *tcpip.Error { switch v := option.(type) { case *tcpip.DefaultTTLOption: *v = tcpip.DefaultTTLOption(p.DefaultTTL()) @@ -521,37 +536,14 @@ func (*protocol) Close() {} // Wait implements stack.TransportProtocol.Wait. func (*protocol) Wait() {} -// Parse implements stack.TransportProtocol.Parse. +// Parse implements stack.NetworkProtocol.Parse. func (*protocol) Parse(pkt *stack.PacketBuffer) (proto tcpip.TransportProtocolNumber, hasTransportHdr bool, ok bool) { - hdr, ok := pkt.Data.PullUp(header.IPv4MinimumSize) - if !ok { + if ok := parse.IPv4(pkt); !ok { return 0, false, false } - ipHdr := header.IPv4(hdr) - - // Header may have options, determine the true header length. - headerLen := int(ipHdr.HeaderLength()) - if headerLen < header.IPv4MinimumSize { - // TODO(gvisor.dev/issue/2404): Per RFC 791, IHL needs to be at least 5 in - // order for the packet to be valid. Figure out if we want to reject this - // case. - headerLen = header.IPv4MinimumSize - } - hdr, ok = pkt.NetworkHeader().Consume(headerLen) - if !ok { - return 0, false, false - } - ipHdr = header.IPv4(hdr) - - // If this is a fragment, don't bother parsing the transport header. - parseTransportHeader := true - if ipHdr.More() || ipHdr.FragmentOffset() != 0 { - parseTransportHeader = false - } - pkt.NetworkProtocolNumber = header.IPv4ProtocolNumber - pkt.Data.CapLength(int(ipHdr.TotalLength()) - len(hdr)) - return ipHdr.TransportProtocol(), parseTransportHeader, true + ipHdr := header.IPv4(pkt.NetworkHeader().View()) + return ipHdr.TransportProtocol(), !ipHdr.More() && ipHdr.FragmentOffset() == 0, true } // calculateMTU calculates the network-layer payload MTU based on the link-layer diff --git a/pkg/tcpip/network/ipv4/ipv4_test.go b/pkg/tcpip/network/ipv4/ipv4_test.go index 197e3bc51..b14bc98e8 100644 --- a/pkg/tcpip/network/ipv4/ipv4_test.go +++ b/pkg/tcpip/network/ipv4/ipv4_test.go @@ -17,8 +17,6 @@ package ipv4_test import ( "bytes" "encoding/hex" - "fmt" - "math/rand" "testing" "github.com/google/go-cmp/cmp" @@ -28,6 +26,7 @@ import ( "gvisor.dev/gvisor/pkg/tcpip/link/channel" "gvisor.dev/gvisor/pkg/tcpip/link/sniffer" "gvisor.dev/gvisor/pkg/tcpip/network/ipv4" + "gvisor.dev/gvisor/pkg/tcpip/network/testutil" "gvisor.dev/gvisor/pkg/tcpip/stack" "gvisor.dev/gvisor/pkg/tcpip/transport/tcp" "gvisor.dev/gvisor/pkg/tcpip/transport/udp" @@ -92,31 +91,6 @@ func TestExcludeBroadcast(t *testing.T) { }) } -// makeRandPkt generates a randomize packet. hdrLength indicates how much -// data should already be in the header before WritePacket. extraLength -// indicates how much extra space should be in the header. The payload is made -// from many Views of the sizes listed in viewSizes. -func makeRandPkt(hdrLength int, extraLength int, viewSizes []int) *stack.PacketBuffer { - var views []buffer.View - totalLength := 0 - for _, s := range viewSizes { - newView := buffer.NewView(s) - rand.Read(newView) - views = append(views, newView) - totalLength += s - } - - pkt := stack.NewPacketBuffer(stack.PacketBufferOptions{ - ReserveHeaderBytes: hdrLength + extraLength, - Data: buffer.NewVectorisedView(totalLength, views), - }) - pkt.NetworkProtocolNumber = header.IPv4ProtocolNumber - if _, err := rand.Read(pkt.TransportHeader().Push(hdrLength)); err != nil { - panic(fmt.Sprintf("rand.Read: %s", err)) - } - return pkt -} - // comparePayloads compared the contents of all the packets against the contents // of the source packet. func compareFragments(t *testing.T, packets []*stack.PacketBuffer, sourcePacketInfo *stack.PacketBuffer, mtu uint32) { @@ -186,63 +160,19 @@ func compareFragments(t *testing.T, packets []*stack.PacketBuffer, sourcePacketI } } -type errorChannel struct { - *channel.Endpoint - Ch chan *stack.PacketBuffer - packetCollectorErrors []*tcpip.Error -} - -// newErrorChannel creates a new errorChannel endpoint. Each call to WritePacket -// will return successive errors from packetCollectorErrors until the list is -// empty and then return nil each time. -func newErrorChannel(size int, mtu uint32, linkAddr tcpip.LinkAddress, packetCollectorErrors []*tcpip.Error) *errorChannel { - return &errorChannel{ - Endpoint: channel.New(size, mtu, linkAddr), - Ch: make(chan *stack.PacketBuffer, size), - packetCollectorErrors: packetCollectorErrors, - } -} - -// Drain removes all outbound packets from the channel and counts them. -func (e *errorChannel) Drain() int { - c := 0 - for { - select { - case <-e.Ch: - c++ - default: - return c - } - } -} - -// WritePacket stores outbound packets into the channel. -func (e *errorChannel) WritePacket(r *stack.Route, gso *stack.GSO, protocol tcpip.NetworkProtocolNumber, pkt *stack.PacketBuffer) *tcpip.Error { - select { - case e.Ch <- pkt: - default: - } - - nextError := (*tcpip.Error)(nil) - if len(e.packetCollectorErrors) > 0 { - nextError = e.packetCollectorErrors[0] - e.packetCollectorErrors = e.packetCollectorErrors[1:] - } - return nextError -} - -type context struct { +type testRoute struct { stack.Route - linkEP *errorChannel + + linkEP *testutil.TestEndpoint } -func buildContext(t *testing.T, packetCollectorErrors []*tcpip.Error, mtu uint32) context { +func buildTestRoute(t *testing.T, ep *channel.Endpoint, packetCollectorErrors []*tcpip.Error) testRoute { // Make the packet and write it. s := stack.New(stack.Options{ NetworkProtocols: []stack.NetworkProtocol{ipv4.NewProtocol()}, }) - ep := newErrorChannel(100 /* Enough for all tests. */, mtu, "", packetCollectorErrors) - s.CreateNIC(1, ep) + testEP := testutil.NewTestEndpoint(ep, packetCollectorErrors) + s.CreateNIC(1, testEP) const ( src = "\x10\x00\x00\x01" dst = "\x10\x00\x00\x02" @@ -262,9 +192,12 @@ func buildContext(t *testing.T, packetCollectorErrors []*tcpip.Error, mtu uint32 if err != nil { t.Fatalf("s.FindRoute got %v, want %v", err, nil) } - return context{ + t.Cleanup(func() { + testEP.Close() + }) + return testRoute{ Route: r, - linkEP: ep, + linkEP: testEP, } } @@ -274,13 +207,13 @@ func TestFragmentation(t *testing.T) { manyPayloadViewsSizes[i] = 7 } fragTests := []struct { - description string - mtu uint32 - gso *stack.GSO - hdrLength int - extraLength int - payloadViewsSizes []int - expectedFrags int + description string + mtu uint32 + gso *stack.GSO + transportHeaderLength int + extraHeaderReserveLength int + payloadViewsSizes []int + expectedFrags int }{ {"NoFragmentation", 2000, &stack.GSO{}, 0, header.IPv4MinimumSize, []int{1000}, 1}, {"NoFragmentationWithBigHeader", 2000, &stack.GSO{}, 16, header.IPv4MinimumSize, []int{1000}, 1}, @@ -295,10 +228,10 @@ func TestFragmentation(t *testing.T) { for _, ft := range fragTests { t.Run(ft.description, func(t *testing.T) { - pkt := makeRandPkt(ft.hdrLength, ft.extraLength, ft.payloadViewsSizes) + r := buildTestRoute(t, channel.New(0, ft.mtu, ""), nil) + pkt := testutil.MakeRandPkt(ft.transportHeaderLength, ft.extraHeaderReserveLength, ft.payloadViewsSizes, header.IPv4ProtocolNumber) source := pkt.Clone() - c := buildContext(t, nil, ft.mtu) - err := c.Route.WritePacket(ft.gso, stack.NetworkHeaderParams{ + err := r.WritePacket(ft.gso, stack.NetworkHeaderParams{ Protocol: tcp.ProtocolNumber, TTL: 42, TOS: stack.DefaultTOS, @@ -307,24 +240,13 @@ func TestFragmentation(t *testing.T) { t.Errorf("err got %v, want %v", err, nil) } - var results []*stack.PacketBuffer - L: - for { - select { - case pi := <-c.linkEP.Ch: - results = append(results, pi) - default: - break L - } - } - - if got, want := len(results), ft.expectedFrags; got != want { - t.Errorf("len(result) got %d, want %d", got, want) + if got, want := len(r.linkEP.WrittenPackets), ft.expectedFrags; got != want { + t.Errorf("len(r.linkEP.WrittenPackets) got %d, want %d", got, want) } - if got, want := len(results), int(c.Route.Stats().IP.PacketsSent.Value()); got != want { - t.Errorf("no errors yet len(result) got %d, want %d", got, want) + if got, want := len(r.linkEP.WrittenPackets), int(r.Stats().IP.PacketsSent.Value()); got != want { + t.Errorf("no errors yet len(r.linkEP.WrittenPackets) got %d, want %d", got, want) } - compareFragments(t, results, source, ft.mtu) + compareFragments(t, r.linkEP.WrittenPackets, source, ft.mtu) }) } } @@ -335,21 +257,21 @@ func TestFragmentationErrors(t *testing.T) { fragTests := []struct { description string mtu uint32 - hdrLength int + transportHeaderLength int payloadViewsSizes []int packetCollectorErrors []*tcpip.Error }{ {"NoFrag", 2000, 0, []int{1000}, []*tcpip.Error{tcpip.ErrAborted}}, {"ErrorOnFirstFrag", 500, 0, []int{1000}, []*tcpip.Error{tcpip.ErrAborted}}, {"ErrorOnSecondFrag", 500, 0, []int{1000}, []*tcpip.Error{nil, tcpip.ErrAborted}}, - {"ErrorOnFirstFragMTUSmallerThanHdr", 500, 1000, []int{500}, []*tcpip.Error{tcpip.ErrAborted}}, + {"ErrorOnFirstFragMTUSmallerThanHeader", 500, 1000, []int{500}, []*tcpip.Error{tcpip.ErrAborted}}, } for _, ft := range fragTests { t.Run(ft.description, func(t *testing.T) { - pkt := makeRandPkt(ft.hdrLength, header.IPv4MinimumSize, ft.payloadViewsSizes) - c := buildContext(t, ft.packetCollectorErrors, ft.mtu) - err := c.Route.WritePacket(&stack.GSO{}, stack.NetworkHeaderParams{ + r := buildTestRoute(t, channel.New(0, ft.mtu, ""), ft.packetCollectorErrors) + pkt := testutil.MakeRandPkt(ft.transportHeaderLength, header.IPv4MinimumSize, ft.payloadViewsSizes, header.IPv4ProtocolNumber) + err := r.WritePacket(&stack.GSO{}, stack.NetworkHeaderParams{ Protocol: tcp.ProtocolNumber, TTL: 42, TOS: stack.DefaultTOS, @@ -364,7 +286,7 @@ func TestFragmentationErrors(t *testing.T) { if got, want := err, ft.packetCollectorErrors[len(ft.packetCollectorErrors)-1]; got != want { t.Errorf("err got %v, want %v", got, want) } - if got, want := c.linkEP.Drain(), int(c.Route.Stats().IP.PacketsSent.Value())+1; err != nil && got != want { + if got, want := len(r.linkEP.WrittenPackets), int(r.Stats().IP.PacketsSent.Value())+1; err != nil && got != want { t.Errorf("after linkEP error len(result) got %d, want %d", got, want) } }) @@ -372,115 +294,308 @@ func TestFragmentationErrors(t *testing.T) { } func TestInvalidFragments(t *testing.T) { + const ( + nicID = 1 + linkAddr = tcpip.LinkAddress("\x0a\x0b\x0c\x0d\x0e\x0e") + addr1 = "\x0a\x00\x00\x01" + addr2 = "\x0a\x00\x00\x02" + tos = 0 + ident = 1 + ttl = 48 + protocol = 6 + ) + + payloadGen := func(payloadLen int) []byte { + payload := make([]byte, payloadLen) + for i := 0; i < len(payload); i++ { + payload[i] = 0x30 + } + return payload + } + + type fragmentData struct { + ipv4fields header.IPv4Fields + payload []byte + autoChecksum bool // if true, the Checksum field will be overwritten. + } + // These packets have both IHL and TotalLength set to 0. - testCases := []struct { + tests := []struct { name string - packets [][]byte + fragments []fragmentData wantMalformedIPPackets uint64 wantMalformedFragments uint64 }{ { - "ihl_totallen_zero_valid_frag_offset", - [][]byte{ - {0x40, 0x30, 0x00, 0x00, 0x6c, 0x74, 0x7d, 0x30, 0x30, 0x30, 0x30, 0x30, 0x39, 0x32, 0x39, 0x33, 0xff, 0xff, 0xff, 0xff, 0x30, 0x30, 0x30, 0x30, 0x30, 0x30, 0x30, 0x30, 0x30, 0x30, 0x30, 0x30}, - }, - 1, - 0, - }, - { - "ihl_totallen_zero_invalid_frag_offset", - [][]byte{ - {0x40, 0x30, 0x00, 0x00, 0x6c, 0x74, 0x20, 0x00, 0x30, 0x30, 0x30, 0x30, 0x39, 0x32, 0x39, 0x33, 0xff, 0xff, 0xff, 0xff, 0x30, 0x30, 0x30, 0x30, 0x30, 0x30, 0x30, 0x30, 0x30, 0x30, 0x30, 0x30}, + name: "IHL and TotalLength zero, FragmentOffset non-zero", + fragments: []fragmentData{ + { + ipv4fields: header.IPv4Fields{ + IHL: 0, + TOS: tos, + TotalLength: 0, + ID: ident, + Flags: header.IPv4FlagDontFragment | header.IPv4FlagMoreFragments, + FragmentOffset: 59776, + TTL: ttl, + Protocol: protocol, + SrcAddr: addr1, + DstAddr: addr2, + }, + payload: payloadGen(12), + autoChecksum: true, + }, }, - 1, - 0, + wantMalformedIPPackets: 1, + wantMalformedFragments: 0, }, { - // Total Length of 37(20 bytes IP header + 17 bytes of - // payload) - // Frag Offset of 0x1ffe = 8190*8 = 65520 - // Leading to the fragment end to be past 65535. - "ihl_totallen_valid_invalid_frag_offset_1", - [][]byte{ - {0x45, 0x30, 0x00, 0x25, 0x6c, 0x74, 0x1f, 0xfe, 0x30, 0x30, 0x30, 0x30, 0x39, 0x32, 0x39, 0x33, 0xff, 0xff, 0xff, 0xff, 0x30, 0x30, 0x30, 0x30, 0x30, 0x30, 0x30, 0x30, 0x30, 0x30, 0x30, 0x30, 0x30, 0x30, 0x30, 0x30, 0x30}, + name: "IHL and TotalLength zero, FragmentOffset zero", + fragments: []fragmentData{ + { + ipv4fields: header.IPv4Fields{ + IHL: 0, + TOS: tos, + TotalLength: 0, + ID: ident, + Flags: header.IPv4FlagMoreFragments, + FragmentOffset: 0, + TTL: ttl, + Protocol: protocol, + SrcAddr: addr1, + DstAddr: addr2, + }, + payload: payloadGen(12), + autoChecksum: true, + }, }, - 1, - 1, + wantMalformedIPPackets: 1, + wantMalformedFragments: 0, }, - // The following 3 tests were found by running a fuzzer and were - // triggering a panic in the IPv4 reassembler code. { - "ihl_less_than_ipv4_minimum_size_1", - [][]byte{ - {0x42, 0x30, 0x0, 0x30, 0x30, 0x40, 0x0, 0xf3, 0x30, 0x1, 0x30, 0x30, 0x73, 0x73, 0x69, 0x6e, 0xff, 0xff, 0xff, 0xff, 0x30, 0x30, 0x30, 0x30, 0x30, 0x30, 0x30, 0x30, 0x30, 0x30, 0x30, 0x30, 0x30, 0x30, 0x30, 0x30, 0x30, 0x30, 0x30, 0x30, 0x30, 0x30, 0x30, 0x30, 0x30, 0x30, 0x30, 0x30}, - {0x42, 0x30, 0x0, 0x8, 0x30, 0x40, 0x20, 0x0, 0x30, 0x1, 0x30, 0x30, 0x73, 0x73, 0x69, 0x6e, 0xff, 0xff, 0xff, 0xff, 0x30, 0x30, 0x30, 0x30, 0x30, 0x30, 0x30, 0x30, 0x30, 0x30, 0x30, 0x30, 0x30, 0x30, 0x30, 0x30, 0x30, 0x30, 0x30, 0x30, 0x30, 0x30, 0x30, 0x30, 0x30, 0x30, 0x30, 0x30}, + // Payload 17 octets and Fragment offset 65520 + // Leading to the fragment end to be past 65536. + name: "fragment ends past 65536", + fragments: []fragmentData{ + { + ipv4fields: header.IPv4Fields{ + IHL: header.IPv4MinimumSize, + TOS: tos, + TotalLength: header.IPv4MinimumSize + 17, + ID: ident, + Flags: 0, + FragmentOffset: 65520, + TTL: ttl, + Protocol: protocol, + SrcAddr: addr1, + DstAddr: addr2, + }, + payload: payloadGen(17), + autoChecksum: true, + }, }, - 2, - 0, + wantMalformedIPPackets: 1, + wantMalformedFragments: 1, }, { - "ihl_less_than_ipv4_minimum_size_2", - [][]byte{ - {0x42, 0x30, 0x0, 0x30, 0x30, 0x40, 0xb3, 0x12, 0x30, 0x6, 0x30, 0x30, 0x73, 0x73, 0x69, 0x6e, 0xff, 0xff, 0xff, 0xff, 0x30, 0x30, 0x30, 0x30, 0x30, 0x30, 0x30, 0x30, 0x30, 0x30, 0x30, 0x30, 0x30, 0x30, 0x30, 0x30, 0x30, 0x30, 0x30, 0x30, 0x30, 0x30, 0x30, 0x30, 0x30, 0x30, 0x30, 0x30}, - {0x42, 0x30, 0x0, 0x8, 0x30, 0x40, 0x20, 0x0, 0x30, 0x6, 0x30, 0x30, 0x73, 0x73, 0x69, 0x6e, 0xff, 0xff, 0xff, 0xff, 0x30, 0x30, 0x30, 0x30, 0x30, 0x30, 0x30, 0x30, 0x30, 0x30, 0x30, 0x30, 0x30, 0x30, 0x30, 0x30, 0x30, 0x30, 0x30, 0x30, 0x30, 0x30, 0x30, 0x30, 0x30, 0x30, 0x30, 0x30}, + // Payload 16 octets and fragment offset 65520 + // Leading to the fragment end to be exactly 65536. + name: "fragment ends exactly at 65536", + fragments: []fragmentData{ + { + ipv4fields: header.IPv4Fields{ + IHL: header.IPv4MinimumSize, + TOS: tos, + TotalLength: header.IPv4MinimumSize + 16, + ID: ident, + Flags: 0, + FragmentOffset: 65520, + TTL: ttl, + Protocol: protocol, + SrcAddr: addr1, + DstAddr: addr2, + }, + payload: payloadGen(16), + autoChecksum: true, + }, }, - 2, - 0, + wantMalformedIPPackets: 0, + wantMalformedFragments: 0, }, { - "ihl_less_than_ipv4_minimum_size_3", - [][]byte{ - {0x42, 0x30, 0x0, 0x30, 0x30, 0x40, 0xb3, 0x30, 0x30, 0x6, 0x30, 0x30, 0x73, 0x73, 0x69, 0x6e, 0xff, 0xff, 0xff, 0xff, 0x30, 0x30, 0x30, 0x30, 0x30, 0x30, 0x30, 0x30, 0x30, 0x30, 0x30, 0x30, 0x30, 0x30, 0x30, 0x30, 0x30, 0x30, 0x30, 0x30, 0x30, 0x30, 0x30, 0x30, 0x30, 0x30, 0x30, 0x30}, - {0x42, 0x30, 0x0, 0x8, 0x30, 0x40, 0x20, 0x0, 0x30, 0x6, 0x30, 0x30, 0x73, 0x73, 0x69, 0x6e, 0xff, 0xff, 0xff, 0xff, 0x30, 0x30, 0x30, 0x30, 0x30, 0x30, 0x30, 0x30, 0x30, 0x30, 0x30, 0x30, 0x30, 0x30, 0x30, 0x30, 0x30, 0x30, 0x30, 0x30, 0x30, 0x30, 0x30, 0x30, 0x30, 0x30, 0x30, 0x30}, + name: "IHL less than IPv4 minimum size", + fragments: []fragmentData{ + { + ipv4fields: header.IPv4Fields{ + IHL: header.IPv4MinimumSize - 12, + TOS: tos, + TotalLength: header.IPv4MinimumSize + 28, + ID: ident, + Flags: 0, + FragmentOffset: 1944, + TTL: ttl, + Protocol: protocol, + SrcAddr: addr1, + DstAddr: addr2, + }, + payload: payloadGen(28), + autoChecksum: true, + }, + { + ipv4fields: header.IPv4Fields{ + IHL: header.IPv4MinimumSize - 12, + TOS: tos, + TotalLength: header.IPv4MinimumSize - 12, + ID: ident, + Flags: header.IPv4FlagMoreFragments, + FragmentOffset: 0, + TTL: ttl, + Protocol: protocol, + SrcAddr: addr1, + DstAddr: addr2, + }, + payload: payloadGen(28), + autoChecksum: true, + }, }, - 2, - 0, + wantMalformedIPPackets: 2, + wantMalformedFragments: 0, }, { - "fragment_with_short_total_len_extra_payload", - [][]byte{ - {0x46, 0x30, 0x00, 0x30, 0x30, 0x40, 0x0e, 0x12, 0x30, 0x06, 0x30, 0x30, 0x73, 0x73, 0x69, 0x6e, 0xff, 0xff, 0xff, 0xff, 0x30, 0x30, 0x30, 0x30, 0x30, 0x30, 0x30, 0x30, 0x30, 0x30, 0x30, 0x30, 0x30, 0x30, 0x30, 0x30, 0x30, 0x30, 0x30, 0x30, 0x30, 0x30, 0x30, 0x30, 0x30, 0x30, 0x30, 0x30}, - {0x46, 0x30, 0x00, 0x18, 0x30, 0x40, 0x20, 0x00, 0x30, 0x06, 0x30, 0x30, 0x73, 0x73, 0x69, 0x6e, 0xff, 0xff, 0xff, 0xff, 0x30, 0x30, 0x30, 0x30, 0x30, 0x30, 0x30, 0x30, 0x30, 0x30, 0x30, 0x30, 0x30, 0x30, 0x30, 0x30, 0x30, 0x30, 0x30, 0x30, 0x30, 0x30, 0x30, 0x30, 0x30, 0x30, 0x30, 0x30}, + name: "fragment with short TotalLength and extra payload", + fragments: []fragmentData{ + { + ipv4fields: header.IPv4Fields{ + IHL: header.IPv4MinimumSize + 4, + TOS: tos, + TotalLength: header.IPv4MinimumSize + 28, + ID: ident, + Flags: 0, + FragmentOffset: 28816, + TTL: ttl, + Protocol: protocol, + SrcAddr: addr1, + DstAddr: addr2, + }, + payload: payloadGen(28), + autoChecksum: true, + }, + { + ipv4fields: header.IPv4Fields{ + IHL: header.IPv4MinimumSize + 4, + TOS: tos, + TotalLength: header.IPv4MinimumSize + 4, + ID: ident, + Flags: header.IPv4FlagMoreFragments, + FragmentOffset: 0, + TTL: ttl, + Protocol: protocol, + SrcAddr: addr1, + DstAddr: addr2, + }, + payload: payloadGen(28), + autoChecksum: true, + }, }, - 1, - 1, + wantMalformedIPPackets: 1, + wantMalformedFragments: 1, }, { - "multiple_fragments_with_more_fragments_set_to_false", - [][]byte{ - {0x45, 0x00, 0x00, 0x1c, 0x30, 0x40, 0x00, 0x10, 0x00, 0x06, 0x34, 0x69, 0x73, 0x73, 0x69, 0x6e, 0xff, 0xff, 0xff, 0xff, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00}, - {0x45, 0x00, 0x00, 0x1c, 0x30, 0x40, 0x00, 0x01, 0x61, 0x06, 0x34, 0x69, 0x73, 0x73, 0x69, 0x6e, 0xff, 0xff, 0xff, 0xff, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00}, - {0x45, 0x00, 0x00, 0x1c, 0x30, 0x40, 0x20, 0x00, 0x00, 0x06, 0x34, 0x1e, 0x73, 0x73, 0x69, 0x6e, 0xff, 0xff, 0xff, 0xff, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00}, + name: "multiple fragments with More Fragments flag set to false", + fragments: []fragmentData{ + { + ipv4fields: header.IPv4Fields{ + IHL: header.IPv4MinimumSize, + TOS: tos, + TotalLength: header.IPv4MinimumSize + 8, + ID: ident, + Flags: 0, + FragmentOffset: 128, + TTL: ttl, + Protocol: protocol, + SrcAddr: addr1, + DstAddr: addr2, + }, + payload: payloadGen(8), + autoChecksum: true, + }, + { + ipv4fields: header.IPv4Fields{ + IHL: header.IPv4MinimumSize, + TOS: tos, + TotalLength: header.IPv4MinimumSize + 8, + ID: ident, + Flags: 0, + FragmentOffset: 8, + TTL: ttl, + Protocol: protocol, + SrcAddr: addr1, + DstAddr: addr2, + }, + payload: payloadGen(8), + autoChecksum: true, + }, + { + ipv4fields: header.IPv4Fields{ + IHL: header.IPv4MinimumSize, + TOS: tos, + TotalLength: header.IPv4MinimumSize + 8, + ID: ident, + Flags: header.IPv4FlagMoreFragments, + FragmentOffset: 0, + TTL: ttl, + Protocol: protocol, + SrcAddr: addr1, + DstAddr: addr2, + }, + payload: payloadGen(8), + autoChecksum: true, + }, }, - 1, - 1, + wantMalformedIPPackets: 1, + wantMalformedFragments: 1, }, } - for _, tc := range testCases { - t.Run(tc.name, func(t *testing.T) { - const nicID tcpip.NICID = 42 + for _, test := range tests { + t.Run(test.name, func(t *testing.T) { + s := stack.New(stack.Options{ NetworkProtocols: []stack.NetworkProtocol{ ipv4.NewProtocol(), }, }) + e := channel.New(0, 1500, linkAddr) + if err := s.CreateNIC(nicID, e); err != nil { + t.Fatalf("CreateNIC(%d, _) = %s", nicID, err) + } + if err := s.AddAddress(nicID, ipv4.ProtocolNumber, addr2); err != nil { + t.Fatalf("AddAddress(%d, %d, %s) = %s", nicID, header.IPv4ProtocolNumber, addr2, err) + } + + for _, f := range test.fragments { + pktSize := header.IPv4MinimumSize + len(f.payload) + hdr := buffer.NewPrependable(pktSize) - var linkAddr = tcpip.LinkAddress([]byte{0x30, 0x30, 0x30, 0x30, 0x30, 0x30}) - var remoteLinkAddr = tcpip.LinkAddress([]byte{0x30, 0x30, 0x30, 0x30, 0x30, 0x31}) - ep := channel.New(10, 1500, linkAddr) - s.CreateNIC(nicID, sniffer.New(ep)) + ip := header.IPv4(hdr.Prepend(pktSize)) + ip.Encode(&f.ipv4fields) + copy(ip[header.IPv4MinimumSize:], f.payload) - for _, pkt := range tc.packets { - ep.InjectLinkAddr(header.IPv4ProtocolNumber, remoteLinkAddr, stack.NewPacketBuffer(stack.PacketBufferOptions{ - Data: buffer.NewVectorisedView(len(pkt), []buffer.View{pkt}), + if f.autoChecksum { + ip.SetChecksum(0) + ip.SetChecksum(^ip.CalculateChecksum()) + } + + vv := hdr.View().ToVectorisedView() + e.InjectInbound(header.IPv4ProtocolNumber, stack.NewPacketBuffer(stack.PacketBufferOptions{ + Data: vv, })) } - if got, want := s.Stats().IP.MalformedPacketsReceived.Value(), tc.wantMalformedIPPackets; got != want { + if got, want := s.Stats().IP.MalformedPacketsReceived.Value(), test.wantMalformedIPPackets; got != want { t.Errorf("incorrect Stats.IP.MalformedPacketsReceived, got: %d, want: %d", got, want) } - if got, want := s.Stats().IP.MalformedFragmentsReceived.Value(), tc.wantMalformedFragments; got != want { + if got, want := s.Stats().IP.MalformedFragmentsReceived.Value(), test.wantMalformedFragments; got != want { t.Errorf("incorrect Stats.IP.MalformedFragmentsReceived, got: %d, want: %d", got, want) } }) @@ -534,6 +649,9 @@ func TestReceiveFragments(t *testing.T) { // the fragment block size of 8 (RFC 791 section 3.1 page 14). ipv4Payload3Addr1ToAddr2 := udpGen(127, 3, addr1, addr2) udpPayload3Addr1ToAddr2 := ipv4Payload3Addr1ToAddr2[header.UDPMinimumSize:] + // Used to test the max reassembled payload length (65,535 octets). + ipv4Payload4Addr1ToAddr2 := udpGen(header.UDPMaximumSize-header.UDPMinimumSize, 4, addr1, addr2) + udpPayload4Addr1ToAddr2 := ipv4Payload4Addr1ToAddr2[header.UDPMinimumSize:] type fragmentData struct { srcAddr tcpip.Address @@ -827,6 +945,28 @@ func TestReceiveFragments(t *testing.T) { }, expectedPayloads: nil, }, + { + name: "Two fragments reassembled into a maximum UDP packet", + fragments: []fragmentData{ + { + srcAddr: addr1, + dstAddr: addr2, + id: 1, + flags: header.IPv4FlagMoreFragments, + fragmentOffset: 0, + payload: ipv4Payload4Addr1ToAddr2[:65512], + }, + { + srcAddr: addr1, + dstAddr: addr2, + id: 1, + flags: 0, + fragmentOffset: 65512, + payload: ipv4Payload4Addr1ToAddr2[65512:], + }, + }, + expectedPayloads: [][]byte{udpPayload4Addr1ToAddr2}, + }, } for _, test := range tests { @@ -906,3 +1046,252 @@ func TestReceiveFragments(t *testing.T) { }) } } + +func TestWriteStats(t *testing.T) { + const nPackets = 3 + tests := []struct { + name string + setup func(*testing.T, *stack.Stack) + linkEP func() stack.LinkEndpoint + expectSent int + expectDropped int + expectWritten int + }{ + { + name: "Accept all", + // No setup needed, tables accept everything by default. + setup: func(*testing.T, *stack.Stack) {}, + linkEP: func() stack.LinkEndpoint { return &limitedEP{nPackets} }, + expectSent: nPackets, + expectDropped: 0, + expectWritten: nPackets, + }, { + name: "Accept all with error", + // No setup needed, tables accept everything by default. + setup: func(*testing.T, *stack.Stack) {}, + linkEP: func() stack.LinkEndpoint { return &limitedEP{nPackets - 1} }, + expectSent: nPackets - 1, + expectDropped: 0, + expectWritten: nPackets - 1, + }, { + name: "Drop all", + setup: func(t *testing.T, stk *stack.Stack) { + // Install Output DROP rule. + t.Helper() + ipt := stk.IPTables() + filter, ok := ipt.GetTable(stack.FilterTable, false /* ipv6 */) + if !ok { + t.Fatalf("failed to find filter table") + } + ruleIdx := filter.BuiltinChains[stack.Output] + filter.Rules[ruleIdx].Target = stack.DropTarget{} + if err := ipt.ReplaceTable(stack.FilterTable, filter, false /* ipv6 */); err != nil { + t.Fatalf("failed to replace table: %v", err) + } + }, + linkEP: func() stack.LinkEndpoint { return &limitedEP{nPackets} }, + expectSent: 0, + expectDropped: nPackets, + expectWritten: nPackets, + }, { + name: "Drop some", + setup: func(t *testing.T, stk *stack.Stack) { + // Install Output DROP rule that matches only 1 + // of the 3 packets. + t.Helper() + ipt := stk.IPTables() + filter, ok := ipt.GetTable(stack.FilterTable, false /* ipv6 */) + if !ok { + t.Fatalf("failed to find filter table") + } + // We'll match and DROP the last packet. + ruleIdx := filter.BuiltinChains[stack.Output] + filter.Rules[ruleIdx].Target = stack.DropTarget{} + filter.Rules[ruleIdx].Matchers = []stack.Matcher{&limitedMatcher{nPackets - 1}} + // Make sure the next rule is ACCEPT. + filter.Rules[ruleIdx+1].Target = stack.AcceptTarget{} + if err := ipt.ReplaceTable(stack.FilterTable, filter, false /* ipv6 */); err != nil { + t.Fatalf("failed to replace table: %v", err) + } + }, + linkEP: func() stack.LinkEndpoint { return &limitedEP{nPackets} }, + expectSent: nPackets - 1, + expectDropped: 1, + expectWritten: nPackets, + }, + } + + // Parameterize the tests to run with both WritePacket and WritePackets. + writers := []struct { + name string + writePackets func(*stack.Route, stack.PacketBufferList) (int, *tcpip.Error) + }{ + { + name: "WritePacket", + writePackets: func(rt *stack.Route, pkts stack.PacketBufferList) (int, *tcpip.Error) { + nWritten := 0 + for pkt := pkts.Front(); pkt != nil; pkt = pkt.Next() { + if err := rt.WritePacket(nil, stack.NetworkHeaderParams{}, pkt); err != nil { + return nWritten, err + } + nWritten++ + } + return nWritten, nil + }, + }, { + name: "WritePackets", + writePackets: func(rt *stack.Route, pkts stack.PacketBufferList) (int, *tcpip.Error) { + return rt.WritePackets(nil, pkts, stack.NetworkHeaderParams{}) + }, + }, + } + + for _, writer := range writers { + t.Run(writer.name, func(t *testing.T) { + for _, test := range tests { + t.Run(test.name, func(t *testing.T) { + rt := buildRoute(t, nil, test.linkEP()) + + var pkts stack.PacketBufferList + for i := 0; i < nPackets; i++ { + pkt := stack.NewPacketBuffer(stack.PacketBufferOptions{ + ReserveHeaderBytes: header.UDPMinimumSize + int(rt.MaxHeaderLength()), + Data: buffer.NewView(0).ToVectorisedView(), + }) + pkt.TransportHeader().Push(header.UDPMinimumSize) + pkts.PushBack(pkt) + } + + test.setup(t, rt.Stack()) + + nWritten, _ := writer.writePackets(&rt, pkts) + + if got := int(rt.Stats().IP.PacketsSent.Value()); got != test.expectSent { + t.Errorf("sent %d packets, but expected to send %d", got, test.expectSent) + } + if got := int(rt.Stats().IP.IPTablesOutputDropped.Value()); got != test.expectDropped { + t.Errorf("dropped %d packets, but expected to drop %d", got, test.expectDropped) + } + if nWritten != test.expectWritten { + t.Errorf("wrote %d packets, but expected WritePackets to return %d", nWritten, test.expectWritten) + } + }) + } + }) + } +} + +func buildRoute(t *testing.T, packetCollectorErrors []*tcpip.Error, linkEP stack.LinkEndpoint) stack.Route { + s := stack.New(stack.Options{ + NetworkProtocols: []stack.NetworkProtocol{ipv4.NewProtocol()}, + }) + s.CreateNIC(1, linkEP) + const ( + src = "\x10\x00\x00\x01" + dst = "\x10\x00\x00\x02" + ) + s.AddAddress(1, ipv4.ProtocolNumber, src) + { + subnet, err := tcpip.NewSubnet(dst, tcpip.AddressMask(header.IPv4Broadcast)) + if err != nil { + t.Fatal(err) + } + s.SetRouteTable([]tcpip.Route{{ + Destination: subnet, + NIC: 1, + }}) + } + rt, err := s.FindRoute(0, src, dst, ipv4.ProtocolNumber, false /* multicastLoop */) + if err != nil { + t.Fatalf("s.FindRoute got %v, want %v", err, nil) + } + return rt +} + +// limitedEP is a link endpoint that writes up to a certain number of packets +// before returning errors. +type limitedEP struct { + limit int +} + +// MTU implements LinkEndpoint.MTU. +func (*limitedEP) MTU() uint32 { + // Give an MTU that won't cause fragmentation for IPv4+UDP. + return header.IPv4MinimumSize + header.UDPMinimumSize +} + +// Capabilities implements LinkEndpoint.Capabilities. +func (*limitedEP) Capabilities() stack.LinkEndpointCapabilities { return 0 } + +// MaxHeaderLength implements LinkEndpoint.MaxHeaderLength. +func (*limitedEP) MaxHeaderLength() uint16 { return 0 } + +// LinkAddress implements LinkEndpoint.LinkAddress. +func (*limitedEP) LinkAddress() tcpip.LinkAddress { return "" } + +// WritePacket implements LinkEndpoint.WritePacket. +func (ep *limitedEP) WritePacket(*stack.Route, *stack.GSO, tcpip.NetworkProtocolNumber, *stack.PacketBuffer) *tcpip.Error { + if ep.limit == 0 { + return tcpip.ErrInvalidEndpointState + } + ep.limit-- + return nil +} + +// WritePackets implements LinkEndpoint.WritePackets. +func (ep *limitedEP) WritePackets(_ *stack.Route, _ *stack.GSO, pkts stack.PacketBufferList, _ tcpip.NetworkProtocolNumber) (int, *tcpip.Error) { + if ep.limit == 0 { + return 0, tcpip.ErrInvalidEndpointState + } + nWritten := ep.limit + if nWritten > pkts.Len() { + nWritten = pkts.Len() + } + ep.limit -= nWritten + return nWritten, nil +} + +// WriteRawPacket implements LinkEndpoint.WriteRawPacket. +func (ep *limitedEP) WriteRawPacket(_ buffer.VectorisedView) *tcpip.Error { + if ep.limit == 0 { + return tcpip.ErrInvalidEndpointState + } + ep.limit-- + return nil +} + +// Attach implements LinkEndpoint.Attach. +func (*limitedEP) Attach(_ stack.NetworkDispatcher) {} + +// IsAttached implements LinkEndpoint.IsAttached. +func (*limitedEP) IsAttached() bool { return false } + +// Wait implements LinkEndpoint.Wait. +func (*limitedEP) Wait() {} + +// ARPHardwareType implements LinkEndpoint.ARPHardwareType. +func (*limitedEP) ARPHardwareType() header.ARPHardwareType { return header.ARPHardwareEther } + +// AddHeader implements LinkEndpoint.AddHeader. +func (*limitedEP) AddHeader(_, _ tcpip.LinkAddress, _ tcpip.NetworkProtocolNumber, _ *stack.PacketBuffer) { +} + +// limitedMatcher is an iptables matcher that matches after a certain number of +// packets are checked against it. +type limitedMatcher struct { + limit int +} + +// Name implements Matcher.Name. +func (*limitedMatcher) Name() string { + return "limitedMatcher" +} + +// Match implements Matcher.Match. +func (lm *limitedMatcher) Match(stack.Hook, *stack.PacketBuffer, string) (bool, bool) { + if lm.limit == 0 { + return true, false + } + lm.limit-- + return false, false +} diff --git a/pkg/tcpip/network/ipv6/BUILD b/pkg/tcpip/network/ipv6/BUILD index bcc64994e..cd5fe3ea8 100644 --- a/pkg/tcpip/network/ipv6/BUILD +++ b/pkg/tcpip/network/ipv6/BUILD @@ -13,6 +13,7 @@ go_library( "//pkg/tcpip", "//pkg/tcpip/buffer", "//pkg/tcpip/header", + "//pkg/tcpip/header/parse", "//pkg/tcpip/network/fragmentation", "//pkg/tcpip/stack", ], diff --git a/pkg/tcpip/network/ipv6/icmp.go b/pkg/tcpip/network/ipv6/icmp.go index 66d3a953a..7430b8fcd 100644 --- a/pkg/tcpip/network/ipv6/icmp.go +++ b/pkg/tcpip/network/ipv6/icmp.go @@ -15,8 +15,6 @@ package ipv6 import ( - "fmt" - "gvisor.dev/gvisor/pkg/tcpip" "gvisor.dev/gvisor/pkg/tcpip/buffer" "gvisor.dev/gvisor/pkg/tcpip/header" @@ -71,6 +69,59 @@ func (e *endpoint) handleControl(typ stack.ControlType, extra uint32, pkt *stack e.dispatcher.DeliverTransportControlPacket(src, hdr.DestinationAddress(), ProtocolNumber, p, typ, extra, pkt) } +// getLinkAddrOption searches NDP options for a given link address option using +// the provided getAddr function as a filter. Returns the link address if +// found; otherwise, returns the zero link address value. Also returns true if +// the options are valid as per the wire format, false otherwise. +func getLinkAddrOption(it header.NDPOptionIterator, getAddr func(header.NDPOption) tcpip.LinkAddress) (tcpip.LinkAddress, bool) { + var linkAddr tcpip.LinkAddress + for { + opt, done, err := it.Next() + if err != nil { + return "", false + } + if done { + break + } + if addr := getAddr(opt); len(addr) != 0 { + // No RFCs define what to do when an NDP message has multiple Link-Layer + // Address options. Since no interface can have multiple link-layer + // addresses, we consider such messages invalid. + if len(linkAddr) != 0 { + return "", false + } + linkAddr = addr + } + } + return linkAddr, true +} + +// getSourceLinkAddr searches NDP options for the source link address option. +// Returns the link address if found; otherwise, returns the zero link address +// value. Also returns true if the options are valid as per the wire format, +// false otherwise. +func getSourceLinkAddr(it header.NDPOptionIterator) (tcpip.LinkAddress, bool) { + return getLinkAddrOption(it, func(opt header.NDPOption) tcpip.LinkAddress { + if src, ok := opt.(header.NDPSourceLinkLayerAddressOption); ok { + return src.EthernetAddress() + } + return "" + }) +} + +// getTargetLinkAddr searches NDP options for the target link address option. +// Returns the link address if found; otherwise, returns the zero link address +// value. Also returns true if the options are valid as per the wire format, +// false otherwise. +func getTargetLinkAddr(it header.NDPOptionIterator) (tcpip.LinkAddress, bool) { + return getLinkAddrOption(it, func(opt header.NDPOption) tcpip.LinkAddress { + if dst, ok := opt.(header.NDPTargetLinkLayerAddressOption); ok { + return dst.EthernetAddress() + } + return "" + }) +} + func (e *endpoint) handleICMP(r *stack.Route, pkt *stack.PacketBuffer, hasFragmentHeader bool) { stats := r.Stats().ICMP sent := stats.V6PacketsSent @@ -137,7 +188,7 @@ func (e *endpoint) handleICMP(r *stack.Route, pkt *stack.PacketBuffer, hasFragme case header.ICMPv6NeighborSolicit: received.NeighborSolicit.Increment() - if pkt.Data.Size() < header.ICMPv6NeighborSolicitMinimumSize || !isNDPValid() { + if !isNDPValid() || pkt.Data.Size() < header.ICMPv6NeighborSolicitMinimumSize { received.Invalid.Increment() return } @@ -147,14 +198,15 @@ func (e *endpoint) handleICMP(r *stack.Route, pkt *stack.PacketBuffer, hasFragme // NDP messages cannot be fragmented. Also note that in the common case NDP // datagrams are very small and ToView() will not incur allocations. ns := header.NDPNeighborSolicit(payload.ToView()) - it, err := ns.Options().Iter(true) - if err != nil { - // If we have a malformed NDP NS option, drop the packet. + targetAddr := ns.TargetAddress() + + // As per RFC 4861 section 4.3, the Target Address MUST NOT be a multicast + // address. + if header.IsV6MulticastAddress(targetAddr) { received.Invalid.Increment() return } - targetAddr := ns.TargetAddress() s := r.Stack() if isTentative, err := s.IsAddrTentative(e.nicID, targetAddr); err != nil { // We will only get an error if the NIC is unrecognized, which should not @@ -187,39 +239,22 @@ func (e *endpoint) handleICMP(r *stack.Route, pkt *stack.PacketBuffer, hasFragme // so the packet is processed as defined in RFC 4861, as per RFC 4862 // section 5.4.3. - // Is the NS targetting us? - if e.linkAddrCache.CheckLocalAddress(e.nicID, ProtocolNumber, targetAddr) == 0 { + // Is the NS targeting us? + if s.CheckLocalAddress(e.nicID, ProtocolNumber, targetAddr) == 0 { return } - // If the NS message contains the Source Link-Layer Address option, update - // the link address cache with the value of the option. - // - // TODO(b/148429853): Properly process the NS message and do Neighbor - // Unreachability Detection. - var sourceLinkAddr tcpip.LinkAddress - for { - opt, done, err := it.Next() - if err != nil { - // This should never happen as Iter(true) above did not return an error. - panic(fmt.Sprintf("unexpected error when iterating over NDP options: %s", err)) - } - if done { - break - } + it, err := ns.Options().Iter(false /* check */) + if err != nil { + // Options are not valid as per the wire format, silently drop the packet. + received.Invalid.Increment() + return + } - switch opt := opt.(type) { - case header.NDPSourceLinkLayerAddressOption: - // No RFCs define what to do when an NS message has multiple Source - // Link-Layer Address options. Since no interface can have multiple - // link-layer addresses, we consider such messages invalid. - if len(sourceLinkAddr) != 0 { - received.Invalid.Increment() - return - } - - sourceLinkAddr = opt.EthernetAddress() - } + sourceLinkAddr, ok := getSourceLinkAddr(it) + if !ok { + received.Invalid.Increment() + return } unspecifiedSource := r.RemoteAddress == header.IPv6Any @@ -237,6 +272,8 @@ func (e *endpoint) handleICMP(r *stack.Route, pkt *stack.PacketBuffer, hasFragme } else if unspecifiedSource { received.Invalid.Increment() return + } else if e.nud != nil { + e.nud.HandleProbe(r.RemoteAddress, r.LocalAddress, header.IPv6ProtocolNumber, sourceLinkAddr, e.protocol) } else { e.linkAddrCache.AddLinkAddress(e.nicID, r.RemoteAddress, sourceLinkAddr) } @@ -304,7 +341,7 @@ func (e *endpoint) handleICMP(r *stack.Route, pkt *stack.PacketBuffer, hasFragme case header.ICMPv6NeighborAdvert: received.NeighborAdvert.Increment() - if pkt.Data.Size() < header.ICMPv6NeighborAdvertSize || !isNDPValid() { + if !isNDPValid() || pkt.Data.Size() < header.ICMPv6NeighborAdvertSize { received.Invalid.Increment() return } @@ -314,17 +351,10 @@ func (e *endpoint) handleICMP(r *stack.Route, pkt *stack.PacketBuffer, hasFragme // 5, NDP messages cannot be fragmented. Also note that in the common case // NDP datagrams are very small and ToView() will not incur allocations. na := header.NDPNeighborAdvert(payload.ToView()) - it, err := na.Options().Iter(true) - if err != nil { - // If we have a malformed NDP NA option, drop the packet. - received.Invalid.Increment() - return - } - targetAddr := na.TargetAddress() - stack := r.Stack() + s := r.Stack() - if isTentative, err := stack.IsAddrTentative(e.nicID, targetAddr); err != nil { + if isTentative, err := s.IsAddrTentative(e.nicID, targetAddr); err != nil { // We will only get an error if the NIC is unrecognized, which should not // happen. For now short-circuit this packet. // @@ -335,7 +365,14 @@ func (e *endpoint) handleICMP(r *stack.Route, pkt *stack.PacketBuffer, hasFragme // DAD on, implying the address is not unique. In this case we let the // stack know so it can handle such a scenario and do nothing furthur with // the NDP NA. - stack.DupTentativeAddrDetected(e.nicID, targetAddr) + s.DupTentativeAddrDetected(e.nicID, targetAddr) + return + } + + it, err := na.Options().Iter(false /* check */) + if err != nil { + // If we have a malformed NDP NA option, drop the packet. + received.Invalid.Increment() return } @@ -348,39 +385,25 @@ func (e *endpoint) handleICMP(r *stack.Route, pkt *stack.PacketBuffer, hasFragme // TODO(b/143147598): Handle the scenario described above. Also inform the // netstack integration that a duplicate address was detected outside of // DAD. + targetLinkAddr, ok := getTargetLinkAddr(it) + if !ok { + received.Invalid.Increment() + return + } // If the NA message has the target link layer option, update the link // address cache with the link address for the target of the message. - // - // TODO(b/148429853): Properly process the NA message and do Neighbor - // Unreachability Detection. - var targetLinkAddr tcpip.LinkAddress - for { - opt, done, err := it.Next() - if err != nil { - // This should never happen as Iter(true) above did not return an error. - panic(fmt.Sprintf("unexpected error when iterating over NDP options: %s", err)) - } - if done { - break + if len(targetLinkAddr) != 0 { + if e.nud == nil { + e.linkAddrCache.AddLinkAddress(e.nicID, targetAddr, targetLinkAddr) + return } - switch opt := opt.(type) { - case header.NDPTargetLinkLayerAddressOption: - // No RFCs define what to do when an NA message has multiple Target - // Link-Layer Address options. Since no interface can have multiple - // link-layer addresses, we consider such messages invalid. - if len(targetLinkAddr) != 0 { - received.Invalid.Increment() - return - } - - targetLinkAddr = opt.EthernetAddress() - } - } - - if len(targetLinkAddr) != 0 { - e.linkAddrCache.AddLinkAddress(e.nicID, targetAddr, targetLinkAddr) + e.nud.HandleConfirmation(targetAddr, targetLinkAddr, stack.ReachabilityConfirmationFlags{ + Solicited: na.SolicitedFlag(), + Override: na.OverrideFlag(), + IsRouter: na.RouterFlag(), + }) } case header.ICMPv6EchoRequest: @@ -440,27 +463,75 @@ func (e *endpoint) handleICMP(r *stack.Route, pkt *stack.PacketBuffer, hasFragme case header.ICMPv6RouterSolicit: received.RouterSolicit.Increment() - if !isNDPValid() { + + // + // Validate the RS as per RFC 4861 section 6.1.1. + // + + // Is the NDP payload of sufficient size to hold a Router Solictation? + if !isNDPValid() || pkt.Data.Size()-header.ICMPv6HeaderSize < header.NDPRSMinimumSize { received.Invalid.Increment() return } - case header.ICMPv6RouterAdvert: - received.RouterAdvert.Increment() + stack := r.Stack() - // Is the NDP payload of sufficient size to hold a Router - // Advertisement? - if pkt.Data.Size()-header.ICMPv6HeaderSize < header.NDPRAMinimumSize || !isNDPValid() { + // Is the networking stack operating as a router? + if !stack.Forwarding(ProtocolNumber) { + // ... No, silently drop the packet. + received.RouterOnlyPacketsDroppedByHost.Increment() + return + } + + // Note that in the common case NDP datagrams are very small and ToView() + // will not incur allocations. + rs := header.NDPRouterSolicit(payload.ToView()) + it, err := rs.Options().Iter(false /* check */) + if err != nil { + // Options are not valid as per the wire format, silently drop the packet. received.Invalid.Increment() return } - routerAddr := iph.SourceAddress() + sourceLinkAddr, ok := getSourceLinkAddr(it) + if !ok { + received.Invalid.Increment() + return + } + + // If the RS message has the source link layer option, update the link + // address cache with the link address for the source of the message. + if len(sourceLinkAddr) != 0 { + // As per RFC 4861 section 4.1, the Source Link-Layer Address Option MUST + // NOT be included when the source IP address is the unspecified address. + // Otherwise, it SHOULD be included on link layers that have addresses. + if r.RemoteAddress == header.IPv6Any { + received.Invalid.Increment() + return + } + + if e.nud != nil { + // A RS with a specified source IP address modifies the NUD state + // machine in the same way a reachability probe would. + e.nud.HandleProbe(r.RemoteAddress, r.LocalAddress, header.IPv6ProtocolNumber, sourceLinkAddr, e.protocol) + } + } + + case header.ICMPv6RouterAdvert: + received.RouterAdvert.Increment() // // Validate the RA as per RFC 4861 section 6.1.2. // + // Is the NDP payload of sufficient size to hold a Router Advertisement? + if !isNDPValid() || pkt.Data.Size()-header.ICMPv6HeaderSize < header.NDPRAMinimumSize { + received.Invalid.Increment() + return + } + + routerAddr := iph.SourceAddress() + // Is the IP Source Address a link-local address? if !header.IsV6LinkLocalAddress(routerAddr) { // ...No, silently drop the packet. @@ -468,16 +539,18 @@ func (e *endpoint) handleICMP(r *stack.Route, pkt *stack.PacketBuffer, hasFragme return } - // The remainder of payload must be only the router advertisement, so - // payload.ToView() always returns the advertisement. Per RFC 6980 section - // 5, NDP messages cannot be fragmented. Also note that in the common case - // NDP datagrams are very small and ToView() will not incur allocations. + // Note that in the common case NDP datagrams are very small and ToView() + // will not incur allocations. ra := header.NDPRouterAdvert(payload.ToView()) - opts := ra.Options() + it, err := ra.Options().Iter(false /* check */) + if err != nil { + // Options are not valid as per the wire format, silently drop the packet. + received.Invalid.Increment() + return + } - // Are options valid as per the wire format? - if _, err := opts.Iter(true); err != nil { - // ...No, silently drop the packet. + sourceLinkAddr, ok := getSourceLinkAddr(it) + if !ok { received.Invalid.Increment() return } @@ -487,12 +560,33 @@ func (e *endpoint) handleICMP(r *stack.Route, pkt *stack.PacketBuffer, hasFragme // as RFC 4861 section 6.1.2 is concerned. // + // If the RA has the source link layer option, update the link address + // cache with the link address for the advertised router. + if len(sourceLinkAddr) != 0 && e.nud != nil { + e.nud.HandleProbe(routerAddr, r.LocalAddress, header.IPv6ProtocolNumber, sourceLinkAddr, e.protocol) + } + // Tell the NIC to handle the RA. stack := r.Stack() - rxNICID := r.NICID() - stack.HandleNDPRA(rxNICID, routerAddr, ra) + stack.HandleNDPRA(e.nicID, routerAddr, ra) case header.ICMPv6RedirectMsg: + // TODO(gvisor.dev/issue/2285): Call `e.nud.HandleProbe` after validating + // this redirect message, as per RFC 4871 section 7.3.3: + // + // "A Neighbor Cache entry enters the STALE state when created as a + // result of receiving packets other than solicited Neighbor + // Advertisements (i.e., Router Solicitations, Router Advertisements, + // Redirects, and Neighbor Solicitations). These packets contain the + // link-layer address of either the sender or, in the case of Redirect, + // the redirection target. However, receipt of these link-layer + // addresses does not confirm reachability of the forward-direction path + // to that node. Placing a newly created Neighbor Cache entry for which + // the link-layer address is known in the STALE state provides assurance + // that path failures are detected quickly. In addition, should a cached + // link-layer address be modified due to receiving one of the above + // messages, the state SHOULD also be set to STALE to provide prompt + // verification that the path to the new link-layer address is working." received.RedirectMsg.Increment() if !isNDPValid() { received.Invalid.Increment() diff --git a/pkg/tcpip/network/ipv6/icmp_test.go b/pkg/tcpip/network/ipv6/icmp_test.go index 9e4eeea77..0f50bfb8e 100644 --- a/pkg/tcpip/network/ipv6/icmp_test.go +++ b/pkg/tcpip/network/ipv6/icmp_test.go @@ -31,6 +31,8 @@ import ( ) const ( + nicID = 1 + linkAddr0 = tcpip.LinkAddress("\x02\x02\x03\x04\x05\x06") linkAddr1 = tcpip.LinkAddress("\x0a\x0b\x0c\x0d\x0e\x0e") linkAddr2 = tcpip.LinkAddress("\x0a\x0b\x0c\x0d\x0e\x0f") @@ -49,7 +51,10 @@ type stubLinkEndpoint struct { } func (*stubLinkEndpoint) Capabilities() stack.LinkEndpointCapabilities { - return 0 + // Indicate that resolution for link layer addresses is required to send + // packets over this link. This is needed so the NIC knows to allocate a + // neighbor table. + return stack.CapabilityResolutionRequired } func (*stubLinkEndpoint) MaxHeaderLength() uint16 { @@ -84,16 +89,184 @@ func (*stubLinkAddressCache) CheckLocalAddress(tcpip.NICID, tcpip.NetworkProtoco func (*stubLinkAddressCache) AddLinkAddress(tcpip.NICID, tcpip.Address, tcpip.LinkAddress) { } +type stubNUDHandler struct{} + +var _ stack.NUDHandler = (*stubNUDHandler)(nil) + +func (*stubNUDHandler) HandleProbe(remoteAddr, localAddr tcpip.Address, protocol tcpip.NetworkProtocolNumber, remoteLinkAddr tcpip.LinkAddress, linkRes stack.LinkAddressResolver) { +} + +func (*stubNUDHandler) HandleConfirmation(addr tcpip.Address, linkAddr tcpip.LinkAddress, flags stack.ReachabilityConfirmationFlags) { +} + +func (*stubNUDHandler) HandleUpperLevelConfirmation(addr tcpip.Address) { +} + func TestICMPCounts(t *testing.T) { + tests := []struct { + name string + useNeighborCache bool + }{ + { + name: "linkAddrCache", + useNeighborCache: false, + }, + { + name: "neighborCache", + useNeighborCache: true, + }, + } + + for _, test := range tests { + t.Run(test.name, func(t *testing.T) { + s := stack.New(stack.Options{ + NetworkProtocols: []stack.NetworkProtocol{NewProtocol()}, + TransportProtocols: []stack.TransportProtocol{icmp.NewProtocol6()}, + UseNeighborCache: test.useNeighborCache, + }) + { + if err := s.CreateNIC(nicID, &stubLinkEndpoint{}); err != nil { + t.Fatalf("CreateNIC(_, _) = %s", err) + } + if err := s.AddAddress(nicID, ProtocolNumber, lladdr0); err != nil { + t.Fatalf("AddAddress(_, %d, %s) = %s", ProtocolNumber, lladdr0, err) + } + } + { + subnet, err := tcpip.NewSubnet(lladdr1, tcpip.AddressMask(strings.Repeat("\xff", len(lladdr1)))) + if err != nil { + t.Fatal(err) + } + s.SetRouteTable( + []tcpip.Route{{ + Destination: subnet, + NIC: nicID, + }}, + ) + } + + netProto := s.NetworkProtocolInstance(ProtocolNumber) + if netProto == nil { + t.Fatalf("cannot find protocol instance for network protocol %d", ProtocolNumber) + } + ep := netProto.NewEndpoint(0, &stubLinkAddressCache{}, &stubNUDHandler{}, &stubDispatcher{}, nil, s) + defer ep.Close() + + r, err := s.FindRoute(nicID, lladdr0, lladdr1, ProtocolNumber, false /* multicastLoop */) + if err != nil { + t.Fatalf("FindRoute(%d, %s, %s, _, false) = (_, %s), want = (_, nil)", nicID, lladdr0, lladdr1, err) + } + defer r.Release() + + var tllData [header.NDPLinkLayerAddressSize]byte + header.NDPOptions(tllData[:]).Serialize(header.NDPOptionsSerializer{ + header.NDPTargetLinkLayerAddressOption(linkAddr1), + }) + + types := []struct { + typ header.ICMPv6Type + size int + extraData []byte + }{ + { + typ: header.ICMPv6DstUnreachable, + size: header.ICMPv6DstUnreachableMinimumSize, + }, + { + typ: header.ICMPv6PacketTooBig, + size: header.ICMPv6PacketTooBigMinimumSize, + }, + { + typ: header.ICMPv6TimeExceeded, + size: header.ICMPv6MinimumSize, + }, + { + typ: header.ICMPv6ParamProblem, + size: header.ICMPv6MinimumSize, + }, + { + typ: header.ICMPv6EchoRequest, + size: header.ICMPv6EchoMinimumSize, + }, + { + typ: header.ICMPv6EchoReply, + size: header.ICMPv6EchoMinimumSize, + }, + { + typ: header.ICMPv6RouterSolicit, + size: header.ICMPv6MinimumSize, + }, + { + typ: header.ICMPv6RouterAdvert, + size: header.ICMPv6HeaderSize + header.NDPRAMinimumSize, + }, + { + typ: header.ICMPv6NeighborSolicit, + size: header.ICMPv6NeighborSolicitMinimumSize, + }, + { + typ: header.ICMPv6NeighborAdvert, + size: header.ICMPv6NeighborAdvertMinimumSize, + extraData: tllData[:], + }, + { + typ: header.ICMPv6RedirectMsg, + size: header.ICMPv6MinimumSize, + }, + } + + handleIPv6Payload := func(icmp header.ICMPv6) { + pkt := stack.NewPacketBuffer(stack.PacketBufferOptions{ + ReserveHeaderBytes: header.IPv6MinimumSize, + Data: buffer.View(icmp).ToVectorisedView(), + }) + ip := header.IPv6(pkt.NetworkHeader().Push(header.IPv6MinimumSize)) + ip.Encode(&header.IPv6Fields{ + PayloadLength: uint16(len(icmp)), + NextHeader: uint8(header.ICMPv6ProtocolNumber), + HopLimit: header.NDPHopLimit, + SrcAddr: r.LocalAddress, + DstAddr: r.RemoteAddress, + }) + ep.HandlePacket(&r, pkt) + } + + for _, typ := range types { + icmp := header.ICMPv6(buffer.NewView(typ.size + len(typ.extraData))) + copy(icmp[typ.size:], typ.extraData) + icmp.SetType(typ.typ) + icmp.SetChecksum(header.ICMPv6Checksum(icmp[:typ.size], r.LocalAddress, r.RemoteAddress, buffer.View(typ.extraData).ToVectorisedView())) + handleIPv6Payload(icmp) + } + + // Construct an empty ICMP packet so that + // Stats().ICMP.ICMPv6ReceivedPacketStats.Invalid is incremented. + handleIPv6Payload(header.ICMPv6(buffer.NewView(header.IPv6MinimumSize))) + + icmpv6Stats := s.Stats().ICMP.V6PacketsReceived + visitStats(reflect.ValueOf(&icmpv6Stats).Elem(), func(name string, s *tcpip.StatCounter) { + if got, want := s.Value(), uint64(1); got != want { + t.Errorf("got %s = %d, want = %d", name, got, want) + } + }) + if t.Failed() { + t.Logf("stats:\n%+v", s.Stats()) + } + }) + } +} + +func TestICMPCountsWithNeighborCache(t *testing.T) { s := stack.New(stack.Options{ NetworkProtocols: []stack.NetworkProtocol{NewProtocol()}, TransportProtocols: []stack.TransportProtocol{icmp.NewProtocol6()}, + UseNeighborCache: true, }) { - if err := s.CreateNIC(1, &stubLinkEndpoint{}); err != nil { - t.Fatalf("CreateNIC(_) = %s", err) + if err := s.CreateNIC(nicID, &stubLinkEndpoint{}); err != nil { + t.Fatalf("CreateNIC(_, _) = %s", err) } - if err := s.AddAddress(1, ProtocolNumber, lladdr0); err != nil { + if err := s.AddAddress(nicID, ProtocolNumber, lladdr0); err != nil { t.Fatalf("AddAddress(_, %d, %s) = %s", ProtocolNumber, lladdr0, err) } } @@ -105,7 +278,7 @@ func TestICMPCounts(t *testing.T) { s.SetRouteTable( []tcpip.Route{{ Destination: subnet, - NIC: 1, + NIC: nicID, }}, ) } @@ -114,12 +287,12 @@ func TestICMPCounts(t *testing.T) { if netProto == nil { t.Fatalf("cannot find protocol instance for network protocol %d", ProtocolNumber) } - ep := netProto.NewEndpoint(0, &stubLinkAddressCache{}, &stubDispatcher{}, nil, s) + ep := netProto.NewEndpoint(0, nil, &stubNUDHandler{}, &stubDispatcher{}, nil, s) defer ep.Close() - r, err := s.FindRoute(1, lladdr0, lladdr1, ProtocolNumber, false /* multicastLoop */) + r, err := s.FindRoute(nicID, lladdr0, lladdr1, ProtocolNumber, false /* multicastLoop */) if err != nil { - t.Fatalf("FindRoute(_) = _, %s, want = _, nil", err) + t.Fatalf("FindRoute(%d, %s, %s, _, false) = (_, %s), want = (_, nil)", nicID, lladdr0, lladdr1, err) } defer r.Release() @@ -265,19 +438,19 @@ func newTestContext(t *testing.T) *testContext { if testing.Verbose() { wrappedEP0 = sniffer.New(wrappedEP0) } - if err := c.s0.CreateNIC(1, wrappedEP0); err != nil { + if err := c.s0.CreateNIC(nicID, wrappedEP0); err != nil { t.Fatalf("CreateNIC s0: %v", err) } - if err := c.s0.AddAddress(1, ProtocolNumber, lladdr0); err != nil { + if err := c.s0.AddAddress(nicID, ProtocolNumber, lladdr0); err != nil { t.Fatalf("AddAddress lladdr0: %v", err) } c.linkEP1 = channel.New(defaultChannelSize, defaultMTU, linkAddr1) wrappedEP1 := stack.LinkEndpoint(endpointWithResolutionCapability{LinkEndpoint: c.linkEP1}) - if err := c.s1.CreateNIC(1, wrappedEP1); err != nil { + if err := c.s1.CreateNIC(nicID, wrappedEP1); err != nil { t.Fatalf("CreateNIC failed: %v", err) } - if err := c.s1.AddAddress(1, ProtocolNumber, lladdr1); err != nil { + if err := c.s1.AddAddress(nicID, ProtocolNumber, lladdr1); err != nil { t.Fatalf("AddAddress lladdr1: %v", err) } @@ -288,7 +461,7 @@ func newTestContext(t *testing.T) *testContext { c.s0.SetRouteTable( []tcpip.Route{{ Destination: subnet0, - NIC: 1, + NIC: nicID, }}, ) subnet1, err := tcpip.NewSubnet(lladdr0, tcpip.AddressMask(strings.Repeat("\xff", len(lladdr0)))) @@ -298,7 +471,7 @@ func newTestContext(t *testing.T) *testContext { c.s1.SetRouteTable( []tcpip.Route{{ Destination: subnet1, - NIC: 1, + NIC: nicID, }}, ) @@ -359,9 +532,9 @@ func TestLinkResolution(t *testing.T) { c := newTestContext(t) defer c.cleanup() - r, err := c.s0.FindRoute(1, lladdr0, lladdr1, ProtocolNumber, false /* multicastLoop */) + r, err := c.s0.FindRoute(nicID, lladdr0, lladdr1, ProtocolNumber, false /* multicastLoop */) if err != nil { - t.Fatalf("FindRoute(_) = _, %s, want = _, nil", err) + t.Fatalf("FindRoute(%d, %s, %s, _, false) = (_, %s), want = (_, nil)", nicID, lladdr0, lladdr1, err) } defer r.Release() @@ -376,14 +549,14 @@ func TestLinkResolution(t *testing.T) { var wq waiter.Queue ep, err := c.s0.NewEndpoint(header.ICMPv6ProtocolNumber, ProtocolNumber, &wq) if err != nil { - t.Fatalf("NewEndpoint(_) = _, %s, want = _, nil", err) + t.Fatalf("NewEndpoint(_) = (_, %s), want = (_, nil)", err) } for { - _, resCh, err := ep.Write(payload, tcpip.WriteOptions{To: &tcpip.FullAddress{NIC: 1, Addr: lladdr1}}) + _, resCh, err := ep.Write(payload, tcpip.WriteOptions{To: &tcpip.FullAddress{NIC: nicID, Addr: lladdr1}}) if resCh != nil { if err != tcpip.ErrNoLinkAddress { - t.Fatalf("ep.Write(_) = _, <non-nil>, %s, want = _, <non-nil>, tcpip.ErrNoLinkAddress", err) + t.Fatalf("ep.Write(_) = (_, <non-nil>, %s), want = (_, <non-nil>, tcpip.ErrNoLinkAddress)", err) } for _, args := range []routeArgs{ {src: c.linkEP0, dst: c.linkEP1, typ: header.ICMPv6NeighborSolicit, remoteLinkAddr: header.EthernetAddressFromMulticastIPv6Address(header.SolicitedNodeAddr(lladdr1))}, @@ -399,7 +572,7 @@ func TestLinkResolution(t *testing.T) { continue } if err != nil { - t.Fatalf("ep.Write(_) = _, _, %s", err) + t.Fatalf("ep.Write(_) = (_, _, %s)", err) } break } @@ -424,6 +597,7 @@ func TestICMPChecksumValidationSimple(t *testing.T) { size int extraData []byte statCounter func(tcpip.ICMPv6ReceivedPacketStats) *tcpip.StatCounter + routerOnly bool }{ { name: "DstUnreachable", @@ -480,6 +654,8 @@ func TestICMPChecksumValidationSimple(t *testing.T) { statCounter: func(stats tcpip.ICMPv6ReceivedPacketStats) *tcpip.StatCounter { return stats.RouterSolicit }, + // Hosts MUST silently discard any received Router Solicitation messages. + routerOnly: true, }, { name: "RouterAdvert", @@ -516,84 +692,133 @@ func TestICMPChecksumValidationSimple(t *testing.T) { }, } - for _, typ := range types { - t.Run(typ.name, func(t *testing.T) { - e := channel.New(10, 1280, linkAddr0) - s := stack.New(stack.Options{ - NetworkProtocols: []stack.NetworkProtocol{NewProtocol()}, - }) - if err := s.CreateNIC(1, e); err != nil { - t.Fatalf("CreateNIC(_) = %s", err) - } - - if err := s.AddAddress(1, ProtocolNumber, lladdr0); err != nil { - t.Fatalf("AddAddress(_, %d, %s) = %s", ProtocolNumber, lladdr0, err) - } - { - subnet, err := tcpip.NewSubnet(lladdr1, tcpip.AddressMask(strings.Repeat("\xff", len(lladdr1)))) - if err != nil { - t.Fatal(err) - } - s.SetRouteTable( - []tcpip.Route{{ - Destination: subnet, - NIC: 1, - }}, - ) - } + tests := []struct { + name string + useNeighborCache bool + }{ + { + name: "linkAddrCache", + useNeighborCache: false, + }, + { + name: "neighborCache", + useNeighborCache: true, + }, + } - handleIPv6Payload := func(checksum bool) { - icmp := header.ICMPv6(buffer.NewView(typ.size + len(typ.extraData))) - copy(icmp[typ.size:], typ.extraData) - icmp.SetType(typ.typ) - if checksum { - icmp.SetChecksum(header.ICMPv6Checksum(icmp, lladdr1, lladdr0, buffer.View{}.ToVectorisedView())) + for _, test := range tests { + t.Run(test.name, func(t *testing.T) { + for _, typ := range types { + for _, isRouter := range []bool{false, true} { + name := typ.name + if isRouter { + name += " (Router)" + } + t.Run(name, func(t *testing.T) { + e := channel.New(0, 1280, linkAddr0) + + // Indicate that resolution for link layer addresses is required to + // send packets over this link. This is needed so the NIC knows to + // allocate a neighbor table. + e.LinkEPCapabilities |= stack.CapabilityResolutionRequired + + s := stack.New(stack.Options{ + NetworkProtocols: []stack.NetworkProtocol{NewProtocol()}, + UseNeighborCache: test.useNeighborCache, + }) + if isRouter { + // Enabling forwarding makes the stack act as a router. + s.SetForwarding(ProtocolNumber, true) + } + if err := s.CreateNIC(nicID, e); err != nil { + t.Fatalf("CreateNIC(_, _) = %s", err) + } + + if err := s.AddAddress(nicID, ProtocolNumber, lladdr0); err != nil { + t.Fatalf("AddAddress(_, %d, %s) = %s", ProtocolNumber, lladdr0, err) + } + { + subnet, err := tcpip.NewSubnet(lladdr1, tcpip.AddressMask(strings.Repeat("\xff", len(lladdr1)))) + if err != nil { + t.Fatal(err) + } + s.SetRouteTable( + []tcpip.Route{{ + Destination: subnet, + NIC: nicID, + }}, + ) + } + + handleIPv6Payload := func(checksum bool) { + icmp := header.ICMPv6(buffer.NewView(typ.size + len(typ.extraData))) + copy(icmp[typ.size:], typ.extraData) + icmp.SetType(typ.typ) + if checksum { + icmp.SetChecksum(header.ICMPv6Checksum(icmp, lladdr1, lladdr0, buffer.View{}.ToVectorisedView())) + } + ip := header.IPv6(buffer.NewView(header.IPv6MinimumSize)) + ip.Encode(&header.IPv6Fields{ + PayloadLength: uint16(len(icmp)), + NextHeader: uint8(header.ICMPv6ProtocolNumber), + HopLimit: header.NDPHopLimit, + SrcAddr: lladdr1, + DstAddr: lladdr0, + }) + pkt := stack.NewPacketBuffer(stack.PacketBufferOptions{ + Data: buffer.NewVectorisedView(len(ip)+len(icmp), []buffer.View{buffer.View(ip), buffer.View(icmp)}), + }) + e.InjectInbound(ProtocolNumber, pkt) + } + + stats := s.Stats().ICMP.V6PacketsReceived + invalid := stats.Invalid + routerOnly := stats.RouterOnlyPacketsDroppedByHost + typStat := typ.statCounter(stats) + + // Initial stat counts should be 0. + if got := invalid.Value(); got != 0 { + t.Fatalf("got invalid = %d, want = 0", got) + } + if got := routerOnly.Value(); got != 0 { + t.Fatalf("got RouterOnlyPacketsReceivedByHost = %d, want = 0", got) + } + if got := typStat.Value(); got != 0 { + t.Fatalf("got %s = %d, want = 0", typ.name, got) + } + + // Without setting checksum, the incoming packet should + // be invalid. + handleIPv6Payload(false) + if got := invalid.Value(); got != 1 { + t.Fatalf("got invalid = %d, want = 1", got) + } + // Router only count should not have increased. + if got := routerOnly.Value(); got != 0 { + t.Fatalf("got RouterOnlyPacketsReceivedByHost = %d, want = 0", got) + } + // Rx count of type typ.typ should not have increased. + if got := typStat.Value(); got != 0 { + t.Fatalf("got %s = %d, want = 0", typ.name, got) + } + + // When checksum is set, it should be received. + handleIPv6Payload(true) + if got := typStat.Value(); got != 1 { + t.Fatalf("got %s = %d, want = 1", typ.name, got) + } + // Invalid count should not have increased again. + if got := invalid.Value(); got != 1 { + t.Fatalf("got invalid = %d, want = 1", got) + } + if !isRouter && typ.routerOnly && test.useNeighborCache { + // Router only count should have increased. + if got := routerOnly.Value(); got != 1 { + t.Fatalf("got RouterOnlyPacketsReceivedByHost = %d, want = 1", got) + } + } + }) } - ip := header.IPv6(buffer.NewView(header.IPv6MinimumSize)) - ip.Encode(&header.IPv6Fields{ - PayloadLength: uint16(len(icmp)), - NextHeader: uint8(header.ICMPv6ProtocolNumber), - HopLimit: header.NDPHopLimit, - SrcAddr: lladdr1, - DstAddr: lladdr0, - }) - pkt := stack.NewPacketBuffer(stack.PacketBufferOptions{ - Data: buffer.NewVectorisedView(len(ip)+len(icmp), []buffer.View{buffer.View(ip), buffer.View(icmp)}), - }) - e.InjectInbound(ProtocolNumber, pkt) - } - - stats := s.Stats().ICMP.V6PacketsReceived - invalid := stats.Invalid - typStat := typ.statCounter(stats) - - // Initial stat counts should be 0. - if got := invalid.Value(); got != 0 { - t.Fatalf("got invalid = %d, want = 0", got) - } - if got := typStat.Value(); got != 0 { - t.Fatalf("got %s = %d, want = 0", typ.name, got) - } - - // Without setting checksum, the incoming packet should - // be invalid. - handleIPv6Payload(false) - if got := invalid.Value(); got != 1 { - t.Fatalf("got invalid = %d, want = 1", got) - } - // Rx count of type typ.typ should not have increased. - if got := typStat.Value(); got != 0 { - t.Fatalf("got %s = %d, want = 0", typ.name, got) - } - - // When checksum is set, it should be received. - handleIPv6Payload(true) - if got := typStat.Value(); got != 1 { - t.Fatalf("got %s = %d, want = 1", typ.name, got) - } - // Invalid count should not have increased again. - if got := invalid.Value(); got != 1 { - t.Fatalf("got invalid = %d, want = 1", got) } }) } @@ -696,11 +921,11 @@ func TestICMPChecksumValidationWithPayload(t *testing.T) { s := stack.New(stack.Options{ NetworkProtocols: []stack.NetworkProtocol{NewProtocol()}, }) - if err := s.CreateNIC(1, e); err != nil { - t.Fatalf("CreateNIC(_) = %s", err) + if err := s.CreateNIC(nicID, e); err != nil { + t.Fatalf("CreateNIC(_, _) = %s", err) } - if err := s.AddAddress(1, ProtocolNumber, lladdr0); err != nil { + if err := s.AddAddress(nicID, ProtocolNumber, lladdr0); err != nil { t.Fatalf("AddAddress(_, %d, %s) = %s", ProtocolNumber, lladdr0, err) } { @@ -711,7 +936,7 @@ func TestICMPChecksumValidationWithPayload(t *testing.T) { s.SetRouteTable( []tcpip.Route{{ Destination: subnet, - NIC: 1, + NIC: nicID, }}, ) } @@ -750,7 +975,7 @@ func TestICMPChecksumValidationWithPayload(t *testing.T) { t.Fatalf("got invalid = %d, want = 0", got) } if got := typStat.Value(); got != 0 { - t.Fatalf("got %s = %d, want = 0", typ.name, got) + t.Fatalf("got = %d, want = 0", got) } // Without setting checksum, the incoming packet should @@ -761,13 +986,13 @@ func TestICMPChecksumValidationWithPayload(t *testing.T) { } // Rx count of type typ.typ should not have increased. if got := typStat.Value(); got != 0 { - t.Fatalf("got %s = %d, want = 0", typ.name, got) + t.Fatalf("got = %d, want = 0", got) } // When checksum is set, it should be received. handleIPv6Payload(typ.typ, typ.size, typ.payloadSize, typ.payload, true) if got := typStat.Value(); got != 1 { - t.Fatalf("got %s = %d, want = 1", typ.name, got) + t.Fatalf("got = %d, want = 0", got) } // Invalid count should not have increased again. if got := invalid.Value(); got != 1 { @@ -874,12 +1099,12 @@ func TestICMPChecksumValidationWithPayloadMultipleViews(t *testing.T) { s := stack.New(stack.Options{ NetworkProtocols: []stack.NetworkProtocol{NewProtocol()}, }) - if err := s.CreateNIC(1, e); err != nil { - t.Fatalf("CreateNIC(_) = %s", err) + if err := s.CreateNIC(nicID, e); err != nil { + t.Fatalf("CreateNIC(%d, _) = %s", nicID, err) } - if err := s.AddAddress(1, ProtocolNumber, lladdr0); err != nil { - t.Fatalf("AddAddress(_, %d, %s) = %s", ProtocolNumber, lladdr0, err) + if err := s.AddAddress(nicID, ProtocolNumber, lladdr0); err != nil { + t.Fatalf("AddAddress(%d, %d, %s) = %s", nicID, ProtocolNumber, lladdr0, err) } { subnet, err := tcpip.NewSubnet(lladdr1, tcpip.AddressMask(strings.Repeat("\xff", len(lladdr1)))) @@ -889,7 +1114,7 @@ func TestICMPChecksumValidationWithPayloadMultipleViews(t *testing.T) { s.SetRouteTable( []tcpip.Route{{ Destination: subnet, - NIC: 1, + NIC: nicID, }}, ) } @@ -929,7 +1154,7 @@ func TestICMPChecksumValidationWithPayloadMultipleViews(t *testing.T) { t.Fatalf("got invalid = %d, want = 0", got) } if got := typStat.Value(); got != 0 { - t.Fatalf("got %s = %d, want = 0", typ.name, got) + t.Fatalf("got = %d, want = 0", got) } // Without setting checksum, the incoming packet should @@ -940,13 +1165,13 @@ func TestICMPChecksumValidationWithPayloadMultipleViews(t *testing.T) { } // Rx count of type typ.typ should not have increased. if got := typStat.Value(); got != 0 { - t.Fatalf("got %s = %d, want = 0", typ.name, got) + t.Fatalf("got = %d, want = 0", got) } // When checksum is set, it should be received. handleIPv6Payload(typ.typ, typ.size, typ.payloadSize, typ.payload, true) if got := typStat.Value(); got != 1 { - t.Fatalf("got %s = %d, want = 1", typ.name, got) + t.Fatalf("got = %d, want = 0", got) } // Invalid count should not have increased again. if got := invalid.Value(); got != 1 { diff --git a/pkg/tcpip/network/ipv6/ipv6.go b/pkg/tcpip/network/ipv6/ipv6.go index 0eafe9790..ee64d92d8 100644 --- a/pkg/tcpip/network/ipv6/ipv6.go +++ b/pkg/tcpip/network/ipv6/ipv6.go @@ -27,6 +27,7 @@ import ( "gvisor.dev/gvisor/pkg/tcpip" "gvisor.dev/gvisor/pkg/tcpip/buffer" "gvisor.dev/gvisor/pkg/tcpip/header" + "gvisor.dev/gvisor/pkg/tcpip/header/parse" "gvisor.dev/gvisor/pkg/tcpip/network/fragmentation" "gvisor.dev/gvisor/pkg/tcpip/stack" ) @@ -48,6 +49,7 @@ type endpoint struct { nicID tcpip.NICID linkEP stack.LinkEndpoint linkAddrCache stack.LinkAddressCache + nud stack.NUDHandler dispatcher stack.TransportDispatcher protocol *protocol stack *stack.Stack @@ -106,6 +108,32 @@ func (e *endpoint) addIPHeader(r *stack.Route, pkt *stack.PacketBuffer, params s func (e *endpoint) WritePacket(r *stack.Route, gso *stack.GSO, params stack.NetworkHeaderParams, pkt *stack.PacketBuffer) *tcpip.Error { e.addIPHeader(r, pkt, params) + // iptables filtering. All packets that reach here are locally + // generated. + nicName := e.stack.FindNICNameFromID(e.NICID()) + ipt := e.stack.IPTables() + if ok := ipt.Check(stack.Output, pkt, gso, r, "", nicName); !ok { + // iptables is telling us to drop the packet. + r.Stats().IP.IPTablesOutputDropped.Increment() + return nil + } + + // If the packet is manipulated as per NAT Output rules, handle packet + // based on destination address and do not send the packet to link + // layer. + // + // TODO(gvisor.dev/issue/170): We should do this for every + // packet, rather than only NATted packets, but removing this check + // short circuits broadcasts before they are sent out to other hosts. + if pkt.NatDone { + netHeader := header.IPv6(pkt.NetworkHeader().View()) + if ep, err := e.stack.FindNetworkEndpoint(header.IPv6ProtocolNumber, netHeader.DestinationAddress()); err == nil { + route := r.ReverseRoute(netHeader.SourceAddress(), netHeader.DestinationAddress()) + ep.HandlePacket(&route, pkt) + return nil + } + } + if r.Loop&stack.PacketLoop != 0 { loopedR := r.MakeLoopedRoute() @@ -120,8 +148,11 @@ func (e *endpoint) WritePacket(r *stack.Route, gso *stack.GSO, params stack.Netw return nil } + if err := e.linkEP.WritePacket(r, gso, ProtocolNumber, pkt); err != nil { + return err + } r.Stats().IP.PacketsSent.Increment() - return e.linkEP.WritePacket(r, gso, ProtocolNumber, pkt) + return nil } // WritePackets implements stack.LinkEndpoint.WritePackets. @@ -137,9 +168,50 @@ func (e *endpoint) WritePackets(r *stack.Route, gso *stack.GSO, pkts stack.Packe e.addIPHeader(r, pb, params) } - n, err := e.linkEP.WritePackets(r, gso, pkts, ProtocolNumber) + // iptables filtering. All packets that reach here are locally + // generated. + nicName := e.stack.FindNICNameFromID(e.NICID()) + ipt := e.stack.IPTables() + dropped, natPkts := ipt.CheckPackets(stack.Output, pkts, gso, r, nicName) + if len(dropped) == 0 && len(natPkts) == 0 { + // Fast path: If no packets are to be dropped then we can just invoke the + // faster WritePackets API directly. + n, err := e.linkEP.WritePackets(r, gso, pkts, ProtocolNumber) + r.Stats().IP.PacketsSent.IncrementBy(uint64(n)) + return n, err + } + r.Stats().IP.IPTablesOutputDropped.IncrementBy(uint64(len(dropped))) + + // Slow path as we are dropping some packets in the batch degrade to + // emitting one packet at a time. + n := 0 + for pkt := pkts.Front(); pkt != nil; pkt = pkt.Next() { + if _, ok := dropped[pkt]; ok { + continue + } + if _, ok := natPkts[pkt]; ok { + netHeader := header.IPv6(pkt.NetworkHeader().View()) + if ep, err := e.stack.FindNetworkEndpoint(header.IPv6ProtocolNumber, netHeader.DestinationAddress()); err == nil { + src := netHeader.SourceAddress() + dst := netHeader.DestinationAddress() + route := r.ReverseRoute(src, dst) + ep.HandlePacket(&route, pkt) + n++ + continue + } + } + if err := e.linkEP.WritePacket(r, gso, ProtocolNumber, pkt); err != nil { + r.Stats().IP.PacketsSent.IncrementBy(uint64(n)) + // Dropped packets aren't errors, so include them in + // the return value. + return n + len(dropped), err + } + n++ + } + r.Stats().IP.PacketsSent.IncrementBy(uint64(n)) - return n, err + // Dropped packets aren't errors, so include them in the return value. + return n + len(dropped), nil } // WriteHeaderIncludedPacker implements stack.NetworkEndpoint. It is not yet @@ -168,6 +240,15 @@ func (e *endpoint) HandlePacket(r *stack.Route, pkt *stack.PacketBuffer) { it := header.MakeIPv6PayloadIterator(header.IPv6ExtensionHeaderIdentifier(h.NextHeader()), vv) hasFragmentHeader := false + // iptables filtering. All packets that reach here are intended for + // this machine and will not be forwarded. + ipt := e.stack.IPTables() + if ok := ipt.Check(stack.Input, pkt, nil, nil, "", ""); !ok { + // iptables is telling us to drop the packet. + r.Stats().IP.IPTablesInputDropped.Increment() + return + } + for firstHeader := true; ; firstHeader = false { extHdr, done, err := it.Next() if err != nil { @@ -310,21 +391,18 @@ func (e *endpoint) HandlePacket(r *stack.Route, pkt *stack.PacketBuffer) { // The packet is a fragment, let's try to reassemble it. start := extHdr.FragmentOffset() * header.IPv6FragmentExtHdrFragmentOffsetBytesPerUnit - last := start + uint16(fragmentPayloadLen) - 1 - // Drop the packet if the fragmentOffset is incorrect. i.e the - // combination of fragmentOffset and pkt.Data.size() causes a - // wrap around resulting in last being less than the offset. - if last < start { + // Drop the fragment if the size of the reassembled payload would exceed + // the maximum payload size. + if int(start)+fragmentPayloadLen > header.IPv6MaximumPayloadSize { r.Stats().IP.MalformedPacketsReceived.Increment() r.Stats().IP.MalformedFragmentsReceived.Increment() return } - var ready bool // Note that pkt doesn't have its transport header set after reassembly, // and won't until DeliverNetworkPacket sets it. - pkt.Data, ready, err = e.protocol.fragmentation.Process( + data, proto, ready, err := e.protocol.fragmentation.Process( // IPv6 ignores the Protocol field since the ID only needs to be unique // across source-destination pairs, as per RFC 8200 section 4.5. fragmentation.FragmentID{ @@ -333,8 +411,9 @@ func (e *endpoint) HandlePacket(r *stack.Route, pkt *stack.PacketBuffer) { ID: extHdr.ID(), }, start, - last, + start+uint16(fragmentPayloadLen)-1, extHdr.More(), + uint8(rawPayload.Identifier), rawPayload.Buf, ) if err != nil { @@ -342,12 +421,14 @@ func (e *endpoint) HandlePacket(r *stack.Route, pkt *stack.PacketBuffer) { r.Stats().IP.MalformedFragmentsReceived.Increment() return } + pkt.Data = data if ready { // We create a new iterator with the reassembled packet because we could // have more extension headers in the reassembled payload, as per RFC - // 8200 section 4.5. - it = header.MakeIPv6PayloadIterator(rawPayload.Identifier, pkt.Data) + // 8200 section 4.5. We also use the NextHeader value from the first + // fragment. + it = header.MakeIPv6PayloadIterator(header.IPv6ExtensionHeaderIdentifier(proto), pkt.Data) } case header.IPv6DestinationOptionsExtHdr: @@ -453,11 +534,12 @@ func (*protocol) ParseAddresses(v buffer.View) (src, dst tcpip.Address) { } // NewEndpoint creates a new ipv6 endpoint. -func (p *protocol) NewEndpoint(nicID tcpip.NICID, linkAddrCache stack.LinkAddressCache, dispatcher stack.TransportDispatcher, linkEP stack.LinkEndpoint, st *stack.Stack) stack.NetworkEndpoint { +func (p *protocol) NewEndpoint(nicID tcpip.NICID, linkAddrCache stack.LinkAddressCache, nud stack.NUDHandler, dispatcher stack.TransportDispatcher, linkEP stack.LinkEndpoint, st *stack.Stack) stack.NetworkEndpoint { return &endpoint{ nicID: nicID, linkEP: linkEP, linkAddrCache: linkAddrCache, + nud: nud, dispatcher: dispatcher, protocol: p, stack: st, @@ -465,10 +547,10 @@ func (p *protocol) NewEndpoint(nicID tcpip.NICID, linkAddrCache stack.LinkAddres } // SetOption implements NetworkProtocol.SetOption. -func (p *protocol) SetOption(option interface{}) *tcpip.Error { +func (p *protocol) SetOption(option tcpip.SettableNetworkProtocolOption) *tcpip.Error { switch v := option.(type) { - case tcpip.DefaultTTLOption: - p.SetDefaultTTL(uint8(v)) + case *tcpip.DefaultTTLOption: + p.SetDefaultTTL(uint8(*v)) return nil default: return tcpip.ErrUnknownProtocolOption @@ -476,7 +558,7 @@ func (p *protocol) SetOption(option interface{}) *tcpip.Error { } // Option implements NetworkProtocol.Option. -func (p *protocol) Option(option interface{}) *tcpip.Error { +func (p *protocol) Option(option tcpip.GettableNetworkProtocolOption) *tcpip.Error { switch v := option.(type) { case *tcpip.DefaultTTLOption: *v = tcpip.DefaultTTLOption(p.DefaultTTL()) @@ -502,75 +584,14 @@ func (*protocol) Close() {} // Wait implements stack.TransportProtocol.Wait. func (*protocol) Wait() {} -// Parse implements stack.TransportProtocol.Parse. +// Parse implements stack.NetworkProtocol.Parse. func (*protocol) Parse(pkt *stack.PacketBuffer) (proto tcpip.TransportProtocolNumber, hasTransportHdr bool, ok bool) { - hdr, ok := pkt.Data.PullUp(header.IPv6MinimumSize) + proto, _, fragOffset, fragMore, ok := parse.IPv6(pkt) if !ok { return 0, false, false } - ipHdr := header.IPv6(hdr) - - // dataClone consists of: - // - Any IPv6 header bytes after the first 40 (i.e. extensions). - // - The transport header, if present. - // - Any other payload data. - views := [8]buffer.View{} - dataClone := pkt.Data.Clone(views[:]) - dataClone.TrimFront(header.IPv6MinimumSize) - it := header.MakeIPv6PayloadIterator(header.IPv6ExtensionHeaderIdentifier(ipHdr.NextHeader()), dataClone) - - // Iterate over the IPv6 extensions to find their length. - // - // Parsing occurs again in HandlePacket because we don't track the - // extensions in PacketBuffer. Unfortunately, that means HandlePacket - // has to do the parsing work again. - var nextHdr tcpip.TransportProtocolNumber - foundNext := true - extensionsSize := 0 -traverseExtensions: - for extHdr, done, err := it.Next(); ; extHdr, done, err = it.Next() { - if err != nil { - break - } - // If we exhaust the extension list, the entire packet is the IPv6 header - // and (possibly) extensions. - if done { - extensionsSize = dataClone.Size() - foundNext = false - break - } - - switch extHdr := extHdr.(type) { - case header.IPv6FragmentExtHdr: - // If this is an atomic fragment, we don't have to treat it specially. - if !extHdr.More() && extHdr.FragmentOffset() == 0 { - continue - } - // This is a non-atomic fragment and has to be re-assembled before we can - // examine the payload for a transport header. - foundNext = false - - case header.IPv6RawPayloadHeader: - // We've found the payload after any extensions. - extensionsSize = dataClone.Size() - extHdr.Buf.Size() - nextHdr = tcpip.TransportProtocolNumber(extHdr.Identifier) - break traverseExtensions - - default: - // Any other extension is a no-op, keep looping until we find the payload. - } - } - - // Put the IPv6 header with extensions in pkt.NetworkHeader(). - hdr, ok = pkt.NetworkHeader().Consume(header.IPv6MinimumSize + extensionsSize) - if !ok { - panic(fmt.Sprintf("pkt.Data should have at least %d bytes, but only has %d.", header.IPv6MinimumSize+extensionsSize, pkt.Data.Size())) - } - ipHdr = header.IPv6(hdr) - pkt.Data.CapLength(int(ipHdr.PayloadLength())) - pkt.NetworkProtocolNumber = header.IPv6ProtocolNumber - return nextHdr, foundNext, true + return proto, !fragMore && fragOffset == 0, true } // calculateMTU calculates the network-layer payload MTU based on the link-layer diff --git a/pkg/tcpip/network/ipv6/ipv6_test.go b/pkg/tcpip/network/ipv6/ipv6_test.go index 0a183bfde..9eea1de8d 100644 --- a/pkg/tcpip/network/ipv6/ipv6_test.go +++ b/pkg/tcpip/network/ipv6/ipv6_test.go @@ -15,6 +15,7 @@ package ipv6 import ( + "math" "testing" "github.com/google/go-cmp/cmp" @@ -687,6 +688,7 @@ func TestReceiveIPv6Fragments(t *testing.T) { // Used to test cases where the fragment blocks are not a multiple of // the fragment block size of 8 (RFC 8200 section 4.5). udpPayload3Length = 127 + udpPayload4Length = header.IPv6MaximumPayloadSize - header.UDPMinimumSize fragmentExtHdrLen = 8 // Note, not all routing extension headers will be 8 bytes but this test // uses 8 byte routing extension headers for most sub tests. @@ -731,6 +733,10 @@ func TestReceiveIPv6Fragments(t *testing.T) { udpPayload3Addr1ToAddr2 := udpPayload3Addr1ToAddr2Buf[:] ipv6Payload3Addr1ToAddr2 := udpGen(udpPayload3Addr1ToAddr2, 3, addr1, addr2) + var udpPayload4Addr1ToAddr2Buf [udpPayload4Length]byte + udpPayload4Addr1ToAddr2 := udpPayload4Addr1ToAddr2Buf[:] + ipv6Payload4Addr1ToAddr2 := udpGen(udpPayload4Addr1ToAddr2, 4, addr1, addr2) + tests := []struct { name string expectedPayload []byte @@ -866,6 +872,46 @@ func TestReceiveIPv6Fragments(t *testing.T) { expectedPayloads: [][]byte{udpPayload1Addr1ToAddr2}, }, { + name: "Two fragments with different Next Header values", + fragments: []fragmentData{ + { + srcAddr: addr1, + dstAddr: addr2, + nextHdr: fragmentExtHdrID, + data: buffer.NewVectorisedView( + fragmentExtHdrLen+64, + []buffer.View{ + // Fragment extension header. + // + // Fragment offset = 0, More = true, ID = 1 + buffer.View([]byte{uint8(header.UDPProtocolNumber), 0, 0, 1, 0, 0, 0, 1}), + + ipv6Payload1Addr1ToAddr2[:64], + }, + ), + }, + { + srcAddr: addr1, + dstAddr: addr2, + nextHdr: fragmentExtHdrID, + data: buffer.NewVectorisedView( + fragmentExtHdrLen+len(ipv6Payload1Addr1ToAddr2)-64, + []buffer.View{ + // Fragment extension header. + // + // Fragment offset = 8, More = false, ID = 1 + // NextHeader value is different than the one in the first fragment, so + // this NextHeader should be ignored. + buffer.View([]byte{uint8(header.IPv6NoNextHeaderIdentifier), 0, 0, 64, 0, 0, 0, 1}), + + ipv6Payload1Addr1ToAddr2[64:], + }, + ), + }, + }, + expectedPayloads: [][]byte{udpPayload1Addr1ToAddr2}, + }, + { name: "Two fragments with last fragment size not a multiple of fragment block size", fragments: []fragmentData{ { @@ -980,6 +1026,44 @@ func TestReceiveIPv6Fragments(t *testing.T) { expectedPayloads: nil, }, { + name: "Two fragments reassembled into a maximum UDP packet", + fragments: []fragmentData{ + { + srcAddr: addr1, + dstAddr: addr2, + nextHdr: fragmentExtHdrID, + data: buffer.NewVectorisedView( + fragmentExtHdrLen+65520, + []buffer.View{ + // Fragment extension header. + // + // Fragment offset = 0, More = true, ID = 1 + buffer.View([]byte{uint8(header.UDPProtocolNumber), 0, 0, 1, 0, 0, 0, 1}), + + ipv6Payload4Addr1ToAddr2[:65520], + }, + ), + }, + { + srcAddr: addr1, + dstAddr: addr2, + nextHdr: fragmentExtHdrID, + data: buffer.NewVectorisedView( + fragmentExtHdrLen+len(ipv6Payload4Addr1ToAddr2)-65520, + []buffer.View{ + // Fragment extension header. + // + // Fragment offset = 8190, More = false, ID = 1 + buffer.View([]byte{uint8(header.UDPProtocolNumber), 0, 255, 240, 0, 0, 0, 1}), + + ipv6Payload4Addr1ToAddr2[65520:], + }, + ), + }, + }, + expectedPayloads: [][]byte{udpPayload4Addr1ToAddr2}, + }, + { name: "Two fragments with per-fragment routing header with zero segments left", fragments: []fragmentData{ { @@ -1532,3 +1616,343 @@ func TestReceiveIPv6Fragments(t *testing.T) { }) } } + +func TestInvalidIPv6Fragments(t *testing.T) { + const ( + nicID = 1 + fragmentExtHdrLen = 8 + ) + + payloadGen := func(payloadLen int) []byte { + payload := make([]byte, payloadLen) + for i := 0; i < len(payload); i++ { + payload[i] = 0x30 + } + return payload + } + + tests := []struct { + name string + fragments []fragmentData + wantMalformedIPPackets uint64 + wantMalformedFragments uint64 + }{ + { + name: "fragments reassembled into a payload exceeding the max IPv6 payload size", + fragments: []fragmentData{ + { + srcAddr: addr1, + dstAddr: addr2, + nextHdr: fragmentExtHdrID, + data: buffer.NewVectorisedView( + fragmentExtHdrLen+(header.IPv6MaximumPayloadSize+1)-16, + []buffer.View{ + // Fragment extension header. + // Fragment offset = 8190, More = false, ID = 1 + buffer.View([]byte{uint8(header.UDPProtocolNumber), 0, + ((header.IPv6MaximumPayloadSize + 1) - 16) >> 8, + ((header.IPv6MaximumPayloadSize + 1) - 16) & math.MaxUint8, + 0, 0, 0, 1}), + // Payload length = 16 + payloadGen(16), + }, + ), + }, + }, + wantMalformedIPPackets: 1, + wantMalformedFragments: 1, + }, + } + + for _, test := range tests { + t.Run(test.name, func(t *testing.T) { + s := stack.New(stack.Options{ + NetworkProtocols: []stack.NetworkProtocol{ + NewProtocol(), + }, + }) + e := channel.New(0, 1500, linkAddr1) + if err := s.CreateNIC(nicID, e); err != nil { + t.Fatalf("CreateNIC(%d, _) = %s", nicID, err) + } + if err := s.AddAddress(nicID, ProtocolNumber, addr2); err != nil { + t.Fatalf("AddAddress(%d, %d, %s) = %s", nicID, ProtocolNumber, addr2, err) + } + + for _, f := range test.fragments { + hdr := buffer.NewPrependable(header.IPv6MinimumSize) + + // Serialize IPv6 fixed header. + ip := header.IPv6(hdr.Prepend(header.IPv6MinimumSize)) + ip.Encode(&header.IPv6Fields{ + PayloadLength: uint16(f.data.Size()), + NextHeader: f.nextHdr, + HopLimit: 255, + SrcAddr: f.srcAddr, + DstAddr: f.dstAddr, + }) + + vv := hdr.View().ToVectorisedView() + vv.Append(f.data) + + e.InjectInbound(ProtocolNumber, stack.NewPacketBuffer(stack.PacketBufferOptions{ + Data: vv, + })) + } + + if got, want := s.Stats().IP.MalformedPacketsReceived.Value(), test.wantMalformedIPPackets; got != want { + t.Errorf("got Stats.IP.MalformedPacketsReceived = %d, want = %d", got, want) + } + if got, want := s.Stats().IP.MalformedFragmentsReceived.Value(), test.wantMalformedFragments; got != want { + t.Errorf("got Stats.IP.MalformedFragmentsReceived = %d, want = %d", got, want) + } + }) + } +} + +func TestWriteStats(t *testing.T) { + const nPackets = 3 + tests := []struct { + name string + setup func(*testing.T, *stack.Stack) + linkEP func() stack.LinkEndpoint + expectSent int + expectDropped int + expectWritten int + }{ + { + name: "Accept all", + // No setup needed, tables accept everything by default. + setup: func(*testing.T, *stack.Stack) {}, + linkEP: func() stack.LinkEndpoint { return &limitedEP{nPackets} }, + expectSent: nPackets, + expectDropped: 0, + expectWritten: nPackets, + }, { + name: "Accept all with error", + // No setup needed, tables accept everything by default. + setup: func(*testing.T, *stack.Stack) {}, + linkEP: func() stack.LinkEndpoint { return &limitedEP{nPackets - 1} }, + expectSent: nPackets - 1, + expectDropped: 0, + expectWritten: nPackets - 1, + }, { + name: "Drop all", + setup: func(t *testing.T, stk *stack.Stack) { + // Install Output DROP rule. + t.Helper() + ipt := stk.IPTables() + filter, ok := ipt.GetTable(stack.FilterTable, true /* ipv6 */) + if !ok { + t.Fatalf("failed to find filter table") + } + ruleIdx := filter.BuiltinChains[stack.Output] + filter.Rules[ruleIdx].Target = stack.DropTarget{} + if err := ipt.ReplaceTable(stack.FilterTable, filter, true /* ipv6 */); err != nil { + t.Fatalf("failed to replace table: %v", err) + } + }, + linkEP: func() stack.LinkEndpoint { return &limitedEP{nPackets} }, + expectSent: 0, + expectDropped: nPackets, + expectWritten: nPackets, + }, { + name: "Drop some", + setup: func(t *testing.T, stk *stack.Stack) { + // Install Output DROP rule that matches only 1 + // of the 3 packets. + t.Helper() + ipt := stk.IPTables() + filter, ok := ipt.GetTable(stack.FilterTable, true /* ipv6 */) + if !ok { + t.Fatalf("failed to find filter table") + } + // We'll match and DROP the last packet. + ruleIdx := filter.BuiltinChains[stack.Output] + filter.Rules[ruleIdx].Target = stack.DropTarget{} + filter.Rules[ruleIdx].Matchers = []stack.Matcher{&limitedMatcher{nPackets - 1}} + // Make sure the next rule is ACCEPT. + filter.Rules[ruleIdx+1].Target = stack.AcceptTarget{} + if err := ipt.ReplaceTable(stack.FilterTable, filter, true /* ipv6 */); err != nil { + t.Fatalf("failed to replace table: %v", err) + } + }, + linkEP: func() stack.LinkEndpoint { return &limitedEP{nPackets} }, + expectSent: nPackets - 1, + expectDropped: 1, + expectWritten: nPackets, + }, + } + + writers := []struct { + name string + writePackets func(*stack.Route, stack.PacketBufferList) (int, *tcpip.Error) + }{ + { + name: "WritePacket", + writePackets: func(rt *stack.Route, pkts stack.PacketBufferList) (int, *tcpip.Error) { + nWritten := 0 + for pkt := pkts.Front(); pkt != nil; pkt = pkt.Next() { + if err := rt.WritePacket(nil, stack.NetworkHeaderParams{}, pkt); err != nil { + return nWritten, err + } + nWritten++ + } + return nWritten, nil + }, + }, { + name: "WritePackets", + writePackets: func(rt *stack.Route, pkts stack.PacketBufferList) (int, *tcpip.Error) { + return rt.WritePackets(nil, pkts, stack.NetworkHeaderParams{}) + }, + }, + } + + for _, writer := range writers { + t.Run(writer.name, func(t *testing.T) { + for _, test := range tests { + t.Run(test.name, func(t *testing.T) { + rt := buildRoute(t, nil, test.linkEP()) + + var pkts stack.PacketBufferList + for i := 0; i < nPackets; i++ { + pkt := stack.NewPacketBuffer(stack.PacketBufferOptions{ + ReserveHeaderBytes: header.UDPMinimumSize + int(rt.MaxHeaderLength()), + Data: buffer.NewView(0).ToVectorisedView(), + }) + pkt.TransportHeader().Push(header.UDPMinimumSize) + pkts.PushBack(pkt) + } + + test.setup(t, rt.Stack()) + + nWritten, _ := writer.writePackets(&rt, pkts) + + if got := int(rt.Stats().IP.PacketsSent.Value()); got != test.expectSent { + t.Errorf("sent %d packets, but expected to send %d", got, test.expectSent) + } + if got := int(rt.Stats().IP.IPTablesOutputDropped.Value()); got != test.expectDropped { + t.Errorf("dropped %d packets, but expected to drop %d", got, test.expectDropped) + } + if nWritten != test.expectWritten { + t.Errorf("wrote %d packets, but expected WritePackets to return %d", nWritten, test.expectWritten) + } + }) + } + }) + } +} + +func buildRoute(t *testing.T, packetCollectorErrors []*tcpip.Error, linkEP stack.LinkEndpoint) stack.Route { + s := stack.New(stack.Options{ + NetworkProtocols: []stack.NetworkProtocol{NewProtocol()}, + }) + s.CreateNIC(1, linkEP) + const ( + src = "\xfc\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x01" + dst = "\xfc\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x02" + ) + s.AddAddress(1, ProtocolNumber, src) + { + subnet, err := tcpip.NewSubnet(dst, tcpip.AddressMask("\xfc\xff\xff\xff\xff\xff\xff\xff\xff\xff\xff\xff\xff\xff\xff\xff")) + if err != nil { + t.Fatal(err) + } + s.SetRouteTable([]tcpip.Route{{ + Destination: subnet, + NIC: 1, + }}) + } + rt, err := s.FindRoute(0, src, dst, ProtocolNumber, false /* multicastLoop */) + if err != nil { + t.Fatalf("s.FindRoute got %v, want %v", err, nil) + } + return rt +} + +// limitedEP is a link endpoint that writes up to a certain number of packets +// before returning errors. +type limitedEP struct { + limit int +} + +// MTU implements LinkEndpoint.MTU. +func (*limitedEP) MTU() uint32 { + return header.IPv6MinimumMTU +} + +// Capabilities implements LinkEndpoint.Capabilities. +func (*limitedEP) Capabilities() stack.LinkEndpointCapabilities { return 0 } + +// MaxHeaderLength implements LinkEndpoint.MaxHeaderLength. +func (*limitedEP) MaxHeaderLength() uint16 { return 0 } + +// LinkAddress implements LinkEndpoint.LinkAddress. +func (*limitedEP) LinkAddress() tcpip.LinkAddress { return "" } + +// WritePacket implements LinkEndpoint.WritePacket. +func (ep *limitedEP) WritePacket(*stack.Route, *stack.GSO, tcpip.NetworkProtocolNumber, *stack.PacketBuffer) *tcpip.Error { + if ep.limit == 0 { + return tcpip.ErrInvalidEndpointState + } + ep.limit-- + return nil +} + +// WritePackets implements LinkEndpoint.WritePackets. +func (ep *limitedEP) WritePackets(_ *stack.Route, _ *stack.GSO, pkts stack.PacketBufferList, _ tcpip.NetworkProtocolNumber) (int, *tcpip.Error) { + if ep.limit == 0 { + return 0, tcpip.ErrInvalidEndpointState + } + nWritten := ep.limit + if nWritten > pkts.Len() { + nWritten = pkts.Len() + } + ep.limit -= nWritten + return nWritten, nil +} + +// WriteRawPacket implements LinkEndpoint.WriteRawPacket. +func (ep *limitedEP) WriteRawPacket(_ buffer.VectorisedView) *tcpip.Error { + if ep.limit == 0 { + return tcpip.ErrInvalidEndpointState + } + ep.limit-- + return nil +} + +// Attach implements LinkEndpoint.Attach. +func (*limitedEP) Attach(_ stack.NetworkDispatcher) {} + +// IsAttached implements LinkEndpoint.IsAttached. +func (*limitedEP) IsAttached() bool { return false } + +// Wait implements LinkEndpoint.Wait. +func (*limitedEP) Wait() {} + +// ARPHardwareType implements LinkEndpoint.ARPHardwareType. +func (*limitedEP) ARPHardwareType() header.ARPHardwareType { return header.ARPHardwareEther } + +// AddHeader implements LinkEndpoint.AddHeader. +func (*limitedEP) AddHeader(_, _ tcpip.LinkAddress, _ tcpip.NetworkProtocolNumber, _ *stack.PacketBuffer) { +} + +// limitedMatcher is an iptables matcher that matches after a certain number of +// packets are checked against it. +type limitedMatcher struct { + limit int +} + +// Name implements Matcher.Name. +func (*limitedMatcher) Name() string { + return "limitedMatcher" +} + +// Match implements Matcher.Match. +func (lm *limitedMatcher) Match(stack.Hook, *stack.PacketBuffer, string) (bool, bool) { + if lm.limit == 0 { + return true, false + } + lm.limit-- + return false, false +} diff --git a/pkg/tcpip/network/ipv6/ndp_test.go b/pkg/tcpip/network/ipv6/ndp_test.go index af71a7d6b..7434df4a1 100644 --- a/pkg/tcpip/network/ipv6/ndp_test.go +++ b/pkg/tcpip/network/ipv6/ndp_test.go @@ -18,6 +18,7 @@ import ( "strings" "testing" + "github.com/google/go-cmp/cmp" "gvisor.dev/gvisor/pkg/tcpip" "gvisor.dev/gvisor/pkg/tcpip/buffer" "gvisor.dev/gvisor/pkg/tcpip/checker" @@ -30,12 +31,13 @@ import ( // setupStackAndEndpoint creates a stack with a single NIC with a link-local // address llladdr and an IPv6 endpoint to a remote with link-local address // rlladdr -func setupStackAndEndpoint(t *testing.T, llladdr, rlladdr tcpip.Address) (*stack.Stack, stack.NetworkEndpoint) { +func setupStackAndEndpoint(t *testing.T, llladdr, rlladdr tcpip.Address, useNeighborCache bool) (*stack.Stack, stack.NetworkEndpoint) { t.Helper() s := stack.New(stack.Options{ NetworkProtocols: []stack.NetworkProtocol{NewProtocol()}, TransportProtocols: []stack.TransportProtocol{icmp.NewProtocol6()}, + UseNeighborCache: useNeighborCache, }) if err := s.CreateNIC(1, &stubLinkEndpoint{}); err != nil { @@ -63,8 +65,7 @@ func setupStackAndEndpoint(t *testing.T, llladdr, rlladdr tcpip.Address) (*stack t.Fatalf("cannot find protocol instance for network protocol %d", ProtocolNumber) } - ep := netProto.NewEndpoint(0, &stubLinkAddressCache{}, &stubDispatcher{}, nil, s) - + ep := netProto.NewEndpoint(0, &stubLinkAddressCache{}, &stubNUDHandler{}, &stubDispatcher{}, nil, s) return s, ep } @@ -171,6 +172,123 @@ func TestNeighorSolicitationWithSourceLinkLayerOption(t *testing.T) { } } +// TestNeighorSolicitationWithSourceLinkLayerOptionUsingNeighborCache tests +// that receiving a valid NDP NS message with the Source Link Layer Address +// option results in a new entry in the link address cache for the sender of +// the message. +func TestNeighorSolicitationWithSourceLinkLayerOptionUsingNeighborCache(t *testing.T) { + const nicID = 1 + + tests := []struct { + name string + optsBuf []byte + expectedLinkAddr tcpip.LinkAddress + }{ + { + name: "Valid", + optsBuf: []byte{1, 1, 2, 3, 4, 5, 6, 7}, + expectedLinkAddr: "\x02\x03\x04\x05\x06\x07", + }, + { + name: "Too Small", + optsBuf: []byte{1, 1, 2, 3, 4, 5, 6}, + }, + { + name: "Invalid Length", + optsBuf: []byte{1, 2, 2, 3, 4, 5, 6, 7}, + }, + } + + for _, test := range tests { + t.Run(test.name, func(t *testing.T) { + s := stack.New(stack.Options{ + NetworkProtocols: []stack.NetworkProtocol{NewProtocol()}, + UseNeighborCache: true, + }) + e := channel.New(0, 1280, linkAddr0) + e.LinkEPCapabilities |= stack.CapabilityResolutionRequired + if err := s.CreateNIC(nicID, e); err != nil { + t.Fatalf("CreateNIC(%d, _) = %s", nicID, err) + } + if err := s.AddAddress(nicID, ProtocolNumber, lladdr0); err != nil { + t.Fatalf("AddAddress(%d, %d, %s) = %s", nicID, ProtocolNumber, lladdr0, err) + } + + ndpNSSize := header.ICMPv6NeighborSolicitMinimumSize + len(test.optsBuf) + hdr := buffer.NewPrependable(header.IPv6MinimumSize + ndpNSSize) + pkt := header.ICMPv6(hdr.Prepend(ndpNSSize)) + pkt.SetType(header.ICMPv6NeighborSolicit) + ns := header.NDPNeighborSolicit(pkt.NDPPayload()) + ns.SetTargetAddress(lladdr0) + opts := ns.Options() + copy(opts, test.optsBuf) + pkt.SetChecksum(header.ICMPv6Checksum(pkt, lladdr1, lladdr0, buffer.VectorisedView{})) + payloadLength := hdr.UsedLength() + ip := header.IPv6(hdr.Prepend(header.IPv6MinimumSize)) + ip.Encode(&header.IPv6Fields{ + PayloadLength: uint16(payloadLength), + NextHeader: uint8(header.ICMPv6ProtocolNumber), + HopLimit: 255, + SrcAddr: lladdr1, + DstAddr: lladdr0, + }) + + invalid := s.Stats().ICMP.V6PacketsReceived.Invalid + + // Invalid count should initially be 0. + if got := invalid.Value(); got != 0 { + t.Fatalf("got invalid = %d, want = 0", got) + } + + e.InjectInbound(ProtocolNumber, &stack.PacketBuffer{ + Data: hdr.View().ToVectorisedView(), + }) + + neighbors, err := s.Neighbors(nicID) + if err != nil { + t.Fatalf("s.Neighbors(%d): %s", nicID, err) + } + + neighborByAddr := make(map[tcpip.Address]stack.NeighborEntry) + for _, n := range neighbors { + if existing, ok := neighborByAddr[n.Addr]; ok { + if diff := cmp.Diff(existing, n); diff != "" { + t.Fatalf("s.Neighbors(%d) returned unexpected duplicate neighbor entry (-existing +got):\n%s", nicID, diff) + } + t.Fatalf("s.Neighbors(%d) returned unexpected duplicate neighbor entry: %s", nicID, existing) + } + neighborByAddr[n.Addr] = n + } + + if neigh, ok := neighborByAddr[lladdr1]; len(test.expectedLinkAddr) != 0 { + // Invalid count should not have increased. + if got := invalid.Value(); got != 0 { + t.Errorf("got invalid = %d, want = 0", got) + } + + if !ok { + t.Fatalf("expected a neighbor entry for %q", lladdr1) + } + if neigh.LinkAddr != test.expectedLinkAddr { + t.Errorf("got link address = %s, want = %s", neigh.LinkAddr, test.expectedLinkAddr) + } + if neigh.State != stack.Stale { + t.Errorf("got NUD state = %s, want = %s", neigh.State, stack.Stale) + } + } else { + // Invalid count should have increased. + if got := invalid.Value(); got != 1 { + t.Errorf("got invalid = %d, want = 1", got) + } + + if ok { + t.Fatalf("unexpectedly got neighbor entry: %s", neigh) + } + } + }) + } +} + func TestNeighorSolicitationResponse(t *testing.T) { const nicID = 1 nicAddr := lladdr0 @@ -180,6 +298,20 @@ func TestNeighorSolicitationResponse(t *testing.T) { remoteLinkAddr0 := linkAddr1 remoteLinkAddr1 := linkAddr2 + stacks := []struct { + name string + useNeighborCache bool + }{ + { + name: "linkAddrCache", + useNeighborCache: false, + }, + { + name: "neighborCache", + useNeighborCache: true, + }, + } + tests := []struct { name string nsOpts header.NDPOptionsSerializer @@ -338,86 +470,92 @@ func TestNeighorSolicitationResponse(t *testing.T) { }, } - for _, test := range tests { - t.Run(test.name, func(t *testing.T) { - s := stack.New(stack.Options{ - NetworkProtocols: []stack.NetworkProtocol{NewProtocol()}, - }) - e := channel.New(1, 1280, nicLinkAddr) - if err := s.CreateNIC(nicID, e); err != nil { - t.Fatalf("CreateNIC(%d, _) = %s", nicID, err) - } - if err := s.AddAddress(nicID, ProtocolNumber, nicAddr); err != nil { - t.Fatalf("AddAddress(%d, %d, %s) = %s", nicID, ProtocolNumber, nicAddr, err) - } + for _, stackTyp := range stacks { + t.Run(stackTyp.name, func(t *testing.T) { + for _, test := range tests { + t.Run(test.name, func(t *testing.T) { + s := stack.New(stack.Options{ + NetworkProtocols: []stack.NetworkProtocol{NewProtocol()}, + UseNeighborCache: stackTyp.useNeighborCache, + }) + e := channel.New(1, 1280, nicLinkAddr) + e.LinkEPCapabilities |= stack.CapabilityResolutionRequired + if err := s.CreateNIC(nicID, e); err != nil { + t.Fatalf("CreateNIC(%d, _) = %s", nicID, err) + } + if err := s.AddAddress(nicID, ProtocolNumber, nicAddr); err != nil { + t.Fatalf("AddAddress(%d, %d, %s) = %s", nicID, ProtocolNumber, nicAddr, err) + } - ndpNSSize := header.ICMPv6NeighborSolicitMinimumSize + test.nsOpts.Length() - hdr := buffer.NewPrependable(header.IPv6MinimumSize + ndpNSSize) - pkt := header.ICMPv6(hdr.Prepend(ndpNSSize)) - pkt.SetType(header.ICMPv6NeighborSolicit) - ns := header.NDPNeighborSolicit(pkt.NDPPayload()) - ns.SetTargetAddress(nicAddr) - opts := ns.Options() - opts.Serialize(test.nsOpts) - pkt.SetChecksum(header.ICMPv6Checksum(pkt, test.nsSrc, test.nsDst, buffer.VectorisedView{})) - payloadLength := hdr.UsedLength() - ip := header.IPv6(hdr.Prepend(header.IPv6MinimumSize)) - ip.Encode(&header.IPv6Fields{ - PayloadLength: uint16(payloadLength), - NextHeader: uint8(header.ICMPv6ProtocolNumber), - HopLimit: 255, - SrcAddr: test.nsSrc, - DstAddr: test.nsDst, - }) + ndpNSSize := header.ICMPv6NeighborSolicitMinimumSize + test.nsOpts.Length() + hdr := buffer.NewPrependable(header.IPv6MinimumSize + ndpNSSize) + pkt := header.ICMPv6(hdr.Prepend(ndpNSSize)) + pkt.SetType(header.ICMPv6NeighborSolicit) + ns := header.NDPNeighborSolicit(pkt.NDPPayload()) + ns.SetTargetAddress(nicAddr) + opts := ns.Options() + opts.Serialize(test.nsOpts) + pkt.SetChecksum(header.ICMPv6Checksum(pkt, test.nsSrc, test.nsDst, buffer.VectorisedView{})) + payloadLength := hdr.UsedLength() + ip := header.IPv6(hdr.Prepend(header.IPv6MinimumSize)) + ip.Encode(&header.IPv6Fields{ + PayloadLength: uint16(payloadLength), + NextHeader: uint8(header.ICMPv6ProtocolNumber), + HopLimit: 255, + SrcAddr: test.nsSrc, + DstAddr: test.nsDst, + }) + + invalid := s.Stats().ICMP.V6PacketsReceived.Invalid - invalid := s.Stats().ICMP.V6PacketsReceived.Invalid + // Invalid count should initially be 0. + if got := invalid.Value(); got != 0 { + t.Fatalf("got invalid = %d, want = 0", got) + } - // Invalid count should initially be 0. - if got := invalid.Value(); got != 0 { - t.Fatalf("got invalid = %d, want = 0", got) - } + e.InjectLinkAddr(ProtocolNumber, test.nsSrcLinkAddr, stack.NewPacketBuffer(stack.PacketBufferOptions{ + Data: hdr.View().ToVectorisedView(), + })) - e.InjectLinkAddr(ProtocolNumber, test.nsSrcLinkAddr, stack.NewPacketBuffer(stack.PacketBufferOptions{ - Data: hdr.View().ToVectorisedView(), - })) + if test.nsInvalid { + if got := invalid.Value(); got != 1 { + t.Fatalf("got invalid = %d, want = 1", got) + } - if test.nsInvalid { - if got := invalid.Value(); got != 1 { - t.Fatalf("got invalid = %d, want = 1", got) - } + if p, got := e.Read(); got { + t.Fatalf("unexpected response to an invalid NS = %+v", p.Pkt) + } - if p, got := e.Read(); got { - t.Fatalf("unexpected response to an invalid NS = %+v", p.Pkt) - } + // If we expected the NS to be invalid, we have nothing else to check. + return + } - // If we expected the NS to be invalid, we have nothing else to check. - return - } + if got := invalid.Value(); got != 0 { + t.Fatalf("got invalid = %d, want = 0", got) + } - if got := invalid.Value(); got != 0 { - t.Fatalf("got invalid = %d, want = 0", got) - } + p, got := e.Read() + if !got { + t.Fatal("expected an NDP NA response") + } - p, got := e.Read() - if !got { - t.Fatal("expected an NDP NA response") - } + if p.Route.RemoteLinkAddress != test.naDstLinkAddr { + t.Errorf("got p.Route.RemoteLinkAddress = %s, want = %s", p.Route.RemoteLinkAddress, test.naDstLinkAddr) + } - if p.Route.RemoteLinkAddress != test.naDstLinkAddr { - t.Errorf("got p.Route.RemoteLinkAddress = %s, want = %s", p.Route.RemoteLinkAddress, test.naDstLinkAddr) + checker.IPv6(t, stack.PayloadSince(p.Pkt.NetworkHeader()), + checker.SrcAddr(test.naSrc), + checker.DstAddr(test.naDst), + checker.TTL(header.NDPHopLimit), + checker.NDPNA( + checker.NDPNASolicitedFlag(test.naSolicited), + checker.NDPNATargetAddress(nicAddr), + checker.NDPNAOptions([]header.NDPOption{ + header.NDPTargetLinkLayerAddressOption(nicLinkAddr[:]), + }), + )) + }) } - - checker.IPv6(t, stack.PayloadSince(p.Pkt.NetworkHeader()), - checker.SrcAddr(test.naSrc), - checker.DstAddr(test.naDst), - checker.TTL(header.NDPHopLimit), - checker.NDPNA( - checker.NDPNASolicitedFlag(test.naSolicited), - checker.NDPNATargetAddress(nicAddr), - checker.NDPNAOptions([]header.NDPOption{ - header.NDPTargetLinkLayerAddressOption(nicLinkAddr[:]), - }), - )) }) } } @@ -532,197 +670,380 @@ func TestNeighorAdvertisementWithTargetLinkLayerOption(t *testing.T) { } } -func TestNDPValidation(t *testing.T) { - setup := func(t *testing.T) (*stack.Stack, stack.NetworkEndpoint, stack.Route) { - t.Helper() - - // Create a stack with the assigned link-local address lladdr0 - // and an endpoint to lladdr1. - s, ep := setupStackAndEndpoint(t, lladdr0, lladdr1) - - r, err := s.FindRoute(1, lladdr0, lladdr1, ProtocolNumber, false /* multicastLoop */) - if err != nil { - t.Fatalf("FindRoute(_) = _, %s, want = _, nil", err) - } - - return s, ep, r - } - - handleIPv6Payload := func(payload buffer.View, hopLimit uint8, atomicFragment bool, ep stack.NetworkEndpoint, r *stack.Route) { - nextHdr := uint8(header.ICMPv6ProtocolNumber) - var extensions buffer.View - if atomicFragment { - extensions = buffer.NewView(header.IPv6FragmentExtHdrLength) - extensions[0] = nextHdr - nextHdr = uint8(header.IPv6FragmentExtHdrIdentifier) - } - - pkt := stack.NewPacketBuffer(stack.PacketBufferOptions{ - ReserveHeaderBytes: header.IPv6MinimumSize + len(extensions), - Data: payload.ToVectorisedView(), - }) - ip := header.IPv6(pkt.NetworkHeader().Push(header.IPv6MinimumSize + len(extensions))) - ip.Encode(&header.IPv6Fields{ - PayloadLength: uint16(len(payload) + len(extensions)), - NextHeader: nextHdr, - HopLimit: hopLimit, - SrcAddr: r.LocalAddress, - DstAddr: r.RemoteAddress, - }) - if n := copy(ip[header.IPv6MinimumSize:], extensions); n != len(extensions) { - t.Fatalf("expected to write %d bytes of extensions, but wrote %d", len(extensions), n) - } - ep.HandlePacket(r, pkt) - } - - var tllData [header.NDPLinkLayerAddressSize]byte - header.NDPOptions(tllData[:]).Serialize(header.NDPOptionsSerializer{ - header.NDPTargetLinkLayerAddressOption(linkAddr1), - }) +// TestNeighorAdvertisementWithTargetLinkLayerOptionUsingNeighborCache tests +// that receiving a valid NDP NA message with the Target Link Layer Address +// option does not result in a new entry in the neighbor cache for the target +// of the message. +func TestNeighorAdvertisementWithTargetLinkLayerOptionUsingNeighborCache(t *testing.T) { + const nicID = 1 - types := []struct { - name string - typ header.ICMPv6Type - size int - extraData []byte - statCounter func(tcpip.ICMPv6ReceivedPacketStats) *tcpip.StatCounter + tests := []struct { + name string + optsBuf []byte + isValid bool }{ { - name: "RouterSolicit", - typ: header.ICMPv6RouterSolicit, - size: header.ICMPv6MinimumSize, - statCounter: func(stats tcpip.ICMPv6ReceivedPacketStats) *tcpip.StatCounter { - return stats.RouterSolicit - }, - }, - { - name: "RouterAdvert", - typ: header.ICMPv6RouterAdvert, - size: header.ICMPv6HeaderSize + header.NDPRAMinimumSize, - statCounter: func(stats tcpip.ICMPv6ReceivedPacketStats) *tcpip.StatCounter { - return stats.RouterAdvert - }, + name: "Valid", + optsBuf: []byte{2, 1, 2, 3, 4, 5, 6, 7}, + isValid: true, }, { - name: "NeighborSolicit", - typ: header.ICMPv6NeighborSolicit, - size: header.ICMPv6NeighborSolicitMinimumSize, - statCounter: func(stats tcpip.ICMPv6ReceivedPacketStats) *tcpip.StatCounter { - return stats.NeighborSolicit - }, + name: "Too Small", + optsBuf: []byte{2, 1, 2, 3, 4, 5, 6}, }, { - name: "NeighborAdvert", - typ: header.ICMPv6NeighborAdvert, - size: header.ICMPv6NeighborAdvertMinimumSize, - extraData: tllData[:], - statCounter: func(stats tcpip.ICMPv6ReceivedPacketStats) *tcpip.StatCounter { - return stats.NeighborAdvert - }, + name: "Invalid Length", + optsBuf: []byte{2, 2, 2, 3, 4, 5, 6, 7}, }, { - name: "RedirectMsg", - typ: header.ICMPv6RedirectMsg, - size: header.ICMPv6MinimumSize, - statCounter: func(stats tcpip.ICMPv6ReceivedPacketStats) *tcpip.StatCounter { - return stats.RedirectMsg + name: "Multiple", + optsBuf: []byte{ + 2, 1, 2, 3, 4, 5, 6, 7, + 2, 1, 2, 3, 4, 5, 6, 8, }, }, } - subTests := []struct { - name string - atomicFragment bool - hopLimit uint8 - code header.ICMPv6Code - valid bool + for _, test := range tests { + t.Run(test.name, func(t *testing.T) { + s := stack.New(stack.Options{ + NetworkProtocols: []stack.NetworkProtocol{NewProtocol()}, + UseNeighborCache: true, + }) + e := channel.New(0, 1280, linkAddr0) + e.LinkEPCapabilities |= stack.CapabilityResolutionRequired + if err := s.CreateNIC(nicID, e); err != nil { + t.Fatalf("CreateNIC(%d, _) = %s", nicID, err) + } + if err := s.AddAddress(nicID, ProtocolNumber, lladdr0); err != nil { + t.Fatalf("AddAddress(%d, %d, %s) = %s", nicID, ProtocolNumber, lladdr0, err) + } + + ndpNASize := header.ICMPv6NeighborAdvertMinimumSize + len(test.optsBuf) + hdr := buffer.NewPrependable(header.IPv6MinimumSize + ndpNASize) + pkt := header.ICMPv6(hdr.Prepend(ndpNASize)) + pkt.SetType(header.ICMPv6NeighborAdvert) + ns := header.NDPNeighborAdvert(pkt.NDPPayload()) + ns.SetTargetAddress(lladdr1) + opts := ns.Options() + copy(opts, test.optsBuf) + pkt.SetChecksum(header.ICMPv6Checksum(pkt, lladdr1, lladdr0, buffer.VectorisedView{})) + payloadLength := hdr.UsedLength() + ip := header.IPv6(hdr.Prepend(header.IPv6MinimumSize)) + ip.Encode(&header.IPv6Fields{ + PayloadLength: uint16(payloadLength), + NextHeader: uint8(header.ICMPv6ProtocolNumber), + HopLimit: 255, + SrcAddr: lladdr1, + DstAddr: lladdr0, + }) + + invalid := s.Stats().ICMP.V6PacketsReceived.Invalid + + // Invalid count should initially be 0. + if got := invalid.Value(); got != 0 { + t.Fatalf("got invalid = %d, want = 0", got) + } + + e.InjectInbound(ProtocolNumber, &stack.PacketBuffer{ + Data: hdr.View().ToVectorisedView(), + }) + + neighbors, err := s.Neighbors(nicID) + if err != nil { + t.Fatalf("s.Neighbors(%d): %s", nicID, err) + } + + neighborByAddr := make(map[tcpip.Address]stack.NeighborEntry) + for _, n := range neighbors { + if existing, ok := neighborByAddr[n.Addr]; ok { + if diff := cmp.Diff(existing, n); diff != "" { + t.Fatalf("s.Neighbors(%d) returned unexpected duplicate neighbor entry (-existing +got):\n%s", nicID, diff) + } + t.Fatalf("s.Neighbors(%d) returned unexpected duplicate neighbor entry: %s", nicID, existing) + } + neighborByAddr[n.Addr] = n + } + + if neigh, ok := neighborByAddr[lladdr1]; ok { + t.Fatalf("unexpectedly got neighbor entry: %s", neigh) + } + + if test.isValid { + // Invalid count should not have increased. + if got := invalid.Value(); got != 0 { + t.Errorf("got invalid = %d, want = 0", got) + } + } else { + // Invalid count should have increased. + if got := invalid.Value(); got != 1 { + t.Errorf("got invalid = %d, want = 1", got) + } + } + }) + } +} + +func TestNDPValidation(t *testing.T) { + stacks := []struct { + name string + useNeighborCache bool }{ { - name: "Valid", - atomicFragment: false, - hopLimit: header.NDPHopLimit, - code: 0, - valid: true, - }, - { - name: "Fragmented", - atomicFragment: true, - hopLimit: header.NDPHopLimit, - code: 0, - valid: false, - }, - { - name: "Invalid hop limit", - atomicFragment: false, - hopLimit: header.NDPHopLimit - 1, - code: 0, - valid: false, + name: "linkAddrCache", + useNeighborCache: false, }, { - name: "Invalid ICMPv6 code", - atomicFragment: false, - hopLimit: header.NDPHopLimit, - code: 1, - valid: false, + name: "neighborCache", + useNeighborCache: true, }, } - for _, typ := range types { - t.Run(typ.name, func(t *testing.T) { - for _, test := range subTests { - t.Run(test.name, func(t *testing.T) { - s, ep, r := setup(t) - defer r.Release() + for _, stackTyp := range stacks { + t.Run(stackTyp.name, func(t *testing.T) { + setup := func(t *testing.T) (*stack.Stack, stack.NetworkEndpoint, stack.Route) { + t.Helper() - stats := s.Stats().ICMP.V6PacketsReceived - invalid := stats.Invalid - typStat := typ.statCounter(stats) + // Create a stack with the assigned link-local address lladdr0 + // and an endpoint to lladdr1. + s, ep := setupStackAndEndpoint(t, lladdr0, lladdr1, stackTyp.useNeighborCache) - icmp := header.ICMPv6(buffer.NewView(typ.size + len(typ.extraData))) - copy(icmp[typ.size:], typ.extraData) - icmp.SetType(typ.typ) - icmp.SetCode(test.code) - icmp.SetChecksum(header.ICMPv6Checksum(icmp[:typ.size], r.LocalAddress, r.RemoteAddress, buffer.View(typ.extraData).ToVectorisedView())) + r, err := s.FindRoute(1, lladdr0, lladdr1, ProtocolNumber, false /* multicastLoop */) + if err != nil { + t.Fatalf("FindRoute(_) = _, %s, want = _, nil", err) + } - // Rx count of the NDP message should initially be 0. - if got := typStat.Value(); got != 0 { - t.Errorf("got %s = %d, want = 0", typ.name, got) - } + return s, ep, r + } - // Invalid count should initially be 0. - if got := invalid.Value(); got != 0 { - t.Errorf("got invalid = %d, want = 0", got) - } + handleIPv6Payload := func(payload buffer.View, hopLimit uint8, atomicFragment bool, ep stack.NetworkEndpoint, r *stack.Route) { + nextHdr := uint8(header.ICMPv6ProtocolNumber) + var extensions buffer.View + if atomicFragment { + extensions = buffer.NewView(header.IPv6FragmentExtHdrLength) + extensions[0] = nextHdr + nextHdr = uint8(header.IPv6FragmentExtHdrIdentifier) + } - if t.Failed() { - t.FailNow() - } + pkt := stack.NewPacketBuffer(stack.PacketBufferOptions{ + ReserveHeaderBytes: header.IPv6MinimumSize + len(extensions), + Data: payload.ToVectorisedView(), + }) + ip := header.IPv6(pkt.NetworkHeader().Push(header.IPv6MinimumSize + len(extensions))) + ip.Encode(&header.IPv6Fields{ + PayloadLength: uint16(len(payload) + len(extensions)), + NextHeader: nextHdr, + HopLimit: hopLimit, + SrcAddr: r.LocalAddress, + DstAddr: r.RemoteAddress, + }) + if n := copy(ip[header.IPv6MinimumSize:], extensions); n != len(extensions) { + t.Fatalf("expected to write %d bytes of extensions, but wrote %d", len(extensions), n) + } + ep.HandlePacket(r, pkt) + } - handleIPv6Payload(buffer.View(icmp), test.hopLimit, test.atomicFragment, ep, &r) + var tllData [header.NDPLinkLayerAddressSize]byte + header.NDPOptions(tllData[:]).Serialize(header.NDPOptionsSerializer{ + header.NDPTargetLinkLayerAddressOption(linkAddr1), + }) - // Rx count of the NDP packet should have increased. - if got := typStat.Value(); got != 1 { - t.Errorf("got %s = %d, want = 1", typ.name, got) - } + var sllData [header.NDPLinkLayerAddressSize]byte + header.NDPOptions(sllData[:]).Serialize(header.NDPOptionsSerializer{ + header.NDPSourceLinkLayerAddressOption(linkAddr1), + }) - want := uint64(0) - if !test.valid { - // Invalid count should have increased. - want = 1 - } - if got := invalid.Value(); got != want { - t.Errorf("got invalid = %d, want = %d", got, want) + types := []struct { + name string + typ header.ICMPv6Type + size int + extraData []byte + statCounter func(tcpip.ICMPv6ReceivedPacketStats) *tcpip.StatCounter + routerOnly bool + }{ + { + name: "RouterSolicit", + typ: header.ICMPv6RouterSolicit, + size: header.ICMPv6MinimumSize, + statCounter: func(stats tcpip.ICMPv6ReceivedPacketStats) *tcpip.StatCounter { + return stats.RouterSolicit + }, + routerOnly: true, + }, + { + name: "RouterAdvert", + typ: header.ICMPv6RouterAdvert, + size: header.ICMPv6HeaderSize + header.NDPRAMinimumSize, + statCounter: func(stats tcpip.ICMPv6ReceivedPacketStats) *tcpip.StatCounter { + return stats.RouterAdvert + }, + }, + { + name: "NeighborSolicit", + typ: header.ICMPv6NeighborSolicit, + size: header.ICMPv6NeighborSolicitMinimumSize, + extraData: sllData[:], + statCounter: func(stats tcpip.ICMPv6ReceivedPacketStats) *tcpip.StatCounter { + return stats.NeighborSolicit + }, + }, + { + name: "NeighborAdvert", + typ: header.ICMPv6NeighborAdvert, + size: header.ICMPv6NeighborAdvertMinimumSize, + extraData: tllData[:], + statCounter: func(stats tcpip.ICMPv6ReceivedPacketStats) *tcpip.StatCounter { + return stats.NeighborAdvert + }, + }, + { + name: "RedirectMsg", + typ: header.ICMPv6RedirectMsg, + size: header.ICMPv6MinimumSize, + statCounter: func(stats tcpip.ICMPv6ReceivedPacketStats) *tcpip.StatCounter { + return stats.RedirectMsg + }, + }, + } + + subTests := []struct { + name string + atomicFragment bool + hopLimit uint8 + code header.ICMPv6Code + valid bool + }{ + { + name: "Valid", + atomicFragment: false, + hopLimit: header.NDPHopLimit, + code: 0, + valid: true, + }, + { + name: "Fragmented", + atomicFragment: true, + hopLimit: header.NDPHopLimit, + code: 0, + valid: false, + }, + { + name: "Invalid hop limit", + atomicFragment: false, + hopLimit: header.NDPHopLimit - 1, + code: 0, + valid: false, + }, + { + name: "Invalid ICMPv6 code", + atomicFragment: false, + hopLimit: header.NDPHopLimit, + code: 1, + valid: false, + }, + } + + for _, typ := range types { + for _, isRouter := range []bool{false, true} { + name := typ.name + if isRouter { + name += " (Router)" } - }) + + t.Run(name, func(t *testing.T) { + for _, test := range subTests { + t.Run(test.name, func(t *testing.T) { + s, ep, r := setup(t) + defer r.Release() + + if isRouter { + // Enabling forwarding makes the stack act as a router. + s.SetForwarding(ProtocolNumber, true) + } + + stats := s.Stats().ICMP.V6PacketsReceived + invalid := stats.Invalid + routerOnly := stats.RouterOnlyPacketsDroppedByHost + typStat := typ.statCounter(stats) + + icmp := header.ICMPv6(buffer.NewView(typ.size + len(typ.extraData))) + copy(icmp[typ.size:], typ.extraData) + icmp.SetType(typ.typ) + icmp.SetCode(test.code) + icmp.SetChecksum(header.ICMPv6Checksum(icmp[:typ.size], r.LocalAddress, r.RemoteAddress, buffer.View(typ.extraData).ToVectorisedView())) + + // Rx count of the NDP message should initially be 0. + if got := typStat.Value(); got != 0 { + t.Errorf("got %s = %d, want = 0", typ.name, got) + } + + // Invalid count should initially be 0. + if got := invalid.Value(); got != 0 { + t.Errorf("got invalid = %d, want = 0", got) + } + + // RouterOnlyPacketsReceivedByHost count should initially be 0. + if got := routerOnly.Value(); got != 0 { + t.Errorf("got RouterOnlyPacketsReceivedByHost = %d, want = 0", got) + } + + if t.Failed() { + t.FailNow() + } + + handleIPv6Payload(buffer.View(icmp), test.hopLimit, test.atomicFragment, ep, &r) + + // Rx count of the NDP packet should have increased. + if got := typStat.Value(); got != 1 { + t.Errorf("got %s = %d, want = 1", typ.name, got) + } + + want := uint64(0) + if !test.valid { + // Invalid count should have increased. + want = 1 + } + if got := invalid.Value(); got != want { + t.Errorf("got invalid = %d, want = %d", got, want) + } + + want = 0 + if test.valid && !isRouter && typ.routerOnly { + // RouterOnlyPacketsReceivedByHost count should have increased. + want = 1 + } + if got := routerOnly.Value(); got != want { + t.Errorf("got RouterOnlyPacketsReceivedByHost = %d, want = %d", got, want) + } + + }) + } + }) + } } }) } + } // TestRouterAdvertValidation tests that when the NIC is configured to handle // NDP Router Advertisement packets, it validates the Router Advertisement // properly before handling them. func TestRouterAdvertValidation(t *testing.T) { + stacks := []struct { + name string + useNeighborCache bool + }{ + { + name: "linkAddrCache", + useNeighborCache: false, + }, + { + name: "neighborCache", + useNeighborCache: true, + }, + } + tests := []struct { name string src tcpip.Address @@ -844,61 +1165,67 @@ func TestRouterAdvertValidation(t *testing.T) { }, } - for _, test := range tests { - t.Run(test.name, func(t *testing.T) { - e := channel.New(10, 1280, linkAddr1) - s := stack.New(stack.Options{ - NetworkProtocols: []stack.NetworkProtocol{NewProtocol()}, - }) - - if err := s.CreateNIC(1, e); err != nil { - t.Fatalf("CreateNIC(_) = %s", err) - } + for _, stackTyp := range stacks { + t.Run(stackTyp.name, func(t *testing.T) { + for _, test := range tests { + t.Run(test.name, func(t *testing.T) { + e := channel.New(10, 1280, linkAddr1) + e.LinkEPCapabilities |= stack.CapabilityResolutionRequired + s := stack.New(stack.Options{ + NetworkProtocols: []stack.NetworkProtocol{NewProtocol()}, + UseNeighborCache: stackTyp.useNeighborCache, + }) + + if err := s.CreateNIC(1, e); err != nil { + t.Fatalf("CreateNIC(_) = %s", err) + } - icmpSize := header.ICMPv6HeaderSize + len(test.ndpPayload) - hdr := buffer.NewPrependable(header.IPv6MinimumSize + icmpSize) - pkt := header.ICMPv6(hdr.Prepend(icmpSize)) - pkt.SetType(header.ICMPv6RouterAdvert) - pkt.SetCode(test.code) - copy(pkt.NDPPayload(), test.ndpPayload) - payloadLength := hdr.UsedLength() - pkt.SetChecksum(header.ICMPv6Checksum(pkt, test.src, header.IPv6AllNodesMulticastAddress, buffer.VectorisedView{})) - ip := header.IPv6(hdr.Prepend(header.IPv6MinimumSize)) - ip.Encode(&header.IPv6Fields{ - PayloadLength: uint16(payloadLength), - NextHeader: uint8(icmp.ProtocolNumber6), - HopLimit: test.hopLimit, - SrcAddr: test.src, - DstAddr: header.IPv6AllNodesMulticastAddress, - }) + icmpSize := header.ICMPv6HeaderSize + len(test.ndpPayload) + hdr := buffer.NewPrependable(header.IPv6MinimumSize + icmpSize) + pkt := header.ICMPv6(hdr.Prepend(icmpSize)) + pkt.SetType(header.ICMPv6RouterAdvert) + pkt.SetCode(test.code) + copy(pkt.NDPPayload(), test.ndpPayload) + payloadLength := hdr.UsedLength() + pkt.SetChecksum(header.ICMPv6Checksum(pkt, test.src, header.IPv6AllNodesMulticastAddress, buffer.VectorisedView{})) + ip := header.IPv6(hdr.Prepend(header.IPv6MinimumSize)) + ip.Encode(&header.IPv6Fields{ + PayloadLength: uint16(payloadLength), + NextHeader: uint8(icmp.ProtocolNumber6), + HopLimit: test.hopLimit, + SrcAddr: test.src, + DstAddr: header.IPv6AllNodesMulticastAddress, + }) - stats := s.Stats().ICMP.V6PacketsReceived - invalid := stats.Invalid - rxRA := stats.RouterAdvert + stats := s.Stats().ICMP.V6PacketsReceived + invalid := stats.Invalid + rxRA := stats.RouterAdvert - if got := invalid.Value(); got != 0 { - t.Fatalf("got invalid = %d, want = 0", got) - } - if got := rxRA.Value(); got != 0 { - t.Fatalf("got rxRA = %d, want = 0", got) - } + if got := invalid.Value(); got != 0 { + t.Fatalf("got invalid = %d, want = 0", got) + } + if got := rxRA.Value(); got != 0 { + t.Fatalf("got rxRA = %d, want = 0", got) + } - e.InjectInbound(header.IPv6ProtocolNumber, stack.NewPacketBuffer(stack.PacketBufferOptions{ - Data: hdr.View().ToVectorisedView(), - })) + e.InjectInbound(header.IPv6ProtocolNumber, stack.NewPacketBuffer(stack.PacketBufferOptions{ + Data: hdr.View().ToVectorisedView(), + })) - if got := rxRA.Value(); got != 1 { - t.Fatalf("got rxRA = %d, want = 1", got) - } + if got := rxRA.Value(); got != 1 { + t.Fatalf("got rxRA = %d, want = 1", got) + } - if test.expectedSuccess { - if got := invalid.Value(); got != 0 { - t.Fatalf("got invalid = %d, want = 0", got) - } - } else { - if got := invalid.Value(); got != 1 { - t.Fatalf("got invalid = %d, want = 1", got) - } + if test.expectedSuccess { + if got := invalid.Value(); got != 0 { + t.Fatalf("got invalid = %d, want = 0", got) + } + } else { + if got := invalid.Value(); got != 1 { + t.Fatalf("got invalid = %d, want = 1", got) + } + } + }) } }) } diff --git a/pkg/tcpip/network/testutil/BUILD b/pkg/tcpip/network/testutil/BUILD new file mode 100644 index 000000000..e218563d0 --- /dev/null +++ b/pkg/tcpip/network/testutil/BUILD @@ -0,0 +1,17 @@ +load("//tools:defs.bzl", "go_library") + +package(licenses = ["notice"]) + +go_library( + name = "testutil", + srcs = [ + "testutil.go", + ], + visibility = ["//pkg/tcpip/network/ipv4:__pkg__"], + deps = [ + "//pkg/tcpip", + "//pkg/tcpip/buffer", + "//pkg/tcpip/link/channel", + "//pkg/tcpip/stack", + ], +) diff --git a/pkg/tcpip/network/testutil/testutil.go b/pkg/tcpip/network/testutil/testutil.go new file mode 100644 index 000000000..bf5ce74be --- /dev/null +++ b/pkg/tcpip/network/testutil/testutil.go @@ -0,0 +1,92 @@ +// Copyright 2020 The gVisor Authors. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +// Package testutil defines types and functions used to test Network Layer +// functionality such as IP fragmentation. +package testutil + +import ( + "fmt" + "math/rand" + + "gvisor.dev/gvisor/pkg/tcpip" + "gvisor.dev/gvisor/pkg/tcpip/buffer" + "gvisor.dev/gvisor/pkg/tcpip/link/channel" + "gvisor.dev/gvisor/pkg/tcpip/stack" +) + +// TestEndpoint is an endpoint used for testing, it stores packets written to it +// and can mock errors. +type TestEndpoint struct { + *channel.Endpoint + + // WrittenPackets is where we store packets written via WritePacket(). + WrittenPackets []*stack.PacketBuffer + + packetCollectorErrors []*tcpip.Error +} + +// NewTestEndpoint creates a new TestEndpoint endpoint. +// +// packetCollectorErrors can be used to set error values and each call to +// WritePacket will remove the first one from the slice and return it until +// the slice is empty - at that point it will return nil every time. +func NewTestEndpoint(ep *channel.Endpoint, packetCollectorErrors []*tcpip.Error) *TestEndpoint { + return &TestEndpoint{ + Endpoint: ep, + WrittenPackets: make([]*stack.PacketBuffer, 0), + packetCollectorErrors: packetCollectorErrors, + } +} + +// WritePacket stores outbound packets and may return an error if one was +// injected. +func (e *TestEndpoint) WritePacket(_ *stack.Route, _ *stack.GSO, _ tcpip.NetworkProtocolNumber, pkt *stack.PacketBuffer) *tcpip.Error { + e.WrittenPackets = append(e.WrittenPackets, pkt) + + if len(e.packetCollectorErrors) > 0 { + nextError := e.packetCollectorErrors[0] + e.packetCollectorErrors = e.packetCollectorErrors[1:] + return nextError + } + + return nil +} + +// MakeRandPkt generates a randomized packet. transportHeaderLength indicates +// how many random bytes will be copied in the Transport Header. +// extraHeaderReserveLength indicates how much extra space will be reserved for +// the other headers. The payload is made from Views of the sizes listed in +// viewSizes. +func MakeRandPkt(transportHeaderLength int, extraHeaderReserveLength int, viewSizes []int, proto tcpip.NetworkProtocolNumber) *stack.PacketBuffer { + var views buffer.VectorisedView + + for _, s := range viewSizes { + newView := buffer.NewView(s) + if _, err := rand.Read(newView); err != nil { + panic(fmt.Sprintf("rand.Read: %s", err)) + } + views.AppendView(newView) + } + + pkt := stack.NewPacketBuffer(stack.PacketBufferOptions{ + ReserveHeaderBytes: transportHeaderLength + extraHeaderReserveLength, + Data: views, + }) + pkt.NetworkProtocolNumber = proto + if _, err := rand.Read(pkt.TransportHeader().Push(transportHeaderLength)); err != nil { + panic(fmt.Sprintf("rand.Read: %s", err)) + } + return pkt +} diff --git a/pkg/tcpip/ports/ports.go b/pkg/tcpip/ports/ports.go index f6d592eb5..d87193650 100644 --- a/pkg/tcpip/ports/ports.go +++ b/pkg/tcpip/ports/ports.go @@ -400,7 +400,11 @@ func (s *PortManager) isPortAvailableLocked(networks []tcpip.NetworkProtocolNumb // reserved by another endpoint. If port is zero, ReservePort will search for // an unreserved ephemeral port and reserve it, returning its value in the // "port" return value. -func (s *PortManager) ReservePort(networks []tcpip.NetworkProtocolNumber, transport tcpip.TransportProtocolNumber, addr tcpip.Address, port uint16, flags Flags, bindToDevice tcpip.NICID, dest tcpip.FullAddress) (reservedPort uint16, err *tcpip.Error) { +// +// An optional testPort closure can be passed in which if provided will be used +// to test if the picked port can be used. The function should return true if +// the port is safe to use, false otherwise. +func (s *PortManager) ReservePort(networks []tcpip.NetworkProtocolNumber, transport tcpip.TransportProtocolNumber, addr tcpip.Address, port uint16, flags Flags, bindToDevice tcpip.NICID, dest tcpip.FullAddress, testPort func(port uint16) bool) (reservedPort uint16, err *tcpip.Error) { s.mu.Lock() defer s.mu.Unlock() @@ -412,12 +416,23 @@ func (s *PortManager) ReservePort(networks []tcpip.NetworkProtocolNumber, transp if !s.reserveSpecificPort(networks, transport, addr, port, flags, bindToDevice, dst) { return 0, tcpip.ErrPortInUse } + if testPort != nil && !testPort(port) { + s.releasePortLocked(networks, transport, addr, port, flags.Bits(), bindToDevice, dst) + return 0, tcpip.ErrPortInUse + } return port, nil } // A port wasn't specified, so try to find one. return s.PickEphemeralPort(func(p uint16) (bool, *tcpip.Error) { - return s.reserveSpecificPort(networks, transport, addr, p, flags, bindToDevice, dst), nil + if !s.reserveSpecificPort(networks, transport, addr, p, flags, bindToDevice, dst) { + return false, nil + } + if testPort != nil && !testPort(p) { + s.releasePortLocked(networks, transport, addr, p, flags.Bits(), bindToDevice, dst) + return false, nil + } + return true, nil }) } diff --git a/pkg/tcpip/ports/ports_test.go b/pkg/tcpip/ports/ports_test.go index 58db5868c..4bc949fd8 100644 --- a/pkg/tcpip/ports/ports_test.go +++ b/pkg/tcpip/ports/ports_test.go @@ -332,7 +332,7 @@ func TestPortReservation(t *testing.T) { pm.ReleasePort(net, fakeTransNumber, test.ip, test.port, test.flags, test.device, test.dest) continue } - gotPort, err := pm.ReservePort(net, fakeTransNumber, test.ip, test.port, test.flags, test.device, test.dest) + gotPort, err := pm.ReservePort(net, fakeTransNumber, test.ip, test.port, test.flags, test.device, test.dest, nil /* testPort */) if err != test.want { t.Fatalf("ReservePort(.., .., %s, %d, %+v, %d, %v) = %v, want %v", test.ip, test.port, test.flags, test.device, test.dest, err, test.want) } diff --git a/pkg/tcpip/sample/tun_tcp_connect/main.go b/pkg/tcpip/sample/tun_tcp_connect/main.go index 0ab089208..91fc26722 100644 --- a/pkg/tcpip/sample/tun_tcp_connect/main.go +++ b/pkg/tcpip/sample/tun_tcp_connect/main.go @@ -182,7 +182,7 @@ func main() { if terr == tcpip.ErrConnectStarted { fmt.Println("Connect is pending...") <-notifyCh - terr = ep.GetSockOpt(tcpip.ErrorOption{}) + terr = ep.LastError() } wq.EventUnregister(&waitEntry) diff --git a/pkg/tcpip/sample/tun_tcp_echo/main.go b/pkg/tcpip/sample/tun_tcp_echo/main.go index 9e37cab18..3f58a15ea 100644 --- a/pkg/tcpip/sample/tun_tcp_echo/main.go +++ b/pkg/tcpip/sample/tun_tcp_echo/main.go @@ -188,7 +188,7 @@ func main() { defer wq.EventUnregister(&waitEntry) for { - n, wq, err := ep.Accept() + n, wq, err := ep.Accept(nil) if err != nil { if err == tcpip.ErrWouldBlock { <-notifyCh diff --git a/pkg/tcpip/stack/conntrack.go b/pkg/tcpip/stack/conntrack.go index 7dd344b4f..836682ea0 100644 --- a/pkg/tcpip/stack/conntrack.go +++ b/pkg/tcpip/stack/conntrack.go @@ -572,7 +572,9 @@ func (ct *ConnTrack) reapUnused(start int, prevInterval time.Duration) (int, tim // reapTupleLocked tries to remove tuple and its reply from the table. It // returns whether the tuple's connection has timed out. // -// Preconditions: ct.mu is locked for reading and bucket is locked. +// Preconditions: +// * ct.mu is locked for reading. +// * bucket is locked. func (ct *ConnTrack) reapTupleLocked(tuple *tuple, bucket int, now time.Time) bool { if !tuple.conn.timedOut(now) { return false diff --git a/pkg/tcpip/stack/forwarder_test.go b/pkg/tcpip/stack/forwarder_test.go index 9dff23623..38c5bac71 100644 --- a/pkg/tcpip/stack/forwarder_test.go +++ b/pkg/tcpip/stack/forwarder_test.go @@ -16,6 +16,7 @@ package stack import ( "encoding/binary" + "math" "testing" "time" @@ -25,8 +26,9 @@ import ( ) const ( - fwdTestNetHeaderLen = 12 - fwdTestNetDefaultPrefixLen = 8 + fwdTestNetNumber tcpip.NetworkProtocolNumber = math.MaxUint32 + fwdTestNetHeaderLen = 12 + fwdTestNetDefaultPrefixLen = 8 // fwdTestNetDefaultMTU is the MTU, in bytes, used throughout the tests, // except where another value is explicitly used. It is chosen to match @@ -49,6 +51,8 @@ type fwdTestNetworkEndpoint struct { ep LinkEndpoint } +var _ NetworkEndpoint = (*fwdTestNetworkEndpoint)(nil) + func (f *fwdTestNetworkEndpoint) MTU() uint32 { return f.ep.MTU() - uint32(f.MaxHeaderLength()) } @@ -90,7 +94,7 @@ func (f *fwdTestNetworkEndpoint) WritePacket(r *Route, gso *GSO, params NetworkH b[srcAddrOffset] = r.LocalAddress[0] b[protocolNumberOffset] = byte(params.Protocol) - return f.ep.WritePacket(r, gso, fakeNetNumber, pkt) + return f.ep.WritePacket(r, gso, fwdTestNetNumber, pkt) } // WritePackets implements LinkEndpoint.WritePackets. @@ -108,15 +112,17 @@ func (*fwdTestNetworkEndpoint) Close() {} // resolution. type fwdTestNetworkProtocol struct { addrCache *linkAddrCache + neigh *neighborCache addrResolveDelay time.Duration - onLinkAddressResolved func(cache *linkAddrCache, addr tcpip.Address, _ tcpip.LinkAddress) + onLinkAddressResolved func(cache *linkAddrCache, neigh *neighborCache, addr tcpip.Address, _ tcpip.LinkAddress) onResolveStaticAddress func(tcpip.Address) (tcpip.LinkAddress, bool) } +var _ NetworkProtocol = (*fwdTestNetworkProtocol)(nil) var _ LinkAddressResolver = (*fwdTestNetworkProtocol)(nil) func (f *fwdTestNetworkProtocol) Number() tcpip.NetworkProtocolNumber { - return fakeNetNumber + return fwdTestNetNumber } func (f *fwdTestNetworkProtocol) MinimumPacketSize() int { @@ -139,7 +145,7 @@ func (*fwdTestNetworkProtocol) Parse(pkt *PacketBuffer) (tcpip.TransportProtocol return tcpip.TransportProtocolNumber(netHeader[protocolNumberOffset]), true, true } -func (f *fwdTestNetworkProtocol) NewEndpoint(nicID tcpip.NICID, linkAddrCache LinkAddressCache, dispatcher TransportDispatcher, ep LinkEndpoint, _ *Stack) NetworkEndpoint { +func (f *fwdTestNetworkProtocol) NewEndpoint(nicID tcpip.NICID, _ LinkAddressCache, _ NUDHandler, dispatcher TransportDispatcher, ep LinkEndpoint, _ *Stack) NetworkEndpoint { return &fwdTestNetworkEndpoint{ nicID: nicID, proto: f, @@ -148,22 +154,22 @@ func (f *fwdTestNetworkProtocol) NewEndpoint(nicID tcpip.NICID, linkAddrCache Li } } -func (f *fwdTestNetworkProtocol) SetOption(option interface{}) *tcpip.Error { +func (*fwdTestNetworkProtocol) SetOption(tcpip.SettableNetworkProtocolOption) *tcpip.Error { return tcpip.ErrUnknownProtocolOption } -func (f *fwdTestNetworkProtocol) Option(option interface{}) *tcpip.Error { +func (*fwdTestNetworkProtocol) Option(tcpip.GettableNetworkProtocolOption) *tcpip.Error { return tcpip.ErrUnknownProtocolOption } -func (f *fwdTestNetworkProtocol) Close() {} +func (*fwdTestNetworkProtocol) Close() {} -func (f *fwdTestNetworkProtocol) Wait() {} +func (*fwdTestNetworkProtocol) Wait() {} func (f *fwdTestNetworkProtocol) LinkAddressRequest(addr, localAddr tcpip.Address, remoteLinkAddr tcpip.LinkAddress, linkEP LinkEndpoint) *tcpip.Error { - if f.addrCache != nil && f.onLinkAddressResolved != nil { + if f.onLinkAddressResolved != nil { time.AfterFunc(f.addrResolveDelay, func() { - f.onLinkAddressResolved(f.addrCache, addr, remoteLinkAddr) + f.onLinkAddressResolved(f.addrCache, f.neigh, addr, remoteLinkAddr) }) } return nil @@ -176,8 +182,8 @@ func (f *fwdTestNetworkProtocol) ResolveStaticAddress(addr tcpip.Address) (tcpip return "", false } -func (f *fwdTestNetworkProtocol) LinkAddressProtocol() tcpip.NetworkProtocolNumber { - return fakeNetNumber +func (*fwdTestNetworkProtocol) LinkAddressProtocol() tcpip.NetworkProtocolNumber { + return fwdTestNetNumber } // fwdTestPacketInfo holds all the information about an outbound packet. @@ -298,13 +304,16 @@ func (e *fwdTestLinkEndpoint) AddHeader(local, remote tcpip.LinkAddress, protoco panic("not implemented") } -func fwdTestNetFactory(t *testing.T, proto *fwdTestNetworkProtocol) (ep1, ep2 *fwdTestLinkEndpoint) { +func fwdTestNetFactory(t *testing.T, proto *fwdTestNetworkProtocol, useNeighborCache bool) (ep1, ep2 *fwdTestLinkEndpoint) { // Create a stack with the network protocol and two NICs. s := New(Options{ NetworkProtocols: []NetworkProtocol{proto}, + UseNeighborCache: useNeighborCache, }) - proto.addrCache = s.linkAddrCache + if !useNeighborCache { + proto.addrCache = s.linkAddrCache + } // Enable forwarding. s.SetForwarding(proto.Number(), true) @@ -318,7 +327,7 @@ func fwdTestNetFactory(t *testing.T, proto *fwdTestNetworkProtocol) (ep1, ep2 *f if err := s.CreateNIC(1, ep1); err != nil { t.Fatal("CreateNIC #1 failed:", err) } - if err := s.AddAddress(1, fakeNetNumber, "\x01"); err != nil { + if err := s.AddAddress(1, fwdTestNetNumber, "\x01"); err != nil { t.Fatal("AddAddress #1 failed:", err) } @@ -331,10 +340,19 @@ func fwdTestNetFactory(t *testing.T, proto *fwdTestNetworkProtocol) (ep1, ep2 *f if err := s.CreateNIC(2, ep2); err != nil { t.Fatal("CreateNIC #2 failed:", err) } - if err := s.AddAddress(2, fakeNetNumber, "\x02"); err != nil { + if err := s.AddAddress(2, fwdTestNetNumber, "\x02"); err != nil { t.Fatal("AddAddress #2 failed:", err) } + if useNeighborCache { + // Control the neighbor cache for NIC 2. + nic, ok := s.nics[2] + if !ok { + t.Fatal("failed to get the neighbor cache for NIC 2") + } + proto.neigh = nic.neigh + } + // Route all packets to NIC 2. { subnet, err := tcpip.NewSubnet("\x00", "\x00") @@ -348,79 +366,129 @@ func fwdTestNetFactory(t *testing.T, proto *fwdTestNetworkProtocol) (ep1, ep2 *f } func TestForwardingWithStaticResolver(t *testing.T) { - // Create a network protocol with a static resolver. - proto := &fwdTestNetworkProtocol{ - onResolveStaticAddress: - // The network address 3 is resolved to the link address "c". - func(addr tcpip.Address) (tcpip.LinkAddress, bool) { - if addr == "\x03" { - return "c", true - } - return "", false + tests := []struct { + name string + useNeighborCache bool + }{ + { + name: "linkAddrCache", + useNeighborCache: false, + }, + { + name: "neighborCache", + useNeighborCache: true, }, } - ep1, ep2 := fwdTestNetFactory(t, proto) + for _, test := range tests { + t.Run(test.name, func(t *testing.T) { + // Create a network protocol with a static resolver. + proto := &fwdTestNetworkProtocol{ + onResolveStaticAddress: + // The network address 3 is resolved to the link address "c". + func(addr tcpip.Address) (tcpip.LinkAddress, bool) { + if addr == "\x03" { + return "c", true + } + return "", false + }, + } - // Inject an inbound packet to address 3 on NIC 1, and see if it is - // forwarded to NIC 2. - buf := buffer.NewView(30) - buf[dstAddrOffset] = 3 - ep1.InjectInbound(fakeNetNumber, NewPacketBuffer(PacketBufferOptions{ - Data: buf.ToVectorisedView(), - })) + ep1, ep2 := fwdTestNetFactory(t, proto, test.useNeighborCache) - var p fwdTestPacketInfo + // Inject an inbound packet to address 3 on NIC 1, and see if it is + // forwarded to NIC 2. + buf := buffer.NewView(30) + buf[dstAddrOffset] = 3 + ep1.InjectInbound(fwdTestNetNumber, NewPacketBuffer(PacketBufferOptions{ + Data: buf.ToVectorisedView(), + })) - select { - case p = <-ep2.C: - default: - t.Fatal("packet not forwarded") - } + var p fwdTestPacketInfo - // Test that the static address resolution happened correctly. - if p.RemoteLinkAddress != "c" { - t.Fatalf("got p.RemoteLinkAddress = %s, want = c", p.RemoteLinkAddress) - } - if p.LocalLinkAddress != "b" { - t.Fatalf("got p.LocalLinkAddress = %s, want = b", p.LocalLinkAddress) + select { + case p = <-ep2.C: + default: + t.Fatal("packet not forwarded") + } + + // Test that the static address resolution happened correctly. + if p.RemoteLinkAddress != "c" { + t.Fatalf("got p.RemoteLinkAddress = %s, want = c", p.RemoteLinkAddress) + } + if p.LocalLinkAddress != "b" { + t.Fatalf("got p.LocalLinkAddress = %s, want = b", p.LocalLinkAddress) + } + }) } } func TestForwardingWithFakeResolver(t *testing.T) { - // Create a network protocol with a fake resolver. - proto := &fwdTestNetworkProtocol{ - addrResolveDelay: 500 * time.Millisecond, - onLinkAddressResolved: func(cache *linkAddrCache, addr tcpip.Address, _ tcpip.LinkAddress) { - // Any address will be resolved to the link address "c". - cache.add(tcpip.FullAddress{NIC: 2, Addr: addr}, "c") + tests := []struct { + name string + useNeighborCache bool + proto *fwdTestNetworkProtocol + }{ + { + name: "linkAddrCache", + useNeighborCache: false, + proto: &fwdTestNetworkProtocol{ + addrResolveDelay: 500 * time.Millisecond, + onLinkAddressResolved: func(cache *linkAddrCache, neigh *neighborCache, addr tcpip.Address, _ tcpip.LinkAddress) { + // Any address will be resolved to the link address "c". + cache.add(tcpip.FullAddress{NIC: 2, Addr: addr}, "c") + }, + }, + }, + { + name: "neighborCache", + useNeighborCache: true, + proto: &fwdTestNetworkProtocol{ + addrResolveDelay: 500 * time.Millisecond, + onLinkAddressResolved: func(cache *linkAddrCache, neigh *neighborCache, addr tcpip.Address, remoteLinkAddr tcpip.LinkAddress) { + t.Helper() + if len(remoteLinkAddr) != 0 { + t.Fatalf("got remoteLinkAddr=%q, want unspecified", remoteLinkAddr) + } + // Any address will be resolved to the link address "c". + neigh.HandleConfirmation(addr, "c", ReachabilityConfirmationFlags{ + Solicited: true, + Override: false, + IsRouter: false, + }) + }, + }, }, } - ep1, ep2 := fwdTestNetFactory(t, proto) + for _, test := range tests { + t.Run(test.name, func(t *testing.T) { + ep1, ep2 := fwdTestNetFactory(t, test.proto, test.useNeighborCache) - // Inject an inbound packet to address 3 on NIC 1, and see if it is - // forwarded to NIC 2. - buf := buffer.NewView(30) - buf[dstAddrOffset] = 3 - ep1.InjectInbound(fakeNetNumber, NewPacketBuffer(PacketBufferOptions{ - Data: buf.ToVectorisedView(), - })) + // Inject an inbound packet to address 3 on NIC 1, and see if it is + // forwarded to NIC 2. + buf := buffer.NewView(30) + buf[dstAddrOffset] = 3 + ep1.InjectInbound(fwdTestNetNumber, NewPacketBuffer(PacketBufferOptions{ + Data: buf.ToVectorisedView(), + })) - var p fwdTestPacketInfo + var p fwdTestPacketInfo - select { - case p = <-ep2.C: - case <-time.After(time.Second): - t.Fatal("packet not forwarded") - } + select { + case p = <-ep2.C: + case <-time.After(time.Second): + t.Fatal("packet not forwarded") + } - // Test that the address resolution happened correctly. - if p.RemoteLinkAddress != "c" { - t.Fatalf("got p.RemoteLinkAddress = %s, want = c", p.RemoteLinkAddress) - } - if p.LocalLinkAddress != "b" { - t.Fatalf("got p.LocalLinkAddress = %s, want = b", p.LocalLinkAddress) + // Test that the address resolution happened correctly. + if p.RemoteLinkAddress != "c" { + t.Fatalf("got p.RemoteLinkAddress = %s, want = c", p.RemoteLinkAddress) + } + if p.LocalLinkAddress != "b" { + t.Fatalf("got p.LocalLinkAddress = %s, want = b", p.LocalLinkAddress) + } + }) } } @@ -428,13 +496,15 @@ func TestForwardingWithNoResolver(t *testing.T) { // Create a network protocol without a resolver. proto := &fwdTestNetworkProtocol{} - ep1, ep2 := fwdTestNetFactory(t, proto) + // Whether or not we use the neighbor cache here does not matter since + // neither linkAddrCache nor neighborCache will be used. + ep1, ep2 := fwdTestNetFactory(t, proto, false /* useNeighborCache */) // inject an inbound packet to address 3 on NIC 1, and see if it is // forwarded to NIC 2. buf := buffer.NewView(30) buf[dstAddrOffset] = 3 - ep1.InjectInbound(fakeNetNumber, NewPacketBuffer(PacketBufferOptions{ + ep1.InjectInbound(fwdTestNetNumber, NewPacketBuffer(PacketBufferOptions{ Data: buf.ToVectorisedView(), })) @@ -446,203 +516,334 @@ func TestForwardingWithNoResolver(t *testing.T) { } func TestForwardingWithFakeResolverPartialTimeout(t *testing.T) { - // Create a network protocol with a fake resolver. - proto := &fwdTestNetworkProtocol{ - addrResolveDelay: 500 * time.Millisecond, - onLinkAddressResolved: func(cache *linkAddrCache, addr tcpip.Address, _ tcpip.LinkAddress) { - // Only packets to address 3 will be resolved to the - // link address "c". - if addr == "\x03" { - cache.add(tcpip.FullAddress{NIC: 2, Addr: addr}, "c") - } + tests := []struct { + name string + useNeighborCache bool + proto *fwdTestNetworkProtocol + }{ + { + name: "linkAddrCache", + useNeighborCache: false, + proto: &fwdTestNetworkProtocol{ + addrResolveDelay: 500 * time.Millisecond, + onLinkAddressResolved: func(cache *linkAddrCache, neigh *neighborCache, addr tcpip.Address, _ tcpip.LinkAddress) { + // Only packets to address 3 will be resolved to the + // link address "c". + if addr == "\x03" { + cache.add(tcpip.FullAddress{NIC: 2, Addr: addr}, "c") + } + }, + }, + }, + { + name: "neighborCache", + useNeighborCache: true, + proto: &fwdTestNetworkProtocol{ + addrResolveDelay: 500 * time.Millisecond, + onLinkAddressResolved: func(cache *linkAddrCache, neigh *neighborCache, addr tcpip.Address, remoteLinkAddr tcpip.LinkAddress) { + t.Helper() + if len(remoteLinkAddr) != 0 { + t.Fatalf("got remoteLinkAddr=%q, want unspecified", remoteLinkAddr) + } + // Only packets to address 3 will be resolved to the + // link address "c". + if addr == "\x03" { + neigh.HandleConfirmation(addr, "c", ReachabilityConfirmationFlags{ + Solicited: true, + Override: false, + IsRouter: false, + }) + } + }, + }, }, } - ep1, ep2 := fwdTestNetFactory(t, proto) - - // Inject an inbound packet to address 4 on NIC 1. This packet should - // not be forwarded. - buf := buffer.NewView(30) - buf[dstAddrOffset] = 4 - ep1.InjectInbound(fakeNetNumber, NewPacketBuffer(PacketBufferOptions{ - Data: buf.ToVectorisedView(), - })) - - // Inject an inbound packet to address 3 on NIC 1, and see if it is - // forwarded to NIC 2. - buf = buffer.NewView(30) - buf[dstAddrOffset] = 3 - ep1.InjectInbound(fakeNetNumber, NewPacketBuffer(PacketBufferOptions{ - Data: buf.ToVectorisedView(), - })) - - var p fwdTestPacketInfo - - select { - case p = <-ep2.C: - case <-time.After(time.Second): - t.Fatal("packet not forwarded") - } + for _, test := range tests { + t.Run(test.name, func(t *testing.T) { + ep1, ep2 := fwdTestNetFactory(t, test.proto, test.useNeighborCache) + + // Inject an inbound packet to address 4 on NIC 1. This packet should + // not be forwarded. + buf := buffer.NewView(30) + buf[dstAddrOffset] = 4 + ep1.InjectInbound(fwdTestNetNumber, NewPacketBuffer(PacketBufferOptions{ + Data: buf.ToVectorisedView(), + })) + + // Inject an inbound packet to address 3 on NIC 1, and see if it is + // forwarded to NIC 2. + buf = buffer.NewView(30) + buf[dstAddrOffset] = 3 + ep1.InjectInbound(fwdTestNetNumber, NewPacketBuffer(PacketBufferOptions{ + Data: buf.ToVectorisedView(), + })) + + var p fwdTestPacketInfo + + select { + case p = <-ep2.C: + case <-time.After(time.Second): + t.Fatal("packet not forwarded") + } - if nh := PayloadSince(p.Pkt.NetworkHeader()); nh[dstAddrOffset] != 3 { - t.Fatalf("got p.Pkt.NetworkHeader[dstAddrOffset] = %d, want = 3", nh[dstAddrOffset]) - } + if nh := PayloadSince(p.Pkt.NetworkHeader()); nh[dstAddrOffset] != 3 { + t.Fatalf("got p.Pkt.NetworkHeader[dstAddrOffset] = %d, want = 3", nh[dstAddrOffset]) + } - // Test that the address resolution happened correctly. - if p.RemoteLinkAddress != "c" { - t.Fatalf("got p.RemoteLinkAddress = %s, want = c", p.RemoteLinkAddress) - } - if p.LocalLinkAddress != "b" { - t.Fatalf("got p.LocalLinkAddress = %s, want = b", p.LocalLinkAddress) + // Test that the address resolution happened correctly. + if p.RemoteLinkAddress != "c" { + t.Fatalf("got p.RemoteLinkAddress = %s, want = c", p.RemoteLinkAddress) + } + if p.LocalLinkAddress != "b" { + t.Fatalf("got p.LocalLinkAddress = %s, want = b", p.LocalLinkAddress) + } + }) } } func TestForwardingWithFakeResolverTwoPackets(t *testing.T) { - // Create a network protocol with a fake resolver. - proto := &fwdTestNetworkProtocol{ - addrResolveDelay: 500 * time.Millisecond, - onLinkAddressResolved: func(cache *linkAddrCache, addr tcpip.Address, _ tcpip.LinkAddress) { - // Any packets will be resolved to the link address "c". - cache.add(tcpip.FullAddress{NIC: 2, Addr: addr}, "c") + tests := []struct { + name string + useNeighborCache bool + proto *fwdTestNetworkProtocol + }{ + { + name: "linkAddrCache", + useNeighborCache: false, + proto: &fwdTestNetworkProtocol{ + addrResolveDelay: 500 * time.Millisecond, + onLinkAddressResolved: func(cache *linkAddrCache, neigh *neighborCache, addr tcpip.Address, _ tcpip.LinkAddress) { + // Any packets will be resolved to the link address "c". + cache.add(tcpip.FullAddress{NIC: 2, Addr: addr}, "c") + }, + }, + }, + { + name: "neighborCache", + useNeighborCache: true, + proto: &fwdTestNetworkProtocol{ + addrResolveDelay: 500 * time.Millisecond, + onLinkAddressResolved: func(cache *linkAddrCache, neigh *neighborCache, addr tcpip.Address, remoteLinkAddr tcpip.LinkAddress) { + t.Helper() + if len(remoteLinkAddr) != 0 { + t.Fatalf("got remoteLinkAddr=%q, want unspecified", remoteLinkAddr) + } + // Any packets will be resolved to the link address "c". + neigh.HandleConfirmation(addr, "c", ReachabilityConfirmationFlags{ + Solicited: true, + Override: false, + IsRouter: false, + }) + }, + }, }, } - ep1, ep2 := fwdTestNetFactory(t, proto) - - // Inject two inbound packets to address 3 on NIC 1. - for i := 0; i < 2; i++ { - buf := buffer.NewView(30) - buf[dstAddrOffset] = 3 - ep1.InjectInbound(fakeNetNumber, NewPacketBuffer(PacketBufferOptions{ - Data: buf.ToVectorisedView(), - })) - } - - for i := 0; i < 2; i++ { - var p fwdTestPacketInfo + for _, test := range tests { + t.Run(test.name, func(t *testing.T) { + ep1, ep2 := fwdTestNetFactory(t, test.proto, test.useNeighborCache) - select { - case p = <-ep2.C: - case <-time.After(time.Second): - t.Fatal("packet not forwarded") - } - - if nh := PayloadSince(p.Pkt.NetworkHeader()); nh[dstAddrOffset] != 3 { - t.Fatalf("got p.Pkt.NetworkHeader[dstAddrOffset] = %d, want = 3", nh[dstAddrOffset]) - } + // Inject two inbound packets to address 3 on NIC 1. + for i := 0; i < 2; i++ { + buf := buffer.NewView(30) + buf[dstAddrOffset] = 3 + ep1.InjectInbound(fwdTestNetNumber, NewPacketBuffer(PacketBufferOptions{ + Data: buf.ToVectorisedView(), + })) + } - // Test that the address resolution happened correctly. - if p.RemoteLinkAddress != "c" { - t.Fatalf("got p.RemoteLinkAddress = %s, want = c", p.RemoteLinkAddress) - } - if p.LocalLinkAddress != "b" { - t.Fatalf("got p.LocalLinkAddress = %s, want = b", p.LocalLinkAddress) - } + for i := 0; i < 2; i++ { + var p fwdTestPacketInfo + + select { + case p = <-ep2.C: + case <-time.After(time.Second): + t.Fatal("packet not forwarded") + } + + if nh := PayloadSince(p.Pkt.NetworkHeader()); nh[dstAddrOffset] != 3 { + t.Fatalf("got p.Pkt.NetworkHeader[dstAddrOffset] = %d, want = 3", nh[dstAddrOffset]) + } + + // Test that the address resolution happened correctly. + if p.RemoteLinkAddress != "c" { + t.Fatalf("got p.RemoteLinkAddress = %s, want = c", p.RemoteLinkAddress) + } + if p.LocalLinkAddress != "b" { + t.Fatalf("got p.LocalLinkAddress = %s, want = b", p.LocalLinkAddress) + } + } + }) } } func TestForwardingWithFakeResolverManyPackets(t *testing.T) { - // Create a network protocol with a fake resolver. - proto := &fwdTestNetworkProtocol{ - addrResolveDelay: 500 * time.Millisecond, - onLinkAddressResolved: func(cache *linkAddrCache, addr tcpip.Address, _ tcpip.LinkAddress) { - // Any packets will be resolved to the link address "c". - cache.add(tcpip.FullAddress{NIC: 2, Addr: addr}, "c") + tests := []struct { + name string + useNeighborCache bool + proto *fwdTestNetworkProtocol + }{ + { + name: "linkAddrCache", + useNeighborCache: false, + proto: &fwdTestNetworkProtocol{ + addrResolveDelay: 500 * time.Millisecond, + onLinkAddressResolved: func(cache *linkAddrCache, neigh *neighborCache, addr tcpip.Address, _ tcpip.LinkAddress) { + // Any packets will be resolved to the link address "c". + cache.add(tcpip.FullAddress{NIC: 2, Addr: addr}, "c") + }, + }, + }, + { + name: "neighborCache", + useNeighborCache: true, + proto: &fwdTestNetworkProtocol{ + addrResolveDelay: 500 * time.Millisecond, + onLinkAddressResolved: func(cache *linkAddrCache, neigh *neighborCache, addr tcpip.Address, remoteLinkAddr tcpip.LinkAddress) { + t.Helper() + if len(remoteLinkAddr) != 0 { + t.Fatalf("got remoteLinkAddr=%q, want unspecified", remoteLinkAddr) + } + // Any packets will be resolved to the link address "c". + neigh.HandleConfirmation(addr, "c", ReachabilityConfirmationFlags{ + Solicited: true, + Override: false, + IsRouter: false, + }) + }, + }, }, } - ep1, ep2 := fwdTestNetFactory(t, proto) - - for i := 0; i < maxPendingPacketsPerResolution+5; i++ { - // Inject inbound 'maxPendingPacketsPerResolution + 5' packets on NIC 1. - buf := buffer.NewView(30) - buf[dstAddrOffset] = 3 - // Set the packet sequence number. - binary.BigEndian.PutUint16(buf[fwdTestNetHeaderLen:], uint16(i)) - ep1.InjectInbound(fakeNetNumber, NewPacketBuffer(PacketBufferOptions{ - Data: buf.ToVectorisedView(), - })) - } - - for i := 0; i < maxPendingPacketsPerResolution; i++ { - var p fwdTestPacketInfo - - select { - case p = <-ep2.C: - case <-time.After(time.Second): - t.Fatal("packet not forwarded") - } - - b := PayloadSince(p.Pkt.NetworkHeader()) - if b[dstAddrOffset] != 3 { - t.Fatalf("got b[dstAddrOffset] = %d, want = 3", b[dstAddrOffset]) - } - if len(b) < fwdTestNetHeaderLen+2 { - t.Fatalf("packet is too short to hold a sequence number: len(b) = %d", b) - } - seqNumBuf := b[fwdTestNetHeaderLen:] + for _, test := range tests { + t.Run(test.name, func(t *testing.T) { + ep1, ep2 := fwdTestNetFactory(t, test.proto, test.useNeighborCache) - // The first 5 packets should not be forwarded so the sequence number should - // start with 5. - want := uint16(i + 5) - if n := binary.BigEndian.Uint16(seqNumBuf); n != want { - t.Fatalf("got the packet #%d, want = #%d", n, want) - } + for i := 0; i < maxPendingPacketsPerResolution+5; i++ { + // Inject inbound 'maxPendingPacketsPerResolution + 5' packets on NIC 1. + buf := buffer.NewView(30) + buf[dstAddrOffset] = 3 + // Set the packet sequence number. + binary.BigEndian.PutUint16(buf[fwdTestNetHeaderLen:], uint16(i)) + ep1.InjectInbound(fwdTestNetNumber, NewPacketBuffer(PacketBufferOptions{ + Data: buf.ToVectorisedView(), + })) + } - // Test that the address resolution happened correctly. - if p.RemoteLinkAddress != "c" { - t.Fatalf("got p.RemoteLinkAddress = %s, want = c", p.RemoteLinkAddress) - } - if p.LocalLinkAddress != "b" { - t.Fatalf("got p.LocalLinkAddress = %s, want = b", p.LocalLinkAddress) - } + for i := 0; i < maxPendingPacketsPerResolution; i++ { + var p fwdTestPacketInfo + + select { + case p = <-ep2.C: + case <-time.After(time.Second): + t.Fatal("packet not forwarded") + } + + b := PayloadSince(p.Pkt.NetworkHeader()) + if b[dstAddrOffset] != 3 { + t.Fatalf("got b[dstAddrOffset] = %d, want = 3", b[dstAddrOffset]) + } + if len(b) < fwdTestNetHeaderLen+2 { + t.Fatalf("packet is too short to hold a sequence number: len(b) = %d", b) + } + seqNumBuf := b[fwdTestNetHeaderLen:] + + // The first 5 packets should not be forwarded so the sequence number should + // start with 5. + want := uint16(i + 5) + if n := binary.BigEndian.Uint16(seqNumBuf); n != want { + t.Fatalf("got the packet #%d, want = #%d", n, want) + } + + // Test that the address resolution happened correctly. + if p.RemoteLinkAddress != "c" { + t.Fatalf("got p.RemoteLinkAddress = %s, want = c", p.RemoteLinkAddress) + } + if p.LocalLinkAddress != "b" { + t.Fatalf("got p.LocalLinkAddress = %s, want = b", p.LocalLinkAddress) + } + } + }) } } func TestForwardingWithFakeResolverManyResolutions(t *testing.T) { - // Create a network protocol with a fake resolver. - proto := &fwdTestNetworkProtocol{ - addrResolveDelay: 500 * time.Millisecond, - onLinkAddressResolved: func(cache *linkAddrCache, addr tcpip.Address, _ tcpip.LinkAddress) { - // Any packets will be resolved to the link address "c". - cache.add(tcpip.FullAddress{NIC: 2, Addr: addr}, "c") + tests := []struct { + name string + useNeighborCache bool + proto *fwdTestNetworkProtocol + }{ + { + name: "linkAddrCache", + useNeighborCache: false, + proto: &fwdTestNetworkProtocol{ + addrResolveDelay: 500 * time.Millisecond, + onLinkAddressResolved: func(cache *linkAddrCache, neigh *neighborCache, addr tcpip.Address, _ tcpip.LinkAddress) { + // Any packets will be resolved to the link address "c". + cache.add(tcpip.FullAddress{NIC: 2, Addr: addr}, "c") + }, + }, + }, + { + name: "neighborCache", + useNeighborCache: true, + proto: &fwdTestNetworkProtocol{ + addrResolveDelay: 500 * time.Millisecond, + onLinkAddressResolved: func(cache *linkAddrCache, neigh *neighborCache, addr tcpip.Address, remoteLinkAddr tcpip.LinkAddress) { + t.Helper() + if len(remoteLinkAddr) != 0 { + t.Fatalf("got remoteLinkAddr=%q, want unspecified", remoteLinkAddr) + } + // Any packets will be resolved to the link address "c". + neigh.HandleConfirmation(addr, "c", ReachabilityConfirmationFlags{ + Solicited: true, + Override: false, + IsRouter: false, + }) + }, + }, }, } - ep1, ep2 := fwdTestNetFactory(t, proto) - - for i := 0; i < maxPendingResolutions+5; i++ { - // Inject inbound 'maxPendingResolutions + 5' packets on NIC 1. - // Each packet has a different destination address (3 to - // maxPendingResolutions + 7). - buf := buffer.NewView(30) - buf[dstAddrOffset] = byte(3 + i) - ep1.InjectInbound(fakeNetNumber, NewPacketBuffer(PacketBufferOptions{ - Data: buf.ToVectorisedView(), - })) - } - - for i := 0; i < maxPendingResolutions; i++ { - var p fwdTestPacketInfo - - select { - case p = <-ep2.C: - case <-time.After(time.Second): - t.Fatal("packet not forwarded") - } + for _, test := range tests { + t.Run(test.name, func(t *testing.T) { + ep1, ep2 := fwdTestNetFactory(t, test.proto, test.useNeighborCache) - // The first 5 packets (address 3 to 7) should not be forwarded - // because their address resolutions are interrupted. - if nh := PayloadSince(p.Pkt.NetworkHeader()); nh[dstAddrOffset] < 8 { - t.Fatalf("got p.Pkt.NetworkHeader[dstAddrOffset] = %d, want p.Pkt.NetworkHeader[dstAddrOffset] >= 8", nh[dstAddrOffset]) - } + for i := 0; i < maxPendingResolutions+5; i++ { + // Inject inbound 'maxPendingResolutions + 5' packets on NIC 1. + // Each packet has a different destination address (3 to + // maxPendingResolutions + 7). + buf := buffer.NewView(30) + buf[dstAddrOffset] = byte(3 + i) + ep1.InjectInbound(fwdTestNetNumber, NewPacketBuffer(PacketBufferOptions{ + Data: buf.ToVectorisedView(), + })) + } - // Test that the address resolution happened correctly. - if p.RemoteLinkAddress != "c" { - t.Fatalf("got p.RemoteLinkAddress = %s, want = c", p.RemoteLinkAddress) - } - if p.LocalLinkAddress != "b" { - t.Fatalf("got p.LocalLinkAddress = %s, want = b", p.LocalLinkAddress) - } + for i := 0; i < maxPendingResolutions; i++ { + var p fwdTestPacketInfo + + select { + case p = <-ep2.C: + case <-time.After(time.Second): + t.Fatal("packet not forwarded") + } + + // The first 5 packets (address 3 to 7) should not be forwarded + // because their address resolutions are interrupted. + if nh := PayloadSince(p.Pkt.NetworkHeader()); nh[dstAddrOffset] < 8 { + t.Fatalf("got p.Pkt.NetworkHeader[dstAddrOffset] = %d, want p.Pkt.NetworkHeader[dstAddrOffset] >= 8", nh[dstAddrOffset]) + } + + // Test that the address resolution happened correctly. + if p.RemoteLinkAddress != "c" { + t.Fatalf("got p.RemoteLinkAddress = %s, want = c", p.RemoteLinkAddress) + } + if p.LocalLinkAddress != "b" { + t.Fatalf("got p.LocalLinkAddress = %s, want = b", p.LocalLinkAddress) + } + } + }) } } diff --git a/pkg/tcpip/stack/iptables.go b/pkg/tcpip/stack/iptables.go index c37da814f..4a521eca9 100644 --- a/pkg/tcpip/stack/iptables.go +++ b/pkg/tcpip/stack/iptables.go @@ -57,7 +57,72 @@ const reaperDelay = 5 * time.Second // all packets. func DefaultTables() *IPTables { return &IPTables{ - tables: [numTables]Table{ + v4Tables: [numTables]Table{ + natID: Table{ + Rules: []Rule{ + Rule{Target: AcceptTarget{}}, + Rule{Target: AcceptTarget{}}, + Rule{Target: AcceptTarget{}}, + Rule{Target: AcceptTarget{}}, + Rule{Target: ErrorTarget{}}, + }, + BuiltinChains: [NumHooks]int{ + Prerouting: 0, + Input: 1, + Forward: HookUnset, + Output: 2, + Postrouting: 3, + }, + Underflows: [NumHooks]int{ + Prerouting: 0, + Input: 1, + Forward: HookUnset, + Output: 2, + Postrouting: 3, + }, + }, + mangleID: Table{ + Rules: []Rule{ + Rule{Target: AcceptTarget{}}, + Rule{Target: AcceptTarget{}}, + Rule{Target: ErrorTarget{}}, + }, + BuiltinChains: [NumHooks]int{ + Prerouting: 0, + Output: 1, + }, + Underflows: [NumHooks]int{ + Prerouting: 0, + Input: HookUnset, + Forward: HookUnset, + Output: 1, + Postrouting: HookUnset, + }, + }, + filterID: Table{ + Rules: []Rule{ + Rule{Target: AcceptTarget{}}, + Rule{Target: AcceptTarget{}}, + Rule{Target: AcceptTarget{}}, + Rule{Target: ErrorTarget{}}, + }, + BuiltinChains: [NumHooks]int{ + Prerouting: HookUnset, + Input: 0, + Forward: 1, + Output: 2, + Postrouting: HookUnset, + }, + Underflows: [NumHooks]int{ + Prerouting: HookUnset, + Input: 0, + Forward: 1, + Output: 2, + Postrouting: HookUnset, + }, + }, + }, + v6Tables: [numTables]Table{ natID: Table{ Rules: []Rule{ Rule{Target: AcceptTarget{}}, @@ -165,18 +230,21 @@ func EmptyNATTable() Table { } // GetTable returns a table by name. -func (it *IPTables) GetTable(name string) (Table, bool) { +func (it *IPTables) GetTable(name string, ipv6 bool) (Table, bool) { id, ok := nameToID[name] if !ok { return Table{}, false } it.mu.RLock() defer it.mu.RUnlock() - return it.tables[id], true + if ipv6 { + return it.v6Tables[id], true + } + return it.v4Tables[id], true } // ReplaceTable replaces or inserts table by name. -func (it *IPTables) ReplaceTable(name string, table Table) *tcpip.Error { +func (it *IPTables) ReplaceTable(name string, table Table, ipv6 bool) *tcpip.Error { id, ok := nameToID[name] if !ok { return tcpip.ErrInvalidOptionValue @@ -190,7 +258,11 @@ func (it *IPTables) ReplaceTable(name string, table Table) *tcpip.Error { it.startReaper(reaperDelay) } it.modified = true - it.tables[id] = table + if ipv6 { + it.v6Tables[id] = table + } else { + it.v4Tables[id] = table + } return nil } @@ -213,8 +285,15 @@ const ( // should continue traversing the network stack and false when it should be // dropped. // +// TODO(gvisor.dev/issue/170): PacketBuffer should hold the GSO and route, from +// which address and nicName can be gathered. Currently, address is only +// needed for prerouting and nicName is only needed for output. +// // Precondition: pkt.NetworkHeader is set. -func (it *IPTables) Check(hook Hook, pkt *PacketBuffer, gso *GSO, r *Route, address tcpip.Address, nicName string) bool { +func (it *IPTables) Check(hook Hook, pkt *PacketBuffer, gso *GSO, r *Route, preroutingAddr tcpip.Address, nicName string) bool { + if pkt.NetworkProtocolNumber != header.IPv4ProtocolNumber && pkt.NetworkProtocolNumber != header.IPv6ProtocolNumber { + return true + } // Many users never configure iptables. Spare them the cost of rule // traversal if rules have never been set. it.mu.RLock() @@ -235,9 +314,14 @@ func (it *IPTables) Check(hook Hook, pkt *PacketBuffer, gso *GSO, r *Route, addr if tableID == natID && pkt.NatDone { continue } - table := it.tables[tableID] + var table Table + if pkt.NetworkProtocolNumber == header.IPv6ProtocolNumber { + table = it.v6Tables[tableID] + } else { + table = it.v4Tables[tableID] + } ruleIdx := table.BuiltinChains[hook] - switch verdict := it.checkChain(hook, pkt, table, ruleIdx, gso, r, address, nicName); verdict { + switch verdict := it.checkChain(hook, pkt, table, ruleIdx, gso, r, preroutingAddr, nicName); verdict { // If the table returns Accept, move on to the next table. case chainAccept: continue @@ -248,7 +332,7 @@ func (it *IPTables) Check(hook Hook, pkt *PacketBuffer, gso *GSO, r *Route, addr // Any Return from a built-in chain means we have to // call the underflow. underflow := table.Rules[table.Underflows[hook]] - switch v, _ := underflow.Target.Action(pkt, &it.connections, hook, gso, r, address); v { + switch v, _ := underflow.Target.Action(pkt, &it.connections, hook, gso, r, preroutingAddr); v { case RuleAccept: continue case RuleDrop: @@ -315,8 +399,8 @@ func (it *IPTables) startReaper(interval time.Duration) { // should not go forward. // // Preconditions: -// - pkt is a IPv4 packet of at least length header.IPv4MinimumSize. -// - pkt.NetworkHeader is not nil. +// * pkt is a IPv4 packet of at least length header.IPv4MinimumSize. +// * pkt.NetworkHeader is not nil. // // NOTE: unlike the Check API the returned map contains packets that should be // dropped. @@ -341,13 +425,13 @@ func (it *IPTables) CheckPackets(hook Hook, pkts PacketBufferList, gso *GSO, r * } // Preconditions: -// - pkt is a IPv4 packet of at least length header.IPv4MinimumSize. -// - pkt.NetworkHeader is not nil. -func (it *IPTables) checkChain(hook Hook, pkt *PacketBuffer, table Table, ruleIdx int, gso *GSO, r *Route, address tcpip.Address, nicName string) chainVerdict { +// * pkt is a IPv4 packet of at least length header.IPv4MinimumSize. +// * pkt.NetworkHeader is not nil. +func (it *IPTables) checkChain(hook Hook, pkt *PacketBuffer, table Table, ruleIdx int, gso *GSO, r *Route, preroutingAddr tcpip.Address, nicName string) chainVerdict { // Start from ruleIdx and walk the list of rules until a rule gives us // a verdict. for ruleIdx < len(table.Rules) { - switch verdict, jumpTo := it.checkRule(hook, pkt, table, ruleIdx, gso, r, address, nicName); verdict { + switch verdict, jumpTo := it.checkRule(hook, pkt, table, ruleIdx, gso, r, preroutingAddr, nicName); verdict { case RuleAccept: return chainAccept @@ -364,7 +448,7 @@ func (it *IPTables) checkChain(hook Hook, pkt *PacketBuffer, table Table, ruleId ruleIdx++ continue } - switch verdict := it.checkChain(hook, pkt, table, jumpTo, gso, r, address, nicName); verdict { + switch verdict := it.checkChain(hook, pkt, table, jumpTo, gso, r, preroutingAddr, nicName); verdict { case chainAccept: return chainAccept case chainDrop: @@ -388,13 +472,13 @@ func (it *IPTables) checkChain(hook Hook, pkt *PacketBuffer, table Table, ruleId } // Preconditions: -// - pkt is a IPv4 packet of at least length header.IPv4MinimumSize. -// - pkt.NetworkHeader is not nil. -func (it *IPTables) checkRule(hook Hook, pkt *PacketBuffer, table Table, ruleIdx int, gso *GSO, r *Route, address tcpip.Address, nicName string) (RuleVerdict, int) { +// * pkt is a IPv4 packet of at least length header.IPv4MinimumSize. +// * pkt.NetworkHeader is not nil. +func (it *IPTables) checkRule(hook Hook, pkt *PacketBuffer, table Table, ruleIdx int, gso *GSO, r *Route, preroutingAddr tcpip.Address, nicName string) (RuleVerdict, int) { rule := table.Rules[ruleIdx] // Check whether the packet matches the IP header filter. - if !rule.Filter.match(header.IPv4(pkt.NetworkHeader().View()), hook, nicName) { + if !rule.Filter.match(pkt, hook, nicName) { // Continue on to the next rule. return RuleJump, ruleIdx + 1 } @@ -413,11 +497,16 @@ func (it *IPTables) checkRule(hook Hook, pkt *PacketBuffer, table Table, ruleIdx } // All the matchers matched, so run the target. - return rule.Target.Action(pkt, &it.connections, hook, gso, r, address) + return rule.Target.Action(pkt, &it.connections, hook, gso, r, preroutingAddr) } // OriginalDst returns the original destination of redirected connections. It // returns an error if the connection doesn't exist or isn't redirected. func (it *IPTables) OriginalDst(epID TransportEndpointID) (tcpip.Address, uint16, *tcpip.Error) { + it.mu.RLock() + defer it.mu.RUnlock() + if !it.modified { + return "", 0, tcpip.ErrNotConnected + } return it.connections.originalDst(epID) } diff --git a/pkg/tcpip/stack/iptables_types.go b/pkg/tcpip/stack/iptables_types.go index 73274ada9..093ee6881 100644 --- a/pkg/tcpip/stack/iptables_types.go +++ b/pkg/tcpip/stack/iptables_types.go @@ -15,6 +15,7 @@ package stack import ( + "fmt" "strings" "sync" @@ -81,26 +82,25 @@ const ( // // +stateify savable type IPTables struct { - // mu protects tables, priorities, and modified. + // mu protects v4Tables, v6Tables, and modified. mu sync.RWMutex - - // tables maps tableIDs to tables. Holds builtin tables only, not user - // tables. mu must be locked for accessing. - tables [numTables]Table - - // priorities maps each hook to a list of table names. The order of the - // list is the order in which each table should be visited for that - // hook. mu needs to be locked for accessing. - priorities [NumHooks][]tableID - + // v4Tables and v6tables map tableIDs to tables. They hold builtin + // tables only, not user tables. mu must be locked for accessing. + v4Tables [numTables]Table + v6Tables [numTables]Table // modified is whether tables have been modified at least once. It is // used to elide the iptables performance overhead for workloads that // don't utilize iptables. modified bool + // priorities maps each hook to a list of table names. The order of the + // list is the order in which each table should be visited for that + // hook. It is immutable. + priorities [NumHooks][]tableID + connections ConnTrack - // reaperDone can be signalled to stop the reaper goroutine. + // reaperDone can be signaled to stop the reaper goroutine. reaperDone chan struct{} } @@ -148,13 +148,18 @@ type Rule struct { Target Target } -// IPHeaderFilter holds basic IP filtering data common to every rule. +// IPHeaderFilter performs basic IP header matching common to every rule. // // +stateify savable type IPHeaderFilter struct { // Protocol matches the transport protocol. Protocol tcpip.TransportProtocolNumber + // CheckProtocol determines whether the Protocol field should be + // checked during matching. + // TODO(gvisor.dev/issue/3549): Check this field during matching. + CheckProtocol bool + // Dst matches the destination IP address. Dst tcpip.Address @@ -191,16 +196,43 @@ type IPHeaderFilter struct { OutputInterfaceInvert bool } -// match returns whether hdr matches the filter. -func (fl IPHeaderFilter) match(hdr header.IPv4, hook Hook, nicName string) bool { - // TODO(gvisor.dev/issue/170): Support other fields of the filter. +// match returns whether pkt matches the filter. +// +// Preconditions: pkt.NetworkHeader is set and is at least of the minimal IPv4 +// or IPv6 header length. +func (fl IPHeaderFilter) match(pkt *PacketBuffer, hook Hook, nicName string) bool { + // Extract header fields. + var ( + // TODO(gvisor.dev/issue/170): Support other filter fields. + transProto tcpip.TransportProtocolNumber + dstAddr tcpip.Address + srcAddr tcpip.Address + ) + switch proto := pkt.NetworkProtocolNumber; proto { + case header.IPv4ProtocolNumber: + hdr := header.IPv4(pkt.NetworkHeader().View()) + transProto = hdr.TransportProtocol() + dstAddr = hdr.DestinationAddress() + srcAddr = hdr.SourceAddress() + + case header.IPv6ProtocolNumber: + hdr := header.IPv6(pkt.NetworkHeader().View()) + transProto = hdr.TransportProtocol() + dstAddr = hdr.DestinationAddress() + srcAddr = hdr.SourceAddress() + + default: + panic(fmt.Sprintf("unknown network protocol with EtherType: %d", proto)) + } + // Check the transport protocol. - if fl.Protocol != 0 && fl.Protocol != hdr.TransportProtocol() { + if fl.CheckProtocol && fl.Protocol != transProto { return false } - // Check the source and destination IPs. - if !filterAddress(hdr.DestinationAddress(), fl.DstMask, fl.Dst, fl.DstInvert) || !filterAddress(hdr.SourceAddress(), fl.SrcMask, fl.Src, fl.SrcInvert) { + // Check the addresses. + if !filterAddress(dstAddr, fl.DstMask, fl.Dst, fl.DstInvert) || + !filterAddress(srcAddr, fl.SrcMask, fl.Src, fl.SrcInvert) { return false } diff --git a/pkg/tcpip/stack/linkaddrcache_test.go b/pkg/tcpip/stack/linkaddrcache_test.go index b15b8d1cb..33806340e 100644 --- a/pkg/tcpip/stack/linkaddrcache_test.go +++ b/pkg/tcpip/stack/linkaddrcache_test.go @@ -16,6 +16,7 @@ package stack import ( "fmt" + "math" "sync/atomic" "testing" "time" @@ -191,7 +192,13 @@ func TestCacheReplace(t *testing.T) { } func TestCacheResolution(t *testing.T) { - c := newLinkAddrCache(1<<63-1, 250*time.Millisecond, 1) + // There is a race condition causing this test to fail when the executor + // takes longer than the resolution timeout to call linkAddrCache.get. This + // is especially common when this test is run with gotsan. + // + // Using a large resolution timeout decreases the probability of experiencing + // this race condition and does not affect how long this test takes to run. + c := newLinkAddrCache(1<<63-1, math.MaxInt64, 1) linkRes := &testLinkAddressResolver{cache: c} for i, ta := range testAddrs { got, err := getBlocking(c, ta.addr, linkRes) @@ -275,3 +282,71 @@ func TestStaticResolution(t *testing.T) { t.Errorf("c.get(%q)=%q, want %q", string(addr), string(got), string(want)) } } + +// TestCacheWaker verifies that RemoveWaker removes a waker previously added +// through get(). +func TestCacheWaker(t *testing.T) { + c := newLinkAddrCache(1<<63-1, 1*time.Second, 3) + + // First, sanity check that wakers are working. + { + linkRes := &testLinkAddressResolver{cache: c} + s := sleep.Sleeper{} + defer s.Done() + + const wakerID = 1 + w := sleep.Waker{} + s.AddWaker(&w, wakerID) + + e := testAddrs[0] + + if _, _, err := c.get(e.addr, linkRes, "", nil, &w); err != tcpip.ErrWouldBlock { + t.Fatalf("got c.get(%q, _, _, _, _) = %s, want = %s", e.addr.Addr, err, tcpip.ErrWouldBlock) + } + id, ok := s.Fetch(true /* block */) + if !ok { + t.Fatal("got s.Fetch(true) = (_, false), want = (_, true)") + } + if id != wakerID { + t.Fatalf("got s.Fetch(true) = (%d, %t), want = (%d, true)", id, ok, wakerID) + } + + if got, _, err := c.get(e.addr, linkRes, "", nil, nil); err != nil { + t.Fatalf("c.get(%q, _, _, _, _): %s", e.addr.Addr, err) + } else if got != e.linkAddr { + t.Fatalf("got c.get(%q) = %q, want = %q", e.addr.Addr, got, e.linkAddr) + } + } + + // Check that RemoveWaker works. + { + linkRes := &testLinkAddressResolver{cache: c} + s := sleep.Sleeper{} + defer s.Done() + + const wakerID = 2 // different than the ID used in the sanity check + w := sleep.Waker{} + s.AddWaker(&w, wakerID) + + e := testAddrs[1] + linkRes.onLinkAddressRequest = func() { + // Remove the waker before the linkAddrCache has the opportunity to send + // a notification. + c.removeWaker(e.addr, &w) + } + + if _, _, err := c.get(e.addr, linkRes, "", nil, &w); err != tcpip.ErrWouldBlock { + t.Fatalf("got c.get(%q, _, _, _, _) = %s, want = %s", e.addr.Addr, err, tcpip.ErrWouldBlock) + } + + if got, err := getBlocking(c, e.addr, linkRes); err != nil { + t.Fatalf("c.get(%q, _, _, _, _): %s", e.addr.Addr, err) + } else if got != e.linkAddr { + t.Fatalf("c.get(%q) = %q, want = %q", e.addr.Addr, got, e.linkAddr) + } + + if id, ok := s.Fetch(false /* block */); ok { + t.Fatalf("unexpected notification from waker with id %d", id) + } + } +} diff --git a/pkg/tcpip/stack/ndp_test.go b/pkg/tcpip/stack/ndp_test.go index 1a6724c31..5e43a9b0b 100644 --- a/pkg/tcpip/stack/ndp_test.go +++ b/pkg/tcpip/stack/ndp_test.go @@ -2787,7 +2787,7 @@ func TestMixedSLAACAddrConflictRegen(t *testing.T) { // stack.Stack will have a default route through the router (llAddr3) installed // and a static link-address (linkAddr3) added to the link address cache for the // router. -func stackAndNdpDispatcherWithDefaultRoute(t *testing.T, nicID tcpip.NICID) (*ndpDispatcher, *channel.Endpoint, *stack.Stack) { +func stackAndNdpDispatcherWithDefaultRoute(t *testing.T, nicID tcpip.NICID, useNeighborCache bool) (*ndpDispatcher, *channel.Endpoint, *stack.Stack) { t.Helper() ndpDisp := &ndpDispatcher{ autoGenAddrC: make(chan ndpAutoGenAddrEvent, 1), @@ -2800,7 +2800,8 @@ func stackAndNdpDispatcherWithDefaultRoute(t *testing.T, nicID tcpip.NICID) (*nd HandleRAs: true, AutoGenGlobalAddresses: true, }, - NDPDisp: ndpDisp, + NDPDisp: ndpDisp, + UseNeighborCache: useNeighborCache, }) if err := s.CreateNIC(nicID, e); err != nil { t.Fatalf("CreateNIC(%d, _) = %s", nicID, err) @@ -2810,7 +2811,11 @@ func stackAndNdpDispatcherWithDefaultRoute(t *testing.T, nicID tcpip.NICID) (*nd Gateway: llAddr3, NIC: nicID, }}) - s.AddLinkAddress(nicID, llAddr3, linkAddr3) + if useNeighborCache { + s.AddStaticNeighbor(nicID, llAddr3, linkAddr3) + } else { + s.AddLinkAddress(nicID, llAddr3, linkAddr3) + } return ndpDisp, e, s } @@ -2884,110 +2889,128 @@ func addrForNewConnectionWithAddr(t *testing.T, s *stack.Stack, addr tcpip.FullA // TestAutoGenAddrDeprecateFromPI tests deprecating a SLAAC address when // receiving a PI with 0 preferred lifetime. func TestAutoGenAddrDeprecateFromPI(t *testing.T) { - const nicID = 1 + stacks := []struct { + name string + useNeighborCache bool + }{ + { + name: "linkAddrCache", + useNeighborCache: false, + }, + { + name: "neighborCache", + useNeighborCache: true, + }, + } - prefix1, _, addr1 := prefixSubnetAddr(0, linkAddr1) - prefix2, _, addr2 := prefixSubnetAddr(1, linkAddr1) + for _, stackTyp := range stacks { + t.Run(stackTyp.name, func(t *testing.T) { + const nicID = 1 - ndpDisp, e, s := stackAndNdpDispatcherWithDefaultRoute(t, nicID) + prefix1, _, addr1 := prefixSubnetAddr(0, linkAddr1) + prefix2, _, addr2 := prefixSubnetAddr(1, linkAddr1) - expectAutoGenAddrEvent := func(addr tcpip.AddressWithPrefix, eventType ndpAutoGenAddrEventType) { - t.Helper() + ndpDisp, e, s := stackAndNdpDispatcherWithDefaultRoute(t, nicID, stackTyp.useNeighborCache) - select { - case e := <-ndpDisp.autoGenAddrC: - if diff := checkAutoGenAddrEvent(e, addr, eventType); diff != "" { - t.Errorf("auto-gen addr event mismatch (-want +got):\n%s", diff) + expectAutoGenAddrEvent := func(addr tcpip.AddressWithPrefix, eventType ndpAutoGenAddrEventType) { + t.Helper() + + select { + case e := <-ndpDisp.autoGenAddrC: + if diff := checkAutoGenAddrEvent(e, addr, eventType); diff != "" { + t.Errorf("auto-gen addr event mismatch (-want +got):\n%s", diff) + } + default: + t.Fatal("expected addr auto gen event") + } } - default: - t.Fatal("expected addr auto gen event") - } - } - expectPrimaryAddr := func(addr tcpip.AddressWithPrefix) { - t.Helper() + expectPrimaryAddr := func(addr tcpip.AddressWithPrefix) { + t.Helper() - if got, err := s.GetMainNICAddress(nicID, header.IPv6ProtocolNumber); err != nil { - t.Fatalf("s.GetMainNICAddress(%d, %d): %s", nicID, header.IPv6ProtocolNumber, err) - } else if got != addr { - t.Errorf("got s.GetMainNICAddress(%d, %d) = %s, want = %s", nicID, header.IPv6ProtocolNumber, got, addr) - } + if got, err := s.GetMainNICAddress(nicID, header.IPv6ProtocolNumber); err != nil { + t.Fatalf("s.GetMainNICAddress(%d, %d): %s", nicID, header.IPv6ProtocolNumber, err) + } else if got != addr { + t.Errorf("got s.GetMainNICAddress(%d, %d) = %s, want = %s", nicID, header.IPv6ProtocolNumber, got, addr) + } - if got := addrForNewConnection(t, s); got != addr.Address { - t.Errorf("got addrForNewConnection = %s, want = %s", got, addr.Address) - } - } + if got := addrForNewConnection(t, s); got != addr.Address { + t.Errorf("got addrForNewConnection = %s, want = %s", got, addr.Address) + } + } - // Receive PI for prefix1. - e.InjectInbound(header.IPv6ProtocolNumber, raBufWithPI(llAddr2, 0, prefix1, true, true, 100, 100)) - expectAutoGenAddrEvent(addr1, newAddr) - if !containsV6Addr(s.NICInfo()[nicID].ProtocolAddresses, addr1) { - t.Fatalf("should have %s in the list of addresses", addr1) - } - expectPrimaryAddr(addr1) + // Receive PI for prefix1. + e.InjectInbound(header.IPv6ProtocolNumber, raBufWithPI(llAddr2, 0, prefix1, true, true, 100, 100)) + expectAutoGenAddrEvent(addr1, newAddr) + if !containsV6Addr(s.NICInfo()[nicID].ProtocolAddresses, addr1) { + t.Fatalf("should have %s in the list of addresses", addr1) + } + expectPrimaryAddr(addr1) - // Deprecate addr for prefix1 immedaitely. - e.InjectInbound(header.IPv6ProtocolNumber, raBufWithPI(llAddr2, 0, prefix1, true, true, 100, 0)) - expectAutoGenAddrEvent(addr1, deprecatedAddr) - if !containsV6Addr(s.NICInfo()[nicID].ProtocolAddresses, addr1) { - t.Fatalf("should have %s in the list of addresses", addr1) - } - // addr should still be the primary endpoint as there are no other addresses. - expectPrimaryAddr(addr1) + // Deprecate addr for prefix1 immedaitely. + e.InjectInbound(header.IPv6ProtocolNumber, raBufWithPI(llAddr2, 0, prefix1, true, true, 100, 0)) + expectAutoGenAddrEvent(addr1, deprecatedAddr) + if !containsV6Addr(s.NICInfo()[nicID].ProtocolAddresses, addr1) { + t.Fatalf("should have %s in the list of addresses", addr1) + } + // addr should still be the primary endpoint as there are no other addresses. + expectPrimaryAddr(addr1) - // Refresh lifetimes of addr generated from prefix1. - e.InjectInbound(header.IPv6ProtocolNumber, raBufWithPI(llAddr2, 0, prefix1, true, true, 100, 100)) - select { - case <-ndpDisp.autoGenAddrC: - t.Fatal("unexpectedly got an auto-generated event") - default: - } - expectPrimaryAddr(addr1) + // Refresh lifetimes of addr generated from prefix1. + e.InjectInbound(header.IPv6ProtocolNumber, raBufWithPI(llAddr2, 0, prefix1, true, true, 100, 100)) + select { + case <-ndpDisp.autoGenAddrC: + t.Fatal("unexpectedly got an auto-generated event") + default: + } + expectPrimaryAddr(addr1) - // Receive PI for prefix2. - e.InjectInbound(header.IPv6ProtocolNumber, raBufWithPI(llAddr2, 0, prefix2, true, true, 100, 100)) - expectAutoGenAddrEvent(addr2, newAddr) - if !containsV6Addr(s.NICInfo()[nicID].ProtocolAddresses, addr2) { - t.Fatalf("should have %s in the list of addresses", addr2) - } - expectPrimaryAddr(addr2) + // Receive PI for prefix2. + e.InjectInbound(header.IPv6ProtocolNumber, raBufWithPI(llAddr2, 0, prefix2, true, true, 100, 100)) + expectAutoGenAddrEvent(addr2, newAddr) + if !containsV6Addr(s.NICInfo()[nicID].ProtocolAddresses, addr2) { + t.Fatalf("should have %s in the list of addresses", addr2) + } + expectPrimaryAddr(addr2) - // Deprecate addr for prefix2 immedaitely. - e.InjectInbound(header.IPv6ProtocolNumber, raBufWithPI(llAddr2, 0, prefix2, true, true, 100, 0)) - expectAutoGenAddrEvent(addr2, deprecatedAddr) - if !containsV6Addr(s.NICInfo()[nicID].ProtocolAddresses, addr2) { - t.Fatalf("should have %s in the list of addresses", addr2) - } - // addr1 should be the primary endpoint now since addr2 is deprecated but - // addr1 is not. - expectPrimaryAddr(addr1) - // addr2 is deprecated but if explicitly requested, it should be used. - fullAddr2 := tcpip.FullAddress{Addr: addr2.Address, NIC: nicID} - if got := addrForNewConnectionWithAddr(t, s, fullAddr2); got != addr2.Address { - t.Errorf("got addrForNewConnectionWithAddr(_, _, %+v) = %s, want = %s", fullAddr2, got, addr2.Address) - } + // Deprecate addr for prefix2 immedaitely. + e.InjectInbound(header.IPv6ProtocolNumber, raBufWithPI(llAddr2, 0, prefix2, true, true, 100, 0)) + expectAutoGenAddrEvent(addr2, deprecatedAddr) + if !containsV6Addr(s.NICInfo()[nicID].ProtocolAddresses, addr2) { + t.Fatalf("should have %s in the list of addresses", addr2) + } + // addr1 should be the primary endpoint now since addr2 is deprecated but + // addr1 is not. + expectPrimaryAddr(addr1) + // addr2 is deprecated but if explicitly requested, it should be used. + fullAddr2 := tcpip.FullAddress{Addr: addr2.Address, NIC: nicID} + if got := addrForNewConnectionWithAddr(t, s, fullAddr2); got != addr2.Address { + t.Errorf("got addrForNewConnectionWithAddr(_, _, %+v) = %s, want = %s", fullAddr2, got, addr2.Address) + } - // Another PI w/ 0 preferred lifetime should not result in a deprecation - // event. - e.InjectInbound(header.IPv6ProtocolNumber, raBufWithPI(llAddr2, 0, prefix2, true, true, 100, 0)) - select { - case <-ndpDisp.autoGenAddrC: - t.Fatal("unexpectedly got an auto-generated event") - default: - } - expectPrimaryAddr(addr1) - if got := addrForNewConnectionWithAddr(t, s, fullAddr2); got != addr2.Address { - t.Errorf("got addrForNewConnectionWithAddr(_, _, %+v) = %s, want = %s", fullAddr2, got, addr2.Address) - } + // Another PI w/ 0 preferred lifetime should not result in a deprecation + // event. + e.InjectInbound(header.IPv6ProtocolNumber, raBufWithPI(llAddr2, 0, prefix2, true, true, 100, 0)) + select { + case <-ndpDisp.autoGenAddrC: + t.Fatal("unexpectedly got an auto-generated event") + default: + } + expectPrimaryAddr(addr1) + if got := addrForNewConnectionWithAddr(t, s, fullAddr2); got != addr2.Address { + t.Errorf("got addrForNewConnectionWithAddr(_, _, %+v) = %s, want = %s", fullAddr2, got, addr2.Address) + } - // Refresh lifetimes of addr generated from prefix2. - e.InjectInbound(header.IPv6ProtocolNumber, raBufWithPI(llAddr2, 0, prefix2, true, true, 100, 100)) - select { - case <-ndpDisp.autoGenAddrC: - t.Fatal("unexpectedly got an auto-generated event") - default: + // Refresh lifetimes of addr generated from prefix2. + e.InjectInbound(header.IPv6ProtocolNumber, raBufWithPI(llAddr2, 0, prefix2, true, true, 100, 100)) + select { + case <-ndpDisp.autoGenAddrC: + t.Fatal("unexpectedly got an auto-generated event") + default: + } + expectPrimaryAddr(addr2) + }) } - expectPrimaryAddr(addr2) } // TestAutoGenAddrJobDeprecation tests that an address is properly deprecated @@ -2996,217 +3019,236 @@ func TestAutoGenAddrJobDeprecation(t *testing.T) { const nicID = 1 const newMinVL = 2 newMinVLDuration := newMinVL * time.Second - saved := stack.MinPrefixInformationValidLifetimeForUpdate - defer func() { - stack.MinPrefixInformationValidLifetimeForUpdate = saved - }() - stack.MinPrefixInformationValidLifetimeForUpdate = newMinVLDuration - prefix1, _, addr1 := prefixSubnetAddr(0, linkAddr1) - prefix2, _, addr2 := prefixSubnetAddr(1, linkAddr1) + stacks := []struct { + name string + useNeighborCache bool + }{ + { + name: "linkAddrCache", + useNeighborCache: false, + }, + { + name: "neighborCache", + useNeighborCache: true, + }, + } - ndpDisp, e, s := stackAndNdpDispatcherWithDefaultRoute(t, nicID) + for _, stackTyp := range stacks { + t.Run(stackTyp.name, func(t *testing.T) { + saved := stack.MinPrefixInformationValidLifetimeForUpdate + defer func() { + stack.MinPrefixInformationValidLifetimeForUpdate = saved + }() + stack.MinPrefixInformationValidLifetimeForUpdate = newMinVLDuration - expectAutoGenAddrEvent := func(addr tcpip.AddressWithPrefix, eventType ndpAutoGenAddrEventType) { - t.Helper() + prefix1, _, addr1 := prefixSubnetAddr(0, linkAddr1) + prefix2, _, addr2 := prefixSubnetAddr(1, linkAddr1) - select { - case e := <-ndpDisp.autoGenAddrC: - if diff := checkAutoGenAddrEvent(e, addr, eventType); diff != "" { - t.Errorf("auto-gen addr event mismatch (-want +got):\n%s", diff) + ndpDisp, e, s := stackAndNdpDispatcherWithDefaultRoute(t, nicID, stackTyp.useNeighborCache) + + expectAutoGenAddrEvent := func(addr tcpip.AddressWithPrefix, eventType ndpAutoGenAddrEventType) { + t.Helper() + + select { + case e := <-ndpDisp.autoGenAddrC: + if diff := checkAutoGenAddrEvent(e, addr, eventType); diff != "" { + t.Errorf("auto-gen addr event mismatch (-want +got):\n%s", diff) + } + default: + t.Fatal("expected addr auto gen event") + } } - default: - t.Fatal("expected addr auto gen event") - } - } - expectAutoGenAddrEventAfter := func(addr tcpip.AddressWithPrefix, eventType ndpAutoGenAddrEventType, timeout time.Duration) { - t.Helper() + expectAutoGenAddrEventAfter := func(addr tcpip.AddressWithPrefix, eventType ndpAutoGenAddrEventType, timeout time.Duration) { + t.Helper() - select { - case e := <-ndpDisp.autoGenAddrC: - if diff := checkAutoGenAddrEvent(e, addr, eventType); diff != "" { - t.Errorf("auto-gen addr event mismatch (-want +got):\n%s", diff) + select { + case e := <-ndpDisp.autoGenAddrC: + if diff := checkAutoGenAddrEvent(e, addr, eventType); diff != "" { + t.Errorf("auto-gen addr event mismatch (-want +got):\n%s", diff) + } + case <-time.After(timeout): + t.Fatal("timed out waiting for addr auto gen event") + } } - case <-time.After(timeout): - t.Fatal("timed out waiting for addr auto gen event") - } - } - expectPrimaryAddr := func(addr tcpip.AddressWithPrefix) { - t.Helper() + expectPrimaryAddr := func(addr tcpip.AddressWithPrefix) { + t.Helper() - if got, err := s.GetMainNICAddress(nicID, header.IPv6ProtocolNumber); err != nil { - t.Fatalf("s.GetMainNICAddress(%d, %d): %s", nicID, header.IPv6ProtocolNumber, err) - } else if got != addr { - t.Errorf("got s.GetMainNICAddress(%d, %d) = %s, want = %s", nicID, header.IPv6ProtocolNumber, got, addr) - } + if got, err := s.GetMainNICAddress(nicID, header.IPv6ProtocolNumber); err != nil { + t.Fatalf("s.GetMainNICAddress(%d, %d): %s", nicID, header.IPv6ProtocolNumber, err) + } else if got != addr { + t.Errorf("got s.GetMainNICAddress(%d, %d) = %s, want = %s", nicID, header.IPv6ProtocolNumber, got, addr) + } - if got := addrForNewConnection(t, s); got != addr.Address { - t.Errorf("got addrForNewConnection = %s, want = %s", got, addr.Address) - } - } + if got := addrForNewConnection(t, s); got != addr.Address { + t.Errorf("got addrForNewConnection = %s, want = %s", got, addr.Address) + } + } - // Receive PI for prefix2. - e.InjectInbound(header.IPv6ProtocolNumber, raBufWithPI(llAddr2, 0, prefix2, true, true, 100, 100)) - expectAutoGenAddrEvent(addr2, newAddr) - if !containsV6Addr(s.NICInfo()[nicID].ProtocolAddresses, addr2) { - t.Fatalf("should have %s in the list of addresses", addr2) - } - expectPrimaryAddr(addr2) + // Receive PI for prefix2. + e.InjectInbound(header.IPv6ProtocolNumber, raBufWithPI(llAddr2, 0, prefix2, true, true, 100, 100)) + expectAutoGenAddrEvent(addr2, newAddr) + if !containsV6Addr(s.NICInfo()[nicID].ProtocolAddresses, addr2) { + t.Fatalf("should have %s in the list of addresses", addr2) + } + expectPrimaryAddr(addr2) - // Receive a PI for prefix1. - e.InjectInbound(header.IPv6ProtocolNumber, raBufWithPI(llAddr2, 0, prefix1, true, true, 100, 90)) - expectAutoGenAddrEvent(addr1, newAddr) - if !containsV6Addr(s.NICInfo()[nicID].ProtocolAddresses, addr1) { - t.Fatalf("should have %s in the list of addresses", addr1) - } - if !containsV6Addr(s.NICInfo()[nicID].ProtocolAddresses, addr2) { - t.Fatalf("should have %s in the list of addresses", addr2) - } - expectPrimaryAddr(addr1) + // Receive a PI for prefix1. + e.InjectInbound(header.IPv6ProtocolNumber, raBufWithPI(llAddr2, 0, prefix1, true, true, 100, 90)) + expectAutoGenAddrEvent(addr1, newAddr) + if !containsV6Addr(s.NICInfo()[nicID].ProtocolAddresses, addr1) { + t.Fatalf("should have %s in the list of addresses", addr1) + } + if !containsV6Addr(s.NICInfo()[nicID].ProtocolAddresses, addr2) { + t.Fatalf("should have %s in the list of addresses", addr2) + } + expectPrimaryAddr(addr1) - // Refresh lifetime for addr of prefix1. - e.InjectInbound(header.IPv6ProtocolNumber, raBufWithPI(llAddr2, 0, prefix1, true, true, newMinVL, newMinVL-1)) - select { - case <-ndpDisp.autoGenAddrC: - t.Fatal("unexpectedly got an auto-generated event") - default: - } - expectPrimaryAddr(addr1) + // Refresh lifetime for addr of prefix1. + e.InjectInbound(header.IPv6ProtocolNumber, raBufWithPI(llAddr2, 0, prefix1, true, true, newMinVL, newMinVL-1)) + select { + case <-ndpDisp.autoGenAddrC: + t.Fatal("unexpectedly got an auto-generated event") + default: + } + expectPrimaryAddr(addr1) - // Wait for addr of prefix1 to be deprecated. - expectAutoGenAddrEventAfter(addr1, deprecatedAddr, newMinVLDuration-time.Second+defaultAsyncPositiveEventTimeout) - if !containsV6Addr(s.NICInfo()[nicID].ProtocolAddresses, addr1) { - t.Fatalf("should not have %s in the list of addresses", addr1) - } - if !containsV6Addr(s.NICInfo()[nicID].ProtocolAddresses, addr2) { - t.Fatalf("should have %s in the list of addresses", addr2) - } - // addr2 should be the primary endpoint now since addr1 is deprecated but - // addr2 is not. - expectPrimaryAddr(addr2) - // addr1 is deprecated but if explicitly requested, it should be used. - fullAddr1 := tcpip.FullAddress{Addr: addr1.Address, NIC: nicID} - if got := addrForNewConnectionWithAddr(t, s, fullAddr1); got != addr1.Address { - t.Errorf("got addrForNewConnectionWithAddr(_, _, %+v) = %s, want = %s", fullAddr1, got, addr1.Address) - } + // Wait for addr of prefix1 to be deprecated. + expectAutoGenAddrEventAfter(addr1, deprecatedAddr, newMinVLDuration-time.Second+defaultAsyncPositiveEventTimeout) + if !containsV6Addr(s.NICInfo()[nicID].ProtocolAddresses, addr1) { + t.Fatalf("should not have %s in the list of addresses", addr1) + } + if !containsV6Addr(s.NICInfo()[nicID].ProtocolAddresses, addr2) { + t.Fatalf("should have %s in the list of addresses", addr2) + } + // addr2 should be the primary endpoint now since addr1 is deprecated but + // addr2 is not. + expectPrimaryAddr(addr2) + // addr1 is deprecated but if explicitly requested, it should be used. + fullAddr1 := tcpip.FullAddress{Addr: addr1.Address, NIC: nicID} + if got := addrForNewConnectionWithAddr(t, s, fullAddr1); got != addr1.Address { + t.Errorf("got addrForNewConnectionWithAddr(_, _, %+v) = %s, want = %s", fullAddr1, got, addr1.Address) + } - // Refresh valid lifetime for addr of prefix1, w/ 0 preferred lifetime to make - // sure we do not get a deprecation event again. - e.InjectInbound(header.IPv6ProtocolNumber, raBufWithPI(llAddr2, 0, prefix1, true, true, newMinVL, 0)) - select { - case <-ndpDisp.autoGenAddrC: - t.Fatal("unexpectedly got an auto-generated event") - default: - } - expectPrimaryAddr(addr2) - if got := addrForNewConnectionWithAddr(t, s, fullAddr1); got != addr1.Address { - t.Errorf("got addrForNewConnectionWithAddr(_, _, %+v) = %s, want = %s", fullAddr1, got, addr1.Address) - } + // Refresh valid lifetime for addr of prefix1, w/ 0 preferred lifetime to make + // sure we do not get a deprecation event again. + e.InjectInbound(header.IPv6ProtocolNumber, raBufWithPI(llAddr2, 0, prefix1, true, true, newMinVL, 0)) + select { + case <-ndpDisp.autoGenAddrC: + t.Fatal("unexpectedly got an auto-generated event") + default: + } + expectPrimaryAddr(addr2) + if got := addrForNewConnectionWithAddr(t, s, fullAddr1); got != addr1.Address { + t.Errorf("got addrForNewConnectionWithAddr(_, _, %+v) = %s, want = %s", fullAddr1, got, addr1.Address) + } - // Refresh lifetimes for addr of prefix1. - e.InjectInbound(header.IPv6ProtocolNumber, raBufWithPI(llAddr2, 0, prefix1, true, true, newMinVL, newMinVL-1)) - select { - case <-ndpDisp.autoGenAddrC: - t.Fatal("unexpectedly got an auto-generated event") - default: - } - // addr1 is the primary endpoint again since it is non-deprecated now. - expectPrimaryAddr(addr1) + // Refresh lifetimes for addr of prefix1. + e.InjectInbound(header.IPv6ProtocolNumber, raBufWithPI(llAddr2, 0, prefix1, true, true, newMinVL, newMinVL-1)) + select { + case <-ndpDisp.autoGenAddrC: + t.Fatal("unexpectedly got an auto-generated event") + default: + } + // addr1 is the primary endpoint again since it is non-deprecated now. + expectPrimaryAddr(addr1) - // Wait for addr of prefix1 to be deprecated. - expectAutoGenAddrEventAfter(addr1, deprecatedAddr, newMinVLDuration-time.Second+defaultAsyncPositiveEventTimeout) - if !containsV6Addr(s.NICInfo()[nicID].ProtocolAddresses, addr1) { - t.Fatalf("should not have %s in the list of addresses", addr1) - } - if !containsV6Addr(s.NICInfo()[nicID].ProtocolAddresses, addr2) { - t.Fatalf("should have %s in the list of addresses", addr2) - } - // addr2 should be the primary endpoint now since it is not deprecated. - expectPrimaryAddr(addr2) - if got := addrForNewConnectionWithAddr(t, s, fullAddr1); got != addr1.Address { - t.Errorf("got addrForNewConnectionWithAddr(_, _, %+v) = %s, want = %s", fullAddr1, got, addr1.Address) - } + // Wait for addr of prefix1 to be deprecated. + expectAutoGenAddrEventAfter(addr1, deprecatedAddr, newMinVLDuration-time.Second+defaultAsyncPositiveEventTimeout) + if !containsV6Addr(s.NICInfo()[nicID].ProtocolAddresses, addr1) { + t.Fatalf("should not have %s in the list of addresses", addr1) + } + if !containsV6Addr(s.NICInfo()[nicID].ProtocolAddresses, addr2) { + t.Fatalf("should have %s in the list of addresses", addr2) + } + // addr2 should be the primary endpoint now since it is not deprecated. + expectPrimaryAddr(addr2) + if got := addrForNewConnectionWithAddr(t, s, fullAddr1); got != addr1.Address { + t.Errorf("got addrForNewConnectionWithAddr(_, _, %+v) = %s, want = %s", fullAddr1, got, addr1.Address) + } - // Wait for addr of prefix1 to be invalidated. - expectAutoGenAddrEventAfter(addr1, invalidatedAddr, time.Second+defaultAsyncPositiveEventTimeout) - if containsV6Addr(s.NICInfo()[nicID].ProtocolAddresses, addr1) { - t.Fatalf("should not have %s in the list of addresses", addr1) - } - if !containsV6Addr(s.NICInfo()[nicID].ProtocolAddresses, addr2) { - t.Fatalf("should have %s in the list of addresses", addr2) - } - expectPrimaryAddr(addr2) + // Wait for addr of prefix1 to be invalidated. + expectAutoGenAddrEventAfter(addr1, invalidatedAddr, time.Second+defaultAsyncPositiveEventTimeout) + if containsV6Addr(s.NICInfo()[nicID].ProtocolAddresses, addr1) { + t.Fatalf("should not have %s in the list of addresses", addr1) + } + if !containsV6Addr(s.NICInfo()[nicID].ProtocolAddresses, addr2) { + t.Fatalf("should have %s in the list of addresses", addr2) + } + expectPrimaryAddr(addr2) - // Refresh both lifetimes for addr of prefix2 to the same value. - e.InjectInbound(header.IPv6ProtocolNumber, raBufWithPI(llAddr2, 0, prefix2, true, true, newMinVL, newMinVL)) - select { - case <-ndpDisp.autoGenAddrC: - t.Fatal("unexpectedly got an auto-generated event") - default: - } + // Refresh both lifetimes for addr of prefix2 to the same value. + e.InjectInbound(header.IPv6ProtocolNumber, raBufWithPI(llAddr2, 0, prefix2, true, true, newMinVL, newMinVL)) + select { + case <-ndpDisp.autoGenAddrC: + t.Fatal("unexpectedly got an auto-generated event") + default: + } - // Wait for a deprecation then invalidation events, or just an invalidation - // event. We need to cover both cases but cannot deterministically hit both - // cases because the deprecation and invalidation handlers could be handled in - // either deprecation then invalidation, or invalidation then deprecation - // (which should be cancelled by the invalidation handler). - select { - case e := <-ndpDisp.autoGenAddrC: - if diff := checkAutoGenAddrEvent(e, addr2, deprecatedAddr); diff == "" { - // If we get a deprecation event first, we should get an invalidation - // event almost immediately after. + // Wait for a deprecation then invalidation events, or just an invalidation + // event. We need to cover both cases but cannot deterministically hit both + // cases because the deprecation and invalidation handlers could be handled in + // either deprecation then invalidation, or invalidation then deprecation + // (which should be cancelled by the invalidation handler). select { case e := <-ndpDisp.autoGenAddrC: - if diff := checkAutoGenAddrEvent(e, addr2, invalidatedAddr); diff != "" { - t.Errorf("auto-gen addr event mismatch (-want +got):\n%s", diff) + if diff := checkAutoGenAddrEvent(e, addr2, deprecatedAddr); diff == "" { + // If we get a deprecation event first, we should get an invalidation + // event almost immediately after. + select { + case e := <-ndpDisp.autoGenAddrC: + if diff := checkAutoGenAddrEvent(e, addr2, invalidatedAddr); diff != "" { + t.Errorf("auto-gen addr event mismatch (-want +got):\n%s", diff) + } + case <-time.After(defaultAsyncPositiveEventTimeout): + t.Fatal("timed out waiting for addr auto gen event") + } + } else if diff := checkAutoGenAddrEvent(e, addr2, invalidatedAddr); diff == "" { + // If we get an invalidation event first, we should not get a deprecation + // event after. + select { + case <-ndpDisp.autoGenAddrC: + t.Fatal("unexpectedly got an auto-generated event") + case <-time.After(defaultAsyncNegativeEventTimeout): + } + } else { + t.Fatalf("got unexpected auto-generated event") } - case <-time.After(defaultAsyncPositiveEventTimeout): + case <-time.After(newMinVLDuration + defaultAsyncPositiveEventTimeout): t.Fatal("timed out waiting for addr auto gen event") } - } else if diff := checkAutoGenAddrEvent(e, addr2, invalidatedAddr); diff == "" { - // If we get an invalidation event first, we should not get a deprecation - // event after. - select { - case <-ndpDisp.autoGenAddrC: - t.Fatal("unexpectedly got an auto-generated event") - case <-time.After(defaultAsyncNegativeEventTimeout): + if containsV6Addr(s.NICInfo()[nicID].ProtocolAddresses, addr1) { + t.Fatalf("should not have %s in the list of addresses", addr1) + } + if containsV6Addr(s.NICInfo()[nicID].ProtocolAddresses, addr2) { + t.Fatalf("should not have %s in the list of addresses", addr2) + } + // Should not have any primary endpoints. + if got, err := s.GetMainNICAddress(nicID, header.IPv6ProtocolNumber); err != nil { + t.Fatalf("s.GetMainNICAddress(%d, %d): %s", nicID, header.IPv6ProtocolNumber, err) + } else if want := (tcpip.AddressWithPrefix{}); got != want { + t.Errorf("got s.GetMainNICAddress(%d, %d) = %s, want = %s", nicID, header.IPv6ProtocolNumber, got, want) + } + wq := waiter.Queue{} + we, ch := waiter.NewChannelEntry(nil) + wq.EventRegister(&we, waiter.EventIn) + defer wq.EventUnregister(&we) + defer close(ch) + ep, err := s.NewEndpoint(header.UDPProtocolNumber, header.IPv6ProtocolNumber, &wq) + if err != nil { + t.Fatalf("s.NewEndpoint(%d, %d, _): %s", header.UDPProtocolNumber, header.IPv6ProtocolNumber, err) + } + defer ep.Close() + if err := ep.SetSockOptBool(tcpip.V6OnlyOption, true); err != nil { + t.Fatalf("SetSockOpt(tcpip.V6OnlyOption, true): %s", err) } - } else { - t.Fatalf("got unexpected auto-generated event") - } - case <-time.After(newMinVLDuration + defaultAsyncPositiveEventTimeout): - t.Fatal("timed out waiting for addr auto gen event") - } - if containsV6Addr(s.NICInfo()[nicID].ProtocolAddresses, addr1) { - t.Fatalf("should not have %s in the list of addresses", addr1) - } - if containsV6Addr(s.NICInfo()[nicID].ProtocolAddresses, addr2) { - t.Fatalf("should not have %s in the list of addresses", addr2) - } - // Should not have any primary endpoints. - if got, err := s.GetMainNICAddress(nicID, header.IPv6ProtocolNumber); err != nil { - t.Fatalf("s.GetMainNICAddress(%d, %d): %s", nicID, header.IPv6ProtocolNumber, err) - } else if want := (tcpip.AddressWithPrefix{}); got != want { - t.Errorf("got s.GetMainNICAddress(%d, %d) = %s, want = %s", nicID, header.IPv6ProtocolNumber, got, want) - } - wq := waiter.Queue{} - we, ch := waiter.NewChannelEntry(nil) - wq.EventRegister(&we, waiter.EventIn) - defer wq.EventUnregister(&we) - defer close(ch) - ep, err := s.NewEndpoint(header.UDPProtocolNumber, header.IPv6ProtocolNumber, &wq) - if err != nil { - t.Fatalf("s.NewEndpoint(%d, %d, _): %s", header.UDPProtocolNumber, header.IPv6ProtocolNumber, err) - } - defer ep.Close() - if err := ep.SetSockOptBool(tcpip.V6OnlyOption, true); err != nil { - t.Fatalf("SetSockOpt(tcpip.V6OnlyOption, true): %s", err) - } - if err := ep.Connect(dstAddr); err != tcpip.ErrNoRoute { - t.Errorf("got ep.Connect(%+v) = %v, want = %s", dstAddr, err, tcpip.ErrNoRoute) + if err := ep.Connect(dstAddr); err != tcpip.ErrNoRoute { + t.Errorf("got ep.Connect(%+v) = %s, want = %s", dstAddr, err, tcpip.ErrNoRoute) + } + }) } } @@ -3524,110 +3566,128 @@ func TestAutoGenAddrRemoval(t *testing.T) { func TestAutoGenAddrAfterRemoval(t *testing.T) { const nicID = 1 - prefix1, _, addr1 := prefixSubnetAddr(0, linkAddr1) - prefix2, _, addr2 := prefixSubnetAddr(1, linkAddr1) - ndpDisp, e, s := stackAndNdpDispatcherWithDefaultRoute(t, nicID) - - expectAutoGenAddrEvent := func(addr tcpip.AddressWithPrefix, eventType ndpAutoGenAddrEventType) { - t.Helper() - - select { - case e := <-ndpDisp.autoGenAddrC: - if diff := checkAutoGenAddrEvent(e, addr, eventType); diff != "" { - t.Errorf("auto-gen addr event mismatch (-want +got):\n%s", diff) - } - default: - t.Fatal("expected addr auto gen event") - } + stacks := []struct { + name string + useNeighborCache bool + }{ + { + name: "linkAddrCache", + useNeighborCache: false, + }, + { + name: "neighborCache", + useNeighborCache: true, + }, } - expectPrimaryAddr := func(addr tcpip.AddressWithPrefix) { - t.Helper() + for _, stackTyp := range stacks { + t.Run(stackTyp.name, func(t *testing.T) { + prefix1, _, addr1 := prefixSubnetAddr(0, linkAddr1) + prefix2, _, addr2 := prefixSubnetAddr(1, linkAddr1) + ndpDisp, e, s := stackAndNdpDispatcherWithDefaultRoute(t, nicID, stackTyp.useNeighborCache) - if got, err := s.GetMainNICAddress(nicID, header.IPv6ProtocolNumber); err != nil { - t.Fatalf("s.GetMainNICAddress(%d, %d): %s", nicID, header.IPv6ProtocolNumber, err) - } else if got != addr { - t.Errorf("got s.GetMainNICAddress(%d, %d) = %s, want = %s", nicID, header.IPv6ProtocolNumber, got, addr) - } - - if got := addrForNewConnection(t, s); got != addr.Address { - t.Errorf("got addrForNewConnection = %s, want = %s", got, addr.Address) - } - } + expectAutoGenAddrEvent := func(addr tcpip.AddressWithPrefix, eventType ndpAutoGenAddrEventType) { + t.Helper() - // Receive a PI to auto-generate addr1 with a large valid and preferred - // lifetime. - const largeLifetimeSeconds = 999 - e.InjectInbound(header.IPv6ProtocolNumber, raBufWithPI(llAddr3, 0, prefix1, true, true, largeLifetimeSeconds, largeLifetimeSeconds)) - expectAutoGenAddrEvent(addr1, newAddr) - expectPrimaryAddr(addr1) + select { + case e := <-ndpDisp.autoGenAddrC: + if diff := checkAutoGenAddrEvent(e, addr, eventType); diff != "" { + t.Errorf("auto-gen addr event mismatch (-want +got):\n%s", diff) + } + default: + t.Fatal("expected addr auto gen event") + } + } - // Add addr2 as a static address. - protoAddr2 := tcpip.ProtocolAddress{ - Protocol: header.IPv6ProtocolNumber, - AddressWithPrefix: addr2, - } - if err := s.AddProtocolAddressWithOptions(nicID, protoAddr2, stack.FirstPrimaryEndpoint); err != nil { - t.Fatalf("AddProtocolAddressWithOptions(%d, %+v, %d) = %s", nicID, protoAddr2, stack.FirstPrimaryEndpoint, err) - } - // addr2 should be more preferred now since it is at the front of the primary - // list. - expectPrimaryAddr(addr2) + expectPrimaryAddr := func(addr tcpip.AddressWithPrefix) { + t.Helper() - // Get a route using addr2 to increment its reference count then remove it - // to leave it in the permanentExpired state. - r, err := s.FindRoute(nicID, addr2.Address, addr3, header.IPv6ProtocolNumber, false) - if err != nil { - t.Fatalf("FindRoute(%d, %s, %s, %d, false): %s", nicID, addr2.Address, addr3, header.IPv6ProtocolNumber, err) - } - defer r.Release() - if err := s.RemoveAddress(nicID, addr2.Address); err != nil { - t.Fatalf("s.RemoveAddress(%d, %s): %s", nicID, addr2.Address, err) - } - // addr1 should be preferred again since addr2 is in the expired state. - expectPrimaryAddr(addr1) + if got, err := s.GetMainNICAddress(nicID, header.IPv6ProtocolNumber); err != nil { + t.Fatalf("s.GetMainNICAddress(%d, %d): %s", nicID, header.IPv6ProtocolNumber, err) + } else if got != addr { + t.Errorf("got s.GetMainNICAddress(%d, %d) = %s, want = %s", nicID, header.IPv6ProtocolNumber, got, addr) + } - // Receive a PI to auto-generate addr2 as valid and preferred. - e.InjectInbound(header.IPv6ProtocolNumber, raBufWithPI(llAddr3, 0, prefix2, true, true, largeLifetimeSeconds, largeLifetimeSeconds)) - expectAutoGenAddrEvent(addr2, newAddr) - // addr2 should be more preferred now that it is closer to the front of the - // primary list and not deprecated. - expectPrimaryAddr(addr2) + if got := addrForNewConnection(t, s); got != addr.Address { + t.Errorf("got addrForNewConnection = %s, want = %s", got, addr.Address) + } + } - // Removing the address should result in an invalidation event immediately. - // It should still be in the permanentExpired state because r is still held. - // - // We remove addr2 here to make sure addr2 was marked as a SLAAC address - // (it was previously marked as a static address). - if err := s.RemoveAddress(1, addr2.Address); err != nil { - t.Fatalf("RemoveAddress(_, %s) = %s", addr2.Address, err) - } - expectAutoGenAddrEvent(addr2, invalidatedAddr) - // addr1 should be more preferred since addr2 is in the expired state. - expectPrimaryAddr(addr1) + // Receive a PI to auto-generate addr1 with a large valid and preferred + // lifetime. + const largeLifetimeSeconds = 999 + e.InjectInbound(header.IPv6ProtocolNumber, raBufWithPI(llAddr3, 0, prefix1, true, true, largeLifetimeSeconds, largeLifetimeSeconds)) + expectAutoGenAddrEvent(addr1, newAddr) + expectPrimaryAddr(addr1) - // Receive a PI to auto-generate addr2 as valid and deprecated. - e.InjectInbound(header.IPv6ProtocolNumber, raBufWithPI(llAddr3, 0, prefix2, true, true, largeLifetimeSeconds, 0)) - expectAutoGenAddrEvent(addr2, newAddr) - // addr1 should still be more preferred since addr2 is deprecated, even though - // it is closer to the front of the primary list. - expectPrimaryAddr(addr1) + // Add addr2 as a static address. + protoAddr2 := tcpip.ProtocolAddress{ + Protocol: header.IPv6ProtocolNumber, + AddressWithPrefix: addr2, + } + if err := s.AddProtocolAddressWithOptions(nicID, protoAddr2, stack.FirstPrimaryEndpoint); err != nil { + t.Fatalf("AddProtocolAddressWithOptions(%d, %+v, %d) = %s", nicID, protoAddr2, stack.FirstPrimaryEndpoint, err) + } + // addr2 should be more preferred now since it is at the front of the primary + // list. + expectPrimaryAddr(addr2) - // Receive a PI to refresh addr2's preferred lifetime. - e.InjectInbound(header.IPv6ProtocolNumber, raBufWithPI(llAddr3, 0, prefix2, true, true, largeLifetimeSeconds, largeLifetimeSeconds)) - select { - case <-ndpDisp.autoGenAddrC: - t.Fatal("unexpectedly got an auto gen addr event") - default: - } - // addr2 should be more preferred now that it is not deprecated. - expectPrimaryAddr(addr2) + // Get a route using addr2 to increment its reference count then remove it + // to leave it in the permanentExpired state. + r, err := s.FindRoute(nicID, addr2.Address, addr3, header.IPv6ProtocolNumber, false) + if err != nil { + t.Fatalf("FindRoute(%d, %s, %s, %d, false): %s", nicID, addr2.Address, addr3, header.IPv6ProtocolNumber, err) + } + defer r.Release() + if err := s.RemoveAddress(nicID, addr2.Address); err != nil { + t.Fatalf("s.RemoveAddress(%d, %s): %s", nicID, addr2.Address, err) + } + // addr1 should be preferred again since addr2 is in the expired state. + expectPrimaryAddr(addr1) + + // Receive a PI to auto-generate addr2 as valid and preferred. + e.InjectInbound(header.IPv6ProtocolNumber, raBufWithPI(llAddr3, 0, prefix2, true, true, largeLifetimeSeconds, largeLifetimeSeconds)) + expectAutoGenAddrEvent(addr2, newAddr) + // addr2 should be more preferred now that it is closer to the front of the + // primary list and not deprecated. + expectPrimaryAddr(addr2) + + // Removing the address should result in an invalidation event immediately. + // It should still be in the permanentExpired state because r is still held. + // + // We remove addr2 here to make sure addr2 was marked as a SLAAC address + // (it was previously marked as a static address). + if err := s.RemoveAddress(1, addr2.Address); err != nil { + t.Fatalf("RemoveAddress(_, %s) = %s", addr2.Address, err) + } + expectAutoGenAddrEvent(addr2, invalidatedAddr) + // addr1 should be more preferred since addr2 is in the expired state. + expectPrimaryAddr(addr1) + + // Receive a PI to auto-generate addr2 as valid and deprecated. + e.InjectInbound(header.IPv6ProtocolNumber, raBufWithPI(llAddr3, 0, prefix2, true, true, largeLifetimeSeconds, 0)) + expectAutoGenAddrEvent(addr2, newAddr) + // addr1 should still be more preferred since addr2 is deprecated, even though + // it is closer to the front of the primary list. + expectPrimaryAddr(addr1) + + // Receive a PI to refresh addr2's preferred lifetime. + e.InjectInbound(header.IPv6ProtocolNumber, raBufWithPI(llAddr3, 0, prefix2, true, true, largeLifetimeSeconds, largeLifetimeSeconds)) + select { + case <-ndpDisp.autoGenAddrC: + t.Fatal("unexpectedly got an auto gen addr event") + default: + } + // addr2 should be more preferred now that it is not deprecated. + expectPrimaryAddr(addr2) - if err := s.RemoveAddress(1, addr2.Address); err != nil { - t.Fatalf("RemoveAddress(_, %s) = %s", addr2.Address, err) + if err := s.RemoveAddress(1, addr2.Address); err != nil { + t.Fatalf("RemoveAddress(_, %s) = %s", addr2.Address, err) + } + expectAutoGenAddrEvent(addr2, invalidatedAddr) + expectPrimaryAddr(addr1) + }) } - expectAutoGenAddrEvent(addr2, invalidatedAddr) - expectPrimaryAddr(addr1) } // TestAutoGenAddrStaticConflict tests that if SLAAC generates an address that diff --git a/pkg/tcpip/stack/nic.go b/pkg/tcpip/stack/nic.go index e74d2562a..be274773c 100644 --- a/pkg/tcpip/stack/nic.go +++ b/pkg/tcpip/stack/nic.go @@ -21,6 +21,7 @@ import ( "sort" "sync/atomic" + "gvisor.dev/gvisor/pkg/sleep" "gvisor.dev/gvisor/pkg/sync" "gvisor.dev/gvisor/pkg/tcpip" "gvisor.dev/gvisor/pkg/tcpip/buffer" @@ -135,24 +136,33 @@ func newNIC(stack *Stack, id tcpip.NICID, name string, ep LinkEndpoint, ctx NICC } nic.mu.ndp.initializeTempAddrState() - // Register supported packet endpoint protocols. - for _, netProto := range header.Ethertypes { - nic.mu.packetEPs[netProto] = []PacketEndpoint{} - } - for _, netProto := range stack.networkProtocols { - netNum := netProto.Number() - nic.mu.packetEPs[netNum] = nil - nic.networkEndpoints[netNum] = netProto.NewEndpoint(id, stack, nic, ep, stack) - } - // Check for Neighbor Unreachability Detection support. - if ep.Capabilities()&CapabilityResolutionRequired != 0 && len(stack.linkAddrResolvers) != 0 { + var nud NUDHandler + if ep.Capabilities()&CapabilityResolutionRequired != 0 && len(stack.linkAddrResolvers) != 0 && stack.useNeighborCache { rng := rand.New(rand.NewSource(stack.clock.NowNanoseconds())) nic.neigh = &neighborCache{ nic: nic, state: NewNUDState(stack.nudConfigs, rng), cache: make(map[tcpip.Address]*neighborEntry, neighborCacheSize), } + + // An interface value that holds a nil pointer but non-nil type is not the + // same as the nil interface. Because of this, nud must only be assignd if + // nic.neigh is non-nil since a nil reference to a neighborCache is not + // valid. + // + // See https://golang.org/doc/faq#nil_error for more information. + nud = nic.neigh + } + + // Register supported packet and network endpoint protocols. + for _, netProto := range header.Ethertypes { + nic.mu.packetEPs[netProto] = []PacketEndpoint{} + } + for _, netProto := range stack.networkProtocols { + netNum := netProto.Number() + nic.mu.packetEPs[netNum] = nil + nic.networkEndpoints[netNum] = netProto.NewEndpoint(id, stack, nud, nic, ep, stack) } nic.linkEP.Attach(nic) @@ -431,7 +441,7 @@ func (n *NIC) setSpoofing(enable bool) { // If an IPv6 primary endpoint is requested, Source Address Selection (as // defined by RFC 6724 section 5) will be performed. func (n *NIC) primaryEndpoint(protocol tcpip.NetworkProtocolNumber, remoteAddr tcpip.Address) *referencedNetworkEndpoint { - if protocol == header.IPv6ProtocolNumber && remoteAddr != "" { + if protocol == header.IPv6ProtocolNumber && len(remoteAddr) != 0 { return n.primaryIPv6Endpoint(remoteAddr) } @@ -655,22 +665,15 @@ func (n *NIC) getRefOrCreateTemp(protocol tcpip.NetworkProtocolNumber, address t } } - // Check if address is a broadcast address for the endpoint's network. - // - // Only IPv4 has a notion of broadcast addresses. if protocol == header.IPv4ProtocolNumber { - if ref := n.getRefForBroadcastRLocked(address); ref != nil { + if ref := n.getIPv4RefForBroadcastOrLoopbackRLocked(address); ref != nil { n.mu.RUnlock() return ref } } - - // A usable reference was not found, create a temporary one if requested by - // the caller or if the address is found in the NIC's subnets. - createTempEP := spoofingOrPromiscuous n.mu.RUnlock() - if !createTempEP { + if !spoofingOrPromiscuous { return nil } @@ -683,20 +686,21 @@ func (n *NIC) getRefOrCreateTemp(protocol tcpip.NetworkProtocolNumber, address t return ref } -// getRefForBroadcastLocked returns an endpoint where address is the IPv4 -// broadcast address for the endpoint's network. +// getRefForBroadcastOrLoopbackRLocked returns an endpoint whose address is the +// broadcast address for the endpoint's network or an address in the endpoint's +// subnet if the NIC is a loopback interface. This matches linux behaviour. // -// n.mu MUST be read locked. -func (n *NIC) getRefForBroadcastRLocked(address tcpip.Address) *referencedNetworkEndpoint { +// n.mu MUST be read or write locked. +func (n *NIC) getIPv4RefForBroadcastOrLoopbackRLocked(address tcpip.Address) *referencedNetworkEndpoint { for _, ref := range n.mu.endpoints { - // Only IPv4 has a notion of broadcast addresses. + // Only IPv4 has a notion of broadcast addresses or considers the loopback + // interface bound to an address's whole subnet (on linux). if ref.protocol != header.IPv4ProtocolNumber { continue } - addr := ref.addrWithPrefix() - subnet := addr.Subnet() - if subnet.IsBroadcast(address) && ref.tryIncRef() { + subnet := ref.addrWithPrefix().Subnet() + if (subnet.IsBroadcast(address) || (n.isLoopback() && subnet.Contains(address))) && ref.isValidForOutgoingRLocked() && ref.tryIncRef() { return ref } } @@ -724,11 +728,8 @@ func (n *NIC) getRefOrCreateTempLocked(protocol tcpip.NetworkProtocolNumber, add n.removeEndpointLocked(ref) } - // Check if address is a broadcast address for an endpoint's network. - // - // Only IPv4 has a notion of broadcast addresses. if protocol == header.IPv4ProtocolNumber { - if ref := n.getRefForBroadcastRLocked(address); ref != nil { + if ref := n.getIPv4RefForBroadcastOrLoopbackRLocked(address); ref != nil { return ref } } @@ -833,10 +834,11 @@ func (n *NIC) addAddressLocked(protocolAddress tcpip.ProtocolAddress, peb Primar deprecated: deprecated, } - // Set up cache if link address resolution exists for this protocol. + // Set up resolver if link address resolution exists for this protocol. if n.linkEP.Capabilities()&CapabilityResolutionRequired != 0 { - if _, ok := n.stack.linkAddrResolvers[protocolAddress.Protocol]; ok { + if linkRes, ok := n.stack.linkAddrResolvers[protocolAddress.Protocol]; ok { ref.linkCache = n.stack + ref.linkRes = linkRes } } @@ -1071,6 +1073,51 @@ func (n *NIC) RemoveAddress(addr tcpip.Address) *tcpip.Error { return n.removePermanentAddressLocked(addr) } +func (n *NIC) neighbors() ([]NeighborEntry, *tcpip.Error) { + if n.neigh == nil { + return nil, tcpip.ErrNotSupported + } + + return n.neigh.entries(), nil +} + +func (n *NIC) removeWaker(addr tcpip.Address, w *sleep.Waker) { + if n.neigh == nil { + return + } + + n.neigh.removeWaker(addr, w) +} + +func (n *NIC) addStaticNeighbor(addr tcpip.Address, linkAddress tcpip.LinkAddress) *tcpip.Error { + if n.neigh == nil { + return tcpip.ErrNotSupported + } + + n.neigh.addStaticEntry(addr, linkAddress) + return nil +} + +func (n *NIC) removeNeighbor(addr tcpip.Address) *tcpip.Error { + if n.neigh == nil { + return tcpip.ErrNotSupported + } + + if !n.neigh.removeEntry(addr) { + return tcpip.ErrBadAddress + } + return nil +} + +func (n *NIC) clearNeighbors() *tcpip.Error { + if n.neigh == nil { + return tcpip.ErrNotSupported + } + + n.neigh.clear() + return nil +} + // joinGroup adds a new endpoint for the given multicast address, if none // exists yet. Otherwise it just increments its count. func (n *NIC) joinGroup(protocol tcpip.NetworkProtocolNumber, addr tcpip.Address) *tcpip.Error { @@ -1235,14 +1282,14 @@ func (n *NIC) DeliverNetworkPacket(remote, local tcpip.LinkAddress, protocol tcp return } - // TODO(gvisor.dev/issue/170): Not supporting iptables for IPv6 yet. // Loopback traffic skips the prerouting chain. - if protocol == header.IPv4ProtocolNumber && !n.isLoopback() { + if !n.isLoopback() { // iptables filtering. ipt := n.stack.IPTables() address := n.primaryAddress(protocol) if ok := ipt.Check(Prerouting, pkt, nil, nil, address.Address, ""); !ok { // iptables is telling us to drop the packet. + n.stack.stats.IP.IPTablesPreroutingDropped.Increment() return } } @@ -1652,6 +1699,10 @@ type referencedNetworkEndpoint struct { // protocol. Set to nil otherwise. linkCache LinkAddressCache + // linkRes is set if link address resolution is enabled for this protocol. + // Set to nil otherwise. + linkRes LinkAddressResolver + // refs is counting references held for this endpoint. When refs hits zero it // triggers the automatic removal of the endpoint from the NIC. refs int32 diff --git a/pkg/tcpip/stack/nic_test.go b/pkg/tcpip/stack/nic_test.go index d312a79eb..dd6474297 100644 --- a/pkg/tcpip/stack/nic_test.go +++ b/pkg/tcpip/stack/nic_test.go @@ -192,7 +192,7 @@ func (*testIPv6Protocol) ParseAddresses(v buffer.View) (src, dst tcpip.Address) } // NewEndpoint implements NetworkProtocol.NewEndpoint. -func (p *testIPv6Protocol) NewEndpoint(nicID tcpip.NICID, _ LinkAddressCache, _ TransportDispatcher, linkEP LinkEndpoint, _ *Stack) NetworkEndpoint { +func (p *testIPv6Protocol) NewEndpoint(nicID tcpip.NICID, _ LinkAddressCache, _ NUDHandler, _ TransportDispatcher, linkEP LinkEndpoint, _ *Stack) NetworkEndpoint { return &testIPv6Endpoint{ nicID: nicID, linkEP: linkEP, @@ -201,12 +201,12 @@ func (p *testIPv6Protocol) NewEndpoint(nicID tcpip.NICID, _ LinkAddressCache, _ } // SetOption implements NetworkProtocol.SetOption. -func (*testIPv6Protocol) SetOption(interface{}) *tcpip.Error { +func (*testIPv6Protocol) SetOption(tcpip.SettableNetworkProtocolOption) *tcpip.Error { return nil } // Option implements NetworkProtocol.Option. -func (*testIPv6Protocol) Option(interface{}) *tcpip.Error { +func (*testIPv6Protocol) Option(tcpip.GettableNetworkProtocolOption) *tcpip.Error { return nil } diff --git a/pkg/tcpip/stack/nud_test.go b/pkg/tcpip/stack/nud_test.go index 2494ee610..2b97e5972 100644 --- a/pkg/tcpip/stack/nud_test.go +++ b/pkg/tcpip/stack/nud_test.go @@ -61,6 +61,7 @@ func TestSetNUDConfigurationFailsForBadNICID(t *testing.T) { // stack will only allocate neighbor caches if a protocol providing link // address resolution is specified (e.g. ARP or IPv6). NetworkProtocols: []stack.NetworkProtocol{ipv6.NewProtocol()}, + UseNeighborCache: true, }) // No NIC with ID 1 yet. @@ -84,7 +85,8 @@ func TestNUDConfigurationFailsForNotSupported(t *testing.T) { e.LinkEPCapabilities |= stack.CapabilityResolutionRequired s := stack.New(stack.Options{ - NUDConfigs: stack.DefaultNUDConfigurations(), + NUDConfigs: stack.DefaultNUDConfigurations(), + UseNeighborCache: true, }) if err := s.CreateNIC(nicID, e); err != nil { t.Fatalf("CreateNIC(%d, _) = %s", nicID, err) @@ -108,7 +110,8 @@ func TestSetNUDConfigurationFailsForNotSupported(t *testing.T) { e.LinkEPCapabilities |= stack.CapabilityResolutionRequired s := stack.New(stack.Options{ - NUDConfigs: stack.DefaultNUDConfigurations(), + NUDConfigs: stack.DefaultNUDConfigurations(), + UseNeighborCache: true, }) if err := s.CreateNIC(nicID, e); err != nil { t.Fatalf("CreateNIC(%d, _) = %s", nicID, err) @@ -136,6 +139,7 @@ func TestDefaultNUDConfigurations(t *testing.T) { // address resolution is specified (e.g. ARP or IPv6). NetworkProtocols: []stack.NetworkProtocol{ipv6.NewProtocol()}, NUDConfigs: stack.DefaultNUDConfigurations(), + UseNeighborCache: true, }) if err := s.CreateNIC(nicID, e); err != nil { t.Fatalf("CreateNIC(%d, _) = %s", nicID, err) @@ -190,6 +194,7 @@ func TestNUDConfigurationsBaseReachableTime(t *testing.T) { // providing link address resolution is specified (e.g. ARP or IPv6). NetworkProtocols: []stack.NetworkProtocol{ipv6.NewProtocol()}, NUDConfigs: c, + UseNeighborCache: true, }) if err := s.CreateNIC(nicID, e); err != nil { t.Fatalf("CreateNIC(%d, _) = %s", nicID, err) @@ -246,6 +251,7 @@ func TestNUDConfigurationsMinRandomFactor(t *testing.T) { // providing link address resolution is specified (e.g. ARP or IPv6). NetworkProtocols: []stack.NetworkProtocol{ipv6.NewProtocol()}, NUDConfigs: c, + UseNeighborCache: true, }) if err := s.CreateNIC(nicID, e); err != nil { t.Fatalf("CreateNIC(%d, _) = %s", nicID, err) @@ -325,6 +331,7 @@ func TestNUDConfigurationsMaxRandomFactor(t *testing.T) { // providing link address resolution is specified (e.g. ARP or IPv6). NetworkProtocols: []stack.NetworkProtocol{ipv6.NewProtocol()}, NUDConfigs: c, + UseNeighborCache: true, }) if err := s.CreateNIC(nicID, e); err != nil { t.Fatalf("CreateNIC(%d, _) = %s", nicID, err) @@ -386,6 +393,7 @@ func TestNUDConfigurationsRetransmitTimer(t *testing.T) { // providing link address resolution is specified (e.g. ARP or IPv6). NetworkProtocols: []stack.NetworkProtocol{ipv6.NewProtocol()}, NUDConfigs: c, + UseNeighborCache: true, }) if err := s.CreateNIC(nicID, e); err != nil { t.Fatalf("CreateNIC(%d, _) = %s", nicID, err) @@ -437,6 +445,7 @@ func TestNUDConfigurationsDelayFirstProbeTime(t *testing.T) { // providing link address resolution is specified (e.g. ARP or IPv6). NetworkProtocols: []stack.NetworkProtocol{ipv6.NewProtocol()}, NUDConfigs: c, + UseNeighborCache: true, }) if err := s.CreateNIC(nicID, e); err != nil { t.Fatalf("CreateNIC(%d, _) = %s", nicID, err) @@ -488,6 +497,7 @@ func TestNUDConfigurationsMaxMulticastProbes(t *testing.T) { // providing link address resolution is specified (e.g. ARP or IPv6). NetworkProtocols: []stack.NetworkProtocol{ipv6.NewProtocol()}, NUDConfigs: c, + UseNeighborCache: true, }) if err := s.CreateNIC(nicID, e); err != nil { t.Fatalf("CreateNIC(%d, _) = %s", nicID, err) @@ -539,6 +549,7 @@ func TestNUDConfigurationsMaxUnicastProbes(t *testing.T) { // providing link address resolution is specified (e.g. ARP or IPv6). NetworkProtocols: []stack.NetworkProtocol{ipv6.NewProtocol()}, NUDConfigs: c, + UseNeighborCache: true, }) if err := s.CreateNIC(nicID, e); err != nil { t.Fatalf("CreateNIC(%d, _) = %s", nicID, err) @@ -590,6 +601,7 @@ func TestNUDConfigurationsUnreachableTime(t *testing.T) { // providing link address resolution is specified (e.g. ARP or IPv6). NetworkProtocols: []stack.NetworkProtocol{ipv6.NewProtocol()}, NUDConfigs: c, + UseNeighborCache: true, }) if err := s.CreateNIC(nicID, e); err != nil { t.Fatalf("CreateNIC(%d, _) = %s", nicID, err) diff --git a/pkg/tcpip/stack/packet_buffer.go b/pkg/tcpip/stack/packet_buffer.go index 17b8beebb..1932aaeb7 100644 --- a/pkg/tcpip/stack/packet_buffer.go +++ b/pkg/tcpip/stack/packet_buffer.go @@ -80,7 +80,7 @@ type PacketBuffer struct { // data are held in the same underlying buffer storage. header buffer.Prependable - // NetworkProtocol is only valid when NetworkHeader is set. + // NetworkProtocolNumber is only valid when NetworkHeader is set. // TODO(gvisor.dev/issue/3574): Remove the separately passed protocol // numbers in registration APIs that take a PacketBuffer. NetworkProtocolNumber tcpip.NetworkProtocolNumber diff --git a/pkg/tcpip/stack/registration.go b/pkg/tcpip/stack/registration.go index aca2f77f8..4fa86a3ac 100644 --- a/pkg/tcpip/stack/registration.go +++ b/pkg/tcpip/stack/registration.go @@ -159,12 +159,12 @@ type TransportProtocol interface { // SetOption allows enabling/disabling protocol specific features. // SetOption returns an error if the option is not supported or the // provided option value is invalid. - SetOption(option interface{}) *tcpip.Error + SetOption(option tcpip.SettableTransportProtocolOption) *tcpip.Error // Option allows retrieving protocol specific option values. // Option returns an error if the option is not supported or the // provided option value is invalid. - Option(option interface{}) *tcpip.Error + Option(option tcpip.GettableTransportProtocolOption) *tcpip.Error // Close requests that any worker goroutines owned by the protocol // stop. @@ -298,17 +298,17 @@ type NetworkProtocol interface { ParseAddresses(v buffer.View) (src, dst tcpip.Address) // NewEndpoint creates a new endpoint of this protocol. - NewEndpoint(nicID tcpip.NICID, linkAddrCache LinkAddressCache, dispatcher TransportDispatcher, sender LinkEndpoint, st *Stack) NetworkEndpoint + NewEndpoint(nicID tcpip.NICID, linkAddrCache LinkAddressCache, nud NUDHandler, dispatcher TransportDispatcher, sender LinkEndpoint, st *Stack) NetworkEndpoint // SetOption allows enabling/disabling protocol specific features. // SetOption returns an error if the option is not supported or the // provided option value is invalid. - SetOption(option interface{}) *tcpip.Error + SetOption(option tcpip.SettableNetworkProtocolOption) *tcpip.Error // Option allows retrieving protocol specific option values. // Option returns an error if the option is not supported or the // provided option value is invalid. - Option(option interface{}) *tcpip.Error + Option(option tcpip.GettableNetworkProtocolOption) *tcpip.Error // Close requests that any worker goroutines owned by the protocol // stop. @@ -488,7 +488,7 @@ type LinkAddressResolver interface { ResolveStaticAddress(addr tcpip.Address) (tcpip.LinkAddress, bool) // LinkAddressProtocol returns the network protocol of the - // addresses this this resolver can resolve. + // addresses this resolver can resolve. LinkAddressProtocol() tcpip.NetworkProtocolNumber } diff --git a/pkg/tcpip/stack/route.go b/pkg/tcpip/stack/route.go index e267bebb0..2cbbf0de8 100644 --- a/pkg/tcpip/stack/route.go +++ b/pkg/tcpip/stack/route.go @@ -48,10 +48,6 @@ type Route struct { // Loop controls where WritePacket should send packets. Loop PacketLooping - - // directedBroadcast indicates whether this route is sending a directed - // broadcast packet. - directedBroadcast bool } // makeRoute initializes a new route. It takes ownership of the provided @@ -141,6 +137,16 @@ func (r *Route) Resolve(waker *sleep.Waker) (<-chan struct{}, *tcpip.Error) { } nextAddr = r.RemoteAddress } + + if r.ref.nic.neigh != nil { + entry, ch, err := r.ref.nic.neigh.entry(nextAddr, r.LocalAddress, r.ref.linkRes, waker) + if err != nil { + return ch, err + } + r.RemoteLinkAddress = entry.LinkAddr + return nil, nil + } + linkAddr, ch, err := r.ref.linkCache.GetLinkAddress(r.ref.nic.ID(), nextAddr, r.LocalAddress, r.NetProto, waker) if err != nil { return ch, err @@ -155,6 +161,12 @@ func (r *Route) RemoveWaker(waker *sleep.Waker) { if nextAddr == "" { nextAddr = r.RemoteAddress } + + if r.ref.nic.neigh != nil { + r.ref.nic.neigh.removeWaker(nextAddr, waker) + return + } + r.ref.linkCache.RemoveWaker(r.ref.nic.ID(), nextAddr, waker) } @@ -163,6 +175,9 @@ func (r *Route) RemoveWaker(waker *sleep.Waker) { // // The NIC r uses must not be locked. func (r *Route) IsResolutionRequired() bool { + if r.ref.nic.neigh != nil { + return r.ref.isValidForOutgoing() && r.ref.linkRes != nil && r.RemoteLinkAddress == "" + } return r.ref.isValidForOutgoing() && r.ref.linkCache != nil && r.RemoteLinkAddress == "" } @@ -284,24 +299,27 @@ func (r *Route) Stack() *Stack { return r.ref.stack() } +func (r *Route) isV4Broadcast(addr tcpip.Address) bool { + if addr == header.IPv4Broadcast { + return true + } + + subnet := r.ref.addrWithPrefix().Subnet() + return subnet.IsBroadcast(addr) +} + // IsOutboundBroadcast returns true if the route is for an outbound broadcast // packet. func (r *Route) IsOutboundBroadcast() bool { // Only IPv4 has a notion of broadcast. - return r.directedBroadcast || r.RemoteAddress == header.IPv4Broadcast + return r.isV4Broadcast(r.RemoteAddress) } // IsInboundBroadcast returns true if the route is for an inbound broadcast // packet. func (r *Route) IsInboundBroadcast() bool { // Only IPv4 has a notion of broadcast. - if r.LocalAddress == header.IPv4Broadcast { - return true - } - - addr := r.ref.addrWithPrefix() - subnet := addr.Subnet() - return subnet.IsBroadcast(r.LocalAddress) + return r.isV4Broadcast(r.LocalAddress) } // ReverseRoute returns new route with given source and destination address. diff --git a/pkg/tcpip/stack/stack.go b/pkg/tcpip/stack/stack.go index 814b3e94a..68cf77de2 100644 --- a/pkg/tcpip/stack/stack.go +++ b/pkg/tcpip/stack/stack.go @@ -22,7 +22,6 @@ package stack import ( "bytes" "encoding/binary" - "math" mathrand "math/rand" "sync/atomic" "time" @@ -51,41 +50,6 @@ const ( DefaultTOS = 0 ) -const ( - // fakeNetNumber is used as a protocol number in tests. - // - // This constant should match fakeNetNumber in stack_test.go. - fakeNetNumber tcpip.NetworkProtocolNumber = math.MaxUint32 -) - -type forwardingFlag uint32 - -// Packet forwarding flags. Forwarding settings for different network protocols -// are stored as bit flags in an uint32 number. -const ( - forwardingIPv4 forwardingFlag = 1 << iota - forwardingIPv6 - - // forwardingFake is used to test package forwarding with a fake protocol. - forwardingFake -) - -func getForwardingFlag(protocol tcpip.NetworkProtocolNumber) forwardingFlag { - var flag forwardingFlag - switch protocol { - case header.IPv4ProtocolNumber: - flag = forwardingIPv4 - case header.IPv6ProtocolNumber: - flag = forwardingIPv6 - case fakeNetNumber: - // This network protocol number is used to test packet forwarding. - flag = forwardingFake - default: - // We only support forwarding for IPv4 and IPv6. - } - return flag -} - type transportProtocolState struct { proto TransportProtocol defaultHandler func(r *Route, id TransportEndpointID, pkt *PacketBuffer) bool @@ -284,7 +248,7 @@ type RcvBufAutoTuneParams struct { // was started. MeasureTime time.Time - // CopiedBytes is the number of bytes copied to userspace since + // CopiedBytes is the number of bytes copied to user space since // this measure began. CopiedBytes int @@ -441,6 +405,13 @@ type Stack struct { networkProtocols map[tcpip.NetworkProtocolNumber]NetworkProtocol linkAddrResolvers map[tcpip.NetworkProtocolNumber]LinkAddressResolver + // forwarding contains the whether packet forwarding is enabled or not for + // different network protocols. + forwarding struct { + sync.RWMutex + protocols map[tcpip.NetworkProtocolNumber]bool + } + // rawFactory creates raw endpoints. If nil, raw endpoints are // disabled. It is set during Stack creation and is immutable. rawFactory RawFactory @@ -454,14 +425,9 @@ type Stack struct { mu sync.RWMutex nics map[tcpip.NICID]*NIC - // forwarding contains the enable bits for packet forwarding for different - // network protocols. - forwarding struct { - sync.RWMutex - flag forwardingFlag - } - - cleanupEndpoints map[TransportEndpoint]struct{} + // cleanupEndpointsMu protects cleanupEndpoints. + cleanupEndpointsMu sync.Mutex + cleanupEndpoints map[TransportEndpoint]struct{} // route is the route table passed in by the user via SetRouteTable(), // it is used by FindRoute() to build a route for a specific @@ -472,7 +438,7 @@ type Stack struct { // If not nil, then any new endpoints will have this probe function // invoked everytime they receive a TCP segment. - tcpProbeFunc TCPProbeFunc + tcpProbeFunc atomic.Value // TCPProbeFunc // clock is used to generate user-visible times. clock tcpip.Clock @@ -504,6 +470,10 @@ type Stack struct { // nudConfigs is the default NUD configurations used by interfaces. nudConfigs NUDConfigurations + // useNeighborCache indicates whether ARP and NDP packets should be handled + // by the NIC's neighborCache instead of linkAddrCache. + useNeighborCache bool + // autoGenIPv6LinkLocal determines whether or not the stack will attempt // to auto-generate an IPv6 link-local address for newly enabled non-loopback // NICs. See the AutoGenIPv6LinkLocal field of Options for more details. @@ -584,6 +554,13 @@ type Options struct { // NUDConfigs is the default NUD configurations used by interfaces. NUDConfigs NUDConfigurations + // UseNeighborCache indicates whether ARP and NDP packets should be handled + // by the Neighbor Unreachability Detection (NUD) state machine. This flag + // also enables the APIs for inspecting and modifying the neighbor table via + // NUDDispatcher and the following Stack methods: Neighbors, RemoveNeighbor, + // and ClearNeighbors. + UseNeighborCache bool + // AutoGenIPv6LinkLocal determines whether or not the stack will attempt to // auto-generate an IPv6 link-local address for newly enabled non-loopback // NICs. @@ -758,6 +735,7 @@ func New(opts Options) *Stack { seed: generateRandUint32(), ndpConfigs: opts.NDPConfigs, nudConfigs: opts.NUDConfigs, + useNeighborCache: opts.UseNeighborCache, autoGenIPv6LinkLocal: opts.AutoGenIPv6LinkLocal, uniqueIDGenerator: opts.UniqueID, ndpDisp: opts.NDPDisp, @@ -777,6 +755,7 @@ func New(opts Options) *Stack { Max: DefaultMaxBufferSize, }, } + s.forwarding.protocols = make(map[tcpip.NetworkProtocolNumber]bool) // Add specified network protocols. for _, netProto := range opts.NetworkProtocols { @@ -816,7 +795,7 @@ func (s *Stack) UniqueID() uint64 { // options. This method returns an error if the protocol is not supported or // option is not supported by the protocol implementation or the provided value // is incorrect. -func (s *Stack) SetNetworkProtocolOption(network tcpip.NetworkProtocolNumber, option interface{}) *tcpip.Error { +func (s *Stack) SetNetworkProtocolOption(network tcpip.NetworkProtocolNumber, option tcpip.SettableNetworkProtocolOption) *tcpip.Error { netProto, ok := s.networkProtocols[network] if !ok { return tcpip.ErrUnknownProtocol @@ -833,7 +812,7 @@ func (s *Stack) SetNetworkProtocolOption(network tcpip.NetworkProtocolNumber, op // if err != nil { // ... // } -func (s *Stack) NetworkProtocolOption(network tcpip.NetworkProtocolNumber, option interface{}) *tcpip.Error { +func (s *Stack) NetworkProtocolOption(network tcpip.NetworkProtocolNumber, option tcpip.GettableNetworkProtocolOption) *tcpip.Error { netProto, ok := s.networkProtocols[network] if !ok { return tcpip.ErrUnknownProtocol @@ -845,7 +824,7 @@ func (s *Stack) NetworkProtocolOption(network tcpip.NetworkProtocolNumber, optio // options. This method returns an error if the protocol is not supported or // option is not supported by the protocol implementation or the provided value // is incorrect. -func (s *Stack) SetTransportProtocolOption(transport tcpip.TransportProtocolNumber, option interface{}) *tcpip.Error { +func (s *Stack) SetTransportProtocolOption(transport tcpip.TransportProtocolNumber, option tcpip.SettableTransportProtocolOption) *tcpip.Error { transProtoState, ok := s.transportProtocols[transport] if !ok { return tcpip.ErrUnknownProtocol @@ -860,7 +839,7 @@ func (s *Stack) SetTransportProtocolOption(transport tcpip.TransportProtocolNumb // if err := s.TransportProtocolOption(tcpip.TCPProtocolNumber, &v); err != nil { // ... // } -func (s *Stack) TransportProtocolOption(transport tcpip.TransportProtocolNumber, option interface{}) *tcpip.Error { +func (s *Stack) TransportProtocolOption(transport tcpip.TransportProtocolNumber, option tcpip.GettableTransportProtocolOption) *tcpip.Error { transProtoState, ok := s.transportProtocols[transport] if !ok { return tcpip.ErrUnknownProtocol @@ -904,23 +883,14 @@ func (s *Stack) SetForwarding(protocol tcpip.NetworkProtocolNumber, enable bool) return } - flag := getForwardingFlag(protocol) - // If the forwarding value for this protocol hasn't changed then do // nothing. - if s.forwarding.flag&getForwardingFlag(protocol) != 0 == enable { + if forwarding := s.forwarding.protocols[protocol]; forwarding == enable { return } - var newValue forwardingFlag - if enable { - newValue = s.forwarding.flag | flag - } else { - newValue = s.forwarding.flag & ^flag - } - s.forwarding.flag = newValue + s.forwarding.protocols[protocol] = enable - // Enable or disable NDP for IPv6. if protocol == header.IPv6ProtocolNumber { if enable { for _, nic := range s.nics { @@ -938,7 +908,7 @@ func (s *Stack) SetForwarding(protocol tcpip.NetworkProtocolNumber, enable bool) func (s *Stack) Forwarding(protocol tcpip.NetworkProtocolNumber) bool { s.forwarding.RLock() defer s.forwarding.RUnlock() - return s.forwarding.flag&getForwardingFlag(protocol) != 0 + return s.forwarding.protocols[protocol] } // SetRouteTable assigns the route table to be used by this stack. It @@ -1257,8 +1227,8 @@ func (s *Stack) AddProtocolAddressWithOptions(id tcpip.NICID, protocolAddress tc s.mu.RLock() defer s.mu.RUnlock() - nic := s.nics[id] - if nic == nil { + nic, ok := s.nics[id] + if !ok { return tcpip.ErrUnknownNICID } @@ -1344,13 +1314,11 @@ func (s *Stack) FindRoute(id tcpip.NICID, localAddr, remoteAddr tcpip.Address, n } r := makeRoute(netProto, ref.address(), remoteAddr, nic.linkEP.LinkAddress(), ref, s.handleLocal && !nic.isLoopback(), multicastLoop && !nic.isLoopback()) - r.directedBroadcast = route.Destination.IsBroadcast(remoteAddr) - if len(route.Gateway) > 0 { if needRoute { r.NextHop = route.Gateway } - } else if r.directedBroadcast { + } else if subnet := ref.addrWithPrefix().Subnet(); subnet.IsBroadcast(remoteAddr) { r.RemoteLinkAddress = header.EthernetBroadcastAddress } @@ -1383,8 +1351,8 @@ func (s *Stack) CheckLocalAddress(nicID tcpip.NICID, protocol tcpip.NetworkProto // If a NIC is specified, we try to find the address there only. if nicID != 0 { - nic := s.nics[nicID] - if nic == nil { + nic, ok := s.nics[nicID] + if !ok { return 0 } @@ -1415,8 +1383,8 @@ func (s *Stack) SetPromiscuousMode(nicID tcpip.NICID, enable bool) *tcpip.Error s.mu.RLock() defer s.mu.RUnlock() - nic := s.nics[nicID] - if nic == nil { + nic, ok := s.nics[nicID] + if !ok { return tcpip.ErrUnknownNICID } @@ -1431,8 +1399,8 @@ func (s *Stack) SetSpoofing(nicID tcpip.NICID, enable bool) *tcpip.Error { s.mu.RLock() defer s.mu.RUnlock() - nic := s.nics[nicID] - if nic == nil { + nic, ok := s.nics[nicID] + if !ok { return tcpip.ErrUnknownNICID } @@ -1464,8 +1432,33 @@ func (s *Stack) GetLinkAddress(nicID tcpip.NICID, addr, localAddr tcpip.Address, return s.linkAddrCache.get(fullAddr, linkRes, localAddr, nic.linkEP, waker) } -// RemoveWaker implements LinkAddressCache.RemoveWaker. +// Neighbors returns all IP to MAC address associations. +func (s *Stack) Neighbors(nicID tcpip.NICID) ([]NeighborEntry, *tcpip.Error) { + s.mu.RLock() + nic, ok := s.nics[nicID] + s.mu.RUnlock() + + if !ok { + return nil, tcpip.ErrUnknownNICID + } + + return nic.neighbors() +} + +// RemoveWaker removes a waker that has been added when link resolution for +// addr was requested. func (s *Stack) RemoveWaker(nicID tcpip.NICID, addr tcpip.Address, waker *sleep.Waker) { + if s.useNeighborCache { + s.mu.RLock() + nic, ok := s.nics[nicID] + s.mu.RUnlock() + + if ok { + nic.removeWaker(addr, waker) + } + return + } + s.mu.RLock() defer s.mu.RUnlock() @@ -1475,6 +1468,47 @@ func (s *Stack) RemoveWaker(nicID tcpip.NICID, addr tcpip.Address, waker *sleep. } } +// AddStaticNeighbor statically associates an IP address to a MAC address. +func (s *Stack) AddStaticNeighbor(nicID tcpip.NICID, addr tcpip.Address, linkAddr tcpip.LinkAddress) *tcpip.Error { + s.mu.RLock() + nic, ok := s.nics[nicID] + s.mu.RUnlock() + + if !ok { + return tcpip.ErrUnknownNICID + } + + return nic.addStaticNeighbor(addr, linkAddr) +} + +// RemoveNeighbor removes an IP to MAC address association previously created +// either automically or by AddStaticNeighbor. Returns ErrBadAddress if there +// is no association with the provided address. +func (s *Stack) RemoveNeighbor(nicID tcpip.NICID, addr tcpip.Address) *tcpip.Error { + s.mu.RLock() + nic, ok := s.nics[nicID] + s.mu.RUnlock() + + if !ok { + return tcpip.ErrUnknownNICID + } + + return nic.removeNeighbor(addr) +} + +// ClearNeighbors removes all IP to MAC address associations. +func (s *Stack) ClearNeighbors(nicID tcpip.NICID) *tcpip.Error { + s.mu.RLock() + nic, ok := s.nics[nicID] + s.mu.RUnlock() + + if !ok { + return tcpip.ErrUnknownNICID + } + + return nic.clearNeighbors() +} + // RegisterTransportEndpoint registers the given endpoint with the stack // transport dispatcher. Received packets that match the provided id will be // delivered to the given endpoint; specifying a nic is optional, but @@ -1498,10 +1532,9 @@ func (s *Stack) UnregisterTransportEndpoint(nicID tcpip.NICID, netProtos []tcpip // StartTransportEndpointCleanup removes the endpoint with the given id from // the stack transport dispatcher. It also transitions it to the cleanup stage. func (s *Stack) StartTransportEndpointCleanup(nicID tcpip.NICID, netProtos []tcpip.NetworkProtocolNumber, protocol tcpip.TransportProtocolNumber, id TransportEndpointID, ep TransportEndpoint, flags ports.Flags, bindToDevice tcpip.NICID) { - s.mu.Lock() - defer s.mu.Unlock() - + s.cleanupEndpointsMu.Lock() s.cleanupEndpoints[ep] = struct{}{} + s.cleanupEndpointsMu.Unlock() s.demux.unregisterEndpoint(netProtos, protocol, id, ep, flags, bindToDevice) } @@ -1509,9 +1542,9 @@ func (s *Stack) StartTransportEndpointCleanup(nicID tcpip.NICID, netProtos []tcp // CompleteTransportEndpointCleanup removes the endpoint from the cleanup // stage. func (s *Stack) CompleteTransportEndpointCleanup(ep TransportEndpoint) { - s.mu.Lock() + s.cleanupEndpointsMu.Lock() delete(s.cleanupEndpoints, ep) - s.mu.Unlock() + s.cleanupEndpointsMu.Unlock() } // FindTransportEndpoint finds an endpoint that most closely matches the provided @@ -1554,23 +1587,23 @@ func (s *Stack) RegisteredEndpoints() []TransportEndpoint { // CleanupEndpoints returns endpoints currently in the cleanup state. func (s *Stack) CleanupEndpoints() []TransportEndpoint { - s.mu.Lock() + s.cleanupEndpointsMu.Lock() es := make([]TransportEndpoint, 0, len(s.cleanupEndpoints)) for e := range s.cleanupEndpoints { es = append(es, e) } - s.mu.Unlock() + s.cleanupEndpointsMu.Unlock() return es } // RestoreCleanupEndpoints adds endpoints to cleanup tracking. This is useful // for restoring a stack after a save. func (s *Stack) RestoreCleanupEndpoints(es []TransportEndpoint) { - s.mu.Lock() + s.cleanupEndpointsMu.Lock() for _, e := range es { s.cleanupEndpoints[e] = struct{}{} } - s.mu.Unlock() + s.cleanupEndpointsMu.Unlock() } // Close closes all currently registered transport endpoints. @@ -1765,18 +1798,17 @@ func (s *Stack) TransportProtocolInstance(num tcpip.TransportProtocolNumber) Tra // guarantee provided on which probe will be invoked. Ideally this should only // be called once per stack. func (s *Stack) AddTCPProbe(probe TCPProbeFunc) { - s.mu.Lock() - s.tcpProbeFunc = probe - s.mu.Unlock() + s.tcpProbeFunc.Store(probe) } // GetTCPProbe returns the TCPProbeFunc if installed with AddTCPProbe, nil // otherwise. func (s *Stack) GetTCPProbe() TCPProbeFunc { - s.mu.Lock() - p := s.tcpProbeFunc - s.mu.Unlock() - return p + p := s.tcpProbeFunc.Load() + if p == nil { + return nil + } + return p.(TCPProbeFunc) } // RemoveTCPProbe removes an installed TCP probe. @@ -1785,9 +1817,8 @@ func (s *Stack) GetTCPProbe() TCPProbeFunc { // have a probe attached. Endpoints already created will continue to invoke // TCP probe. func (s *Stack) RemoveTCPProbe() { - s.mu.Lock() - s.tcpProbeFunc = nil - s.mu.Unlock() + // This must be TCPProbeFunc(nil) because atomic.Value.Store(nil) panics. + s.tcpProbeFunc.Store(TCPProbeFunc(nil)) } // JoinGroup joins the given multicast group on the given NIC. @@ -2009,7 +2040,7 @@ func (s *Stack) FindNetworkEndpoint(netProto tcpip.NetworkProtocolNumber, addres return nil, tcpip.ErrBadAddress } -// FindNICNameFromID returns the name of the nic for the given NICID. +// FindNICNameFromID returns the name of the NIC for the given NICID. func (s *Stack) FindNICNameFromID(id tcpip.NICID) string { s.mu.RLock() defer s.mu.RUnlock() diff --git a/pkg/tcpip/stack/stack_test.go b/pkg/tcpip/stack/stack_test.go index f168be402..7669ba672 100644 --- a/pkg/tcpip/stack/stack_test.go +++ b/pkg/tcpip/stack/stack_test.go @@ -42,9 +42,6 @@ import ( ) const ( - // fakeNetNumber is used as a protocol number in tests. - // - // This constant should match fakeNetNumber in stack.go. fakeNetNumber tcpip.NetworkProtocolNumber = math.MaxUint32 fakeNetHeaderLen = 12 fakeDefaultPrefixLen = 8 @@ -161,23 +158,13 @@ func (*fakeNetworkEndpoint) WriteHeaderIncludedPacket(r *stack.Route, pkt *stack func (*fakeNetworkEndpoint) Close() {} -type fakeNetGoodOption bool - -type fakeNetBadOption bool - -type fakeNetInvalidValueOption int - -type fakeNetOptions struct { - good bool -} - // fakeNetworkProtocol is a network-layer protocol descriptor. It aggregates the // number of packets sent and received via endpoints of this protocol. The index // where packets are added is given by the packet's destination address MOD 10. type fakeNetworkProtocol struct { packetCount [10]int sendPacketCount [10]int - opts fakeNetOptions + defaultTTL uint8 } func (f *fakeNetworkProtocol) Number() tcpip.NetworkProtocolNumber { @@ -200,7 +187,7 @@ func (*fakeNetworkProtocol) ParseAddresses(v buffer.View) (src, dst tcpip.Addres return tcpip.Address(v[srcAddrOffset : srcAddrOffset+1]), tcpip.Address(v[dstAddrOffset : dstAddrOffset+1]) } -func (f *fakeNetworkProtocol) NewEndpoint(nicID tcpip.NICID, linkAddrCache stack.LinkAddressCache, dispatcher stack.TransportDispatcher, ep stack.LinkEndpoint, _ *stack.Stack) stack.NetworkEndpoint { +func (f *fakeNetworkProtocol) NewEndpoint(nicID tcpip.NICID, _ stack.LinkAddressCache, _ stack.NUDHandler, dispatcher stack.TransportDispatcher, ep stack.LinkEndpoint, _ *stack.Stack) stack.NetworkEndpoint { return &fakeNetworkEndpoint{ nicID: nicID, proto: f, @@ -209,22 +196,20 @@ func (f *fakeNetworkProtocol) NewEndpoint(nicID tcpip.NICID, linkAddrCache stack } } -func (f *fakeNetworkProtocol) SetOption(option interface{}) *tcpip.Error { +func (f *fakeNetworkProtocol) SetOption(option tcpip.SettableNetworkProtocolOption) *tcpip.Error { switch v := option.(type) { - case fakeNetGoodOption: - f.opts.good = bool(v) + case *tcpip.DefaultTTLOption: + f.defaultTTL = uint8(*v) return nil - case fakeNetInvalidValueOption: - return tcpip.ErrInvalidOptionValue default: return tcpip.ErrUnknownProtocolOption } } -func (f *fakeNetworkProtocol) Option(option interface{}) *tcpip.Error { +func (f *fakeNetworkProtocol) Option(option tcpip.GettableNetworkProtocolOption) *tcpip.Error { switch v := option.(type) { - case *fakeNetGoodOption: - *v = fakeNetGoodOption(f.opts.good) + case *tcpip.DefaultTTLOption: + *v = tcpip.DefaultTTLOption(f.defaultTTL) return nil default: return tcpip.ErrUnknownProtocolOption @@ -1643,46 +1628,24 @@ func TestMulticastOrIPv6LinkLocalNeedsNoRoute(t *testing.T) { } } -func TestNetworkOptions(t *testing.T) { +func TestNetworkOption(t *testing.T) { s := stack.New(stack.Options{ NetworkProtocols: []stack.NetworkProtocol{fakeNetFactory()}, TransportProtocols: []stack.TransportProtocol{}, }) - // Try an unsupported network protocol. - if err := s.SetNetworkProtocolOption(tcpip.NetworkProtocolNumber(99999), fakeNetGoodOption(false)); err != tcpip.ErrUnknownProtocol { - t.Fatalf("SetNetworkProtocolOption(fakeNet2, blah, false) = %v, want = tcpip.ErrUnknownProtocol", err) + opt := tcpip.DefaultTTLOption(5) + if err := s.SetNetworkProtocolOption(fakeNetNumber, &opt); err != nil { + t.Fatalf("s.SetNetworkProtocolOption(%d, &%T(%d)): %s", fakeNetNumber, opt, opt, err) } - testCases := []struct { - option interface{} - wantErr *tcpip.Error - verifier func(t *testing.T, p stack.NetworkProtocol) - }{ - {fakeNetGoodOption(true), nil, func(t *testing.T, p stack.NetworkProtocol) { - t.Helper() - fakeNet := p.(*fakeNetworkProtocol) - if fakeNet.opts.good != true { - t.Fatalf("fakeNet.opts.good = false, want = true") - } - var v fakeNetGoodOption - if err := s.NetworkProtocolOption(fakeNetNumber, &v); err != nil { - t.Fatalf("s.NetworkProtocolOption(fakeNetNumber, &v) = %v, want = nil, where v is option %T", v, err) - } - if v != true { - t.Fatalf("s.NetworkProtocolOption(fakeNetNumber, &v) returned v = %v, want = true", v) - } - }}, - {fakeNetBadOption(true), tcpip.ErrUnknownProtocolOption, nil}, - {fakeNetInvalidValueOption(1), tcpip.ErrInvalidOptionValue, nil}, + var optGot tcpip.DefaultTTLOption + if err := s.NetworkProtocolOption(fakeNetNumber, &optGot); err != nil { + t.Fatalf("s.NetworkProtocolOption(%d, &%T): %s", fakeNetNumber, optGot, err) } - for _, tc := range testCases { - if got := s.SetNetworkProtocolOption(fakeNetNumber, tc.option); got != tc.wantErr { - t.Errorf("s.SetNetworkProtocolOption(fakeNet, %v) = %v, want = %v", tc.option, got, tc.wantErr) - } - if tc.verifier != nil { - tc.verifier(t, s.NetworkProtocolInstance(fakeNetNumber)) - } + + if opt != optGot { + t.Errorf("got optGot = %d, want = %d", optGot, opt) } } diff --git a/pkg/tcpip/stack/transport_demuxer.go b/pkg/tcpip/stack/transport_demuxer.go index b902c6ca9..0774b5382 100644 --- a/pkg/tcpip/stack/transport_demuxer.go +++ b/pkg/tcpip/stack/transport_demuxer.go @@ -165,7 +165,7 @@ func (epsByNIC *endpointsByNIC) handlePacket(r *Route, id TransportEndpointID, p // If this is a broadcast or multicast datagram, deliver the datagram to all // endpoints bound to the right device. - if isMulticastOrBroadcast(id.LocalAddress) { + if isInboundMulticastOrBroadcast(r) { mpep.handlePacketAll(r, id, pkt) epsByNIC.mu.RUnlock() // Don't use defer for performance reasons. return @@ -526,7 +526,7 @@ func (d *transportDemuxer) deliverPacket(r *Route, protocol tcpip.TransportProto // If the packet is a UDP broadcast or multicast, then find all matching // transport endpoints. - if protocol == header.UDPProtocolNumber && isMulticastOrBroadcast(id.LocalAddress) { + if protocol == header.UDPProtocolNumber && isInboundMulticastOrBroadcast(r) { eps.mu.RLock() destEPs := eps.findAllEndpointsLocked(id) eps.mu.RUnlock() @@ -546,7 +546,7 @@ func (d *transportDemuxer) deliverPacket(r *Route, protocol tcpip.TransportProto // If the packet is a TCP packet with a non-unicast source or destination // address, then do nothing further and instruct the caller to do the same. - if protocol == header.TCPProtocolNumber && (!isUnicast(r.LocalAddress) || !isUnicast(r.RemoteAddress)) { + if protocol == header.TCPProtocolNumber && (!isInboundUnicast(r) || !isOutboundUnicast(r)) { // TCP can only be used to communicate between a single source and a // single destination; the addresses must be unicast. r.Stats().TCP.InvalidSegmentsReceived.Increment() @@ -677,10 +677,14 @@ func (d *transportDemuxer) unregisterRawEndpoint(netProto tcpip.NetworkProtocolN eps.mu.Unlock() } -func isMulticastOrBroadcast(addr tcpip.Address) bool { - return addr == header.IPv4Broadcast || header.IsV4MulticastAddress(addr) || header.IsV6MulticastAddress(addr) +func isInboundMulticastOrBroadcast(r *Route) bool { + return r.IsInboundBroadcast() || header.IsV4MulticastAddress(r.LocalAddress) || header.IsV6MulticastAddress(r.LocalAddress) } -func isUnicast(addr tcpip.Address) bool { - return addr != header.IPv4Any && addr != header.IPv6Any && !isMulticastOrBroadcast(addr) +func isInboundUnicast(r *Route) bool { + return r.LocalAddress != header.IPv4Any && r.LocalAddress != header.IPv6Any && !isInboundMulticastOrBroadcast(r) +} + +func isOutboundUnicast(r *Route) bool { + return r.RemoteAddress != header.IPv4Any && r.RemoteAddress != header.IPv6Any && !r.IsOutboundBroadcast() && !header.IsV4MulticastAddress(r.RemoteAddress) && !header.IsV6MulticastAddress(r.RemoteAddress) } diff --git a/pkg/tcpip/stack/transport_demuxer_test.go b/pkg/tcpip/stack/transport_demuxer_test.go index 1339edc2d..4d6d62eec 100644 --- a/pkg/tcpip/stack/transport_demuxer_test.go +++ b/pkg/tcpip/stack/transport_demuxer_test.go @@ -312,8 +312,8 @@ func TestBindToDeviceDistribution(t *testing.T) { t.Fatalf("SetSockOptBool(ReusePortOption, %t) on endpoint %d failed: %s", endpoint.reuse, i, err) } bindToDeviceOption := tcpip.BindToDeviceOption(endpoint.bindToDevice) - if err := ep.SetSockOpt(bindToDeviceOption); err != nil { - t.Fatalf("SetSockOpt(%#v) on endpoint %d failed: %s", bindToDeviceOption, i, err) + if err := ep.SetSockOpt(&bindToDeviceOption); err != nil { + t.Fatalf("SetSockOpt(&%T(%d)) on endpoint %d failed: %s", bindToDeviceOption, bindToDeviceOption, i, err) } var dstAddr tcpip.Address diff --git a/pkg/tcpip/stack/transport_test.go b/pkg/tcpip/stack/transport_test.go index fa4b14ba6..64e44bc99 100644 --- a/pkg/tcpip/stack/transport_test.go +++ b/pkg/tcpip/stack/transport_test.go @@ -53,11 +53,11 @@ func (f *fakeTransportEndpoint) Info() tcpip.EndpointInfo { return &f.TransportEndpointInfo } -func (f *fakeTransportEndpoint) Stats() tcpip.EndpointStats { +func (*fakeTransportEndpoint) Stats() tcpip.EndpointStats { return nil } -func (f *fakeTransportEndpoint) SetOwner(owner tcpip.PacketOwner) {} +func (*fakeTransportEndpoint) SetOwner(owner tcpip.PacketOwner) {} func newFakeTransportEndpoint(s *stack.Stack, proto *fakeTransportProtocol, netProto tcpip.NetworkProtocolNumber, uniqueID uint64) tcpip.Endpoint { return &fakeTransportEndpoint{stack: s, TransportEndpointInfo: stack.TransportEndpointInfo{NetProto: netProto}, proto: proto, uniqueID: uniqueID} @@ -100,12 +100,12 @@ func (f *fakeTransportEndpoint) Write(p tcpip.Payloader, opts tcpip.WriteOptions return int64(len(v)), nil, nil } -func (f *fakeTransportEndpoint) Peek([][]byte) (int64, tcpip.ControlMessages, *tcpip.Error) { +func (*fakeTransportEndpoint) Peek([][]byte) (int64, tcpip.ControlMessages, *tcpip.Error) { return 0, tcpip.ControlMessages{}, nil } // SetSockOpt sets a socket option. Currently not supported. -func (*fakeTransportEndpoint) SetSockOpt(interface{}) *tcpip.Error { +func (*fakeTransportEndpoint) SetSockOpt(tcpip.SettableSocketOption) *tcpip.Error { return tcpip.ErrInvalidEndpointState } @@ -130,11 +130,7 @@ func (*fakeTransportEndpoint) GetSockOptInt(opt tcpip.SockOptInt) (int, *tcpip.E } // GetSockOpt implements tcpip.Endpoint.GetSockOpt. -func (*fakeTransportEndpoint) GetSockOpt(opt interface{}) *tcpip.Error { - switch opt.(type) { - case tcpip.ErrorOption: - return nil - } +func (*fakeTransportEndpoint) GetSockOpt(tcpip.GettableSocketOption) *tcpip.Error { return tcpip.ErrInvalidEndpointState } @@ -169,7 +165,7 @@ func (f *fakeTransportEndpoint) UniqueID() uint64 { return f.uniqueID } -func (f *fakeTransportEndpoint) ConnectEndpoint(e tcpip.Endpoint) *tcpip.Error { +func (*fakeTransportEndpoint) ConnectEndpoint(e tcpip.Endpoint) *tcpip.Error { return nil } @@ -184,7 +180,7 @@ func (*fakeTransportEndpoint) Listen(int) *tcpip.Error { return nil } -func (f *fakeTransportEndpoint) Accept() (tcpip.Endpoint, *waiter.Queue, *tcpip.Error) { +func (f *fakeTransportEndpoint) Accept(*tcpip.FullAddress) (tcpip.Endpoint, *waiter.Queue, *tcpip.Error) { if len(f.acceptQueue) == 0 { return nil, nil, nil } @@ -239,19 +235,19 @@ func (f *fakeTransportEndpoint) HandleControlPacket(stack.TransportEndpointID, s f.proto.controlCount++ } -func (f *fakeTransportEndpoint) State() uint32 { +func (*fakeTransportEndpoint) State() uint32 { return 0 } -func (f *fakeTransportEndpoint) ModerateRecvBuf(copied int) {} +func (*fakeTransportEndpoint) ModerateRecvBuf(copied int) {} -func (f *fakeTransportEndpoint) IPTables() (stack.IPTables, error) { - return stack.IPTables{}, nil -} +func (*fakeTransportEndpoint) Resume(*stack.Stack) {} -func (f *fakeTransportEndpoint) Resume(*stack.Stack) {} +func (*fakeTransportEndpoint) Wait() {} -func (f *fakeTransportEndpoint) Wait() {} +func (*fakeTransportEndpoint) LastError() *tcpip.Error { + return nil +} type fakeTransportGoodOption bool @@ -295,22 +291,20 @@ func (*fakeTransportProtocol) HandleUnknownDestinationPacket(*stack.Route, stack return true } -func (f *fakeTransportProtocol) SetOption(option interface{}) *tcpip.Error { +func (f *fakeTransportProtocol) SetOption(option tcpip.SettableTransportProtocolOption) *tcpip.Error { switch v := option.(type) { - case fakeTransportGoodOption: - f.opts.good = bool(v) + case *tcpip.TCPModerateReceiveBufferOption: + f.opts.good = bool(*v) return nil - case fakeTransportInvalidValueOption: - return tcpip.ErrInvalidOptionValue default: return tcpip.ErrUnknownProtocolOption } } -func (f *fakeTransportProtocol) Option(option interface{}) *tcpip.Error { +func (f *fakeTransportProtocol) Option(option tcpip.GettableTransportProtocolOption) *tcpip.Error { switch v := option.(type) { - case *fakeTransportGoodOption: - *v = fakeTransportGoodOption(f.opts.good) + case *tcpip.TCPModerateReceiveBufferOption: + *v = tcpip.TCPModerateReceiveBufferOption(f.opts.good) return nil default: return tcpip.ErrUnknownProtocolOption @@ -537,41 +531,16 @@ func TestTransportOptions(t *testing.T) { TransportProtocols: []stack.TransportProtocol{fakeTransFactory()}, }) - // Try an unsupported transport protocol. - if err := s.SetTransportProtocolOption(tcpip.TransportProtocolNumber(99999), fakeTransportGoodOption(false)); err != tcpip.ErrUnknownProtocol { - t.Fatalf("SetTransportProtocolOption(fakeTrans2, blah, false) = %v, want = tcpip.ErrUnknownProtocol", err) - } - - testCases := []struct { - option interface{} - wantErr *tcpip.Error - verifier func(t *testing.T, p stack.TransportProtocol) - }{ - {fakeTransportGoodOption(true), nil, func(t *testing.T, p stack.TransportProtocol) { - t.Helper() - fakeTrans := p.(*fakeTransportProtocol) - if fakeTrans.opts.good != true { - t.Fatalf("fakeTrans.opts.good = false, want = true") - } - var v fakeTransportGoodOption - if err := s.TransportProtocolOption(fakeTransNumber, &v); err != nil { - t.Fatalf("s.TransportProtocolOption(fakeTransNumber, &v) = %v, want = nil, where v is option %T", v, err) - } - if v != true { - t.Fatalf("s.TransportProtocolOption(fakeTransNumber, &v) returned v = %v, want = true", v) - } - - }}, - {fakeTransportBadOption(true), tcpip.ErrUnknownProtocolOption, nil}, - {fakeTransportInvalidValueOption(1), tcpip.ErrInvalidOptionValue, nil}, - } - for _, tc := range testCases { - if got := s.SetTransportProtocolOption(fakeTransNumber, tc.option); got != tc.wantErr { - t.Errorf("s.SetTransportProtocolOption(fakeTrans, %v) = %v, want = %v", tc.option, got, tc.wantErr) - } - if tc.verifier != nil { - tc.verifier(t, s.TransportProtocolInstance(fakeTransNumber)) - } + v := tcpip.TCPModerateReceiveBufferOption(true) + if err := s.SetTransportProtocolOption(fakeTransNumber, &v); err != nil { + t.Errorf("s.SetTransportProtocolOption(fakeTrans, &%T(%t)): %s", v, v, err) + } + v = false + if err := s.TransportProtocolOption(fakeTransNumber, &v); err != nil { + t.Fatalf("s.TransportProtocolOption(fakeTransNumber, &%T): %s", v, err) + } + if !v { + t.Fatalf("got tcpip.TCPModerateReceiveBufferOption = false, want = true") } } @@ -635,7 +604,7 @@ func TestTransportForwarding(t *testing.T) { Data: req.ToVectorisedView(), })) - aep, _, err := ep.Accept() + aep, _, err := ep.Accept(nil) if err != nil || aep == nil { t.Fatalf("Accept failed: %v, %v", aep, err) } diff --git a/pkg/tcpip/tcpip.go b/pkg/tcpip/tcpip.go index 07c85ce59..464608dee 100644 --- a/pkg/tcpip/tcpip.go +++ b/pkg/tcpip/tcpip.go @@ -561,7 +561,10 @@ type Endpoint interface { // block if no new connections are available. // // The returned Queue is the wait queue for the newly created endpoint. - Accept() (Endpoint, *waiter.Queue, *Error) + // + // If peerAddr is not nil then it is populated with the peer address of the + // returned endpoint. + Accept(peerAddr *FullAddress) (Endpoint, *waiter.Queue, *Error) // Bind binds the endpoint to a specific local address and port. // Specifying a NIC is optional. @@ -578,8 +581,8 @@ type Endpoint interface { // if waiter.EventIn is set, the endpoint is immediately readable. Readiness(mask waiter.EventMask) waiter.EventMask - // SetSockOpt sets a socket option. opt should be one of the *Option types. - SetSockOpt(opt interface{}) *Error + // SetSockOpt sets a socket option. + SetSockOpt(opt SettableSocketOption) *Error // SetSockOptBool sets a socket option, for simple cases where a value // has the bool type. @@ -589,9 +592,8 @@ type Endpoint interface { // has the int type. SetSockOptInt(opt SockOptInt, v int) *Error - // GetSockOpt gets a socket option. opt should be a pointer to one of the - // *Option types. - GetSockOpt(opt interface{}) *Error + // GetSockOpt gets a socket option. + GetSockOpt(opt GettableSocketOption) *Error // GetSockOptBool gets a socket option for simple cases where a return // value has the bool type. @@ -620,6 +622,9 @@ type Endpoint interface { // SetOwner sets the task owner to the endpoint owner. SetOwner(owner PacketOwner) + + // LastError clears and returns the last error reported by the endpoint. + LastError() *Error } // LinkPacketInfo holds Link layer information for a received packet. @@ -839,14 +844,134 @@ const ( PMTUDiscoveryProbe ) -// ErrorOption is used in GetSockOpt to specify that the last error reported by -// the endpoint should be cleared and returned. -type ErrorOption struct{} +// GettableNetworkProtocolOption is a marker interface for network protocol +// options that may be queried. +type GettableNetworkProtocolOption interface { + isGettableNetworkProtocolOption() +} + +// SettableNetworkProtocolOption is a marker interface for network protocol +// options that may be set. +type SettableNetworkProtocolOption interface { + isSettableNetworkProtocolOption() +} + +// DefaultTTLOption is used by stack.(*Stack).NetworkProtocolOption to specify +// a default TTL. +type DefaultTTLOption uint8 + +func (*DefaultTTLOption) isGettableNetworkProtocolOption() {} + +func (*DefaultTTLOption) isSettableNetworkProtocolOption() {} + +// GettableTransportProtocolOption is a marker interface for transport protocol +// options that may be queried. +type GettableTransportProtocolOption interface { + isGettableTransportProtocolOption() +} + +// SettableTransportProtocolOption is a marker interface for transport protocol +// options that may be set. +type SettableTransportProtocolOption interface { + isSettableTransportProtocolOption() +} + +// TCPSACKEnabled the SACK option for TCP. +// +// See: https://tools.ietf.org/html/rfc2018. +type TCPSACKEnabled bool + +func (*TCPSACKEnabled) isGettableTransportProtocolOption() {} + +func (*TCPSACKEnabled) isSettableTransportProtocolOption() {} + +// TCPRecovery is the loss deteoction algorithm used by TCP. +type TCPRecovery int32 + +func (*TCPRecovery) isGettableTransportProtocolOption() {} + +func (*TCPRecovery) isSettableTransportProtocolOption() {} + +const ( + // TCPRACKLossDetection indicates RACK is used for loss detection and + // recovery. + TCPRACKLossDetection TCPRecovery = 1 << iota + + // TCPRACKStaticReoWnd indicates the reordering window should not be + // adjusted when DSACK is received. + TCPRACKStaticReoWnd + + // TCPRACKNoDupTh indicates RACK should not consider the classic three + // duplicate acknowledgements rule to mark the segments as lost. This + // is used when reordering is not detected. + TCPRACKNoDupTh +) + +// TCPDelayEnabled enables/disables Nagle's algorithm in TCP. +type TCPDelayEnabled bool + +func (*TCPDelayEnabled) isGettableTransportProtocolOption() {} + +func (*TCPDelayEnabled) isSettableTransportProtocolOption() {} + +// TCPSendBufferSizeRangeOption is the send buffer size range for TCP. +type TCPSendBufferSizeRangeOption struct { + Min int + Default int + Max int +} + +func (*TCPSendBufferSizeRangeOption) isGettableTransportProtocolOption() {} + +func (*TCPSendBufferSizeRangeOption) isSettableTransportProtocolOption() {} + +// TCPReceiveBufferSizeRangeOption is the receive buffer size range for TCP. +type TCPReceiveBufferSizeRangeOption struct { + Min int + Default int + Max int +} + +func (*TCPReceiveBufferSizeRangeOption) isGettableTransportProtocolOption() {} + +func (*TCPReceiveBufferSizeRangeOption) isSettableTransportProtocolOption() {} + +// TCPAvailableCongestionControlOption is the supported congestion control +// algorithms for TCP +type TCPAvailableCongestionControlOption string + +func (*TCPAvailableCongestionControlOption) isGettableTransportProtocolOption() {} + +func (*TCPAvailableCongestionControlOption) isSettableTransportProtocolOption() {} + +// TCPModerateReceiveBufferOption enables/disables receive buffer moderation +// for TCP. +type TCPModerateReceiveBufferOption bool + +func (*TCPModerateReceiveBufferOption) isGettableTransportProtocolOption() {} + +func (*TCPModerateReceiveBufferOption) isSettableTransportProtocolOption() {} + +// GettableSocketOption is a marker interface for socket options that may be +// queried. +type GettableSocketOption interface { + isGettableSocketOption() +} + +// SettableSocketOption is a marker interface for socket options that may be +// configured. +type SettableSocketOption interface { + isSettableSocketOption() +} // BindToDeviceOption is used by SetSockOpt/GetSockOpt to specify that sockets // should bind only on a specific NIC. type BindToDeviceOption NICID +func (*BindToDeviceOption) isGettableSocketOption() {} + +func (*BindToDeviceOption) isSettableSocketOption() {} + // TCPInfoOption is used by GetSockOpt to expose TCP statistics. // // TODO(b/64800844): Add and populate stat fields. @@ -855,68 +980,143 @@ type TCPInfoOption struct { RTTVar time.Duration } +func (*TCPInfoOption) isGettableSocketOption() {} + // KeepaliveIdleOption is used by SetSockOpt/GetSockOpt to specify the time a // connection must remain idle before the first TCP keepalive packet is sent. // Once this time is reached, KeepaliveIntervalOption is used instead. type KeepaliveIdleOption time.Duration +func (*KeepaliveIdleOption) isGettableSocketOption() {} + +func (*KeepaliveIdleOption) isSettableSocketOption() {} + // KeepaliveIntervalOption is used by SetSockOpt/GetSockOpt to specify the // interval between sending TCP keepalive packets. type KeepaliveIntervalOption time.Duration +func (*KeepaliveIntervalOption) isGettableSocketOption() {} + +func (*KeepaliveIntervalOption) isSettableSocketOption() {} + // TCPUserTimeoutOption is used by SetSockOpt/GetSockOpt to specify a user // specified timeout for a given TCP connection. // See: RFC5482 for details. type TCPUserTimeoutOption time.Duration +func (*TCPUserTimeoutOption) isGettableSocketOption() {} + +func (*TCPUserTimeoutOption) isSettableSocketOption() {} + // CongestionControlOption is used by SetSockOpt/GetSockOpt to set/get // the current congestion control algorithm. type CongestionControlOption string -// AvailableCongestionControlOption is used to query the supported congestion -// control algorithms. -type AvailableCongestionControlOption string +func (*CongestionControlOption) isGettableSocketOption() {} -// ModerateReceiveBufferOption is used by buffer moderation. -type ModerateReceiveBufferOption bool +func (*CongestionControlOption) isSettableSocketOption() {} + +func (*CongestionControlOption) isGettableTransportProtocolOption() {} + +func (*CongestionControlOption) isSettableTransportProtocolOption() {} // TCPLingerTimeoutOption is used by SetSockOpt/GetSockOpt to set/get the // maximum duration for which a socket lingers in the TCP_FIN_WAIT_2 state // before being marked closed. type TCPLingerTimeoutOption time.Duration +func (*TCPLingerTimeoutOption) isGettableSocketOption() {} + +func (*TCPLingerTimeoutOption) isSettableSocketOption() {} + +func (*TCPLingerTimeoutOption) isGettableTransportProtocolOption() {} + +func (*TCPLingerTimeoutOption) isSettableTransportProtocolOption() {} + // TCPTimeWaitTimeoutOption is used by SetSockOpt/GetSockOpt to set/get the // maximum duration for which a socket lingers in the TIME_WAIT state // before being marked closed. type TCPTimeWaitTimeoutOption time.Duration +func (*TCPTimeWaitTimeoutOption) isGettableSocketOption() {} + +func (*TCPTimeWaitTimeoutOption) isSettableSocketOption() {} + +func (*TCPTimeWaitTimeoutOption) isGettableTransportProtocolOption() {} + +func (*TCPTimeWaitTimeoutOption) isSettableTransportProtocolOption() {} + // TCPDeferAcceptOption is used by SetSockOpt/GetSockOpt to allow a // accept to return a completed connection only when there is data to be // read. This usually means the listening socket will drop the final ACK // for a handshake till the specified timeout until a segment with data arrives. type TCPDeferAcceptOption time.Duration +func (*TCPDeferAcceptOption) isGettableSocketOption() {} + +func (*TCPDeferAcceptOption) isSettableSocketOption() {} + // TCPMinRTOOption is use by SetSockOpt/GetSockOpt to allow overriding // default MinRTO used by the Stack. type TCPMinRTOOption time.Duration +func (*TCPMinRTOOption) isGettableSocketOption() {} + +func (*TCPMinRTOOption) isSettableSocketOption() {} + +func (*TCPMinRTOOption) isGettableTransportProtocolOption() {} + +func (*TCPMinRTOOption) isSettableTransportProtocolOption() {} + // TCPMaxRTOOption is use by SetSockOpt/GetSockOpt to allow overriding // default MaxRTO used by the Stack. type TCPMaxRTOOption time.Duration +func (*TCPMaxRTOOption) isGettableSocketOption() {} + +func (*TCPMaxRTOOption) isSettableSocketOption() {} + +func (*TCPMaxRTOOption) isGettableTransportProtocolOption() {} + +func (*TCPMaxRTOOption) isSettableTransportProtocolOption() {} + // TCPMaxRetriesOption is used by SetSockOpt/GetSockOpt to set/get the // maximum number of retransmits after which we time out the connection. type TCPMaxRetriesOption uint64 +func (*TCPMaxRetriesOption) isGettableSocketOption() {} + +func (*TCPMaxRetriesOption) isSettableSocketOption() {} + +func (*TCPMaxRetriesOption) isGettableTransportProtocolOption() {} + +func (*TCPMaxRetriesOption) isSettableTransportProtocolOption() {} + // TCPSynRcvdCountThresholdOption is used by SetSockOpt/GetSockOpt to specify // the number of endpoints that can be in SYN-RCVD state before the stack // switches to using SYN cookies. type TCPSynRcvdCountThresholdOption uint64 +func (*TCPSynRcvdCountThresholdOption) isGettableSocketOption() {} + +func (*TCPSynRcvdCountThresholdOption) isSettableSocketOption() {} + +func (*TCPSynRcvdCountThresholdOption) isGettableTransportProtocolOption() {} + +func (*TCPSynRcvdCountThresholdOption) isSettableTransportProtocolOption() {} + // TCPSynRetriesOption is used by SetSockOpt/GetSockOpt to specify stack-wide // default for number of times SYN is retransmitted before aborting a connect. type TCPSynRetriesOption uint8 +func (*TCPSynRetriesOption) isGettableSocketOption() {} + +func (*TCPSynRetriesOption) isSettableSocketOption() {} + +func (*TCPSynRetriesOption) isGettableTransportProtocolOption() {} + +func (*TCPSynRetriesOption) isSettableTransportProtocolOption() {} + // MulticastInterfaceOption is used by SetSockOpt/GetSockOpt to specify a // default interface for multicast. type MulticastInterfaceOption struct { @@ -924,45 +1124,61 @@ type MulticastInterfaceOption struct { InterfaceAddr Address } -// MembershipOption is used by SetSockOpt/GetSockOpt as an argument to -// AddMembershipOption and RemoveMembershipOption. +func (*MulticastInterfaceOption) isGettableSocketOption() {} + +func (*MulticastInterfaceOption) isSettableSocketOption() {} + +// MembershipOption is used to identify a multicast membership on an interface. type MembershipOption struct { NIC NICID InterfaceAddr Address MulticastAddr Address } -// AddMembershipOption is used by SetSockOpt/GetSockOpt to join a multicast -// group identified by the given multicast address, on the interface matching -// the given interface address. +// AddMembershipOption identifies a multicast group to join on some interface. type AddMembershipOption MembershipOption -// RemoveMembershipOption is used by SetSockOpt/GetSockOpt to leave a multicast -// group identified by the given multicast address, on the interface matching -// the given interface address. +func (*AddMembershipOption) isSettableSocketOption() {} + +// RemoveMembershipOption identifies a multicast group to leave on some +// interface. type RemoveMembershipOption MembershipOption +func (*RemoveMembershipOption) isSettableSocketOption() {} + // OutOfBandInlineOption is used by SetSockOpt/GetSockOpt to specify whether // TCP out-of-band data is delivered along with the normal in-band data. type OutOfBandInlineOption int -// DefaultTTLOption is used by stack.(*Stack).NetworkProtocolOption to specify -// a default TTL. -type DefaultTTLOption uint8 +func (*OutOfBandInlineOption) isGettableSocketOption() {} + +func (*OutOfBandInlineOption) isSettableSocketOption() {} // SocketDetachFilterOption is used by SetSockOpt to detach a previously attached // classic BPF filter on a given endpoint. type SocketDetachFilterOption int +func (*SocketDetachFilterOption) isSettableSocketOption() {} + // OriginalDestinationOption is used to get the original destination address // and port of a redirected packet. type OriginalDestinationOption FullAddress +func (*OriginalDestinationOption) isGettableSocketOption() {} + // TCPTimeWaitReuseOption is used stack.(*Stack).TransportProtocolOption to // specify if the stack can reuse the port bound by an endpoint in TIME-WAIT for // new connections when it is safe from protocol viewpoint. type TCPTimeWaitReuseOption uint8 +func (*TCPTimeWaitReuseOption) isGettableSocketOption() {} + +func (*TCPTimeWaitReuseOption) isSettableSocketOption() {} + +func (*TCPTimeWaitReuseOption) isGettableTransportProtocolOption() {} + +func (*TCPTimeWaitReuseOption) isSettableTransportProtocolOption() {} + const ( // TCPTimeWaitReuseDisabled indicates reuse of port bound by endponts in TIME-WAIT cannot // be reused for new connections. @@ -978,6 +1194,19 @@ const ( TCPTimeWaitReuseLoopbackOnly ) +// LingerOption is used by SetSockOpt/GetSockOpt to set/get the +// duration for which a socket lingers before returning from Close. +// +// +stateify savable +type LingerOption struct { + Enabled bool + Timeout time.Duration +} + +func (*LingerOption) isGettableSocketOption() {} + +func (*LingerOption) isSettableSocketOption() {} + // IPPacketInfo is the message structure for IP_PKTINFO. // // +stateify savable @@ -1020,7 +1249,10 @@ func (r Route) String() string { // TransportProtocolNumber is the number of a transport protocol. type TransportProtocolNumber uint32 -// NetworkProtocolNumber is the number of a network protocol. +// NetworkProtocolNumber is the EtherType of a network protocol in an Ethernet +// frame. +// +// See: https://www.iana.org/assignments/ieee-802-numbers/ieee-802-numbers.xhtml type NetworkProtocolNumber uint32 // A StatCounter keeps track of a statistic. @@ -1183,6 +1415,10 @@ type ICMPv6ReceivedPacketStats struct { // Invalid is the total number of ICMPv6 packets received that the // transport layer could not parse. Invalid *StatCounter + + // RouterOnlyPacketsDroppedByHost is the total number of ICMPv6 packets + // dropped due to being router-specific packets. + RouterOnlyPacketsDroppedByHost *StatCounter } // ICMPStats collects ICMP-specific stats (both v4 and v6). @@ -1238,6 +1474,18 @@ type IPStats struct { // MalformedFragmentsReceived is the total number of IP Fragments that were // dropped due to the fragment failing validation checks. MalformedFragmentsReceived *StatCounter + + // IPTablesPreroutingDropped is the total number of IP packets dropped + // in the Prerouting chain. + IPTablesPreroutingDropped *StatCounter + + // IPTablesInputDropped is the total number of IP packets dropped in + // the Input chain. + IPTablesInputDropped *StatCounter + + // IPTablesOutputDropped is the total number of IP packets dropped in + // the Output chain. + IPTablesOutputDropped *StatCounter } // TCPStats collects TCP-specific stats. diff --git a/pkg/tcpip/tests/integration/BUILD b/pkg/tcpip/tests/integration/BUILD index 6d52af98a..06c7a3cd3 100644 --- a/pkg/tcpip/tests/integration/BUILD +++ b/pkg/tcpip/tests/integration/BUILD @@ -5,12 +5,16 @@ package(licenses = ["notice"]) go_test( name = "integration_test", size = "small", - srcs = ["multicast_broadcast_test.go"], + srcs = [ + "loopback_test.go", + "multicast_broadcast_test.go", + ], deps = [ "//pkg/tcpip", "//pkg/tcpip/buffer", "//pkg/tcpip/header", "//pkg/tcpip/link/channel", + "//pkg/tcpip/link/loopback", "//pkg/tcpip/network/ipv4", "//pkg/tcpip/network/ipv6", "//pkg/tcpip/stack", diff --git a/pkg/tcpip/tests/integration/loopback_test.go b/pkg/tcpip/tests/integration/loopback_test.go new file mode 100644 index 000000000..fecbe7ba7 --- /dev/null +++ b/pkg/tcpip/tests/integration/loopback_test.go @@ -0,0 +1,250 @@ +// Copyright 2020 The gVisor Authors. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package integration_test + +import ( + "testing" + + "github.com/google/go-cmp/cmp" + "gvisor.dev/gvisor/pkg/tcpip" + "gvisor.dev/gvisor/pkg/tcpip/buffer" + "gvisor.dev/gvisor/pkg/tcpip/header" + "gvisor.dev/gvisor/pkg/tcpip/link/loopback" + "gvisor.dev/gvisor/pkg/tcpip/network/ipv4" + "gvisor.dev/gvisor/pkg/tcpip/network/ipv6" + "gvisor.dev/gvisor/pkg/tcpip/stack" + "gvisor.dev/gvisor/pkg/tcpip/transport/udp" + "gvisor.dev/gvisor/pkg/waiter" +) + +// TestLoopbackAcceptAllInSubnet tests that a loopback interface considers +// itself bound to all addresses in the subnet of an assigned address. +func TestLoopbackAcceptAllInSubnet(t *testing.T) { + const ( + nicID = 1 + localPort = 80 + ) + + data := []byte{1, 2, 3, 4} + + ipv4ProtocolAddress := tcpip.ProtocolAddress{ + Protocol: header.IPv4ProtocolNumber, + AddressWithPrefix: ipv4Addr, + } + ipv4Bytes := []byte(ipv4Addr.Address) + ipv4Bytes[len(ipv4Bytes)-1]++ + otherIPv4Address := tcpip.Address(ipv4Bytes) + + ipv6ProtocolAddress := tcpip.ProtocolAddress{ + Protocol: header.IPv6ProtocolNumber, + AddressWithPrefix: ipv6Addr, + } + ipv6Bytes := []byte(ipv6Addr.Address) + ipv6Bytes[len(ipv6Bytes)-1]++ + otherIPv6Address := tcpip.Address(ipv6Bytes) + + tests := []struct { + name string + addAddress tcpip.ProtocolAddress + bindAddr tcpip.Address + dstAddr tcpip.Address + expectRx bool + }{ + { + name: "IPv4 bind to wildcard and send to assigned address", + addAddress: ipv4ProtocolAddress, + dstAddr: ipv4Addr.Address, + expectRx: true, + }, + { + name: "IPv4 bind to wildcard and send to other subnet-local address", + addAddress: ipv4ProtocolAddress, + dstAddr: otherIPv4Address, + expectRx: true, + }, + { + name: "IPv4 bind to wildcard send to other address", + addAddress: ipv4ProtocolAddress, + dstAddr: remoteIPv4Addr, + expectRx: false, + }, + { + name: "IPv4 bind to other subnet-local address and send to assigned address", + addAddress: ipv4ProtocolAddress, + bindAddr: otherIPv4Address, + dstAddr: ipv4Addr.Address, + expectRx: false, + }, + { + name: "IPv4 bind and send to other subnet-local address", + addAddress: ipv4ProtocolAddress, + bindAddr: otherIPv4Address, + dstAddr: otherIPv4Address, + expectRx: true, + }, + { + name: "IPv4 bind to assigned address and send to other subnet-local address", + addAddress: ipv4ProtocolAddress, + bindAddr: ipv4Addr.Address, + dstAddr: otherIPv4Address, + expectRx: false, + }, + + { + name: "IPv6 bind and send to assigned address", + addAddress: ipv6ProtocolAddress, + bindAddr: ipv6Addr.Address, + dstAddr: ipv6Addr.Address, + expectRx: true, + }, + { + name: "IPv6 bind to wildcard and send to other subnet-local address", + addAddress: ipv6ProtocolAddress, + dstAddr: otherIPv6Address, + expectRx: false, + }, + } + + for _, test := range tests { + t.Run(test.name, func(t *testing.T) { + s := stack.New(stack.Options{ + NetworkProtocols: []stack.NetworkProtocol{ipv4.NewProtocol(), ipv6.NewProtocol()}, + TransportProtocols: []stack.TransportProtocol{udp.NewProtocol()}, + }) + if err := s.CreateNIC(nicID, loopback.New()); err != nil { + t.Fatalf("CreateNIC(%d, _): %s", nicID, err) + } + if err := s.AddProtocolAddress(nicID, test.addAddress); err != nil { + t.Fatalf("AddProtocolAddress(%d, %+v): %s", nicID, test.addAddress, err) + } + s.SetRouteTable([]tcpip.Route{ + tcpip.Route{ + Destination: header.IPv4EmptySubnet, + NIC: nicID, + }, + tcpip.Route{ + Destination: header.IPv6EmptySubnet, + NIC: nicID, + }, + }) + + wq := waiter.Queue{} + rep, err := s.NewEndpoint(udp.ProtocolNumber, test.addAddress.Protocol, &wq) + if err != nil { + t.Fatalf("NewEndpoint(%d, %d, _): %s", udp.ProtocolNumber, test.addAddress.Protocol, err) + } + defer rep.Close() + + bindAddr := tcpip.FullAddress{Addr: test.bindAddr, Port: localPort} + if err := rep.Bind(bindAddr); err != nil { + t.Fatalf("rep.Bind(%+v): %s", bindAddr, err) + } + + sep, err := s.NewEndpoint(udp.ProtocolNumber, test.addAddress.Protocol, &wq) + if err != nil { + t.Fatalf("NewEndpoint(%d, %d, _): %s", udp.ProtocolNumber, test.addAddress.Protocol, err) + } + defer sep.Close() + + wopts := tcpip.WriteOptions{ + To: &tcpip.FullAddress{ + Addr: test.dstAddr, + Port: localPort, + }, + } + n, _, err := sep.Write(tcpip.SlicePayload(data), wopts) + if err != nil { + t.Fatalf("sep.Write(_, _): %s", err) + } + if want := int64(len(data)); n != want { + t.Fatalf("got sep.Write(_, _) = (%d, _, nil), want = (%d, _, nil)", n, want) + } + + if gotPayload, _, err := rep.Read(nil); test.expectRx { + if err != nil { + t.Fatalf("reep.Read(nil): %s", err) + } + if diff := cmp.Diff(buffer.View(data), gotPayload); diff != "" { + t.Errorf("got UDP payload mismatch (-want +got):\n%s", diff) + } + } else { + if err != tcpip.ErrWouldBlock { + t.Fatalf("got rep.Read(nil) = (%x, _, %s), want = (_, _, %s)", gotPayload, err, tcpip.ErrWouldBlock) + } + } + }) + } +} + +// TestLoopbackSubnetLifetimeBoundToAddr tests that the lifetime of an address +// in a loopback interface's associated subnet is bound to the permanently bound +// address. +func TestLoopbackSubnetLifetimeBoundToAddr(t *testing.T) { + const nicID = 1 + + protoAddr := tcpip.ProtocolAddress{ + Protocol: ipv4.ProtocolNumber, + AddressWithPrefix: ipv4Addr, + } + addrBytes := []byte(ipv4Addr.Address) + addrBytes[len(addrBytes)-1]++ + otherAddr := tcpip.Address(addrBytes) + + s := stack.New(stack.Options{ + NetworkProtocols: []stack.NetworkProtocol{ipv4.NewProtocol()}, + }) + if err := s.CreateNIC(nicID, loopback.New()); err != nil { + t.Fatalf("s.CreateNIC(%d, _): %s", nicID, err) + } + if err := s.AddProtocolAddress(nicID, protoAddr); err != nil { + t.Fatalf("s.AddProtocolAddress(%d, %#v): %s", nicID, protoAddr, err) + } + s.SetRouteTable([]tcpip.Route{ + tcpip.Route{ + Destination: header.IPv4EmptySubnet, + NIC: nicID, + }, + }) + + r, err := s.FindRoute(nicID, otherAddr, remoteIPv4Addr, ipv4.ProtocolNumber, false /* multicastLoop */) + if err != nil { + t.Fatalf("s.FindRoute(%d, %s, %s, %d, false): %s", nicID, otherAddr, remoteIPv4Addr, ipv4.ProtocolNumber, err) + } + defer r.Release() + + params := stack.NetworkHeaderParams{ + Protocol: 111, + TTL: 64, + TOS: stack.DefaultTOS, + } + data := buffer.View([]byte{1, 2, 3, 4}) + if err := r.WritePacket(nil /* gso */, params, stack.NewPacketBuffer(stack.PacketBufferOptions{ + ReserveHeaderBytes: int(r.MaxHeaderLength()), + Data: data.ToVectorisedView(), + })); err != nil { + t.Fatalf("r.WritePacket(nil, %#v, _): %s", params, err) + } + + // Removing the address should make the endpoint invalid. + if err := s.RemoveAddress(nicID, protoAddr.AddressWithPrefix.Address); err != nil { + t.Fatalf("s.RemoveAddress(%d, %s): %s", nicID, protoAddr.AddressWithPrefix.Address, err) + } + if err := r.WritePacket(nil /* gso */, params, stack.NewPacketBuffer(stack.PacketBufferOptions{ + ReserveHeaderBytes: int(r.MaxHeaderLength()), + Data: data.ToVectorisedView(), + })); err != tcpip.ErrInvalidEndpointState { + t.Fatalf("got r.WritePacket(nil, %#v, _) = %s, want = %s", params, err, tcpip.ErrInvalidEndpointState) + } +} diff --git a/pkg/tcpip/tests/integration/multicast_broadcast_test.go b/pkg/tcpip/tests/integration/multicast_broadcast_test.go index 9f0dd4d6d..659acbc7a 100644 --- a/pkg/tcpip/tests/integration/multicast_broadcast_test.go +++ b/pkg/tcpip/tests/integration/multicast_broadcast_test.go @@ -23,6 +23,7 @@ import ( "gvisor.dev/gvisor/pkg/tcpip/buffer" "gvisor.dev/gvisor/pkg/tcpip/header" "gvisor.dev/gvisor/pkg/tcpip/link/channel" + "gvisor.dev/gvisor/pkg/tcpip/link/loopback" "gvisor.dev/gvisor/pkg/tcpip/network/ipv4" "gvisor.dev/gvisor/pkg/tcpip/network/ipv6" "gvisor.dev/gvisor/pkg/tcpip/stack" @@ -430,7 +431,126 @@ func TestIncomingMulticastAndBroadcast(t *testing.T) { } } else { if err != tcpip.ErrWouldBlock { - t.Fatalf("got Read(nil) = (%x, _, %v), want = (_, _, %s)", gotPayload, err, tcpip.ErrWouldBlock) + t.Fatalf("got Read(nil) = (%x, _, %s), want = (_, _, %s)", gotPayload, err, tcpip.ErrWouldBlock) + } + } + }) + } +} + +// TestReuseAddrAndBroadcast makes sure broadcast packets are received by all +// interested endpoints. +func TestReuseAddrAndBroadcast(t *testing.T) { + const ( + nicID = 1 + localPort = 9000 + loopbackBroadcast = tcpip.Address("\x7f\xff\xff\xff") + ) + + data := tcpip.SlicePayload([]byte{1, 2, 3, 4}) + + tests := []struct { + name string + broadcastAddr tcpip.Address + }{ + { + name: "Subnet directed broadcast", + broadcastAddr: loopbackBroadcast, + }, + { + name: "IPv4 broadcast", + broadcastAddr: header.IPv4Broadcast, + }, + } + + for _, test := range tests { + t.Run(test.name, func(t *testing.T) { + s := stack.New(stack.Options{ + NetworkProtocols: []stack.NetworkProtocol{ipv4.NewProtocol(), ipv6.NewProtocol()}, + TransportProtocols: []stack.TransportProtocol{udp.NewProtocol()}, + }) + if err := s.CreateNIC(nicID, loopback.New()); err != nil { + t.Fatalf("CreateNIC(%d, _): %s", nicID, err) + } + protoAddr := tcpip.ProtocolAddress{ + Protocol: header.IPv4ProtocolNumber, + AddressWithPrefix: tcpip.AddressWithPrefix{ + Address: "\x7f\x00\x00\x01", + PrefixLen: 8, + }, + } + if err := s.AddProtocolAddress(nicID, protoAddr); err != nil { + t.Fatalf("AddProtocolAddress(%d, %+v): %s", nicID, protoAddr, err) + } + + s.SetRouteTable([]tcpip.Route{ + tcpip.Route{ + // We use the empty subnet instead of just the loopback subnet so we + // also have a route to the IPv4 Broadcast address. + Destination: header.IPv4EmptySubnet, + NIC: nicID, + }, + }) + + // We create endpoints that bind to both the wildcard address and the + // broadcast address to make sure both of these types of "broadcast + // interested" endpoints receive broadcast packets. + wq := waiter.Queue{} + var eps []tcpip.Endpoint + for _, bindWildcard := range []bool{false, true} { + // Create multiple endpoints for each type of "broadcast interested" + // endpoint so we can test that all endpoints receive the broadcast + // packet. + for i := 0; i < 2; i++ { + ep, err := s.NewEndpoint(udp.ProtocolNumber, ipv4.ProtocolNumber, &wq) + if err != nil { + t.Fatalf("(eps[%d]) NewEndpoint(%d, %d, _): %s", len(eps), udp.ProtocolNumber, ipv4.ProtocolNumber, err) + } + defer ep.Close() + + if err := ep.SetSockOptBool(tcpip.ReuseAddressOption, true); err != nil { + t.Fatalf("eps[%d].SetSockOptBool(tcpip.ReuseAddressOption, true): %s", len(eps), err) + } + + if err := ep.SetSockOptBool(tcpip.BroadcastOption, true); err != nil { + t.Fatalf("eps[%d].SetSockOptBool(tcpip.BroadcastOption, true): %s", len(eps), err) + } + + bindAddr := tcpip.FullAddress{Port: localPort} + if bindWildcard { + if err := ep.Bind(bindAddr); err != nil { + t.Fatalf("eps[%d].Bind(%+v): %s", len(eps), bindAddr, err) + } + } else { + bindAddr.Addr = test.broadcastAddr + if err := ep.Bind(bindAddr); err != nil { + t.Fatalf("eps[%d].Bind(%+v): %s", len(eps), bindAddr, err) + } + } + + eps = append(eps, ep) + } + } + + for i, wep := range eps { + writeOpts := tcpip.WriteOptions{ + To: &tcpip.FullAddress{ + Addr: test.broadcastAddr, + Port: localPort, + }, + } + if n, _, err := wep.Write(data, writeOpts); err != nil { + t.Fatalf("eps[%d].Write(_, _): %s", i, err) + } else if want := int64(len(data)); n != want { + t.Fatalf("got eps[%d].Write(_, _) = (%d, nil, nil), want = (%d, nil, nil)", i, n, want) + } + + for j, rep := range eps { + if gotPayload, _, err := rep.Read(nil); err != nil { + t.Errorf("(eps[%d] write) eps[%d].Read(nil): %s", i, j, err) + } else if diff := cmp.Diff(buffer.View(data), gotPayload); diff != "" { + t.Errorf("(eps[%d] write) got UDP payload from eps[%d] mismatch (-want +got):\n%s", i, j, diff) + } } } }) diff --git a/pkg/tcpip/time_unsafe.go b/pkg/tcpip/time_unsafe.go index f32d58091..606363567 100644 --- a/pkg/tcpip/time_unsafe.go +++ b/pkg/tcpip/time_unsafe.go @@ -13,7 +13,7 @@ // limitations under the License. // +build go1.9 -// +build !go1.16 +// +build !go1.17 // Check go:linkname function signatures when updating Go version. diff --git a/pkg/tcpip/transport/icmp/endpoint.go b/pkg/tcpip/transport/icmp/endpoint.go index bd6f49eb8..31116309e 100644 --- a/pkg/tcpip/transport/icmp/endpoint.go +++ b/pkg/tcpip/transport/icmp/endpoint.go @@ -74,6 +74,8 @@ type endpoint struct { route stack.Route `state:"manual"` ttl uint8 stats tcpip.TransportEndpointStats `state:"nosave"` + // linger is used for SO_LINGER socket option. + linger tcpip.LingerOption // owner is used to get uid and gid of the packet. owner tcpip.PacketOwner @@ -343,10 +345,15 @@ func (e *endpoint) Peek([][]byte) (int64, tcpip.ControlMessages, *tcpip.Error) { } // SetSockOpt sets a socket option. -func (e *endpoint) SetSockOpt(opt interface{}) *tcpip.Error { - switch opt.(type) { - case tcpip.SocketDetachFilterOption: +func (e *endpoint) SetSockOpt(opt tcpip.SettableSocketOption) *tcpip.Error { + switch v := opt.(type) { + case *tcpip.SocketDetachFilterOption: return nil + + case *tcpip.LingerOption: + e.mu.Lock() + e.linger = *v + e.mu.Unlock() } return nil } @@ -415,9 +422,12 @@ func (e *endpoint) GetSockOptInt(opt tcpip.SockOptInt) (int, *tcpip.Error) { } // GetSockOpt implements tcpip.Endpoint.GetSockOpt. -func (e *endpoint) GetSockOpt(opt interface{}) *tcpip.Error { - switch opt.(type) { - case tcpip.ErrorOption: +func (e *endpoint) GetSockOpt(opt tcpip.GettableSocketOption) *tcpip.Error { + switch o := opt.(type) { + case *tcpip.LingerOption: + e.mu.Lock() + *o = e.linger + e.mu.Unlock() return nil default: @@ -603,7 +613,7 @@ func (*endpoint) Listen(int) *tcpip.Error { } // Accept is not supported by UDP, it just fails. -func (*endpoint) Accept() (tcpip.Endpoint, *waiter.Queue, *tcpip.Error) { +func (*endpoint) Accept(*tcpip.FullAddress) (tcpip.Endpoint, *waiter.Queue, *tcpip.Error) { return nil, nil, tcpip.ErrNotSupported } @@ -836,3 +846,8 @@ func (e *endpoint) Stats() tcpip.EndpointStats { // Wait implements stack.TransportEndpoint.Wait. func (*endpoint) Wait() {} + +// LastError implements tcpip.Endpoint.LastError. +func (*endpoint) LastError() *tcpip.Error { + return nil +} diff --git a/pkg/tcpip/transport/icmp/protocol.go b/pkg/tcpip/transport/icmp/protocol.go index 74ef6541e..bb11e4e83 100644 --- a/pkg/tcpip/transport/icmp/protocol.go +++ b/pkg/tcpip/transport/icmp/protocol.go @@ -109,12 +109,12 @@ func (*protocol) HandleUnknownDestinationPacket(*stack.Route, stack.TransportEnd } // SetOption implements stack.TransportProtocol.SetOption. -func (*protocol) SetOption(option interface{}) *tcpip.Error { +func (*protocol) SetOption(tcpip.SettableTransportProtocolOption) *tcpip.Error { return tcpip.ErrUnknownProtocolOption } // Option implements stack.TransportProtocol.Option. -func (*protocol) Option(option interface{}) *tcpip.Error { +func (*protocol) Option(tcpip.GettableTransportProtocolOption) *tcpip.Error { return tcpip.ErrUnknownProtocolOption } diff --git a/pkg/tcpip/transport/packet/endpoint.go b/pkg/tcpip/transport/packet/endpoint.go index 1b03ad6bb..072601d2d 100644 --- a/pkg/tcpip/transport/packet/endpoint.go +++ b/pkg/tcpip/transport/packet/endpoint.go @@ -83,6 +83,8 @@ type endpoint struct { stats tcpip.TransportEndpointStats `state:"nosave"` bound bool boundNIC tcpip.NICID + // linger is used for SO_LINGER socket option. + linger tcpip.LingerOption // lastErrorMu protects lastError. lastErrorMu sync.Mutex `state:"nosave"` @@ -192,13 +194,13 @@ func (ep *endpoint) Read(addr *tcpip.FullAddress) (buffer.View, tcpip.ControlMes return ep.ReadPacket(addr, nil) } -func (ep *endpoint) Write(p tcpip.Payloader, opts tcpip.WriteOptions) (int64, <-chan struct{}, *tcpip.Error) { +func (*endpoint) Write(p tcpip.Payloader, opts tcpip.WriteOptions) (int64, <-chan struct{}, *tcpip.Error) { // TODO(gvisor.dev/issue/173): Implement. return 0, nil, tcpip.ErrInvalidOptionValue } // Peek implements tcpip.Endpoint.Peek. -func (ep *endpoint) Peek([][]byte) (int64, tcpip.ControlMessages, *tcpip.Error) { +func (*endpoint) Peek([][]byte) (int64, tcpip.ControlMessages, *tcpip.Error) { return 0, tcpip.ControlMessages{}, nil } @@ -210,25 +212,25 @@ func (*endpoint) Disconnect() *tcpip.Error { // Connect implements tcpip.Endpoint.Connect. Packet sockets cannot be // connected, and this function always returnes tcpip.ErrNotSupported. -func (ep *endpoint) Connect(addr tcpip.FullAddress) *tcpip.Error { +func (*endpoint) Connect(addr tcpip.FullAddress) *tcpip.Error { return tcpip.ErrNotSupported } // Shutdown implements tcpip.Endpoint.Shutdown. Packet sockets cannot be used // with Shutdown, and this function always returns tcpip.ErrNotSupported. -func (ep *endpoint) Shutdown(flags tcpip.ShutdownFlags) *tcpip.Error { +func (*endpoint) Shutdown(flags tcpip.ShutdownFlags) *tcpip.Error { return tcpip.ErrNotSupported } // Listen implements tcpip.Endpoint.Listen. Packet sockets cannot be used with // Listen, and this function always returns tcpip.ErrNotSupported. -func (ep *endpoint) Listen(backlog int) *tcpip.Error { +func (*endpoint) Listen(backlog int) *tcpip.Error { return tcpip.ErrNotSupported } // Accept implements tcpip.Endpoint.Accept. Packet sockets cannot be used with // Accept, and this function always returns tcpip.ErrNotSupported. -func (ep *endpoint) Accept() (tcpip.Endpoint, *waiter.Queue, *tcpip.Error) { +func (*endpoint) Accept(*tcpip.FullAddress) (tcpip.Endpoint, *waiter.Queue, *tcpip.Error) { return nil, nil, tcpip.ErrNotSupported } @@ -267,12 +269,12 @@ func (ep *endpoint) Bind(addr tcpip.FullAddress) *tcpip.Error { } // GetLocalAddress implements tcpip.Endpoint.GetLocalAddress. -func (ep *endpoint) GetLocalAddress() (tcpip.FullAddress, *tcpip.Error) { +func (*endpoint) GetLocalAddress() (tcpip.FullAddress, *tcpip.Error) { return tcpip.FullAddress{}, tcpip.ErrNotSupported } // GetRemoteAddress implements tcpip.Endpoint.GetRemoteAddress. -func (ep *endpoint) GetRemoteAddress() (tcpip.FullAddress, *tcpip.Error) { +func (*endpoint) GetRemoteAddress() (tcpip.FullAddress, *tcpip.Error) { // Even a connected socket doesn't return a remote address. return tcpip.FullAddress{}, tcpip.ErrNotConnected } @@ -297,9 +299,15 @@ func (ep *endpoint) Readiness(mask waiter.EventMask) waiter.EventMask { // SetSockOpt implements tcpip.Endpoint.SetSockOpt. Packet sockets cannot be // used with SetSockOpt, and this function always returns // tcpip.ErrNotSupported. -func (ep *endpoint) SetSockOpt(opt interface{}) *tcpip.Error { - switch opt.(type) { - case tcpip.SocketDetachFilterOption: +func (ep *endpoint) SetSockOpt(opt tcpip.SettableSocketOption) *tcpip.Error { + switch v := opt.(type) { + case *tcpip.SocketDetachFilterOption: + return nil + + case *tcpip.LingerOption: + ep.mu.Lock() + ep.linger = *v + ep.mu.Unlock() return nil default: @@ -356,7 +364,7 @@ func (ep *endpoint) SetSockOptInt(opt tcpip.SockOptInt, v int) *tcpip.Error { } } -func (ep *endpoint) takeLastError() *tcpip.Error { +func (ep *endpoint) LastError() *tcpip.Error { ep.lastErrorMu.Lock() defer ep.lastErrorMu.Unlock() @@ -366,16 +374,21 @@ func (ep *endpoint) takeLastError() *tcpip.Error { } // GetSockOpt implements tcpip.Endpoint.GetSockOpt. -func (ep *endpoint) GetSockOpt(opt interface{}) *tcpip.Error { - switch opt.(type) { - case tcpip.ErrorOption: - return ep.takeLastError() +func (ep *endpoint) GetSockOpt(opt tcpip.GettableSocketOption) *tcpip.Error { + switch o := opt.(type) { + case *tcpip.LingerOption: + ep.mu.Lock() + *o = ep.linger + ep.mu.Unlock() + return nil + + default: + return tcpip.ErrNotSupported } - return tcpip.ErrNotSupported } // GetSockOptBool implements tcpip.Endpoint.GetSockOptBool. -func (ep *endpoint) GetSockOptBool(opt tcpip.SockOptBool) (bool, *tcpip.Error) { +func (*endpoint) GetSockOptBool(opt tcpip.SockOptBool) (bool, *tcpip.Error) { return false, tcpip.ErrNotSupported } @@ -512,7 +525,7 @@ func (ep *endpoint) HandlePacket(nicID tcpip.NICID, localAddr tcpip.LinkAddress, } // State implements socket.Socket.State. -func (ep *endpoint) State() uint32 { +func (*endpoint) State() uint32 { return 0 } diff --git a/pkg/tcpip/transport/raw/endpoint.go b/pkg/tcpip/transport/raw/endpoint.go index edc2b5b61..e37c00523 100644 --- a/pkg/tcpip/transport/raw/endpoint.go +++ b/pkg/tcpip/transport/raw/endpoint.go @@ -84,6 +84,8 @@ type endpoint struct { // Connect(), and is valid only when conneted is true. route stack.Route `state:"manual"` stats tcpip.TransportEndpointStats `state:"nosave"` + // linger is used for SO_LINGER socket option. + linger tcpip.LingerOption // owner is used to get uid and gid of the packet. owner tcpip.PacketOwner @@ -446,12 +448,12 @@ func (e *endpoint) Shutdown(flags tcpip.ShutdownFlags) *tcpip.Error { } // Listen implements tcpip.Endpoint.Listen. -func (e *endpoint) Listen(backlog int) *tcpip.Error { +func (*endpoint) Listen(backlog int) *tcpip.Error { return tcpip.ErrNotSupported } // Accept implements tcpip.Endpoint.Accept. -func (e *endpoint) Accept() (tcpip.Endpoint, *waiter.Queue, *tcpip.Error) { +func (*endpoint) Accept(*tcpip.FullAddress) (tcpip.Endpoint, *waiter.Queue, *tcpip.Error) { return nil, nil, tcpip.ErrNotSupported } @@ -482,12 +484,12 @@ func (e *endpoint) Bind(addr tcpip.FullAddress) *tcpip.Error { } // GetLocalAddress implements tcpip.Endpoint.GetLocalAddress. -func (e *endpoint) GetLocalAddress() (tcpip.FullAddress, *tcpip.Error) { +func (*endpoint) GetLocalAddress() (tcpip.FullAddress, *tcpip.Error) { return tcpip.FullAddress{}, tcpip.ErrNotSupported } // GetRemoteAddress implements tcpip.Endpoint.GetRemoteAddress. -func (e *endpoint) GetRemoteAddress() (tcpip.FullAddress, *tcpip.Error) { +func (*endpoint) GetRemoteAddress() (tcpip.FullAddress, *tcpip.Error) { // Even a connected socket doesn't return a remote address. return tcpip.FullAddress{}, tcpip.ErrNotConnected } @@ -510,9 +512,15 @@ func (e *endpoint) Readiness(mask waiter.EventMask) waiter.EventMask { } // SetSockOpt implements tcpip.Endpoint.SetSockOpt. -func (e *endpoint) SetSockOpt(opt interface{}) *tcpip.Error { - switch opt.(type) { - case tcpip.SocketDetachFilterOption: +func (e *endpoint) SetSockOpt(opt tcpip.SettableSocketOption) *tcpip.Error { + switch v := opt.(type) { + case *tcpip.SocketDetachFilterOption: + return nil + + case *tcpip.LingerOption: + e.mu.Lock() + e.linger = *v + e.mu.Unlock() return nil default: @@ -577,9 +585,12 @@ func (e *endpoint) SetSockOptInt(opt tcpip.SockOptInt, v int) *tcpip.Error { } // GetSockOpt implements tcpip.Endpoint.GetSockOpt. -func (e *endpoint) GetSockOpt(opt interface{}) *tcpip.Error { - switch opt.(type) { - case tcpip.ErrorOption: +func (e *endpoint) GetSockOpt(opt tcpip.GettableSocketOption) *tcpip.Error { + switch o := opt.(type) { + case *tcpip.LingerOption: + e.mu.Lock() + *o = e.linger + e.mu.Unlock() return nil default: @@ -739,3 +750,7 @@ func (e *endpoint) Stats() tcpip.EndpointStats { // Wait implements stack.TransportEndpoint.Wait. func (*endpoint) Wait() {} + +func (*endpoint) LastError() *tcpip.Error { + return nil +} diff --git a/pkg/tcpip/transport/tcp/BUILD b/pkg/tcpip/transport/tcp/BUILD index 234fb95ce..4778e7b1c 100644 --- a/pkg/tcpip/transport/tcp/BUILD +++ b/pkg/tcpip/transport/tcp/BUILD @@ -69,6 +69,7 @@ go_library( "//pkg/tcpip/buffer", "//pkg/tcpip/hash/jenkins", "//pkg/tcpip/header", + "//pkg/tcpip/header/parse", "//pkg/tcpip/ports", "//pkg/tcpip/seqnum", "//pkg/tcpip/stack", diff --git a/pkg/tcpip/transport/tcp/connect.go b/pkg/tcpip/transport/tcp/connect.go index 290172ac9..09d53d158 100644 --- a/pkg/tcpip/transport/tcp/connect.go +++ b/pkg/tcpip/transport/tcp/connect.go @@ -491,7 +491,7 @@ func (h *handshake) resolveRoute() *tcpip.Error { h.ep.mu.Lock() } if n¬ifyError != 0 { - return h.ep.takeLastError() + return h.ep.LastError() } } @@ -522,7 +522,7 @@ func (h *handshake) execute() *tcpip.Error { s.AddWaker(&h.ep.newSegmentWaker, wakerForNewSegment) defer s.Done() - var sackEnabled SACKEnabled + var sackEnabled tcpip.TCPSACKEnabled if err := h.ep.stack.TransportProtocolOption(ProtocolNumber, &sackEnabled); err != nil { // If stack returned an error when checking for SACKEnabled // status then just default to switching off SACK negotiation. @@ -620,7 +620,7 @@ func (h *handshake) execute() *tcpip.Error { h.ep.mu.Lock() } if n¬ifyError != 0 { - return h.ep.takeLastError() + return h.ep.LastError() } case wakerForNewSegment: diff --git a/pkg/tcpip/transport/tcp/dual_stack_test.go b/pkg/tcpip/transport/tcp/dual_stack_test.go index 804e95aea..94207c141 100644 --- a/pkg/tcpip/transport/tcp/dual_stack_test.go +++ b/pkg/tcpip/transport/tcp/dual_stack_test.go @@ -86,8 +86,7 @@ func testV4Connect(t *testing.T, c *context.Context, checkers ...checker.Network // Wait for connection to be established. select { case <-ch: - err = c.EP.GetSockOpt(tcpip.ErrorOption{}) - if err != nil { + if err := c.EP.LastError(); err != nil { t.Fatalf("Unexpected error when connecting: %v", err) } case <-time.After(1 * time.Second): @@ -194,8 +193,7 @@ func testV6Connect(t *testing.T, c *context.Context, checkers ...checker.Network // Wait for connection to be established. select { case <-ch: - err = c.EP.GetSockOpt(tcpip.ErrorOption{}) - if err != nil { + if err := c.EP.LastError(); err != nil { t.Fatalf("Unexpected error when connecting: %v", err) } case <-time.After(1 * time.Second): @@ -373,12 +371,12 @@ func testV4Accept(t *testing.T, c *context.Context) { c.WQ.EventRegister(&we, waiter.EventIn) defer c.WQ.EventUnregister(&we) - nep, _, err := c.EP.Accept() + nep, _, err := c.EP.Accept(nil) if err == tcpip.ErrWouldBlock { // Wait for connection to be established. select { case <-ch: - nep, _, err = c.EP.Accept() + nep, _, err = c.EP.Accept(nil) if err != nil { t.Fatalf("Accept failed: %v", err) } @@ -512,13 +510,13 @@ func TestV6AcceptOnV6(t *testing.T) { we, ch := waiter.NewChannelEntry(nil) c.WQ.EventRegister(&we, waiter.EventIn) defer c.WQ.EventUnregister(&we) - - nep, _, err := c.EP.Accept() + var addr tcpip.FullAddress + nep, _, err := c.EP.Accept(&addr) if err == tcpip.ErrWouldBlock { // Wait for connection to be established. select { case <-ch: - nep, _, err = c.EP.Accept() + nep, _, err = c.EP.Accept(&addr) if err != nil { t.Fatalf("Accept failed: %v", err) } @@ -528,20 +526,14 @@ func TestV6AcceptOnV6(t *testing.T) { } } + if addr.Addr != context.TestV6Addr { + t.Errorf("Unexpected remote address: got %s, want %s", addr.Addr, context.TestV6Addr) + } + // Make sure we can still query the v6 only status of the new endpoint, // that is, that it is in fact a v6 socket. if _, err := nep.GetSockOptBool(tcpip.V6OnlyOption); err != nil { - t.Fatalf("GetSockOpt failed failed: %v", err) - } - - // Check the peer address. - addr, err := nep.GetRemoteAddress() - if err != nil { - t.Fatalf("GetRemoteAddress failed failed: %v", err) - } - - if addr.Addr != context.TestV6Addr { - t.Fatalf("Unexpected remote address: got %v, want %v", addr.Addr, context.TestV6Addr) + t.Errorf("GetSockOptBool(tcpip.V6OnlyOption) failed: %s", err) } } @@ -568,8 +560,9 @@ func TestV4AcceptOnV4(t *testing.T) { func testV4ListenClose(t *testing.T, c *context.Context) { // Set the SynRcvd threshold to zero to force a syn cookie based accept // to happen. - if err := c.Stack().SetTransportProtocolOption(tcp.ProtocolNumber, tcpip.TCPSynRcvdCountThresholdOption(0)); err != nil { - t.Fatalf("setting TCPSynRcvdCountThresholdOption failed: %s", err) + var opt tcpip.TCPSynRcvdCountThresholdOption + if err := c.Stack().SetTransportProtocolOption(tcp.ProtocolNumber, &opt); err != nil { + t.Fatalf("setting TCPSynRcvdCountThresholdOption(%d, &%T(%d)): %s", tcp.ProtocolNumber, opt, opt, err) } const n = uint16(32) @@ -612,12 +605,12 @@ func testV4ListenClose(t *testing.T, c *context.Context) { we, ch := waiter.NewChannelEntry(nil) c.WQ.EventRegister(&we, waiter.EventIn) defer c.WQ.EventUnregister(&we) - nep, _, err := c.EP.Accept() + nep, _, err := c.EP.Accept(nil) if err == tcpip.ErrWouldBlock { // Wait for connection to be established. select { case <-ch: - nep, _, err = c.EP.Accept() + nep, _, err = c.EP.Accept(nil) if err != nil { t.Fatalf("Accept failed: %v", err) } diff --git a/pkg/tcpip/transport/tcp/endpoint.go b/pkg/tcpip/transport/tcp/endpoint.go index 1ccedebcc..120483838 100644 --- a/pkg/tcpip/transport/tcp/endpoint.go +++ b/pkg/tcpip/transport/tcp/endpoint.go @@ -654,6 +654,9 @@ type endpoint struct { // owner is used to get uid and gid of the packet. owner tcpip.PacketOwner + + // linger is used for SO_LINGER socket option. + linger tcpip.LingerOption } // UniqueID implements stack.TransportEndpoint.UniqueID. @@ -849,12 +852,12 @@ func newEndpoint(s *stack.Stack, netProto tcpip.NetworkProtocolNumber, waiterQue maxSynRetries: DefaultSynRetries, } - var ss SendBufferSizeOption + var ss tcpip.TCPSendBufferSizeRangeOption if err := s.TransportProtocolOption(ProtocolNumber, &ss); err == nil { e.sndBufSize = ss.Default } - var rs ReceiveBufferSizeOption + var rs tcpip.TCPReceiveBufferSizeRangeOption if err := s.TransportProtocolOption(ProtocolNumber, &rs); err == nil { e.rcvBufSize = rs.Default } @@ -864,12 +867,12 @@ func newEndpoint(s *stack.Stack, netProto tcpip.NetworkProtocolNumber, waiterQue e.cc = cs } - var mrb tcpip.ModerateReceiveBufferOption + var mrb tcpip.TCPModerateReceiveBufferOption if err := s.TransportProtocolOption(ProtocolNumber, &mrb); err == nil { e.rcvAutoParams.disabled = !bool(mrb) } - var de DelayEnabled + var de tcpip.TCPDelayEnabled if err := s.TransportProtocolOption(ProtocolNumber, &de); err == nil && de { e.SetSockOptBool(tcpip.DelayOption, true) } @@ -1007,6 +1010,26 @@ func (e *endpoint) Close() { return } + if e.linger.Enabled && e.linger.Timeout == 0 { + s := e.EndpointState() + isResetState := s == StateEstablished || s == StateCloseWait || s == StateFinWait1 || s == StateFinWait2 || s == StateSynRecv + if isResetState { + // Close the endpoint without doing full shutdown and + // send a RST. + e.resetConnectionLocked(tcpip.ErrConnectionAborted) + e.closeNoShutdownLocked() + + // Wake up worker to close the endpoint. + switch s { + case StateSynRecv: + e.notifyProtocolGoroutine(notifyClose) + default: + e.notifyProtocolGoroutine(notifyTickleWorker) + } + return + } + } + // Issue a shutdown so that the peer knows we won't send any more data // if we're connected, or stop accepting if we're listening. e.shutdownLocked(tcpip.ShutdownWrite | tcpip.ShutdownRead) @@ -1211,7 +1234,7 @@ func (e *endpoint) SetOwner(owner tcpip.PacketOwner) { e.owner = owner } -func (e *endpoint) takeLastError() *tcpip.Error { +func (e *endpoint) LastError() *tcpip.Error { e.lastErrorMu.Lock() defer e.lastErrorMu.Unlock() err := e.lastError @@ -1294,14 +1317,17 @@ func (e *endpoint) readLocked() (buffer.View, *tcpip.Error) { // indicating the reason why it's not writable. // Caller must hold e.mu and e.sndBufMu func (e *endpoint) isEndpointWritableLocked() (int, *tcpip.Error) { - // The endpoint cannot be written to if it's not connected. - if !e.EndpointState().connected() { - switch e.EndpointState() { - case StateError: - return 0, e.HardError - default: - return 0, tcpip.ErrClosedForSend - } + switch s := e.EndpointState(); { + case s == StateError: + return 0, e.HardError + case !s.connecting() && !s.connected(): + return 0, tcpip.ErrClosedForSend + case s.connecting(): + // As per RFC793, page 56, a send request arriving when in connecting + // state, can be queued to be completed after the state becomes + // connected. Return an error code for the caller of endpoint Write to + // try again, until the connection handshake is complete. + return 0, tcpip.ErrWouldBlock } // Check if the connection has already been closed for sends. @@ -1609,7 +1635,7 @@ func (e *endpoint) SetSockOptInt(opt tcpip.SockOptInt, v int) *tcpip.Error { case tcpip.ReceiveBufferSizeOption: // Make sure the receive buffer size is within the min and max // allowed. - var rs ReceiveBufferSizeOption + var rs tcpip.TCPReceiveBufferSizeRangeOption if err := e.stack.TransportProtocolOption(ProtocolNumber, &rs); err == nil { if v < rs.Min { v = rs.Min @@ -1659,7 +1685,7 @@ func (e *endpoint) SetSockOptInt(opt tcpip.SockOptInt, v int) *tcpip.Error { case tcpip.SendBufferSizeOption: // Make sure the send buffer size is within the min and max // allowed. - var ss SendBufferSizeOption + var ss tcpip.TCPSendBufferSizeRangeOption if err := e.stack.TransportProtocolOption(ProtocolNumber, &ss); err == nil { if v < ss.Min { v = ss.Min @@ -1699,7 +1725,7 @@ func (e *endpoint) SetSockOptInt(opt tcpip.SockOptInt, v int) *tcpip.Error { return tcpip.ErrInvalidOptionValue } } - var rs ReceiveBufferSizeOption + var rs tcpip.TCPReceiveBufferSizeRangeOption if err := e.stack.TransportProtocolOption(ProtocolNumber, &rs); err == nil { if v < rs.Min/2 { v = rs.Min / 2 @@ -1713,10 +1739,10 @@ func (e *endpoint) SetSockOptInt(opt tcpip.SockOptInt, v int) *tcpip.Error { } // SetSockOpt sets a socket option. -func (e *endpoint) SetSockOpt(opt interface{}) *tcpip.Error { +func (e *endpoint) SetSockOpt(opt tcpip.SettableSocketOption) *tcpip.Error { switch v := opt.(type) { - case tcpip.BindToDeviceOption: - id := tcpip.NICID(v) + case *tcpip.BindToDeviceOption: + id := tcpip.NICID(*v) if id != 0 && !e.stack.HasNIC(id) { return tcpip.ErrUnknownDevice } @@ -1724,40 +1750,40 @@ func (e *endpoint) SetSockOpt(opt interface{}) *tcpip.Error { e.bindToDevice = id e.UnlockUser() - case tcpip.KeepaliveIdleOption: + case *tcpip.KeepaliveIdleOption: e.keepalive.Lock() - e.keepalive.idle = time.Duration(v) + e.keepalive.idle = time.Duration(*v) e.keepalive.Unlock() e.notifyProtocolGoroutine(notifyKeepaliveChanged) - case tcpip.KeepaliveIntervalOption: + case *tcpip.KeepaliveIntervalOption: e.keepalive.Lock() - e.keepalive.interval = time.Duration(v) + e.keepalive.interval = time.Duration(*v) e.keepalive.Unlock() e.notifyProtocolGoroutine(notifyKeepaliveChanged) - case tcpip.OutOfBandInlineOption: + case *tcpip.OutOfBandInlineOption: // We don't currently support disabling this option. - case tcpip.TCPUserTimeoutOption: + case *tcpip.TCPUserTimeoutOption: e.LockUser() - e.userTimeout = time.Duration(v) + e.userTimeout = time.Duration(*v) e.UnlockUser() - case tcpip.CongestionControlOption: + case *tcpip.CongestionControlOption: // Query the available cc algorithms in the stack and // validate that the specified algorithm is actually // supported in the stack. - var avail tcpip.AvailableCongestionControlOption + var avail tcpip.TCPAvailableCongestionControlOption if err := e.stack.TransportProtocolOption(ProtocolNumber, &avail); err != nil { return err } availCC := strings.Split(string(avail), " ") for _, cc := range availCC { - if v == tcpip.CongestionControlOption(cc) { + if *v == tcpip.CongestionControlOption(cc) { e.LockUser() state := e.EndpointState() - e.cc = v + e.cc = *v switch state { case StateEstablished: if e.EndpointState() == state { @@ -1773,31 +1799,45 @@ func (e *endpoint) SetSockOpt(opt interface{}) *tcpip.Error { // control algorithm is specified. return tcpip.ErrNoSuchFile - case tcpip.TCPLingerTimeoutOption: + case *tcpip.TCPLingerTimeoutOption: e.LockUser() - if v < 0 { + + switch { + case *v < 0: // Same as effectively disabling TCPLinger timeout. - v = 0 - } - // Cap it to MaxTCPLingerTimeout. - stkTCPLingerTimeout := tcpip.TCPLingerTimeoutOption(MaxTCPLingerTimeout) - if v > stkTCPLingerTimeout { - v = stkTCPLingerTimeout + *v = -1 + case *v == 0: + // Same as the stack default. + var stackLingerTimeout tcpip.TCPLingerTimeoutOption + if err := e.stack.TransportProtocolOption(ProtocolNumber, &stackLingerTimeout); err != nil { + panic(fmt.Sprintf("e.stack.TransportProtocolOption(%d, %+v) = %v", ProtocolNumber, &stackLingerTimeout, err)) + } + *v = stackLingerTimeout + case *v > tcpip.TCPLingerTimeoutOption(MaxTCPLingerTimeout): + // Cap it to Stack's default TCP_LINGER2 timeout. + *v = tcpip.TCPLingerTimeoutOption(MaxTCPLingerTimeout) + default: } - e.tcpLingerTimeout = time.Duration(v) + + e.tcpLingerTimeout = time.Duration(*v) e.UnlockUser() - case tcpip.TCPDeferAcceptOption: + case *tcpip.TCPDeferAcceptOption: e.LockUser() - if time.Duration(v) > MaxRTO { - v = tcpip.TCPDeferAcceptOption(MaxRTO) + if time.Duration(*v) > MaxRTO { + *v = tcpip.TCPDeferAcceptOption(MaxRTO) } - e.deferAccept = time.Duration(v) + e.deferAccept = time.Duration(*v) e.UnlockUser() - case tcpip.SocketDetachFilterOption: + case *tcpip.SocketDetachFilterOption: return nil + case *tcpip.LingerOption: + e.LockUser() + e.linger = *v + e.UnlockUser() + default: return nil } @@ -1956,11 +1996,8 @@ func (e *endpoint) GetSockOptInt(opt tcpip.SockOptInt) (int, *tcpip.Error) { } // GetSockOpt implements tcpip.Endpoint.GetSockOpt. -func (e *endpoint) GetSockOpt(opt interface{}) *tcpip.Error { +func (e *endpoint) GetSockOpt(opt tcpip.GettableSocketOption) *tcpip.Error { switch o := opt.(type) { - case tcpip.ErrorOption: - return e.takeLastError() - case *tcpip.BindToDeviceOption: e.LockUser() *o = tcpip.BindToDeviceOption(e.bindToDevice) @@ -2013,8 +2050,10 @@ func (e *endpoint) GetSockOpt(opt interface{}) *tcpip.Error { e.UnlockUser() case *tcpip.OriginalDestinationOption: + e.LockUser() ipt := e.stack.IPTables() addr, port, err := ipt.OriginalDst(e.ID) + e.UnlockUser() if err != nil { return err } @@ -2023,6 +2062,11 @@ func (e *endpoint) GetSockOpt(opt interface{}) *tcpip.Error { Port: port, } + case *tcpip.LingerOption: + e.LockUser() + *o = e.linger + e.UnlockUser() + default: return tcpip.ErrUnknownProtocolOption } @@ -2169,7 +2213,7 @@ func (e *endpoint) connect(addr tcpip.FullAddress, handshake bool, run bool) *tc if sameAddr && p == e.ID.RemotePort { return false, nil } - if _, err := e.stack.ReservePort(netProtos, ProtocolNumber, e.ID.LocalAddress, p, e.portFlags, e.bindToDevice, addr); err != nil { + if _, err := e.stack.ReservePort(netProtos, ProtocolNumber, e.ID.LocalAddress, p, e.portFlags, e.bindToDevice, addr, nil /* testPort */); err != nil { if err != tcpip.ErrPortInUse || !reuse { return false, nil } @@ -2207,7 +2251,7 @@ func (e *endpoint) connect(addr tcpip.FullAddress, handshake bool, run bool) *tc tcpEP.notifyProtocolGoroutine(notifyAbort) tcpEP.UnlockUser() // Now try and Reserve again if it fails then we skip. - if _, err := e.stack.ReservePort(netProtos, ProtocolNumber, e.ID.LocalAddress, p, e.portFlags, e.bindToDevice, addr); err != nil { + if _, err := e.stack.ReservePort(netProtos, ProtocolNumber, e.ID.LocalAddress, p, e.portFlags, e.bindToDevice, addr, nil /* testPort */); err != nil { return false, nil } } @@ -2447,7 +2491,9 @@ func (e *endpoint) startAcceptedLoop() { // Accept returns a new endpoint if a peer has established a connection // to an endpoint previously set to listen mode. -func (e *endpoint) Accept() (tcpip.Endpoint, *waiter.Queue, *tcpip.Error) { +// +// addr if not-nil will contain the peer address of the returned endpoint. +func (e *endpoint) Accept(peerAddr *tcpip.FullAddress) (tcpip.Endpoint, *waiter.Queue, *tcpip.Error) { e.LockUser() defer e.UnlockUser() @@ -2469,6 +2515,9 @@ func (e *endpoint) Accept() (tcpip.Endpoint, *waiter.Queue, *tcpip.Error) { default: return nil, nil, tcpip.ErrWouldBlock } + if peerAddr != nil { + *peerAddr = n.getRemoteAddress() + } return n, n.waiterQueue, nil } @@ -2505,47 +2554,45 @@ func (e *endpoint) bindLocked(addr tcpip.FullAddress) (err *tcpip.Error) { } } - port, err := e.stack.ReservePort(netProtos, ProtocolNumber, addr.Addr, addr.Port, e.portFlags, e.bindToDevice, tcpip.FullAddress{}) - if err != nil { - return err - } - - e.boundBindToDevice = e.bindToDevice - e.boundPortFlags = e.portFlags - e.isPortReserved = true - e.effectiveNetProtos = netProtos - e.ID.LocalPort = port - - // Any failures beyond this point must remove the port registration. - defer func(portFlags ports.Flags, bindToDevice tcpip.NICID) { - if err != nil { - e.stack.ReleasePort(netProtos, ProtocolNumber, addr.Addr, port, portFlags, bindToDevice, tcpip.FullAddress{}) - e.isPortReserved = false - e.effectiveNetProtos = nil - e.ID.LocalPort = 0 - e.ID.LocalAddress = "" - e.boundNICID = 0 - e.boundBindToDevice = 0 - e.boundPortFlags = ports.Flags{} - } - }(e.boundPortFlags, e.boundBindToDevice) - + var nic tcpip.NICID // If an address is specified, we must ensure that it's one of our // local addresses. if len(addr.Addr) != 0 { - nic := e.stack.CheckLocalAddress(addr.NIC, netProto, addr.Addr) + nic = e.stack.CheckLocalAddress(addr.NIC, netProto, addr.Addr) if nic == 0 { return tcpip.ErrBadLocalAddress } - - e.boundNICID = nic e.ID.LocalAddress = addr.Addr } - if err := e.stack.CheckRegisterTransportEndpoint(e.boundNICID, e.effectiveNetProtos, ProtocolNumber, e.ID, e.boundPortFlags, e.boundBindToDevice); err != nil { + port, err := e.stack.ReservePort(netProtos, ProtocolNumber, addr.Addr, addr.Port, e.portFlags, e.bindToDevice, tcpip.FullAddress{}, func(p uint16) bool { + id := e.ID + id.LocalPort = p + // CheckRegisterTransportEndpoint should only return an error if there is a + // listening endpoint bound with the same id and portFlags and bindToDevice + // options. + // + // NOTE: Only listening and connected endpoint register with + // demuxer. Further connected endpoints always have a remote + // address/port. Hence this will only return an error if there is a matching + // listening endpoint. + if err := e.stack.CheckRegisterTransportEndpoint(nic, netProtos, ProtocolNumber, id, e.portFlags, e.bindToDevice); err != nil { + return false + } + return true + }) + if err != nil { return err } + e.boundBindToDevice = e.bindToDevice + e.boundPortFlags = e.portFlags + // TODO(gvisor.dev/issue/3691): Add test to verify boundNICID is correct. + e.boundNICID = nic + e.isPortReserved = true + e.effectiveNetProtos = netProtos + e.ID.LocalPort = port + // Mark endpoint as bound. e.setEndpointState(StateBound) @@ -2573,11 +2620,15 @@ func (e *endpoint) GetRemoteAddress() (tcpip.FullAddress, *tcpip.Error) { return tcpip.FullAddress{}, tcpip.ErrNotConnected } + return e.getRemoteAddress(), nil +} + +func (e *endpoint) getRemoteAddress() tcpip.FullAddress { return tcpip.FullAddress{ Addr: e.ID.RemoteAddress, Port: e.ID.RemotePort, NIC: e.boundNICID, - }, nil + } } func (e *endpoint) HandlePacket(r *stack.Route, id stack.TransportEndpointID, pkt *stack.PacketBuffer) { @@ -2694,7 +2745,7 @@ func (e *endpoint) receiveBufferSize() int { } func (e *endpoint) maxReceiveBufferSize() int { - var rs ReceiveBufferSizeOption + var rs tcpip.TCPReceiveBufferSizeRangeOption if err := e.stack.TransportProtocolOption(ProtocolNumber, &rs); err != nil { // As a fallback return the hardcoded max buffer size. return MaxBufferSize @@ -2774,7 +2825,7 @@ func timeStampOffset() uint32 { // if the SYN options indicate that the SACK option was negotiated and the TCP // stack is configured to enable TCP SACK option. func (e *endpoint) maybeEnableSACKPermitted(synOpts *header.TCPSynOptions) { - var v SACKEnabled + var v tcpip.TCPSACKEnabled if err := e.stack.TransportProtocolOption(ProtocolNumber, &v); err != nil { // Stack doesn't support SACK. So just return. return diff --git a/pkg/tcpip/transport/tcp/endpoint_state.go b/pkg/tcpip/transport/tcp/endpoint_state.go index 723e47ddc..41d0050f3 100644 --- a/pkg/tcpip/transport/tcp/endpoint_state.go +++ b/pkg/tcpip/transport/tcp/endpoint_state.go @@ -182,14 +182,14 @@ func (e *endpoint) Resume(s *stack.Stack) { epState := e.origEndpointState switch epState { case StateInitial, StateBound, StateListen, StateConnecting, StateEstablished: - var ss SendBufferSizeOption + var ss tcpip.TCPSendBufferSizeRangeOption if err := e.stack.TransportProtocolOption(ProtocolNumber, &ss); err == nil { if e.sndBufSize < ss.Min || e.sndBufSize > ss.Max { panic(fmt.Sprintf("endpoint.sndBufSize %d is outside the min and max allowed [%d, %d]", e.sndBufSize, ss.Min, ss.Max)) } } - var rs ReceiveBufferSizeOption + var rs tcpip.TCPReceiveBufferSizeRangeOption if err := e.stack.TransportProtocolOption(ProtocolNumber, &rs); err == nil { if e.rcvBufSize < rs.Min || e.rcvBufSize > rs.Max { panic(fmt.Sprintf("endpoint.rcvBufSize %d is outside the min and max allowed [%d, %d]", e.rcvBufSize, rs.Min, rs.Max)) diff --git a/pkg/tcpip/transport/tcp/protocol.go b/pkg/tcpip/transport/tcp/protocol.go index c5afa2680..74a17af79 100644 --- a/pkg/tcpip/transport/tcp/protocol.go +++ b/pkg/tcpip/transport/tcp/protocol.go @@ -29,6 +29,7 @@ import ( "gvisor.dev/gvisor/pkg/tcpip" "gvisor.dev/gvisor/pkg/tcpip/buffer" "gvisor.dev/gvisor/pkg/tcpip/header" + "gvisor.dev/gvisor/pkg/tcpip/header/parse" "gvisor.dev/gvisor/pkg/tcpip/seqnum" "gvisor.dev/gvisor/pkg/tcpip/stack" "gvisor.dev/gvisor/pkg/tcpip/transport/raw" @@ -79,50 +80,6 @@ const ( ccCubic = "cubic" ) -// SACKEnabled is used by stack.(*Stack).TransportProtocolOption to -// enable/disable SACK support in TCP. See: https://tools.ietf.org/html/rfc2018. -type SACKEnabled bool - -// Recovery is used by stack.(*Stack).TransportProtocolOption to -// set loss detection algorithm in TCP. -type Recovery int32 - -const ( - // RACKLossDetection indicates RACK is used for loss detection and - // recovery. - RACKLossDetection Recovery = 1 << iota - - // RACKStaticReoWnd indicates the reordering window should not be - // adjusted when DSACK is received. - RACKStaticReoWnd - - // RACKNoDupTh indicates RACK should not consider the classic three - // duplicate acknowledgements rule to mark the segments as lost. This - // is used when reordering is not detected. - RACKNoDupTh -) - -// DelayEnabled is used by stack.(Stack*).TransportProtocolOption to -// enable/disable Nagle's algorithm in TCP. -type DelayEnabled bool - -// SendBufferSizeOption is used by stack.(Stack*).TransportProtocolOption -// to get/set the default, min and max TCP send buffer sizes. -type SendBufferSizeOption struct { - Min int - Default int - Max int -} - -// ReceiveBufferSizeOption is used by -// stack.(Stack*).TransportProtocolOption to get/set the default, min and max -// TCP receive buffer sizes. -type ReceiveBufferSizeOption struct { - Min int - Default int - Max int -} - // syncRcvdCounter tracks the number of endpoints in the SYN-RCVD state. The // value is protected by a mutex so that we can increment only when it's // guaranteed not to go above a threshold. @@ -183,10 +140,10 @@ func (s *synRcvdCounter) Threshold() uint64 { type protocol struct { mu sync.RWMutex sackEnabled bool - recovery Recovery + recovery tcpip.TCPRecovery delayEnabled bool - sendBufferSize SendBufferSizeOption - recvBufferSize ReceiveBufferSizeOption + sendBufferSize tcpip.TCPSendBufferSizeRangeOption + recvBufferSize tcpip.TCPReceiveBufferSizeRangeOption congestionControl string availableCongestionControl []string moderateReceiveBuffer bool @@ -296,49 +253,49 @@ func replyWithReset(s *segment, tos, ttl uint8) { } // SetOption implements stack.TransportProtocol.SetOption. -func (p *protocol) SetOption(option interface{}) *tcpip.Error { +func (p *protocol) SetOption(option tcpip.SettableTransportProtocolOption) *tcpip.Error { switch v := option.(type) { - case SACKEnabled: + case *tcpip.TCPSACKEnabled: p.mu.Lock() - p.sackEnabled = bool(v) + p.sackEnabled = bool(*v) p.mu.Unlock() return nil - case Recovery: + case *tcpip.TCPRecovery: p.mu.Lock() - p.recovery = Recovery(v) + p.recovery = *v p.mu.Unlock() return nil - case DelayEnabled: + case *tcpip.TCPDelayEnabled: p.mu.Lock() - p.delayEnabled = bool(v) + p.delayEnabled = bool(*v) p.mu.Unlock() return nil - case SendBufferSizeOption: + case *tcpip.TCPSendBufferSizeRangeOption: if v.Min <= 0 || v.Default < v.Min || v.Default > v.Max { return tcpip.ErrInvalidOptionValue } p.mu.Lock() - p.sendBufferSize = v + p.sendBufferSize = *v p.mu.Unlock() return nil - case ReceiveBufferSizeOption: + case *tcpip.TCPReceiveBufferSizeRangeOption: if v.Min <= 0 || v.Default < v.Min || v.Default > v.Max { return tcpip.ErrInvalidOptionValue } p.mu.Lock() - p.recvBufferSize = v + p.recvBufferSize = *v p.mu.Unlock() return nil - case tcpip.CongestionControlOption: + case *tcpip.CongestionControlOption: for _, c := range p.availableCongestionControl { - if string(v) == c { + if string(*v) == c { p.mu.Lock() - p.congestionControl = string(v) + p.congestionControl = string(*v) p.mu.Unlock() return nil } @@ -347,75 +304,79 @@ func (p *protocol) SetOption(option interface{}) *tcpip.Error { // is specified. return tcpip.ErrNoSuchFile - case tcpip.ModerateReceiveBufferOption: + case *tcpip.TCPModerateReceiveBufferOption: p.mu.Lock() - p.moderateReceiveBuffer = bool(v) + p.moderateReceiveBuffer = bool(*v) p.mu.Unlock() return nil - case tcpip.TCPLingerTimeoutOption: - if v < 0 { - v = 0 - } + case *tcpip.TCPLingerTimeoutOption: p.mu.Lock() - p.lingerTimeout = time.Duration(v) + if *v < 0 { + p.lingerTimeout = 0 + } else { + p.lingerTimeout = time.Duration(*v) + } p.mu.Unlock() return nil - case tcpip.TCPTimeWaitTimeoutOption: - if v < 0 { - v = 0 - } + case *tcpip.TCPTimeWaitTimeoutOption: p.mu.Lock() - p.timeWaitTimeout = time.Duration(v) + if *v < 0 { + p.timeWaitTimeout = 0 + } else { + p.timeWaitTimeout = time.Duration(*v) + } p.mu.Unlock() return nil - case tcpip.TCPTimeWaitReuseOption: - if v < tcpip.TCPTimeWaitReuseDisabled || v > tcpip.TCPTimeWaitReuseLoopbackOnly { + case *tcpip.TCPTimeWaitReuseOption: + if *v < tcpip.TCPTimeWaitReuseDisabled || *v > tcpip.TCPTimeWaitReuseLoopbackOnly { return tcpip.ErrInvalidOptionValue } p.mu.Lock() - p.timeWaitReuse = v + p.timeWaitReuse = *v p.mu.Unlock() return nil - case tcpip.TCPMinRTOOption: - if v < 0 { - v = tcpip.TCPMinRTOOption(MinRTO) - } + case *tcpip.TCPMinRTOOption: p.mu.Lock() - p.minRTO = time.Duration(v) + if *v < 0 { + p.minRTO = MinRTO + } else { + p.minRTO = time.Duration(*v) + } p.mu.Unlock() return nil - case tcpip.TCPMaxRTOOption: - if v < 0 { - v = tcpip.TCPMaxRTOOption(MaxRTO) - } + case *tcpip.TCPMaxRTOOption: p.mu.Lock() - p.maxRTO = time.Duration(v) + if *v < 0 { + p.maxRTO = MaxRTO + } else { + p.maxRTO = time.Duration(*v) + } p.mu.Unlock() return nil - case tcpip.TCPMaxRetriesOption: + case *tcpip.TCPMaxRetriesOption: p.mu.Lock() - p.maxRetries = uint32(v) + p.maxRetries = uint32(*v) p.mu.Unlock() return nil - case tcpip.TCPSynRcvdCountThresholdOption: + case *tcpip.TCPSynRcvdCountThresholdOption: p.mu.Lock() - p.synRcvdCount.SetThreshold(uint64(v)) + p.synRcvdCount.SetThreshold(uint64(*v)) p.mu.Unlock() return nil - case tcpip.TCPSynRetriesOption: - if v < 1 || v > 255 { + case *tcpip.TCPSynRetriesOption: + if *v < 1 || *v > 255 { return tcpip.ErrInvalidOptionValue } p.mu.Lock() - p.synRetries = uint8(v) + p.synRetries = uint8(*v) p.mu.Unlock() return nil @@ -425,33 +386,33 @@ func (p *protocol) SetOption(option interface{}) *tcpip.Error { } // Option implements stack.TransportProtocol.Option. -func (p *protocol) Option(option interface{}) *tcpip.Error { +func (p *protocol) Option(option tcpip.GettableTransportProtocolOption) *tcpip.Error { switch v := option.(type) { - case *SACKEnabled: + case *tcpip.TCPSACKEnabled: p.mu.RLock() - *v = SACKEnabled(p.sackEnabled) + *v = tcpip.TCPSACKEnabled(p.sackEnabled) p.mu.RUnlock() return nil - case *Recovery: + case *tcpip.TCPRecovery: p.mu.RLock() - *v = Recovery(p.recovery) + *v = tcpip.TCPRecovery(p.recovery) p.mu.RUnlock() return nil - case *DelayEnabled: + case *tcpip.TCPDelayEnabled: p.mu.RLock() - *v = DelayEnabled(p.delayEnabled) + *v = tcpip.TCPDelayEnabled(p.delayEnabled) p.mu.RUnlock() return nil - case *SendBufferSizeOption: + case *tcpip.TCPSendBufferSizeRangeOption: p.mu.RLock() *v = p.sendBufferSize p.mu.RUnlock() return nil - case *ReceiveBufferSizeOption: + case *tcpip.TCPReceiveBufferSizeRangeOption: p.mu.RLock() *v = p.recvBufferSize p.mu.RUnlock() @@ -463,15 +424,15 @@ func (p *protocol) Option(option interface{}) *tcpip.Error { p.mu.RUnlock() return nil - case *tcpip.AvailableCongestionControlOption: + case *tcpip.TCPAvailableCongestionControlOption: p.mu.RLock() - *v = tcpip.AvailableCongestionControlOption(strings.Join(p.availableCongestionControl, " ")) + *v = tcpip.TCPAvailableCongestionControlOption(strings.Join(p.availableCongestionControl, " ")) p.mu.RUnlock() return nil - case *tcpip.ModerateReceiveBufferOption: + case *tcpip.TCPModerateReceiveBufferOption: p.mu.RLock() - *v = tcpip.ModerateReceiveBufferOption(p.moderateReceiveBuffer) + *v = tcpip.TCPModerateReceiveBufferOption(p.moderateReceiveBuffer) p.mu.RUnlock() return nil @@ -546,33 +507,18 @@ func (p *protocol) SynRcvdCounter() *synRcvdCounter { // Parse implements stack.TransportProtocol.Parse. func (*protocol) Parse(pkt *stack.PacketBuffer) bool { - // TCP header is variable length, peek at it first. - hdrLen := header.TCPMinimumSize - hdr, ok := pkt.Data.PullUp(hdrLen) - if !ok { - return false - } - - // If the header has options, pull those up as well. - if offset := int(header.TCP(hdr).DataOffset()); offset > header.TCPMinimumSize && offset <= pkt.Data.Size() { - // TODO(gvisor.dev/issue/2404): Figure out whether to reject this kind of - // packets. - hdrLen = offset - } - - _, ok = pkt.TransportHeader().Consume(hdrLen) - return ok + return parse.TCP(pkt) } // NewProtocol returns a TCP transport protocol. func NewProtocol() stack.TransportProtocol { p := protocol{ - sendBufferSize: SendBufferSizeOption{ + sendBufferSize: tcpip.TCPSendBufferSizeRangeOption{ Min: MinBufferSize, Default: DefaultSendBufferSize, Max: MaxBufferSize, }, - recvBufferSize: ReceiveBufferSizeOption{ + recvBufferSize: tcpip.TCPReceiveBufferSizeRangeOption{ Min: MinBufferSize, Default: DefaultReceiveBufferSize, Max: MaxBufferSize, @@ -587,7 +533,7 @@ func NewProtocol() stack.TransportProtocol { minRTO: MinRTO, maxRTO: MaxRTO, maxRetries: MaxRetries, - recovery: RACKLossDetection, + recovery: tcpip.TCPRACKLossDetection, } p.dispatcher.init(runtime.GOMAXPROCS(0)) return &p diff --git a/pkg/tcpip/transport/tcp/rcv.go b/pkg/tcpip/transport/tcp/rcv.go index 5e0bfe585..cfd43b5e3 100644 --- a/pkg/tcpip/transport/tcp/rcv.go +++ b/pkg/tcpip/transport/tcp/rcv.go @@ -268,14 +268,7 @@ func (r *receiver) handleRcvdSegmentClosing(s *segment, state EndpointState, clo // If we are in one of the shutdown states then we need to do // additional checks before we try and process the segment. switch state { - case StateCloseWait: - // If the ACK acks something not yet sent then we send an ACK. - if r.ep.snd.sndNxt.LessThan(s.ackNumber) { - r.ep.snd.sendAck() - return true, nil - } - fallthrough - case StateClosing, StateLastAck: + case StateCloseWait, StateClosing, StateLastAck: if !s.sequenceNumber.LessThanEq(r.rcvNxt) { // Just drop the segment as we have // already received a FIN and this @@ -284,9 +277,31 @@ func (r *receiver) handleRcvdSegmentClosing(s *segment, state EndpointState, clo return true, nil } fallthrough - case StateFinWait1: - fallthrough - case StateFinWait2: + case StateFinWait1, StateFinWait2: + // If the ACK acks something not yet sent then we send an ACK. + // + // RFC793, page 37: If the connection is in a synchronized state, + // (ESTABLISHED, FIN-WAIT-1, FIN-WAIT-2, CLOSE-WAIT, CLOSING, LAST-ACK, + // TIME-WAIT), any unacceptable segment (out of window sequence number + // or unacceptable acknowledgment number) must elicit only an empty + // acknowledgment segment containing the current send-sequence number + // and an acknowledgment indicating the next sequence number expected + // to be received, and the connection remains in the same state. + // + // Just as on Linux, we do not apply this behavior when state is + // ESTABLISHED. + // Linux receive processing for all states except ESTABLISHED and + // TIME_WAIT is here where if the ACK check fails, we attempt to + // reply back with an ACK with correct seq/ack numbers. + // https://github.com/torvalds/linux/blob/v5.8/net/ipv4/tcp_input.c#L6186 + // The ESTABLISHED state processing is here where if the ACK check + // fails, we ignore the packet: + // https://github.com/torvalds/linux/blob/v5.8/net/ipv4/tcp_input.c#L5591 + if r.ep.snd.sndNxt.LessThan(s.ackNumber) { + r.ep.snd.sendAck() + return true, nil + } + // If we are closed for reads (either due to an // incoming FIN or the user calling shutdown(.., // SHUT_RD) then any data past the rcvNxt should @@ -421,6 +436,13 @@ func (r *receiver) handleTimeWaitSegment(s *segment) (resetTimeWait bool, newSyn // Just silently drop any RST packets in TIME_WAIT. We do not support // TIME_WAIT assasination as a result we confirm w/ fix 1 as described // in https://tools.ietf.org/html/rfc1337#section-3. + // + // This behavior overrides RFC793 page 70 where we transition to CLOSED + // on receiving RST, which is also default Linux behavior. + // On Linux the RST can be ignored by setting sysctl net.ipv4.tcp_rfc1337. + // + // As we do not yet support PAWS, we are being conservative in ignoring + // RSTs by default. if s.flagIsSet(header.TCPFlagRst) { return false, false } diff --git a/pkg/tcpip/transport/tcp/tcp_sack_test.go b/pkg/tcpip/transport/tcp/tcp_sack_test.go index 99521f0c1..ef7f5719f 100644 --- a/pkg/tcpip/transport/tcp/tcp_sack_test.go +++ b/pkg/tcpip/transport/tcp/tcp_sack_test.go @@ -46,8 +46,9 @@ func createConnectedWithSACKAndTS(c *context.Context) *context.RawEndpoint { func setStackSACKPermitted(t *testing.T, c *context.Context, enable bool) { t.Helper() - if err := c.Stack().SetTransportProtocolOption(tcp.ProtocolNumber, tcp.SACKEnabled(enable)); err != nil { - t.Fatalf("c.s.SetTransportProtocolOption(tcp.ProtocolNumber, SACKEnabled(%t) = %s", enable, err) + opt := tcpip.TCPSACKEnabled(enable) + if err := c.Stack().SetTransportProtocolOption(tcp.ProtocolNumber, &opt); err != nil { + t.Fatalf("c.s.SetTransportProtocolOption(%d, &%T(%t)): %s", tcp.ProtocolNumber, opt, opt, err) } } @@ -162,8 +163,9 @@ func TestSackPermittedAccept(t *testing.T) { // Set the SynRcvd threshold to // zero to force a syn cookie // based accept to happen. - if err := c.Stack().SetTransportProtocolOption(tcp.ProtocolNumber, tcpip.TCPSynRcvdCountThresholdOption(0)); err != nil { - t.Fatalf("setting TCPSynRcvdCountThresholdOption to 0 failed: %s", err) + var opt tcpip.TCPSynRcvdCountThresholdOption + if err := c.Stack().SetTransportProtocolOption(tcp.ProtocolNumber, &opt); err != nil { + t.Fatalf("SetTransportProtocolOption(%d, &%T(%d)): %s", tcp.ProtocolNumber, opt, opt, err) } } setStackSACKPermitted(t, c, sackEnabled) @@ -236,8 +238,9 @@ func TestSackDisabledAccept(t *testing.T) { // Set the SynRcvd threshold to // zero to force a syn cookie // based accept to happen. - if err := c.Stack().SetTransportProtocolOption(tcp.ProtocolNumber, tcpip.TCPSynRcvdCountThresholdOption(0)); err != nil { - t.Fatalf("setting TCPSynRcvdCountThresholdOption to 0 failed: %s", err) + var opt tcpip.TCPSynRcvdCountThresholdOption + if err := c.Stack().SetTransportProtocolOption(tcp.ProtocolNumber, &opt); err != nil { + t.Fatalf("SetTransportProtocolOption(%d, &%T(%d)): %s", tcp.ProtocolNumber, opt, opt, err) } } diff --git a/pkg/tcpip/transport/tcp/tcp_test.go b/pkg/tcpip/transport/tcp/tcp_test.go index 55ae09a2f..b1e5f1b24 100644 --- a/pkg/tcpip/transport/tcp/tcp_test.go +++ b/pkg/tcpip/transport/tcp/tcp_test.go @@ -74,8 +74,8 @@ func TestGiveUpConnect(t *testing.T) { // Wait for ep to become writable. <-notifyCh - if err := ep.GetSockOpt(tcpip.ErrorOption{}); err != tcpip.ErrAborted { - t.Fatalf("got ep.GetSockOpt(tcpip.ErrorOption{}) = %s, want = %s", err, tcpip.ErrAborted) + if err := ep.LastError(); err != tcpip.ErrAborted { + t.Fatalf("got ep.LastError() = %s, want = %s", err, tcpip.ErrAborted) } // Call Connect again to retreive the handshake failure status @@ -291,12 +291,12 @@ func TestTCPResetSentForACKWhenNotUsingSynCookies(t *testing.T) { wq.EventRegister(&we, waiter.EventIn) defer wq.EventUnregister(&we) - c.EP, _, err = ep.Accept() + c.EP, _, err = ep.Accept(nil) if err == tcpip.ErrWouldBlock { // Wait for connection to be established. select { case <-ch: - c.EP, _, err = ep.Accept() + c.EP, _, err = ep.Accept(nil) if err != nil { t.Fatalf("Accept failed: %s", err) } @@ -309,8 +309,8 @@ func TestTCPResetSentForACKWhenNotUsingSynCookies(t *testing.T) { // Lower stackwide TIME_WAIT timeout so that the reservations // are released instantly on Close. tcpTW := tcpip.TCPTimeWaitTimeoutOption(1 * time.Millisecond) - if err := c.Stack().SetTransportProtocolOption(tcp.ProtocolNumber, tcpTW); err != nil { - t.Fatalf("e.stack.SetTransportProtocolOption(%d, %#v) = %s", tcp.ProtocolNumber, tcpTW, err) + if err := c.Stack().SetTransportProtocolOption(tcp.ProtocolNumber, &tcpTW); err != nil { + t.Fatalf("SetTransportProtocolOption(%d, &%T(%d)): %s", tcp.ProtocolNumber, tcpTW, tcpTW, err) } c.EP.Close() @@ -432,8 +432,9 @@ func TestConnectResetAfterClose(t *testing.T) { // Set TCPLinger to 3 seconds so that sockets are marked closed // after 3 second in FIN_WAIT2 state. tcpLingerTimeout := 3 * time.Second - if err := c.Stack().SetTransportProtocolOption(tcp.ProtocolNumber, tcpip.TCPLingerTimeoutOption(tcpLingerTimeout)); err != nil { - t.Fatalf("c.stack.SetTransportProtocolOption(tcp, tcpip.TCPLingerTimeoutOption(%s) failed: %s", tcpLingerTimeout, err) + opt := tcpip.TCPLingerTimeoutOption(tcpLingerTimeout) + if err := c.Stack().SetTransportProtocolOption(tcp.ProtocolNumber, &opt); err != nil { + t.Fatalf("SetTransportProtocolOption(%d, &%T(%d)): %s", tcp.ProtocolNumber, opt, opt, err) } c.CreateConnected(789, 30000, -1 /* epRcvBuf */) @@ -506,8 +507,9 @@ func TestCurrentConnectedIncrement(t *testing.T) { // Set TCPTimeWaitTimeout to 1 seconds so that sockets are marked closed // after 1 second in TIME_WAIT state. tcpTimeWaitTimeout := 1 * time.Second - if err := c.Stack().SetTransportProtocolOption(tcp.ProtocolNumber, tcpip.TCPTimeWaitTimeoutOption(tcpTimeWaitTimeout)); err != nil { - t.Fatalf("c.stack.SetTransportProtocolOption(tcp, tcpip.TCPTimeWaitTimeout(%d) failed: %s", tcpTimeWaitTimeout, err) + opt := tcpip.TCPTimeWaitTimeoutOption(tcpTimeWaitTimeout) + if err := c.Stack().SetTransportProtocolOption(tcp.ProtocolNumber, &opt); err != nil { + t.Fatalf("SetTransportProtocolOption(%d, &%T(%d)): %s", tcp.ProtocolNumber, opt, opt, err) } c.CreateConnected(789, 30000, -1 /* epRcvBuf */) @@ -933,8 +935,8 @@ func TestUserSuppliedMSSOnListenAccept(t *testing.T) { // Set the SynRcvd threshold to force a syn cookie based accept to happen. opt := tcpip.TCPSynRcvdCountThresholdOption(nonSynCookieAccepts) - if err := c.Stack().SetTransportProtocolOption(tcp.ProtocolNumber, opt); err != nil { - t.Fatalf("SetTransportProtocolOption(%d, %#v): %s", tcp.ProtocolNumber, opt, err) + if err := c.Stack().SetTransportProtocolOption(tcp.ProtocolNumber, &opt); err != nil { + t.Fatalf("SetTransportProtocolOption(%d, &%T(%d)): %s", tcp.ProtocolNumber, opt, opt, err) } if err := c.EP.SetSockOptInt(tcpip.MaxSegOption, int(test.setMSS)); err != nil { @@ -1349,7 +1351,9 @@ func TestConnectBindToDevice(t *testing.T) { c.Create(-1) bindToDevice := tcpip.BindToDeviceOption(test.device) - c.EP.SetSockOpt(bindToDevice) + if err := c.EP.SetSockOpt(&bindToDevice); err != nil { + t.Fatalf("c.EP.SetSockOpt(&%T(%d)): %s", bindToDevice, bindToDevice, err) + } // Start connection attempt. waitEntry, _ := waiter.NewChannelEntry(nil) c.WQ.EventRegister(&waitEntry, waiter.EventOut) @@ -2201,12 +2205,12 @@ func TestScaledWindowAccept(t *testing.T) { wq.EventRegister(&we, waiter.EventIn) defer wq.EventUnregister(&we) - c.EP, _, err = ep.Accept() + c.EP, _, err = ep.Accept(nil) if err == tcpip.ErrWouldBlock { // Wait for connection to be established. select { case <-ch: - c.EP, _, err = ep.Accept() + c.EP, _, err = ep.Accept(nil) if err != nil { t.Fatalf("Accept failed: %s", err) } @@ -2275,12 +2279,12 @@ func TestNonScaledWindowAccept(t *testing.T) { wq.EventRegister(&we, waiter.EventIn) defer wq.EventUnregister(&we) - c.EP, _, err = ep.Accept() + c.EP, _, err = ep.Accept(nil) if err == tcpip.ErrWouldBlock { // Wait for connection to be established. select { case <-ch: - c.EP, _, err = ep.Accept() + c.EP, _, err = ep.Accept(nil) if err != nil { t.Fatalf("Accept failed: %s", err) } @@ -2838,12 +2842,12 @@ func TestPassiveSendMSSLessThanMTU(t *testing.T) { wq.EventRegister(&we, waiter.EventIn) defer wq.EventUnregister(&we) - c.EP, _, err = ep.Accept() + c.EP, _, err = ep.Accept(nil) if err == tcpip.ErrWouldBlock { // Wait for connection to be established. select { case <-ch: - c.EP, _, err = ep.Accept() + c.EP, _, err = ep.Accept(nil) if err != nil { t.Fatalf("Accept failed: %s", err) } @@ -2865,8 +2869,9 @@ func TestSynCookiePassiveSendMSSLessThanMTU(t *testing.T) { // Set the SynRcvd threshold to zero to force a syn cookie based accept // to happen. - if err := c.Stack().SetTransportProtocolOption(tcp.ProtocolNumber, tcpip.TCPSynRcvdCountThresholdOption(0)); err != nil { - t.Fatalf("setting TCPSynRcvdCountThresholdOption to 0 failed: %s", err) + opt := tcpip.TCPSynRcvdCountThresholdOption(0) + if err := c.Stack().SetTransportProtocolOption(tcp.ProtocolNumber, &opt); err != nil { + t.Fatalf("SetTransportProtocolOption(%d, &%T(%d)): %s", tcp.ProtocolNumber, opt, opt, err) } // Create EP and start listening. @@ -2893,12 +2898,12 @@ func TestSynCookiePassiveSendMSSLessThanMTU(t *testing.T) { wq.EventRegister(&we, waiter.EventIn) defer wq.EventUnregister(&we) - c.EP, _, err = ep.Accept() + c.EP, _, err = ep.Accept(nil) if err == tcpip.ErrWouldBlock { // Wait for connection to be established. select { case <-ch: - c.EP, _, err = ep.Accept() + c.EP, _, err = ep.Accept(nil) if err != nil { t.Fatalf("Accept failed: %s", err) } @@ -3023,8 +3028,8 @@ func TestSynOptionsOnActiveConnect(t *testing.T) { // Wait for connection to be established. select { case <-ch: - if err := c.EP.GetSockOpt(tcpip.ErrorOption{}); err != nil { - t.Fatalf("GetSockOpt failed: %s", err) + if err := c.EP.LastError(); err != nil { + t.Fatalf("Connect failed: %s", err) } case <-time.After(1 * time.Second): t.Fatalf("Timed out waiting for connection") @@ -3144,8 +3149,9 @@ func TestMaxRetransmitsTimeout(t *testing.T) { defer c.Cleanup() const numRetries = 2 - if err := c.Stack().SetTransportProtocolOption(tcp.ProtocolNumber, tcpip.TCPMaxRetriesOption(numRetries)); err != nil { - t.Fatalf("could not set protocol option MaxRetries.\n") + opt := tcpip.TCPMaxRetriesOption(numRetries) + if err := c.Stack().SetTransportProtocolOption(tcp.ProtocolNumber, &opt); err != nil { + t.Fatalf("SetTransportProtocolOption(%d, &%T(%d)): %s", tcp.ProtocolNumber, opt, opt, err) } c.CreateConnected(789 /* iss */, 30000 /* rcvWnd */, -1 /* epRcvBuf */) @@ -3204,8 +3210,9 @@ func TestMaxRTO(t *testing.T) { defer c.Cleanup() rto := 1 * time.Second - if err := c.Stack().SetTransportProtocolOption(tcp.ProtocolNumber, tcpip.TCPMaxRTOOption(rto)); err != nil { - t.Fatalf("c.stack.SetTransportProtocolOption(tcp, tcpip.TCPMaxRTO(%d) failed: %s", rto, err) + opt := tcpip.TCPMaxRTOOption(rto) + if err := c.Stack().SetTransportProtocolOption(tcp.ProtocolNumber, &opt); err != nil { + t.Fatalf("SetTransportProtocolOption(%d, &%T(%d)): %s", tcp.ProtocolNumber, opt, opt, err) } c.CreateConnected(789 /* iss */, 30000 /* rcvWnd */, -1 /* epRcvBuf */) @@ -3962,8 +3969,9 @@ func TestReadAfterClosedState(t *testing.T) { // Set TCPTimeWaitTimeout to 1 seconds so that sockets are marked closed // after 1 second in TIME_WAIT state. tcpTimeWaitTimeout := 1 * time.Second - if err := c.Stack().SetTransportProtocolOption(tcp.ProtocolNumber, tcpip.TCPTimeWaitTimeoutOption(tcpTimeWaitTimeout)); err != nil { - t.Fatalf("c.stack.SetTransportProtocolOption(tcp, tcpip.TCPTimeWaitTimeout(%d) failed: %s", tcpTimeWaitTimeout, err) + opt := tcpip.TCPTimeWaitTimeoutOption(tcpTimeWaitTimeout) + if err := c.Stack().SetTransportProtocolOption(tcp.ProtocolNumber, &opt); err != nil { + t.Fatalf("SetTransportProtocolOption(%d, &%T(%d)): %s", tcp.ProtocolNumber, opt, opt, err) } c.CreateConnected(789, 30000, -1 /* epRcvBuf */) @@ -4202,11 +4210,15 @@ func TestDefaultBufferSizes(t *testing.T) { checkRecvBufferSize(t, ep, tcp.DefaultReceiveBufferSize) // Change the default send buffer size. - if err := s.SetTransportProtocolOption(tcp.ProtocolNumber, tcp.SendBufferSizeOption{ - Min: 1, - Default: tcp.DefaultSendBufferSize * 2, - Max: tcp.DefaultSendBufferSize * 20}); err != nil { - t.Fatalf("SetTransportProtocolOption failed: %s", err) + { + opt := tcpip.TCPSendBufferSizeRangeOption{ + Min: 1, + Default: tcp.DefaultSendBufferSize * 2, + Max: tcp.DefaultSendBufferSize * 20, + } + if err := s.SetTransportProtocolOption(tcp.ProtocolNumber, &opt); err != nil { + t.Fatalf("SetTransportProtocolOption(%d, &%#v): %s", tcp.ProtocolNumber, opt, err) + } } ep.Close() @@ -4219,11 +4231,15 @@ func TestDefaultBufferSizes(t *testing.T) { checkRecvBufferSize(t, ep, tcp.DefaultReceiveBufferSize) // Change the default receive buffer size. - if err := s.SetTransportProtocolOption(tcp.ProtocolNumber, tcp.ReceiveBufferSizeOption{ - Min: 1, - Default: tcp.DefaultReceiveBufferSize * 3, - Max: tcp.DefaultReceiveBufferSize * 30}); err != nil { - t.Fatalf("SetTransportProtocolOption failed: %v", err) + { + opt := tcpip.TCPReceiveBufferSizeRangeOption{ + Min: 1, + Default: tcp.DefaultReceiveBufferSize * 3, + Max: tcp.DefaultReceiveBufferSize * 30, + } + if err := s.SetTransportProtocolOption(tcp.ProtocolNumber, &opt); err != nil { + t.Fatalf("SetTransportProtocolOption(%d, &%#v): %s", tcp.ProtocolNumber, opt, err) + } } ep.Close() @@ -4250,12 +4266,18 @@ func TestMinMaxBufferSizes(t *testing.T) { defer ep.Close() // Change the min/max values for send/receive - if err := s.SetTransportProtocolOption(tcp.ProtocolNumber, tcp.ReceiveBufferSizeOption{Min: 200, Default: tcp.DefaultReceiveBufferSize * 2, Max: tcp.DefaultReceiveBufferSize * 20}); err != nil { - t.Fatalf("SetTransportProtocolOption failed: %s", err) + { + opt := tcpip.TCPReceiveBufferSizeRangeOption{Min: 200, Default: tcp.DefaultReceiveBufferSize * 2, Max: tcp.DefaultReceiveBufferSize * 20} + if err := s.SetTransportProtocolOption(tcp.ProtocolNumber, &opt); err != nil { + t.Fatalf("SetTransportProtocolOption(%d, &%#v): %s", tcp.ProtocolNumber, opt, err) + } } - if err := s.SetTransportProtocolOption(tcp.ProtocolNumber, tcp.SendBufferSizeOption{Min: 300, Default: tcp.DefaultSendBufferSize * 3, Max: tcp.DefaultSendBufferSize * 30}); err != nil { - t.Fatalf("SetTransportProtocolOption failed: %s", err) + { + opt := tcpip.TCPSendBufferSizeRangeOption{Min: 300, Default: tcp.DefaultSendBufferSize * 3, Max: tcp.DefaultSendBufferSize * 30} + if err := s.SetTransportProtocolOption(tcp.ProtocolNumber, &opt); err != nil { + t.Fatalf("SetTransportProtocolOption(%d, &%#v): %s", tcp.ProtocolNumber, opt, err) + } } // Set values below the min. @@ -4321,16 +4343,15 @@ func TestBindToDeviceOption(t *testing.T) { t.Run(testAction.name, func(t *testing.T) { if testAction.setBindToDevice != nil { bindToDevice := tcpip.BindToDeviceOption(*testAction.setBindToDevice) - if gotErr, wantErr := ep.SetSockOpt(bindToDevice), testAction.setBindToDeviceError; gotErr != wantErr { - t.Errorf("SetSockOpt(%#v) got %v, want %v", bindToDevice, gotErr, wantErr) + if gotErr, wantErr := ep.SetSockOpt(&bindToDevice), testAction.setBindToDeviceError; gotErr != wantErr { + t.Errorf("got SetSockOpt(&%T(%d)) = %s, want = %s", bindToDevice, bindToDevice, gotErr, wantErr) } } bindToDevice := tcpip.BindToDeviceOption(88888) if err := ep.GetSockOpt(&bindToDevice); err != nil { - t.Errorf("GetSockOpt got %s, want %v", err, nil) - } - if got, want := bindToDevice, testAction.getBindToDevice; got != want { - t.Errorf("bindToDevice got %d, want %d", got, want) + t.Errorf("GetSockOpt(&%T): %s", bindToDevice, err) + } else if bindToDevice != testAction.getBindToDevice { + t.Errorf("got bindToDevice = %d, want %d", bindToDevice, testAction.getBindToDevice) } }) } @@ -4411,7 +4432,7 @@ func TestSelfConnect(t *testing.T) { } <-notifyCh - if err := ep.GetSockOpt(tcpip.ErrorOption{}); err != nil { + if err := ep.LastError(); err != nil { t.Fatalf("Connect failed: %s", err) } @@ -4717,8 +4738,8 @@ func TestStackSetCongestionControl(t *testing.T) { t.Fatalf("s.TransportProtocolOption(%v, %v) = %s", tcp.ProtocolNumber, &oldCC, err) } - if err := s.SetTransportProtocolOption(tcp.ProtocolNumber, tc.cc); err != tc.err { - t.Fatalf("s.SetTransportProtocolOption(%v, %v) = %v, want %v", tcp.ProtocolNumber, tc.cc, err, tc.err) + if err := s.SetTransportProtocolOption(tcp.ProtocolNumber, &tc.cc); err != tc.err { + t.Fatalf("s.SetTransportProtocolOption(%d, &%T(%s)) = %s, want = %s", tcp.ProtocolNumber, tc.cc, tc.cc, err, tc.err) } var cc tcpip.CongestionControlOption @@ -4750,12 +4771,12 @@ func TestStackAvailableCongestionControl(t *testing.T) { s := c.Stack() // Query permitted congestion control algorithms. - var aCC tcpip.AvailableCongestionControlOption + var aCC tcpip.TCPAvailableCongestionControlOption if err := s.TransportProtocolOption(tcp.ProtocolNumber, &aCC); err != nil { t.Fatalf("s.TransportProtocolOption(%v, %v) = %v", tcp.ProtocolNumber, &aCC, err) } - if got, want := aCC, tcpip.AvailableCongestionControlOption("reno cubic"); got != want { - t.Fatalf("got tcpip.AvailableCongestionControlOption: %v, want: %v", got, want) + if got, want := aCC, tcpip.TCPAvailableCongestionControlOption("reno cubic"); got != want { + t.Fatalf("got tcpip.TCPAvailableCongestionControlOption: %v, want: %v", got, want) } } @@ -4766,18 +4787,18 @@ func TestStackSetAvailableCongestionControl(t *testing.T) { s := c.Stack() // Setting AvailableCongestionControlOption should fail. - aCC := tcpip.AvailableCongestionControlOption("xyz") + aCC := tcpip.TCPAvailableCongestionControlOption("xyz") if err := s.SetTransportProtocolOption(tcp.ProtocolNumber, &aCC); err == nil { - t.Fatalf("s.TransportProtocolOption(%v, %v) = nil, want non-nil", tcp.ProtocolNumber, &aCC) + t.Fatalf("s.SetTransportProtocolOption(%d, &%T(%s)) = nil, want non-nil", tcp.ProtocolNumber, aCC, aCC) } // Verify that we still get the expected list of congestion control options. - var cc tcpip.AvailableCongestionControlOption + var cc tcpip.TCPAvailableCongestionControlOption if err := s.TransportProtocolOption(tcp.ProtocolNumber, &cc); err != nil { - t.Fatalf("s.TransportProtocolOption(%v, %v) = %v", tcp.ProtocolNumber, &cc, err) + t.Fatalf("s.TransportProtocolOptio(%d, &%T(%s)): %s", tcp.ProtocolNumber, cc, cc, err) } - if got, want := cc, tcpip.AvailableCongestionControlOption("reno cubic"); got != want { - t.Fatalf("got tcpip.AvailableCongestionControlOption: %v, want: %v", got, want) + if got, want := cc, tcpip.TCPAvailableCongestionControlOption("reno cubic"); got != want { + t.Fatalf("got tcpip.TCPAvailableCongestionControlOption = %s, want = %s", got, want) } } @@ -4806,20 +4827,20 @@ func TestEndpointSetCongestionControl(t *testing.T) { var oldCC tcpip.CongestionControlOption if err := c.EP.GetSockOpt(&oldCC); err != nil { - t.Fatalf("c.EP.SockOpt(%v) = %s", &oldCC, err) + t.Fatalf("c.EP.GetSockOpt(&%T) = %s", oldCC, err) } if connected { c.Connect(789 /* iss */, 32768 /* rcvWnd */, nil) } - if err := c.EP.SetSockOpt(tc.cc); err != tc.err { - t.Fatalf("c.EP.SetSockOpt(%v) = %s, want %s", tc.cc, err, tc.err) + if err := c.EP.SetSockOpt(&tc.cc); err != tc.err { + t.Fatalf("got c.EP.SetSockOpt(&%#v) = %s, want %s", tc.cc, err, tc.err) } var cc tcpip.CongestionControlOption if err := c.EP.GetSockOpt(&cc); err != nil { - t.Fatalf("c.EP.SockOpt(%v) = %s", &cc, err) + t.Fatalf("c.EP.GetSockOpt(&%T): %s", cc, err) } got, want := cc, oldCC @@ -4831,7 +4852,7 @@ func TestEndpointSetCongestionControl(t *testing.T) { want = tc.cc } if got != want { - t.Fatalf("got congestion control: %v, want: %v", got, want) + t.Fatalf("got congestion control = %+v, want = %+v", got, want) } }) } @@ -4841,8 +4862,8 @@ func TestEndpointSetCongestionControl(t *testing.T) { func enableCUBIC(t *testing.T, c *context.Context) { t.Helper() opt := tcpip.CongestionControlOption("cubic") - if err := c.Stack().SetTransportProtocolOption(tcp.ProtocolNumber, opt); err != nil { - t.Fatalf("c.s.SetTransportProtocolOption(tcp.ProtocolNumber, %s = %s", opt, err) + if err := c.Stack().SetTransportProtocolOption(tcp.ProtocolNumber, &opt); err != nil { + t.Fatalf("SetTransportProtocolOption(%d, &%T(%s)) %s", tcp.ProtocolNumber, opt, opt, err) } } @@ -4852,11 +4873,23 @@ func TestKeepalive(t *testing.T) { c.CreateConnected(789, 30000, -1 /* epRcvBuf */) + const keepAliveIdle = 100 * time.Millisecond const keepAliveInterval = 3 * time.Second - c.EP.SetSockOpt(tcpip.KeepaliveIdleOption(100 * time.Millisecond)) - c.EP.SetSockOpt(tcpip.KeepaliveIntervalOption(keepAliveInterval)) + keepAliveIdleOpt := tcpip.KeepaliveIdleOption(keepAliveIdle) + if err := c.EP.SetSockOpt(&keepAliveIdleOpt); err != nil { + t.Fatalf("c.EP.SetSockOpt(&%T(%s)): %s", keepAliveIdleOpt, keepAliveIdle, err) + } + keepAliveIntervalOpt := tcpip.KeepaliveIntervalOption(keepAliveInterval) + if err := c.EP.SetSockOpt(&keepAliveIntervalOpt); err != nil { + t.Fatalf("c.EP.SetSockOpt(&%T(%s)): %s", keepAliveIntervalOpt, keepAliveInterval, err) + } c.EP.SetSockOptInt(tcpip.KeepaliveCountOption, 5) - c.EP.SetSockOptBool(tcpip.KeepaliveEnabledOption, true) + if err := c.EP.SetSockOptInt(tcpip.KeepaliveCountOption, 5); err != nil { + t.Fatalf("c.EP.SetSockOptInt(tcpip.KeepaliveCountOption, 5): %s", err) + } + if err := c.EP.SetSockOptBool(tcpip.KeepaliveEnabledOption, true); err != nil { + t.Fatalf("c.EP.SetSockOptBool(tcpip.KeepaliveEnabledOption, true): %s", err) + } // 5 unacked keepalives are sent. ACK each one, and check that the // connection stays alive after 5. @@ -5122,12 +5155,12 @@ func TestListenBacklogFull(t *testing.T) { defer c.WQ.EventUnregister(&we) for i := 0; i < listenBacklog; i++ { - _, _, err = c.EP.Accept() + _, _, err = c.EP.Accept(nil) if err == tcpip.ErrWouldBlock { // Wait for connection to be established. select { case <-ch: - _, _, err = c.EP.Accept() + _, _, err = c.EP.Accept(nil) if err != nil { t.Fatalf("Accept failed: %s", err) } @@ -5139,7 +5172,7 @@ func TestListenBacklogFull(t *testing.T) { } // Now verify that there are no more connections that can be accepted. - _, _, err = c.EP.Accept() + _, _, err = c.EP.Accept(nil) if err != tcpip.ErrWouldBlock { select { case <-ch: @@ -5151,12 +5184,12 @@ func TestListenBacklogFull(t *testing.T) { // Now a new handshake must succeed. executeHandshake(t, c, context.TestPort+2, false /*synCookieInUse */) - newEP, _, err := c.EP.Accept() + newEP, _, err := c.EP.Accept(nil) if err == tcpip.ErrWouldBlock { // Wait for connection to be established. select { case <-ch: - newEP, _, err = c.EP.Accept() + newEP, _, err = c.EP.Accept(nil) if err != nil { t.Fatalf("Accept failed: %s", err) } @@ -5181,6 +5214,8 @@ func TestListenBacklogFull(t *testing.T) { func TestListenNoAcceptNonUnicastV4(t *testing.T) { multicastAddr := tcpip.Address("\xe0\x00\x01\x02") otherMulticastAddr := tcpip.Address("\xe0\x00\x01\x03") + subnet := context.StackAddrWithPrefix.Subnet() + subnetBroadcastAddr := subnet.Broadcast() tests := []struct { name string @@ -5188,53 +5223,59 @@ func TestListenNoAcceptNonUnicastV4(t *testing.T) { dstAddr tcpip.Address }{ { - "SourceUnspecified", - header.IPv4Any, - context.StackAddr, + name: "SourceUnspecified", + srcAddr: header.IPv4Any, + dstAddr: context.StackAddr, }, { - "SourceBroadcast", - header.IPv4Broadcast, - context.StackAddr, + name: "SourceBroadcast", + srcAddr: header.IPv4Broadcast, + dstAddr: context.StackAddr, }, { - "SourceOurMulticast", - multicastAddr, - context.StackAddr, + name: "SourceOurMulticast", + srcAddr: multicastAddr, + dstAddr: context.StackAddr, }, { - "SourceOtherMulticast", - otherMulticastAddr, - context.StackAddr, + name: "SourceOtherMulticast", + srcAddr: otherMulticastAddr, + dstAddr: context.StackAddr, }, { - "DestUnspecified", - context.TestAddr, - header.IPv4Any, + name: "DestUnspecified", + srcAddr: context.TestAddr, + dstAddr: header.IPv4Any, }, { - "DestBroadcast", - context.TestAddr, - header.IPv4Broadcast, + name: "DestBroadcast", + srcAddr: context.TestAddr, + dstAddr: header.IPv4Broadcast, }, { - "DestOurMulticast", - context.TestAddr, - multicastAddr, + name: "DestOurMulticast", + srcAddr: context.TestAddr, + dstAddr: multicastAddr, }, { - "DestOtherMulticast", - context.TestAddr, - otherMulticastAddr, + name: "DestOtherMulticast", + srcAddr: context.TestAddr, + dstAddr: otherMulticastAddr, + }, + { + name: "SrcSubnetBroadcast", + srcAddr: subnetBroadcastAddr, + dstAddr: context.StackAddr, + }, + { + name: "DestSubnetBroadcast", + srcAddr: context.TestAddr, + dstAddr: subnetBroadcastAddr, }, } for _, test := range tests { - test := test // capture range variable - t.Run(test.name, func(t *testing.T) { - t.Parallel() - c := context.New(t, defaultMTU) defer c.Cleanup() @@ -5334,11 +5375,7 @@ func TestListenNoAcceptNonUnicastV6(t *testing.T) { } for _, test := range tests { - test := test // capture range variable - t.Run(test.name, func(t *testing.T) { - t.Parallel() - c := context.New(t, defaultMTU) defer c.Cleanup() @@ -5463,12 +5500,12 @@ func TestListenSynRcvdQueueFull(t *testing.T) { c.WQ.EventRegister(&we, waiter.EventIn) defer c.WQ.EventUnregister(&we) - newEP, _, err := c.EP.Accept() + newEP, _, err := c.EP.Accept(nil) if err == tcpip.ErrWouldBlock { // Wait for connection to be established. select { case <-ch: - newEP, _, err = c.EP.Accept() + newEP, _, err = c.EP.Accept(nil) if err != nil { t.Fatalf("Accept failed: %s", err) } @@ -5492,8 +5529,9 @@ func TestListenBacklogFullSynCookieInUse(t *testing.T) { c := context.New(t, defaultMTU) defer c.Cleanup() - if err := c.Stack().SetTransportProtocolOption(tcp.ProtocolNumber, tcpip.TCPSynRcvdCountThresholdOption(1)); err != nil { - t.Fatalf("setting TCPSynRcvdCountThresholdOption to 1 failed: %s", err) + opt := tcpip.TCPSynRcvdCountThresholdOption(1) + if err := c.Stack().SetTransportProtocolOption(tcp.ProtocolNumber, &opt); err != nil { + t.Fatalf("SetTransportProtocolOption(%d, &%T(%d)): %s", tcp.ProtocolNumber, opt, opt, err) } // Create TCP endpoint. @@ -5539,12 +5577,12 @@ func TestListenBacklogFullSynCookieInUse(t *testing.T) { c.WQ.EventRegister(&we, waiter.EventIn) defer c.WQ.EventUnregister(&we) - _, _, err = c.EP.Accept() + _, _, err = c.EP.Accept(nil) if err == tcpip.ErrWouldBlock { // Wait for connection to be established. select { case <-ch: - _, _, err = c.EP.Accept() + _, _, err = c.EP.Accept(nil) if err != nil { t.Fatalf("Accept failed: %s", err) } @@ -5555,7 +5593,7 @@ func TestListenBacklogFullSynCookieInUse(t *testing.T) { } // Now verify that there are no more connections that can be accepted. - _, _, err = c.EP.Accept() + _, _, err = c.EP.Accept(nil) if err != tcpip.ErrWouldBlock { select { case <-ch: @@ -5644,7 +5682,7 @@ func TestSynRcvdBadSeqNumber(t *testing.T) { RcvWnd: 30000, }) - newEP, _, err := c.EP.Accept() + newEP, _, err := c.EP.Accept(nil) if err != nil && err != tcpip.ErrWouldBlock { t.Fatalf("Accept failed: %s", err) @@ -5659,7 +5697,7 @@ func TestSynRcvdBadSeqNumber(t *testing.T) { // Wait for connection to be established. select { case <-ch: - newEP, _, err = c.EP.Accept() + newEP, _, err = c.EP.Accept(nil) if err != nil { t.Fatalf("Accept failed: %s", err) } @@ -5717,12 +5755,12 @@ func TestPassiveConnectionAttemptIncrement(t *testing.T) { defer c.WQ.EventUnregister(&we) // Verify that there is only one acceptable connection at this point. - _, _, err = c.EP.Accept() + _, _, err = c.EP.Accept(nil) if err == tcpip.ErrWouldBlock { // Wait for connection to be established. select { case <-ch: - _, _, err = c.EP.Accept() + _, _, err = c.EP.Accept(nil) if err != nil { t.Fatalf("Accept failed: %s", err) } @@ -5787,12 +5825,12 @@ func TestPassiveFailedConnectionAttemptIncrement(t *testing.T) { defer c.WQ.EventUnregister(&we) // Now check that there is one acceptable connections. - _, _, err = c.EP.Accept() + _, _, err = c.EP.Accept(nil) if err == tcpip.ErrWouldBlock { // Wait for connection to be established. select { case <-ch: - _, _, err = c.EP.Accept() + _, _, err = c.EP.Accept(nil) if err != nil { t.Fatalf("Accept failed: %s", err) } @@ -5840,12 +5878,12 @@ func TestEndpointBindListenAcceptState(t *testing.T) { wq.EventRegister(&we, waiter.EventIn) defer wq.EventUnregister(&we) - aep, _, err := ep.Accept() + aep, _, err := ep.Accept(nil) if err == tcpip.ErrWouldBlock { // Wait for connection to be established. select { case <-ch: - aep, _, err = ep.Accept() + aep, _, err = ep.Accept(nil) if err != nil { t.Fatalf("Accept failed: %s", err) } @@ -5893,13 +5931,19 @@ func TestReceiveBufferAutoTuningApplicationLimited(t *testing.T) { // the segment queue holding unprocessed packets is limited to 500. const receiveBufferSize = 80 << 10 // 80KB. const maxReceiveBufferSize = receiveBufferSize * 10 - if err := stk.SetTransportProtocolOption(tcp.ProtocolNumber, tcp.ReceiveBufferSizeOption{Min: 1, Default: receiveBufferSize, Max: maxReceiveBufferSize}); err != nil { - t.Fatalf("SetTransportProtocolOption failed: %s", err) + { + opt := tcpip.TCPReceiveBufferSizeRangeOption{Min: 1, Default: receiveBufferSize, Max: maxReceiveBufferSize} + if err := stk.SetTransportProtocolOption(tcp.ProtocolNumber, &opt); err != nil { + t.Fatalf("SetTransportProtocolOption(%d, &%#v): %s", tcp.ProtocolNumber, opt, err) + } } // Enable auto-tuning. - if err := stk.SetTransportProtocolOption(tcp.ProtocolNumber, tcpip.ModerateReceiveBufferOption(true)); err != nil { - t.Fatalf("SetTransportProtocolOption failed: %s", err) + { + opt := tcpip.TCPModerateReceiveBufferOption(true) + if err := stk.SetTransportProtocolOption(tcp.ProtocolNumber, &opt); err != nil { + t.Fatalf("SetTransportProtocolOption(%d, &%T(%t)): %s", tcp.ProtocolNumber, opt, opt, err) + } } // Change the expected window scale to match the value needed for the // maximum buffer size defined above. @@ -6014,13 +6058,19 @@ func TestReceiveBufferAutoTuning(t *testing.T) { // the segment queue holding unprocessed packets is limited to 300. const receiveBufferSize = 80 << 10 // 80KB. const maxReceiveBufferSize = receiveBufferSize * 10 - if err := stk.SetTransportProtocolOption(tcp.ProtocolNumber, tcp.ReceiveBufferSizeOption{Min: 1, Default: receiveBufferSize, Max: maxReceiveBufferSize}); err != nil { - t.Fatalf("SetTransportProtocolOption failed: %s", err) + { + opt := tcpip.TCPReceiveBufferSizeRangeOption{Min: 1, Default: receiveBufferSize, Max: maxReceiveBufferSize} + if err := stk.SetTransportProtocolOption(tcp.ProtocolNumber, &opt); err != nil { + t.Fatalf("SetTransportProtocolOption(%d, &%#v): %s", tcp.ProtocolNumber, opt, err) + } } // Enable auto-tuning. - if err := stk.SetTransportProtocolOption(tcp.ProtocolNumber, tcpip.ModerateReceiveBufferOption(true)); err != nil { - t.Fatalf("SetTransportProtocolOption failed: %s", err) + { + opt := tcpip.TCPModerateReceiveBufferOption(true) + if err := stk.SetTransportProtocolOption(tcp.ProtocolNumber, &opt); err != nil { + t.Fatalf("SetTransportProtocolOption(%d, &%T(%t)): %s", tcp.ProtocolNumber, opt, opt, err) + } } // Change the expected window scale to match the value needed for the // maximum buffer size used by stack. @@ -6156,7 +6206,7 @@ func TestDelayEnabled(t *testing.T) { checkDelayOption(t, c, false, false) // Delay is disabled by default. for _, v := range []struct { - delayEnabled tcp.DelayEnabled + delayEnabled tcpip.TCPDelayEnabled wantDelayOption bool }{ {delayEnabled: false, wantDelayOption: false}, @@ -6164,17 +6214,17 @@ func TestDelayEnabled(t *testing.T) { } { c := context.New(t, defaultMTU) defer c.Cleanup() - if err := c.Stack().SetTransportProtocolOption(tcp.ProtocolNumber, v.delayEnabled); err != nil { - t.Fatalf("SetTransportProtocolOption(tcp, %t) failed: %s", v.delayEnabled, err) + if err := c.Stack().SetTransportProtocolOption(tcp.ProtocolNumber, &v.delayEnabled); err != nil { + t.Fatalf("SetTransportProtocolOption(%d, &%T(%t)): %s", tcp.ProtocolNumber, v.delayEnabled, v.delayEnabled, err) } checkDelayOption(t, c, v.delayEnabled, v.wantDelayOption) } } -func checkDelayOption(t *testing.T, c *context.Context, wantDelayEnabled tcp.DelayEnabled, wantDelayOption bool) { +func checkDelayOption(t *testing.T, c *context.Context, wantDelayEnabled tcpip.TCPDelayEnabled, wantDelayOption bool) { t.Helper() - var gotDelayEnabled tcp.DelayEnabled + var gotDelayEnabled tcpip.TCPDelayEnabled if err := c.Stack().TransportProtocolOption(tcp.ProtocolNumber, &gotDelayEnabled); err != nil { t.Fatalf("TransportProtocolOption(tcp, &gotDelayEnabled) failed: %s", err) } @@ -6206,24 +6256,27 @@ func TestTCPLingerTimeout(t *testing.T) { tcpLingerTimeout time.Duration want time.Duration }{ - {"NegativeLingerTimeout", -123123, 0}, - {"ZeroLingerTimeout", 0, 0}, + {"NegativeLingerTimeout", -123123, -1}, + // Zero is treated same as the stack's default TCP_LINGER2 timeout. + {"ZeroLingerTimeout", 0, tcp.DefaultTCPLingerTimeout}, {"InRangeLingerTimeout", 10 * time.Second, 10 * time.Second}, // Values > stack's TCPLingerTimeout are capped to the stack's // value. Defaults to tcp.DefaultTCPLingerTimeout(60 seconds) - {"AboveMaxLingerTimeout", 125 * time.Second, 120 * time.Second}, + {"AboveMaxLingerTimeout", tcp.MaxTCPLingerTimeout + 5*time.Second, tcp.MaxTCPLingerTimeout}, } for _, tc := range testCases { t.Run(tc.name, func(t *testing.T) { - if err := c.EP.SetSockOpt(tcpip.TCPLingerTimeoutOption(tc.tcpLingerTimeout)); err != nil { - t.Fatalf("SetSockOpt(%s) = %s", tc.tcpLingerTimeout, err) + v := tcpip.TCPLingerTimeoutOption(tc.tcpLingerTimeout) + if err := c.EP.SetSockOpt(&v); err != nil { + t.Fatalf("SetSockOpt(&%T(%s)) = %s", v, tc.tcpLingerTimeout, err) } - var v tcpip.TCPLingerTimeoutOption + + v = 0 if err := c.EP.GetSockOpt(&v); err != nil { - t.Fatalf("GetSockOpt(tcpip.TCPLingerTimeoutOption) = %s", err) + t.Fatalf("GetSockOpt(&%T) = %s", v, err) } if got, want := time.Duration(v), tc.want; got != want { - t.Fatalf("unexpected linger timeout got: %s, want: %s", got, want) + t.Fatalf("got linger timeout = %s, want = %s", got, want) } }) } @@ -6277,12 +6330,12 @@ func TestTCPTimeWaitRSTIgnored(t *testing.T) { wq.EventRegister(&we, waiter.EventIn) defer wq.EventUnregister(&we) - c.EP, _, err = ep.Accept() + c.EP, _, err = ep.Accept(nil) if err == tcpip.ErrWouldBlock { // Wait for connection to be established. select { case <-ch: - c.EP, _, err = ep.Accept() + c.EP, _, err = ep.Accept(nil) if err != nil { t.Fatalf("Accept failed: %s", err) } @@ -6396,12 +6449,12 @@ func TestTCPTimeWaitOutOfOrder(t *testing.T) { wq.EventRegister(&we, waiter.EventIn) defer wq.EventUnregister(&we) - c.EP, _, err = ep.Accept() + c.EP, _, err = ep.Accept(nil) if err == tcpip.ErrWouldBlock { // Wait for connection to be established. select { case <-ch: - c.EP, _, err = ep.Accept() + c.EP, _, err = ep.Accept(nil) if err != nil { t.Fatalf("Accept failed: %s", err) } @@ -6503,12 +6556,12 @@ func TestTCPTimeWaitNewSyn(t *testing.T) { wq.EventRegister(&we, waiter.EventIn) defer wq.EventUnregister(&we) - c.EP, _, err = ep.Accept() + c.EP, _, err = ep.Accept(nil) if err == tcpip.ErrWouldBlock { // Wait for connection to be established. select { case <-ch: - c.EP, _, err = ep.Accept() + c.EP, _, err = ep.Accept(nil) if err != nil { t.Fatalf("Accept failed: %s", err) } @@ -6586,12 +6639,12 @@ func TestTCPTimeWaitNewSyn(t *testing.T) { c.SendPacket(nil, ackHeaders) // Try to accept the connection. - c.EP, _, err = ep.Accept() + c.EP, _, err = ep.Accept(nil) if err == tcpip.ErrWouldBlock { // Wait for connection to be established. select { case <-ch: - c.EP, _, err = ep.Accept() + c.EP, _, err = ep.Accept(nil) if err != nil { t.Fatalf("Accept failed: %s", err) } @@ -6609,8 +6662,9 @@ func TestTCPTimeWaitDuplicateFINExtendsTimeWait(t *testing.T) { // Set TCPTimeWaitTimeout to 5 seconds so that sockets are marked closed // after 5 seconds in TIME_WAIT state. tcpTimeWaitTimeout := 5 * time.Second - if err := c.Stack().SetTransportProtocolOption(tcp.ProtocolNumber, tcpip.TCPTimeWaitTimeoutOption(tcpTimeWaitTimeout)); err != nil { - t.Fatalf("c.stack.SetTransportProtocolOption(tcp, tcpip.TCPLingerTimeoutOption(%d) failed: %s", tcpTimeWaitTimeout, err) + opt := tcpip.TCPTimeWaitTimeoutOption(tcpTimeWaitTimeout) + if err := c.Stack().SetTransportProtocolOption(tcp.ProtocolNumber, &opt); err != nil { + t.Fatalf("SetTransportProtocolOption(%d, &%T(%s)): %s", tcp.ProtocolNumber, opt, tcpTimeWaitTimeout, err) } want := c.Stack().Stats().TCP.EstablishedClosed.Value() + 1 @@ -6659,12 +6713,12 @@ func TestTCPTimeWaitDuplicateFINExtendsTimeWait(t *testing.T) { wq.EventRegister(&we, waiter.EventIn) defer wq.EventUnregister(&we) - c.EP, _, err = ep.Accept() + c.EP, _, err = ep.Accept(nil) if err == tcpip.ErrWouldBlock { // Wait for connection to be established. select { case <-ch: - c.EP, _, err = ep.Accept() + c.EP, _, err = ep.Accept(nil) if err != nil { t.Fatalf("Accept failed: %s", err) } @@ -6759,8 +6813,9 @@ func TestTCPCloseWithData(t *testing.T) { // Set TCPTimeWaitTimeout to 5 seconds so that sockets are marked closed // after 5 seconds in TIME_WAIT state. tcpTimeWaitTimeout := 5 * time.Second - if err := c.Stack().SetTransportProtocolOption(tcp.ProtocolNumber, tcpip.TCPTimeWaitTimeoutOption(tcpTimeWaitTimeout)); err != nil { - t.Fatalf("c.stack.SetTransportProtocolOption(tcp, tcpip.TCPLingerTimeoutOption(%d) failed: %s", tcpTimeWaitTimeout, err) + opt := tcpip.TCPTimeWaitTimeoutOption(tcpTimeWaitTimeout) + if err := c.Stack().SetTransportProtocolOption(tcp.ProtocolNumber, &opt); err != nil { + t.Fatalf("SetTransportProtocolOption(%d, &%T(%s)): %s", tcp.ProtocolNumber, opt, tcpTimeWaitTimeout, err) } wq := &waiter.Queue{} @@ -6808,12 +6863,12 @@ func TestTCPCloseWithData(t *testing.T) { wq.EventRegister(&we, waiter.EventIn) defer wq.EventUnregister(&we) - c.EP, _, err = ep.Accept() + c.EP, _, err = ep.Accept(nil) if err == tcpip.ErrWouldBlock { // Wait for connection to be established. select { case <-ch: - c.EP, _, err = ep.Accept() + c.EP, _, err = ep.Accept(nil) if err != nil { t.Fatalf("Accept failed: %s", err) } @@ -6940,7 +6995,10 @@ func TestTCPUserTimeout(t *testing.T) { // expired. initRTO := 1 * time.Second userTimeout := initRTO / 2 - c.EP.SetSockOpt(tcpip.TCPUserTimeoutOption(userTimeout)) + v := tcpip.TCPUserTimeoutOption(userTimeout) + if err := c.EP.SetSockOpt(&v); err != nil { + t.Fatalf("c.EP.SetSockOpt(&%T(%s): %s", v, userTimeout, err) + } // Send some data and wait before ACKing it. view := buffer.NewView(3) @@ -7014,18 +7072,31 @@ func TestKeepaliveWithUserTimeout(t *testing.T) { origEstablishedTimedout := c.Stack().Stats().TCP.EstablishedTimedout.Value() + const keepAliveIdle = 100 * time.Millisecond const keepAliveInterval = 3 * time.Second - c.EP.SetSockOpt(tcpip.KeepaliveIdleOption(100 * time.Millisecond)) - c.EP.SetSockOpt(tcpip.KeepaliveIntervalOption(keepAliveInterval)) - c.EP.SetSockOptInt(tcpip.KeepaliveCountOption, 10) - c.EP.SetSockOptBool(tcpip.KeepaliveEnabledOption, true) + keepAliveIdleOption := tcpip.KeepaliveIdleOption(keepAliveIdle) + if err := c.EP.SetSockOpt(&keepAliveIdleOption); err != nil { + t.Fatalf("c.EP.SetSockOpt(&%T(%s)): %s", keepAliveIdleOption, keepAliveIdle, err) + } + keepAliveIntervalOption := tcpip.KeepaliveIntervalOption(keepAliveInterval) + if err := c.EP.SetSockOpt(&keepAliveIntervalOption); err != nil { + t.Fatalf("c.EP.SetSockOpt(&%T(%s)): %s", keepAliveIntervalOption, keepAliveInterval, err) + } + if err := c.EP.SetSockOptInt(tcpip.KeepaliveCountOption, 10); err != nil { + t.Fatalf("c.EP.SetSockOptInt(tcpip.KeepaliveCountOption, 10): %s", err) + } + if err := c.EP.SetSockOptBool(tcpip.KeepaliveEnabledOption, true); err != nil { + t.Fatalf("c.EP.SetSockOptBool(tcpip.KeepaliveEnabledOption, true): %s", err) + } // Set userTimeout to be the duration to be 1 keepalive // probes. Which means that after the first probe is sent // the second one should cause the connection to be // closed due to userTimeout being hit. - userTimeout := 1 * keepAliveInterval - c.EP.SetSockOpt(tcpip.TCPUserTimeoutOption(userTimeout)) + userTimeout := tcpip.TCPUserTimeoutOption(keepAliveInterval) + if err := c.EP.SetSockOpt(&userTimeout); err != nil { + t.Fatalf("c.EP.SetSockOpt(&%T(%s)): %s", userTimeout, keepAliveInterval, err) + } // Check that the connection is still alive. if _, _, err := c.EP.Read(nil); err != tcpip.ErrWouldBlock { @@ -7232,14 +7303,15 @@ func TestTCPDeferAccept(t *testing.T) { } const tcpDeferAccept = 1 * time.Second - if err := c.EP.SetSockOpt(tcpip.TCPDeferAcceptOption(tcpDeferAccept)); err != nil { - t.Fatalf("c.EP.SetSockOpt(TCPDeferAcceptOption(%s) failed: %s", tcpDeferAccept, err) + tcpDeferAcceptOption := tcpip.TCPDeferAcceptOption(tcpDeferAccept) + if err := c.EP.SetSockOpt(&tcpDeferAcceptOption); err != nil { + t.Fatalf("c.EP.SetSockOpt(&%T(%s)): %s", tcpDeferAcceptOption, tcpDeferAccept, err) } irs, iss := executeHandshake(t, c, context.TestPort, false /* synCookiesInUse */) - if _, _, err := c.EP.Accept(); err != tcpip.ErrWouldBlock { - t.Fatalf("c.EP.Accept() returned unexpected error got: %s, want: %s", err, tcpip.ErrWouldBlock) + if _, _, err := c.EP.Accept(nil); err != tcpip.ErrWouldBlock { + t.Fatalf("got c.EP.Accept(nil) = %s, want: %s", err, tcpip.ErrWouldBlock) } // Send data. This should result in an acceptable endpoint. @@ -7260,9 +7332,9 @@ func TestTCPDeferAccept(t *testing.T) { // Give a bit of time for the socket to be delivered to the accept queue. time.Sleep(50 * time.Millisecond) - aep, _, err := c.EP.Accept() + aep, _, err := c.EP.Accept(nil) if err != nil { - t.Fatalf("c.EP.Accept() returned unexpected error got: %s, want: nil", err) + t.Fatalf("got c.EP.Accept(nil) = %s, want: nil", err) } aep.Close() @@ -7289,14 +7361,15 @@ func TestTCPDeferAcceptTimeout(t *testing.T) { } const tcpDeferAccept = 1 * time.Second - if err := c.EP.SetSockOpt(tcpip.TCPDeferAcceptOption(tcpDeferAccept)); err != nil { - t.Fatalf("c.EP.SetSockOpt(TCPDeferAcceptOption(%s) failed: %s", tcpDeferAccept, err) + tcpDeferAcceptOpt := tcpip.TCPDeferAcceptOption(tcpDeferAccept) + if err := c.EP.SetSockOpt(&tcpDeferAcceptOpt); err != nil { + t.Fatalf("c.EP.SetSockOpt(&%T(%s)) failed: %s", tcpDeferAcceptOpt, tcpDeferAccept, err) } irs, iss := executeHandshake(t, c, context.TestPort, false /* synCookiesInUse */) - if _, _, err := c.EP.Accept(); err != tcpip.ErrWouldBlock { - t.Fatalf("c.EP.Accept() returned unexpected error got: %s, want: %s", err, tcpip.ErrWouldBlock) + if _, _, err := c.EP.Accept(nil); err != tcpip.ErrWouldBlock { + t.Fatalf("got c.EP.Accept(nil) = %s, want: %s", err, tcpip.ErrWouldBlock) } // Sleep for a little of the tcpDeferAccept timeout. @@ -7328,9 +7401,9 @@ func TestTCPDeferAcceptTimeout(t *testing.T) { // Give sometime for the endpoint to be delivered to the accept queue. time.Sleep(50 * time.Millisecond) - aep, _, err := c.EP.Accept() + aep, _, err := c.EP.Accept(nil) if err != nil { - t.Fatalf("c.EP.Accept() returned unexpected error got: %s, want: nil", err) + t.Fatalf("got c.EP.Accept(nil) = %s, want: nil", err) } aep.Close() @@ -7428,9 +7501,10 @@ func TestSetStackTimeWaitReuse(t *testing.T) { } for _, tc := range testCases { - err := s.SetTransportProtocolOption(tcp.ProtocolNumber, tcpip.TCPTimeWaitReuseOption(tc.v)) + opt := tcpip.TCPTimeWaitReuseOption(tc.v) + err := s.SetTransportProtocolOption(tcp.ProtocolNumber, &opt) if got, want := err, tc.err; got != want { - t.Fatalf("s.TransportProtocolOption(%v, %v) = %v, want %v", tcp.ProtocolNumber, tc.v, err, tc.err) + t.Fatalf("s.SetTransportProtocolOption(%d, &%T(%d)) = %s, want = %s", tcp.ProtocolNumber, tc.v, tc.v, err, tc.err) } if tc.err != nil { continue diff --git a/pkg/tcpip/transport/tcp/tcp_timestamp_test.go b/pkg/tcpip/transport/tcp/tcp_timestamp_test.go index 8edbff964..44593ed98 100644 --- a/pkg/tcpip/transport/tcp/tcp_timestamp_test.go +++ b/pkg/tcpip/transport/tcp/tcp_timestamp_test.go @@ -131,8 +131,9 @@ func timeStampEnabledAccept(t *testing.T, cookieEnabled bool, wndScale int, wndS defer c.Cleanup() if cookieEnabled { - if err := c.Stack().SetTransportProtocolOption(tcp.ProtocolNumber, tcpip.TCPSynRcvdCountThresholdOption(0)); err != nil { - t.Fatalf("setting TCPSynRcvdCountThresholdOption to 0 failed: %s", err) + var opt tcpip.TCPSynRcvdCountThresholdOption + if err := c.Stack().SetTransportProtocolOption(tcp.ProtocolNumber, &opt); err != nil { + t.Fatalf("SetTransportProtocolOption(%d, &%T(%d)): %s", tcp.ProtocolNumber, opt, opt, err) } } @@ -192,8 +193,9 @@ func timeStampDisabledAccept(t *testing.T, cookieEnabled bool, wndScale int, wnd defer c.Cleanup() if cookieEnabled { - if err := c.Stack().SetTransportProtocolOption(tcp.ProtocolNumber, tcpip.TCPSynRcvdCountThresholdOption(0)); err != nil { - t.Fatalf("setting TCPSynRcvdCountThresholdOption to 0 failed: %s", err) + var opt tcpip.TCPSynRcvdCountThresholdOption + if err := c.Stack().SetTransportProtocolOption(tcp.ProtocolNumber, &opt); err != nil { + t.Fatalf("SetTransportProtocolOption(%d, &%T(%d)): %s", tcp.ProtocolNumber, opt, opt, err) } } diff --git a/pkg/tcpip/transport/tcp/testing/context/context.go b/pkg/tcpip/transport/tcp/testing/context/context.go index b6031354e..85e8c1c75 100644 --- a/pkg/tcpip/transport/tcp/testing/context/context.go +++ b/pkg/tcpip/transport/tcp/testing/context/context.go @@ -53,11 +53,11 @@ const ( TestPort = 4096 // StackV6Addr is the IPv6 address assigned to the stack. - StackV6Addr = "\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x01" + StackV6Addr = "\x0a\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x01" // TestV6Addr is the source address for packets sent to the stack via // the link layer endpoint. - TestV6Addr = "\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x02" + TestV6Addr = "\x0a\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x02" // StackV4MappedAddr is StackAddr as a mapped v6 address. StackV4MappedAddr = "\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\xff\xff" + StackAddr @@ -73,6 +73,18 @@ const ( testInitialSequenceNumber = 789 ) +// StackAddrWithPrefix is StackAddr with its associated prefix length. +var StackAddrWithPrefix = tcpip.AddressWithPrefix{ + Address: StackAddr, + PrefixLen: 24, +} + +// StackV6AddrWithPrefix is StackV6Addr with its associated prefix length. +var StackV6AddrWithPrefix = tcpip.AddressWithPrefix{ + Address: StackV6Addr, + PrefixLen: header.IIDOffsetInIPv6Address * 8, +} + // Headers is used to represent the TCP header fields when building a // new packet. type Headers struct { @@ -146,19 +158,22 @@ func New(t *testing.T, mtu uint32) *Context { const sendBufferSize = 1 << 20 // 1 MiB const recvBufferSize = 1 << 20 // 1 MiB // Allow minimum send/receive buffer sizes to be 1 during tests. - if err := s.SetTransportProtocolOption(tcp.ProtocolNumber, tcp.SendBufferSizeOption{Min: 1, Default: sendBufferSize, Max: 10 * sendBufferSize}); err != nil { - t.Fatalf("SetTransportProtocolOption failed: %s", err) + sendBufOpt := tcpip.TCPSendBufferSizeRangeOption{Min: 1, Default: sendBufferSize, Max: 10 * sendBufferSize} + if err := s.SetTransportProtocolOption(tcp.ProtocolNumber, &sendBufOpt); err != nil { + t.Fatalf("SetTransportProtocolOption(%d, &%#v) failed: %s", tcp.ProtocolNumber, sendBufOpt, err) } - if err := s.SetTransportProtocolOption(tcp.ProtocolNumber, tcp.ReceiveBufferSizeOption{Min: 1, Default: recvBufferSize, Max: 10 * recvBufferSize}); err != nil { - t.Fatalf("SetTransportProtocolOption failed: %s", err) + rcvBufOpt := tcpip.TCPReceiveBufferSizeRangeOption{Min: 1, Default: recvBufferSize, Max: 10 * recvBufferSize} + if err := s.SetTransportProtocolOption(tcp.ProtocolNumber, &rcvBufOpt); err != nil { + t.Fatalf("SetTransportProtocolOption(%d, &%#v) failed: %s", tcp.ProtocolNumber, rcvBufOpt, err) } // Increase minimum RTO in tests to avoid test flakes due to early // retransmit in case the test executors are overloaded and cause timers // to fire earlier than expected. - if err := s.SetTransportProtocolOption(tcp.ProtocolNumber, tcpip.TCPMinRTOOption(3*time.Second)); err != nil { - t.Fatalf("failed to set stack-wide minRTO: %s", err) + minRTOOpt := tcpip.TCPMinRTOOption(3 * time.Second) + if err := s.SetTransportProtocolOption(tcp.ProtocolNumber, &minRTOOpt); err != nil { + t.Fatalf("s.SetTransportProtocolOption(%d, &%T(%d)): %s", tcp.ProtocolNumber, minRTOOpt, minRTOOpt, err) } // Some of the congestion control tests send up to 640 packets, we so @@ -181,12 +196,20 @@ func New(t *testing.T, mtu uint32) *Context { t.Fatalf("CreateNICWithOptions(_, _, %+v) failed: %v", opts2, err) } - if err := s.AddAddress(1, ipv4.ProtocolNumber, StackAddr); err != nil { - t.Fatalf("AddAddress failed: %v", err) + v4ProtocolAddr := tcpip.ProtocolAddress{ + Protocol: ipv4.ProtocolNumber, + AddressWithPrefix: StackAddrWithPrefix, + } + if err := s.AddProtocolAddress(1, v4ProtocolAddr); err != nil { + t.Fatalf("AddProtocolAddress(1, %#v): %s", v4ProtocolAddr, err) } - if err := s.AddAddress(1, ipv6.ProtocolNumber, StackV6Addr); err != nil { - t.Fatalf("AddAddress failed: %v", err) + v6ProtocolAddr := tcpip.ProtocolAddress{ + Protocol: ipv6.ProtocolNumber, + AddressWithPrefix: StackV6AddrWithPrefix, + } + if err := s.AddProtocolAddress(1, v6ProtocolAddr); err != nil { + t.Fatalf("AddProtocolAddress(1, %#v): %s", v6ProtocolAddr, err) } s.SetRouteTable([]tcpip.Route{ @@ -638,7 +661,7 @@ func (c *Context) Connect(iss seqnum.Value, rcvWnd seqnum.Size, options []byte) // Wait for connection to be established. select { case <-notifyCh: - if err := c.EP.GetSockOpt(tcpip.ErrorOption{}); err != nil { + if err := c.EP.LastError(); err != nil { c.t.Fatalf("Unexpected error when connecting: %v", err) } case <-time.After(1 * time.Second): @@ -882,8 +905,7 @@ func (c *Context) CreateConnectedWithOptions(wantOptions header.TCPSynOptions) * // Wait for connection to be established. select { case <-notifyCh: - err = c.EP.GetSockOpt(tcpip.ErrorOption{}) - if err != nil { + if err := c.EP.LastError(); err != nil { c.t.Fatalf("Unexpected error when connecting: %v", err) } case <-time.After(1 * time.Second): @@ -949,12 +971,12 @@ func (c *Context) AcceptWithOptions(wndScale int, synOptions header.TCPSynOption wq.EventRegister(&we, waiter.EventIn) defer wq.EventUnregister(&we) - c.EP, _, err = ep.Accept() + c.EP, _, err = ep.Accept(nil) if err == tcpip.ErrWouldBlock { // Wait for connection to be established. select { case <-ch: - c.EP, _, err = ep.Accept() + c.EP, _, err = ep.Accept(nil) if err != nil { c.t.Fatalf("Accept failed: %v", err) } @@ -1097,7 +1119,7 @@ func (c *Context) PassiveConnectWithOptions(maxPayload, wndScale int, synOptions // SACKEnabled returns true if the TCP Protocol option SACKEnabled is set to true // for the Stack in the context. func (c *Context) SACKEnabled() bool { - var v tcp.SACKEnabled + var v tcpip.TCPSACKEnabled if err := c.Stack().TransportProtocolOption(tcp.ProtocolNumber, &v); err != nil { // Stack doesn't support SACK. So just return. return false diff --git a/pkg/tcpip/transport/udp/BUILD b/pkg/tcpip/transport/udp/BUILD index b5d2d0ba6..c78549424 100644 --- a/pkg/tcpip/transport/udp/BUILD +++ b/pkg/tcpip/transport/udp/BUILD @@ -32,6 +32,7 @@ go_library( "//pkg/tcpip", "//pkg/tcpip/buffer", "//pkg/tcpip/header", + "//pkg/tcpip/header/parse", "//pkg/tcpip/ports", "//pkg/tcpip/stack", "//pkg/tcpip/transport/raw", diff --git a/pkg/tcpip/transport/udp/endpoint.go b/pkg/tcpip/transport/udp/endpoint.go index 73608783c..518f636f0 100644 --- a/pkg/tcpip/transport/udp/endpoint.go +++ b/pkg/tcpip/transport/udp/endpoint.go @@ -139,7 +139,7 @@ type endpoint struct { // multicastMemberships that need to be remvoed when the endpoint is // closed. Protected by the mu mutex. - multicastMemberships []multicastMembership + multicastMemberships map[multicastMembership]struct{} // effectiveNetProtos contains the network protocols actually in use. In // most cases it will only contain "netProto", but in cases like IPv6 @@ -154,6 +154,9 @@ type endpoint struct { // owner is used to get uid and gid of the packet. owner tcpip.PacketOwner + + // linger is used for SO_LINGER socket option. + linger tcpip.LingerOption } // +stateify savable @@ -182,12 +185,13 @@ func newEndpoint(s *stack.Stack, netProto tcpip.NetworkProtocolNumber, waiterQue // TTL=1. // // Linux defaults to TTL=1. - multicastTTL: 1, - multicastLoop: true, - rcvBufSizeMax: 32 * 1024, - sndBufSizeMax: 32 * 1024, - state: StateInitial, - uniqueID: s.UniqueID(), + multicastTTL: 1, + multicastLoop: true, + rcvBufSizeMax: 32 * 1024, + sndBufSizeMax: 32 * 1024, + multicastMemberships: make(map[multicastMembership]struct{}), + state: StateInitial, + uniqueID: s.UniqueID(), } // Override with stack defaults. @@ -209,7 +213,7 @@ func (e *endpoint) UniqueID() uint64 { return e.uniqueID } -func (e *endpoint) takeLastError() *tcpip.Error { +func (e *endpoint) LastError() *tcpip.Error { e.lastErrorMu.Lock() defer e.lastErrorMu.Unlock() @@ -237,10 +241,10 @@ func (e *endpoint) Close() { e.boundPortFlags = ports.Flags{} } - for _, mem := range e.multicastMemberships { + for mem := range e.multicastMemberships { e.stack.LeaveGroup(e.NetProto, mem.nicID, mem.multicastAddr) } - e.multicastMemberships = nil + e.multicastMemberships = make(map[multicastMembership]struct{}) // Close the receive list and drain it. e.rcvMu.Lock() @@ -268,7 +272,7 @@ func (e *endpoint) ModerateRecvBuf(copied int) {} // Read reads data from the endpoint. This method does not block if // there is no data pending. func (e *endpoint) Read(addr *tcpip.FullAddress) (buffer.View, tcpip.ControlMessages, *tcpip.Error) { - if err := e.takeLastError(); err != nil { + if err := e.LastError(); err != nil { return buffer.View{}, tcpip.ControlMessages{}, err } @@ -411,7 +415,7 @@ func (e *endpoint) Write(p tcpip.Payloader, opts tcpip.WriteOptions) (int64, <-c } func (e *endpoint) write(p tcpip.Payloader, opts tcpip.WriteOptions) (int64, <-chan struct{}, *tcpip.Error) { - if err := e.takeLastError(); err != nil { + if err := e.LastError(); err != nil { return 0, nil, err } @@ -683,9 +687,9 @@ func (e *endpoint) SetSockOptInt(opt tcpip.SockOptInt, v int) *tcpip.Error { } // SetSockOpt implements tcpip.Endpoint.SetSockOpt. -func (e *endpoint) SetSockOpt(opt interface{}) *tcpip.Error { +func (e *endpoint) SetSockOpt(opt tcpip.SettableSocketOption) *tcpip.Error { switch v := opt.(type) { - case tcpip.MulticastInterfaceOption: + case *tcpip.MulticastInterfaceOption: e.mu.Lock() defer e.mu.Unlock() @@ -721,7 +725,7 @@ func (e *endpoint) SetSockOpt(opt interface{}) *tcpip.Error { e.multicastNICID = nic e.multicastAddr = addr - case tcpip.AddMembershipOption: + case *tcpip.AddMembershipOption: if !header.IsV4MulticastAddress(v.MulticastAddr) && !header.IsV6MulticastAddress(v.MulticastAddr) { return tcpip.ErrInvalidOptionValue } @@ -752,19 +756,17 @@ func (e *endpoint) SetSockOpt(opt interface{}) *tcpip.Error { e.mu.Lock() defer e.mu.Unlock() - for _, mem := range e.multicastMemberships { - if mem == memToInsert { - return tcpip.ErrPortInUse - } + if _, ok := e.multicastMemberships[memToInsert]; ok { + return tcpip.ErrPortInUse } if err := e.stack.JoinGroup(e.NetProto, nicID, v.MulticastAddr); err != nil { return err } - e.multicastMemberships = append(e.multicastMemberships, memToInsert) + e.multicastMemberships[memToInsert] = struct{}{} - case tcpip.RemoveMembershipOption: + case *tcpip.RemoveMembershipOption: if !header.IsV4MulticastAddress(v.MulticastAddr) && !header.IsV6MulticastAddress(v.MulticastAddr) { return tcpip.ErrInvalidOptionValue } @@ -786,18 +788,11 @@ func (e *endpoint) SetSockOpt(opt interface{}) *tcpip.Error { } memToRemove := multicastMembership{nicID: nicID, multicastAddr: v.MulticastAddr} - memToRemoveIndex := -1 e.mu.Lock() defer e.mu.Unlock() - for i, mem := range e.multicastMemberships { - if mem == memToRemove { - memToRemoveIndex = i - break - } - } - if memToRemoveIndex == -1 { + if _, ok := e.multicastMemberships[memToRemove]; !ok { return tcpip.ErrBadLocalAddress } @@ -805,11 +800,10 @@ func (e *endpoint) SetSockOpt(opt interface{}) *tcpip.Error { return err } - e.multicastMemberships[memToRemoveIndex] = e.multicastMemberships[len(e.multicastMemberships)-1] - e.multicastMemberships = e.multicastMemberships[:len(e.multicastMemberships)-1] + delete(e.multicastMemberships, memToRemove) - case tcpip.BindToDeviceOption: - id := tcpip.NICID(v) + case *tcpip.BindToDeviceOption: + id := tcpip.NICID(*v) if id != 0 && !e.stack.HasNIC(id) { return tcpip.ErrUnknownDevice } @@ -817,8 +811,13 @@ func (e *endpoint) SetSockOpt(opt interface{}) *tcpip.Error { e.bindToDevice = id e.mu.Unlock() - case tcpip.SocketDetachFilterOption: + case *tcpip.SocketDetachFilterOption: return nil + + case *tcpip.LingerOption: + e.mu.Lock() + e.linger = *v + e.mu.Unlock() } return nil } @@ -960,10 +959,8 @@ func (e *endpoint) GetSockOptInt(opt tcpip.SockOptInt) (int, *tcpip.Error) { } // GetSockOpt implements tcpip.Endpoint.GetSockOpt. -func (e *endpoint) GetSockOpt(opt interface{}) *tcpip.Error { +func (e *endpoint) GetSockOpt(opt tcpip.GettableSocketOption) *tcpip.Error { switch o := opt.(type) { - case tcpip.ErrorOption: - return e.takeLastError() case *tcpip.MulticastInterfaceOption: e.mu.Lock() *o = tcpip.MulticastInterfaceOption{ @@ -977,6 +974,11 @@ func (e *endpoint) GetSockOpt(opt interface{}) *tcpip.Error { *o = tcpip.BindToDeviceOption(e.bindToDevice) e.mu.RUnlock() + case *tcpip.LingerOption: + e.mu.RLock() + *o = e.linger + e.mu.RUnlock() + default: return tcpip.ErrUnknownProtocolOption } @@ -1220,13 +1222,13 @@ func (*endpoint) Listen(int) *tcpip.Error { } // Accept is not supported by UDP, it just fails. -func (*endpoint) Accept() (tcpip.Endpoint, *waiter.Queue, *tcpip.Error) { +func (*endpoint) Accept(*tcpip.FullAddress) (tcpip.Endpoint, *waiter.Queue, *tcpip.Error) { return nil, nil, tcpip.ErrNotSupported } func (e *endpoint) registerWithStack(nicID tcpip.NICID, netProtos []tcpip.NetworkProtocolNumber, id stack.TransportEndpointID) (stack.TransportEndpointID, tcpip.NICID, *tcpip.Error) { if e.ID.LocalPort == 0 { - port, err := e.stack.ReservePort(netProtos, ProtocolNumber, id.LocalAddress, id.LocalPort, e.portFlags, e.bindToDevice, tcpip.FullAddress{}) + port, err := e.stack.ReservePort(netProtos, ProtocolNumber, id.LocalAddress, id.LocalPort, e.portFlags, e.bindToDevice, tcpip.FullAddress{}, nil /* testPort */) if err != nil { return id, e.bindToDevice, err } @@ -1366,6 +1368,22 @@ func (e *endpoint) Readiness(mask waiter.EventMask) waiter.EventMask { return result } +// verifyChecksum verifies the checksum unless RX checksum offload is enabled. +// On IPv4, UDP checksum is optional, and a zero value means the transmitter +// omitted the checksum generation (RFC768). +// On IPv6, UDP checksum is not optional (RFC2460 Section 8.1). +func verifyChecksum(r *stack.Route, hdr header.UDP, pkt *stack.PacketBuffer) bool { + if r.Capabilities()&stack.CapabilityRXChecksumOffload == 0 && + (hdr.Checksum() != 0 || r.NetProto == header.IPv6ProtocolNumber) { + xsum := r.PseudoHeaderChecksum(ProtocolNumber, hdr.Length()) + for _, v := range pkt.Data.Views() { + xsum = header.Checksum(v, xsum) + } + return hdr.CalculateChecksum(xsum) == 0xffff + } + return true +} + // HandlePacket is called by the stack when new packets arrive to this transport // endpoint. func (e *endpoint) HandlePacket(r *stack.Route, id stack.TransportEndpointID, pkt *stack.PacketBuffer) { @@ -1387,22 +1405,11 @@ func (e *endpoint) HandlePacket(r *stack.Route, id stack.TransportEndpointID, pk return } - // Verify checksum unless RX checksum offload is enabled. - // On IPv4, UDP checksum is optional, and a zero value means - // the transmitter omitted the checksum generation (RFC768). - // On IPv6, UDP checksum is not optional (RFC2460 Section 8.1). - if r.Capabilities()&stack.CapabilityRXChecksumOffload == 0 && - (hdr.Checksum() != 0 || r.NetProto == header.IPv6ProtocolNumber) { - xsum := r.PseudoHeaderChecksum(ProtocolNumber, hdr.Length()) - for _, v := range pkt.Data.Views() { - xsum = header.Checksum(v, xsum) - } - if hdr.CalculateChecksum(xsum) != 0xffff { - // Checksum Error. - e.stack.Stats().UDP.ChecksumErrors.Increment() - e.stats.ReceiveErrors.ChecksumErrors.Increment() - return - } + if !verifyChecksum(r, hdr, pkt) { + // Checksum Error. + e.stack.Stats().UDP.ChecksumErrors.Increment() + e.stats.ReceiveErrors.ChecksumErrors.Increment() + return } e.stack.Stats().UDP.PacketsReceived.Increment() diff --git a/pkg/tcpip/transport/udp/endpoint_state.go b/pkg/tcpip/transport/udp/endpoint_state.go index 851e6b635..858c99a45 100644 --- a/pkg/tcpip/transport/udp/endpoint_state.go +++ b/pkg/tcpip/transport/udp/endpoint_state.go @@ -92,7 +92,7 @@ func (e *endpoint) Resume(s *stack.Stack) { e.stack = s - for _, m := range e.multicastMemberships { + for m := range e.multicastMemberships { if err := e.stack.JoinGroup(e.NetProto, m.nicID, m.multicastAddr); err != nil { panic(err) } diff --git a/pkg/tcpip/transport/udp/protocol.go b/pkg/tcpip/transport/udp/protocol.go index 63d4bed7c..7d6b91a75 100644 --- a/pkg/tcpip/transport/udp/protocol.go +++ b/pkg/tcpip/transport/udp/protocol.go @@ -24,6 +24,7 @@ import ( "gvisor.dev/gvisor/pkg/tcpip" "gvisor.dev/gvisor/pkg/tcpip/buffer" "gvisor.dev/gvisor/pkg/tcpip/header" + "gvisor.dev/gvisor/pkg/tcpip/header/parse" "gvisor.dev/gvisor/pkg/tcpip/stack" "gvisor.dev/gvisor/pkg/tcpip/transport/raw" "gvisor.dev/gvisor/pkg/waiter" @@ -88,7 +89,12 @@ func (p *protocol) HandleUnknownDestinationPacket(r *stack.Route, id stack.Trans r.Stack().Stats().UDP.MalformedPacketsReceived.Increment() return true } - // TODO(b/129426613): only send an ICMP message if UDP checksum is valid. + + if !verifyChecksum(r, hdr, pkt) { + // Checksum Error. + r.Stack().Stats().UDP.ChecksumErrors.Increment() + return true + } // Only send ICMP error if the address is not a multicast/broadcast // v4/v6 address or the source is not the unspecified address. @@ -197,12 +203,12 @@ func (p *protocol) HandleUnknownDestinationPacket(r *stack.Route, id stack.Trans } // SetOption implements stack.TransportProtocol.SetOption. -func (p *protocol) SetOption(option interface{}) *tcpip.Error { +func (*protocol) SetOption(tcpip.SettableTransportProtocolOption) *tcpip.Error { return tcpip.ErrUnknownProtocolOption } // Option implements stack.TransportProtocol.Option. -func (p *protocol) Option(option interface{}) *tcpip.Error { +func (*protocol) Option(tcpip.GettableTransportProtocolOption) *tcpip.Error { return tcpip.ErrUnknownProtocolOption } @@ -214,8 +220,7 @@ func (*protocol) Wait() {} // Parse implements stack.TransportProtocol.Parse. func (*protocol) Parse(pkt *stack.PacketBuffer) bool { - _, ok := pkt.TransportHeader().Consume(header.UDPMinimumSize) - return ok + return parse.UDP(pkt) } // NewProtocol returns a UDP transport protocol. diff --git a/pkg/tcpip/transport/udp/udp_test.go b/pkg/tcpip/transport/udp/udp_test.go index f87d99d5a..d5881d183 100644 --- a/pkg/tcpip/transport/udp/udp_test.go +++ b/pkg/tcpip/transport/udp/udp_test.go @@ -403,18 +403,35 @@ func (c *testContext) getPacketAndVerify(flow testFlow, checkers ...checker.Netw } // injectPacket creates a packet of the given flow and with the given payload, -// and injects it into the link endpoint. -func (c *testContext) injectPacket(flow testFlow, payload []byte) { +// and injects it into the link endpoint. If badChecksum is true, the packet has +// a bad checksum in the UDP header. +func (c *testContext) injectPacket(flow testFlow, payload []byte, badChecksum bool) { c.t.Helper() h := flow.header4Tuple(incoming) if flow.isV4() { buf := c.buildV4Packet(payload, &h) + if badChecksum { + // Invalidate the UDP header checksum field, taking care to avoid + // overflow to zero, which would disable checksum validation. + for u := header.UDP(buf[header.IPv4MinimumSize:]); ; { + u.SetChecksum(u.Checksum() + 1) + if u.Checksum() != 0 { + break + } + } + } c.linkEP.InjectInbound(ipv4.ProtocolNumber, stack.NewPacketBuffer(stack.PacketBufferOptions{ Data: buf.ToVectorisedView(), })) } else { buf := c.buildV6Packet(payload, &h) + if badChecksum { + // Invalidate the UDP header checksum field (Unlike IPv4, zero is + // a valid checksum value for IPv6 so no need to avoid it). + u := header.UDP(buf[header.IPv6MinimumSize:]) + u.SetChecksum(u.Checksum() + 1) + } c.linkEP.InjectInbound(ipv6.ProtocolNumber, stack.NewPacketBuffer(stack.PacketBufferOptions{ Data: buf.ToVectorisedView(), })) @@ -522,7 +539,7 @@ func TestBindToDeviceOption(t *testing.T) { opts := stack.NICOptions{Name: "my_device"} if err := s.CreateNICWithOptions(321, loopback.New(), opts); err != nil { - t.Errorf("CreateNICWithOptions(_, _, %+v) failed: %v", opts, err) + t.Errorf("CreateNICWithOptions(_, _, %+v) failed: %s", opts, err) } // nicIDPtr is used instead of taking the address of NICID literals, which is @@ -546,16 +563,15 @@ func TestBindToDeviceOption(t *testing.T) { t.Run(testAction.name, func(t *testing.T) { if testAction.setBindToDevice != nil { bindToDevice := tcpip.BindToDeviceOption(*testAction.setBindToDevice) - if gotErr, wantErr := ep.SetSockOpt(bindToDevice), testAction.setBindToDeviceError; gotErr != wantErr { - t.Errorf("SetSockOpt(%v) got %v, want %v", bindToDevice, gotErr, wantErr) + if gotErr, wantErr := ep.SetSockOpt(&bindToDevice), testAction.setBindToDeviceError; gotErr != wantErr { + t.Errorf("got SetSockOpt(&%T(%d)) = %s, want = %s", bindToDevice, bindToDevice, gotErr, wantErr) } } bindToDevice := tcpip.BindToDeviceOption(88888) if err := ep.GetSockOpt(&bindToDevice); err != nil { - t.Errorf("GetSockOpt got %v, want %v", err, nil) - } - if got, want := bindToDevice, testAction.getBindToDevice; got != want { - t.Errorf("bindToDevice got %d, want %d", got, want) + t.Errorf("GetSockOpt(&%T): %s", bindToDevice, err) + } else if bindToDevice != testAction.getBindToDevice { + t.Errorf("got bindToDevice = %d, want = %d", bindToDevice, testAction.getBindToDevice) } }) } @@ -569,7 +585,7 @@ func testReadInternal(c *testContext, flow testFlow, packetShouldBeDropped, expe c.t.Helper() payload := newPayload() - c.injectPacket(flow, payload) + c.injectPacket(flow, payload, false) // Try to receive the data. we, ch := waiter.NewChannelEntry(nil) @@ -611,12 +627,12 @@ func testReadInternal(c *testContext, flow testFlow, packetShouldBeDropped, expe // Check the peer address. h := flow.header4Tuple(incoming) if addr.Addr != h.srcAddr.Addr { - c.t.Fatalf("unexpected remote address: got %s, want %v", addr.Addr, h.srcAddr) + c.t.Fatalf("got address = %s, want = %s", addr.Addr, h.srcAddr.Addr) } // Check the payload. if !bytes.Equal(payload, v) { - c.t.Fatalf("bad payload: got %x, want %x", v, payload) + c.t.Fatalf("got payload = %x, want = %x", v, payload) } // Run any checkers against the ControlMessages. @@ -677,7 +693,7 @@ func TestBindReservedPort(t *testing.T) { } defer ep.Close() if got, want := ep.Bind(addr), tcpip.ErrPortInUse; got != want { - t.Fatalf("got ep.Bind(...) = %v, want = %v", got, want) + t.Fatalf("got ep.Bind(...) = %s, want = %s", got, want) } } @@ -690,7 +706,7 @@ func TestBindReservedPort(t *testing.T) { // We can't bind ipv4-any on the port reserved by the connected endpoint // above, since the endpoint is dual-stack. if got, want := ep.Bind(tcpip.FullAddress{Port: addr.Port}), tcpip.ErrPortInUse; got != want { - t.Fatalf("got ep.Bind(...) = %v, want = %v", got, want) + t.Fatalf("got ep.Bind(...) = %s, want = %s", got, want) } // We can bind an ipv4 address on this port, though. if err := ep.Bind(tcpip.FullAddress{Addr: stackAddr, Port: addr.Port}); err != nil { @@ -813,7 +829,7 @@ func TestV4ReadSelfSource(t *testing.T) { } if _, _, err := c.ep.Read(nil); err != tt.wantErr { - t.Errorf("c.ep.Read() got error %v, want %v", err, tt.wantErr) + t.Errorf("got c.ep.Read(nil) = %s, want = %s", err, tt.wantErr) } }) } @@ -854,8 +870,8 @@ func TestReadOnBoundToMulticast(t *testing.T) { // Join multicast group. ifoptSet := tcpip.AddMembershipOption{NIC: 1, MulticastAddr: mcastAddr} - if err := c.ep.SetSockOpt(ifoptSet); err != nil { - c.t.Fatal("SetSockOpt failed:", err) + if err := c.ep.SetSockOpt(&ifoptSet); err != nil { + c.t.Fatalf("SetSockOpt(&%#v): %s", ifoptSet, err) } // Check that we receive multicast packets but not unicast or broadcast @@ -925,7 +941,7 @@ func TestReadFromMulticastStats(t *testing.T) { } payload := newPayload() - c.injectPacket(flow, payload) + c.injectPacket(flow, payload, false) var want uint64 = 0 if flow.isReverseMulticast() { @@ -1386,8 +1402,8 @@ func TestReadIPPacketInfo(t *testing.T) { if test.flow.isMulticast() { ifoptSet := tcpip.AddMembershipOption{NIC: 1, MulticastAddr: test.flow.getMcastAddr()} - if err := c.ep.SetSockOpt(ifoptSet); err != nil { - c.t.Fatalf("SetSockOpt(%+v): %s:", ifoptSet, err) + if err := c.ep.SetSockOpt(&ifoptSet); err != nil { + c.t.Fatalf("SetSockOpt(&%#v): %s:", ifoptSet, err) } } @@ -1469,7 +1485,7 @@ func TestTTL(t *testing.T) { } else { p = ipv6.NewProtocol() } - ep := p.NewEndpoint(0, nil, nil, nil, stack.New(stack.Options{ + ep := p.NewEndpoint(0, nil, nil, nil, nil, stack.New(stack.Options{ NetworkProtocols: []stack.NetworkProtocol{ipv4.NewProtocol(), ipv6.NewProtocol()}, TransportProtocols: []stack.TransportProtocol{udp.NewProtocol()}, })) @@ -1502,7 +1518,7 @@ func TestSetTTL(t *testing.T) { } else { p = ipv6.NewProtocol() } - ep := p.NewEndpoint(0, nil, nil, nil, stack.New(stack.Options{ + ep := p.NewEndpoint(0, nil, nil, nil, nil, stack.New(stack.Options{ NetworkProtocols: []stack.NetworkProtocol{ipv4.NewProtocol(), ipv6.NewProtocol()}, TransportProtocols: []stack.TransportProtocol{udp.NewProtocol()}, })) @@ -1530,7 +1546,7 @@ func TestSetTOS(t *testing.T) { } // Test for expected default value. if v != 0 { - c.t.Errorf("got GetSockOpt(IPv4TOSOption) = 0x%x, want = 0x%x", v, 0) + c.t.Errorf("got GetSockOptInt(IPv4TOSOption) = 0x%x, want = 0x%x", v, 0) } if err := c.ep.SetSockOptInt(tcpip.IPv4TOSOption, tos); err != nil { @@ -1691,19 +1707,17 @@ func TestMulticastInterfaceOption(t *testing.T) { } } - if err := c.ep.SetSockOpt(ifoptSet); err != nil { - c.t.Fatalf("SetSockOpt failed: %s", err) + if err := c.ep.SetSockOpt(&ifoptSet); err != nil { + c.t.Fatalf("SetSockOpt(&%#v): %s", ifoptSet, err) } // Verify multicast interface addr and NIC were set correctly. // Note that NIC must be 1 since this is our outgoing interface. - ifoptWant := tcpip.MulticastInterfaceOption{NIC: 1, InterfaceAddr: ifoptSet.InterfaceAddr} var ifoptGot tcpip.MulticastInterfaceOption if err := c.ep.GetSockOpt(&ifoptGot); err != nil { - c.t.Fatalf("GetSockOpt failed: %s", err) - } - if ifoptGot != ifoptWant { - c.t.Errorf("got GetSockOpt() = %#v, want = %#v", ifoptGot, ifoptWant) + c.t.Fatalf("GetSockOpt(&%T): %s", ifoptGot, err) + } else if ifoptWant := (tcpip.MulticastInterfaceOption{NIC: 1, InterfaceAddr: ifoptSet.InterfaceAddr}); ifoptGot != ifoptWant { + c.t.Errorf("got multicast interface option = %#v, want = %#v", ifoptGot, ifoptWant) } }) } @@ -1727,21 +1741,33 @@ func TestV4UnknownDestination(t *testing.T) { // so that the final generated IPv4 packet is larger than // header.IPv4MinimumProcessableDatagramSize. largePayload bool + // badChecksum if true, will set an invalid checksum in the + // header. + badChecksum bool }{ - {unicastV4, true, false}, - {unicastV4, true, true}, - {multicastV4, false, false}, - {multicastV4, false, true}, - {broadcast, false, false}, - {broadcast, false, true}, - } + {unicastV4, true, false, false}, + {unicastV4, true, true, false}, + {unicastV4, false, false, true}, + {unicastV4, false, true, true}, + {multicastV4, false, false, false}, + {multicastV4, false, true, false}, + {broadcast, false, false, false}, + {broadcast, false, true, false}, + } + checksumErrors := uint64(0) for _, tc := range testCases { - t.Run(fmt.Sprintf("flow:%s icmpRequired:%t largePayload:%t", tc.flow, tc.icmpRequired, tc.largePayload), func(t *testing.T) { + t.Run(fmt.Sprintf("flow:%s icmpRequired:%t largePayload:%t badChecksum:%t", tc.flow, tc.icmpRequired, tc.largePayload, tc.badChecksum), func(t *testing.T) { payload := newPayload() if tc.largePayload { payload = newMinPayload(576) } - c.injectPacket(tc.flow, payload) + c.injectPacket(tc.flow, payload, tc.badChecksum) + if tc.badChecksum { + checksumErrors++ + if got, want := c.s.Stats().UDP.ChecksumErrors.Value(), checksumErrors; got != want { + t.Fatalf("got stats.UDP.ChecksumErrors.Value() = %d, want = %d", got, want) + } + } if !tc.icmpRequired { ctx, cancel := context.WithTimeout(context.Background(), time.Second) defer cancel() @@ -1806,19 +1832,31 @@ func TestV6UnknownDestination(t *testing.T) { // largePayload if true will result in a payload large enough to // create an IPv6 packet > header.IPv6MinimumMTU bytes. largePayload bool + // badChecksum if true, will set an invalid checksum in the + // header. + badChecksum bool }{ - {unicastV6, true, false}, - {unicastV6, true, true}, - {multicastV6, false, false}, - {multicastV6, false, true}, - } + {unicastV6, true, false, false}, + {unicastV6, true, true, false}, + {unicastV6, false, false, true}, + {unicastV6, false, true, true}, + {multicastV6, false, false, false}, + {multicastV6, false, true, false}, + } + checksumErrors := uint64(0) for _, tc := range testCases { - t.Run(fmt.Sprintf("flow:%s icmpRequired:%t largePayload:%t", tc.flow, tc.icmpRequired, tc.largePayload), func(t *testing.T) { + t.Run(fmt.Sprintf("flow:%s icmpRequired:%t largePayload:%t badChecksum:%t", tc.flow, tc.icmpRequired, tc.largePayload, tc.badChecksum), func(t *testing.T) { payload := newPayload() if tc.largePayload { payload = newMinPayload(1280) } - c.injectPacket(tc.flow, payload) + c.injectPacket(tc.flow, payload, tc.badChecksum) + if tc.badChecksum { + checksumErrors++ + if got, want := c.s.Stats().UDP.ChecksumErrors.Value(), checksumErrors; got != want { + t.Fatalf("got stats.UDP.ChecksumErrors.Value() = %d, want = %d", got, want) + } + } if !tc.icmpRequired { ctx, cancel := context.WithTimeout(context.Background(), time.Second) defer cancel() @@ -1953,74 +1991,29 @@ func TestShortHeader(t *testing.T) { } } -// TestIncrementChecksumErrorsV4 verifies if a checksum error is detected, +// TestBadChecksumErrors verifies if a checksum error is detected, // global and endpoint stats are incremented. -func TestIncrementChecksumErrorsV4(t *testing.T) { - c := newDualTestContext(t, defaultMTU) - defer c.cleanup() - - c.createEndpoint(ipv4.ProtocolNumber) - // Bind to wildcard. - if err := c.ep.Bind(tcpip.FullAddress{Port: stackPort}); err != nil { - c.t.Fatalf("Bind failed: %s", err) - } - - payload := newPayload() - h := unicastV4.header4Tuple(incoming) - buf := c.buildV4Packet(payload, &h) +func TestBadChecksumErrors(t *testing.T) { + for _, flow := range []testFlow{unicastV4, unicastV6} { + c := newDualTestContext(t, defaultMTU) + defer c.cleanup() - // Invalidate the UDP header checksum field, taking care to avoid - // overflow to zero, which would disable checksum validation. - for u := header.UDP(buf[header.IPv4MinimumSize:]); ; { - u.SetChecksum(u.Checksum() + 1) - if u.Checksum() != 0 { - break + c.createEndpoint(flow.sockProto()) + // Bind to wildcard. + if err := c.ep.Bind(tcpip.FullAddress{Port: stackPort}); err != nil { + c.t.Fatalf("Bind failed: %s", err) } - } - - c.linkEP.InjectInbound(ipv4.ProtocolNumber, stack.NewPacketBuffer(stack.PacketBufferOptions{ - Data: buf.ToVectorisedView(), - })) - - const want = 1 - if got := c.s.Stats().UDP.ChecksumErrors.Value(); got != want { - t.Errorf("got stats.UDP.ChecksumErrors.Value() = %d, want = %d", got, want) - } - if got := c.ep.Stats().(*tcpip.TransportEndpointStats).ReceiveErrors.ChecksumErrors.Value(); got != want { - t.Errorf("got EP Stats.ReceiveErrors.ChecksumErrors stats = %d, want = %d", got, want) - } -} - -// TestIncrementChecksumErrorsV6 verifies if a checksum error is detected, -// global and endpoint stats are incremented. -func TestIncrementChecksumErrorsV6(t *testing.T) { - c := newDualTestContext(t, defaultMTU) - defer c.cleanup() - - c.createEndpoint(ipv6.ProtocolNumber) - // Bind to wildcard. - if err := c.ep.Bind(tcpip.FullAddress{Port: stackPort}); err != nil { - c.t.Fatalf("Bind failed: %s", err) - } - - payload := newPayload() - h := unicastV6.header4Tuple(incoming) - buf := c.buildV6Packet(payload, &h) - // Invalidate the UDP header checksum field. - u := header.UDP(buf[header.IPv6MinimumSize:]) - u.SetChecksum(u.Checksum() + 1) + payload := newPayload() + c.injectPacket(flow, payload, true /* badChecksum */) - c.linkEP.InjectInbound(ipv6.ProtocolNumber, stack.NewPacketBuffer(stack.PacketBufferOptions{ - Data: buf.ToVectorisedView(), - })) - - const want = 1 - if got := c.s.Stats().UDP.ChecksumErrors.Value(); got != want { - t.Errorf("got stats.UDP.ChecksumErrors.Value() = %d, want = %d", got, want) - } - if got := c.ep.Stats().(*tcpip.TransportEndpointStats).ReceiveErrors.ChecksumErrors.Value(); got != want { - t.Errorf("got EP Stats.ReceiveErrors.ChecksumErrors stats = %d, want = %d", got, want) + const want = 1 + if got := c.s.Stats().UDP.ChecksumErrors.Value(); got != want { + t.Errorf("got stats.UDP.ChecksumErrors.Value() = %d, want = %d", got, want) + } + if got := c.ep.Stats().(*tcpip.TransportEndpointStats).ReceiveErrors.ChecksumErrors.Value(); got != want { + t.Errorf("got EP Stats.ReceiveErrors.ChecksumErrors stats = %d, want = %d", got, want) + } } } @@ -2350,8 +2343,10 @@ func TestOutgoingSubnetBroadcast(t *testing.T) { NIC: nicID1, }, }, - remoteAddr: remNetSubnetBcast, - requiresBroadcastOpt: true, + remoteAddr: remNetSubnetBcast, + // TODO(gvisor.dev/issue/3938): Once we support marking a route as + // broadcast, this test should require the broadcast option to be set. + requiresBroadcastOpt: false, }, } diff --git a/pkg/test/dockerutil/container.go b/pkg/test/dockerutil/container.go index 052b6b99d..64d17f661 100644 --- a/pkg/test/dockerutil/container.go +++ b/pkg/test/dockerutil/container.go @@ -22,6 +22,7 @@ import ( "net" "os" "path" + "path/filepath" "regexp" "strconv" "strings" @@ -403,10 +404,13 @@ func (c *Container) CopyFiles(opts *RunOpts, target string, sources ...string) { return } for _, name := range sources { - src, err := testutil.FindFile(name) - if err != nil { - c.copyErr = fmt.Errorf("testutil.FindFile(%q) failed: %v", name, err) - return + src := name + if !filepath.IsAbs(src) { + src, err = testutil.FindFile(name) + if err != nil { + c.copyErr = fmt.Errorf("testutil.FindFile(%q) failed: %w", name, err) + return + } } dst := path.Join(dir, path.Base(name)) if err := testutil.Copy(src, dst); err != nil { diff --git a/pkg/test/dockerutil/dockerutil.go b/pkg/test/dockerutil/dockerutil.go index 952871f95..7027df1a5 100644 --- a/pkg/test/dockerutil/dockerutil.go +++ b/pkg/test/dockerutil/dockerutil.go @@ -60,7 +60,6 @@ var ( // enabled for each run. pprofBlock = flag.Bool("pprof-block", false, "enables block profiling with runsc debug") pprofCPU = flag.Bool("pprof-cpu", false, "enables CPU profiling with runsc debug") - pprofGo = flag.Bool("pprof-go", false, "enables goroutine profiling with runsc debug") pprofHeap = flag.Bool("pprof-heap", false, "enables heap profiling with runsc debug") pprofMutex = flag.Bool("pprof-mutex", false, "enables mutex profiling with runsc debug") ) diff --git a/pkg/test/dockerutil/profile.go b/pkg/test/dockerutil/profile.go index f0396ef24..55f9496cd 100644 --- a/pkg/test/dockerutil/profile.go +++ b/pkg/test/dockerutil/profile.go @@ -63,7 +63,7 @@ type Pprof struct { // MakePprofFromFlags makes a Pprof profile from flags. func MakePprofFromFlags(c *Container) *Pprof { - if !(*pprofBlock || *pprofCPU || *pprofGo || *pprofHeap || *pprofMutex) { + if !(*pprofBlock || *pprofCPU || *pprofHeap || *pprofMutex) { return nil } return &Pprof{ diff --git a/pkg/test/testutil/BUILD b/pkg/test/testutil/BUILD index 2d8f56bc0..c4b131896 100644 --- a/pkg/test/testutil/BUILD +++ b/pkg/test/testutil/BUILD @@ -12,7 +12,7 @@ go_library( visibility = ["//:sandbox"], deps = [ "//pkg/sync", - "//runsc/boot", + "//runsc/config", "//runsc/specutils", "@com_github_cenkalti_backoff//:go_default_library", "@com_github_opencontainers_runtime_spec//specs-go:go_default_library", diff --git a/pkg/test/testutil/testutil.go b/pkg/test/testutil/testutil.go index 1580527b5..06fb823f6 100644 --- a/pkg/test/testutil/testutil.go +++ b/pkg/test/testutil/testutil.go @@ -44,7 +44,7 @@ import ( "github.com/cenkalti/backoff" specs "github.com/opencontainers/runtime-spec/specs-go" "gvisor.dev/gvisor/pkg/sync" - "gvisor.dev/gvisor/runsc/boot" + "gvisor.dev/gvisor/runsc/config" "gvisor.dev/gvisor/runsc/specutils" ) @@ -133,25 +133,28 @@ func Command(logger Logger, args ...string) *Cmd { // TestConfig returns the default configuration to use in tests. Note that // 'RootDir' must be set by caller if required. -func TestConfig(t *testing.T) *boot.Config { +func TestConfig(t *testing.T) *config.Config { logDir := os.TempDir() if dir, ok := os.LookupEnv("TEST_UNDECLARED_OUTPUTS_DIR"); ok { logDir = dir + "/" } - return &boot.Config{ - Debug: true, - DebugLog: path.Join(logDir, "runsc.log."+t.Name()+".%TIMESTAMP%.%COMMAND%"), - LogFormat: "text", - DebugLogFormat: "text", - LogPackets: true, - Network: boot.NetworkNone, - Strace: true, - Platform: "ptrace", - FileAccess: boot.FileAccessExclusive, - NumNetworkChannels: 1, - TestOnlyAllowRunAsCurrentUserWithoutChroot: true, + // Only register flags if config is being used. Otherwise anyone that uses + // testutil will get flags registered and they may conflict. + config.RegisterFlags() + + conf, err := config.NewFromFlags() + if err != nil { + panic(err) } + // Change test defaults. + conf.Debug = true + conf.DebugLog = path.Join(logDir, "runsc.log."+t.Name()+".%TIMESTAMP%.%COMMAND%") + conf.LogPackets = true + conf.Network = config.NetworkNone + conf.Strace = true + conf.TestOnlyAllowRunAsCurrentUserWithoutChroot = true + return conf } // NewSpecWithArgs creates a simple spec with the given args suitable for use @@ -203,7 +206,7 @@ func SetupRootDir() (string, func(), error) { // SetupContainer creates a bundle and root dir for the container, generates a // test config, and writes the spec to config.json in the bundle dir. -func SetupContainer(spec *specs.Spec, conf *boot.Config) (rootDir, bundleDir string, cleanup func(), err error) { +func SetupContainer(spec *specs.Spec, conf *config.Config) (rootDir, bundleDir string, cleanup func(), err error) { rootDir, rootCleanup, err := SetupRootDir() if err != nil { return "", "", nil, err @@ -243,12 +246,15 @@ func writeSpec(dir string, spec *specs.Spec) error { return ioutil.WriteFile(filepath.Join(dir, "config.json"), b, 0755) } +// idRandomSrc is a pseudo random generator used to in RandomID. +var idRandomSrc = rand.New(rand.NewSource(time.Now().UnixNano())) + // RandomID returns 20 random bytes following the given prefix. func RandomID(prefix string) string { // Read 20 random bytes. b := make([]byte, 20) // "[Read] always returns len(p) and a nil error." --godoc - if _, err := rand.Read(b); err != nil { + if _, err := idRandomSrc.Read(b); err != nil { panic("rand.Read failed: " + err.Error()) } if prefix != "" { @@ -326,13 +332,13 @@ func PollContext(ctx context.Context, cb func() error) error { } // WaitForHTTP tries GET requests on a port until the call succeeds or timeout. -func WaitForHTTP(port int, timeout time.Duration) error { +func WaitForHTTP(ip string, port int, timeout time.Duration) error { cb := func() error { c := &http.Client{ // Calculate timeout to be able to do minimum 5 attempts. Timeout: timeout / 5, } - url := fmt.Sprintf("http://localhost:%d/", port) + url := fmt.Sprintf("http://%s:%d/", ip, port) resp, err := c.Get(url) if err != nil { log.Printf("Waiting %s: %v", url, err) diff --git a/pkg/unet/unet.go b/pkg/unet/unet.go index d843f19cf..c976d7230 100644 --- a/pkg/unet/unet.go +++ b/pkg/unet/unet.go @@ -522,7 +522,7 @@ func (s *ServerSocket) Listen() error { // This is always blocking. // // Preconditions: -// * ServerSocket is listening (Listen called). +// * ServerSocket is listening (Listen called). func (s *ServerSocket) Accept() (*Socket, error) { fd, ok := s.socket.enterFD() if !ok { diff --git a/pkg/usermem/addr_range_seq_unsafe.go b/pkg/usermem/addr_range_seq_unsafe.go index c09337c15..495896ded 100644 --- a/pkg/usermem/addr_range_seq_unsafe.go +++ b/pkg/usermem/addr_range_seq_unsafe.go @@ -81,8 +81,10 @@ func AddrRangeSeqFromSlice(slice []AddrRange) AddrRangeSeq { return addrRangeSeqFromSliceLimited(slice, limit) } -// Preconditions: The combined length of all AddrRanges in slice <= limit. -// limit >= 0. If len(slice) != 0, then limit > 0. +// Preconditions: +// * The combined length of all AddrRanges in slice <= limit. +// * limit >= 0. +// * If len(slice) != 0, then limit > 0. func addrRangeSeqFromSliceLimited(slice []AddrRange, limit int64) AddrRangeSeq { switch len(slice) { case 0: diff --git a/pkg/usermem/usermem.go b/pkg/usermem/usermem.go index cd6a0ea6b..27279b409 100644 --- a/pkg/usermem/usermem.go +++ b/pkg/usermem/usermem.go @@ -54,8 +54,10 @@ type IO interface { // of bytes zeroed. If the number of bytes zeroed is < toZero, it returns a // non-nil error explaining why. // - // Preconditions: The caller must not hold mm.MemoryManager.mappingMu or - // any following locks in the lock order. toZero >= 0. + // Preconditions: + // * The caller must not hold mm.MemoryManager.mappingMu or any + // following locks in the lock order. + // * toZero >= 0. ZeroOut(ctx context.Context, addr Addr, toZero int64, opts IOOpts) (int64, error) // CopyOutFrom copies ars.NumBytes() bytes from src to the memory mapped at @@ -66,9 +68,11 @@ type IO interface { // // CopyOutFrom calls src.ReadToBlocks at most once. // - // Preconditions: The caller must not hold mm.MemoryManager.mappingMu or - // any following locks in the lock order. src.ReadToBlocks must not block - // on mm.MemoryManager.activeMu or any preceding locks in the lock order. + // Preconditions: + // * The caller must not hold mm.MemoryManager.mappingMu or any + // following locks in the lock order. + // * src.ReadToBlocks must not block on mm.MemoryManager.activeMu or + // any preceding locks in the lock order. CopyOutFrom(ctx context.Context, ars AddrRangeSeq, src safemem.Reader, opts IOOpts) (int64, error) // CopyInTo copies ars.NumBytes() bytes from the memory mapped at ars to @@ -78,10 +82,11 @@ type IO interface { // // CopyInTo calls dst.WriteFromBlocks at most once. // - // Preconditions: The caller must not hold mm.MemoryManager.mappingMu or - // any following locks in the lock order. dst.WriteFromBlocks must not - // block on mm.MemoryManager.activeMu or any preceding locks in the lock - // order. + // Preconditions: + // * The caller must not hold mm.MemoryManager.mappingMu or any + // following locks in the lock order. + // * dst.WriteFromBlocks must not block on mm.MemoryManager.activeMu or + // any preceding locks in the lock order. CopyInTo(ctx context.Context, ars AddrRangeSeq, dst safemem.Writer, opts IOOpts) (int64, error) // TODO(jamieliu): The requirement that CopyOutFrom/CopyInTo call src/dst @@ -93,25 +98,28 @@ type IO interface { // SwapUint32 atomically sets the uint32 value at addr to new and // returns the previous value. // - // Preconditions: The caller must not hold mm.MemoryManager.mappingMu or - // any following locks in the lock order. addr must be aligned to a 4-byte - // boundary. + // Preconditions: + // * The caller must not hold mm.MemoryManager.mappingMu or any + // following locks in the lock order. + // * addr must be aligned to a 4-byte boundary. SwapUint32(ctx context.Context, addr Addr, new uint32, opts IOOpts) (uint32, error) // CompareAndSwapUint32 atomically compares the uint32 value at addr to // old; if they are equal, the value in memory is replaced by new. In // either case, the previous value stored in memory is returned. // - // Preconditions: The caller must not hold mm.MemoryManager.mappingMu or - // any following locks in the lock order. addr must be aligned to a 4-byte - // boundary. + // Preconditions: + // * The caller must not hold mm.MemoryManager.mappingMu or any + // following locks in the lock order. + // * addr must be aligned to a 4-byte boundary. CompareAndSwapUint32(ctx context.Context, addr Addr, old, new uint32, opts IOOpts) (uint32, error) // LoadUint32 atomically loads the uint32 value at addr and returns it. // - // Preconditions: The caller must not hold mm.MemoryManager.mappingMu or - // any following locks in the lock order. addr must be aligned to a 4-byte - // boundary. + // Preconditions: + // * The caller must not hold mm.MemoryManager.mappingMu or any + // following locks in the lock order. + // * addr must be aligned to a 4-byte boundary. LoadUint32(ctx context.Context, addr Addr, opts IOOpts) (uint32, error) } @@ -183,7 +191,7 @@ func (rw *IOReadWriter) Write(src []byte) (int, error) { // CopyObjectOut must use reflection to encode src; performance-sensitive // clients should do encoding manually and use uio.CopyOut directly. // -// Preconditions: As for IO.CopyOut. +// Preconditions: Same as IO.CopyOut. func CopyObjectOut(ctx context.Context, uio IO, addr Addr, src interface{}, opts IOOpts) (int, error) { w := &IOReadWriter{ Ctx: ctx, @@ -205,7 +213,7 @@ func CopyObjectOut(ctx context.Context, uio IO, addr Addr, src interface{}, opts // CopyObjectIn must use reflection to decode dst; performance-sensitive // clients should use uio.CopyIn directly and do decoding manually. // -// Preconditions: As for IO.CopyIn. +// Preconditions: Same as IO.CopyIn. func CopyObjectIn(ctx context.Context, uio IO, addr Addr, dst interface{}, opts IOOpts) (int, error) { r := &IOReadWriter{ Ctx: ctx, @@ -233,7 +241,8 @@ const ( // would exceed maxlen, CopyStringIn returns the string truncated to maxlen and // ENAMETOOLONG. // -// Preconditions: As for IO.CopyFromUser. maxlen >= 0. +// Preconditions: Same as IO.CopyFromUser, plus: +// * maxlen >= 0. func CopyStringIn(ctx context.Context, uio IO, addr Addr, maxlen int, opts IOOpts) (string, error) { initLen := maxlen if initLen > copyStringMaxInitBufLen { @@ -287,7 +296,7 @@ func CopyStringIn(ctx context.Context, uio IO, addr Addr, maxlen int, opts IOOpt // less. CopyOutVec returns the number of bytes copied; if this is less than // the maximum, it returns a non-nil error explaining why. // -// Preconditions: As for IO.CopyOut. +// Preconditions: Same as IO.CopyOut. func CopyOutVec(ctx context.Context, uio IO, ars AddrRangeSeq, src []byte, opts IOOpts) (int, error) { var done int for !ars.IsEmpty() && done < len(src) { @@ -311,7 +320,7 @@ func CopyOutVec(ctx context.Context, uio IO, ars AddrRangeSeq, src []byte, opts // less. CopyInVec returns the number of bytes copied; if this is less than the // maximum, it returns a non-nil error explaining why. // -// Preconditions: As for IO.CopyIn. +// Preconditions: Same as IO.CopyIn. func CopyInVec(ctx context.Context, uio IO, ars AddrRangeSeq, dst []byte, opts IOOpts) (int, error) { var done int for !ars.IsEmpty() && done < len(dst) { @@ -335,7 +344,7 @@ func CopyInVec(ctx context.Context, uio IO, ars AddrRangeSeq, dst []byte, opts I // ZeroOutVec returns the number of bytes written; if this is less than the // maximum, it returns a non-nil error explaining why. // -// Preconditions: As for IO.ZeroOut. +// Preconditions: Same as IO.ZeroOut. func ZeroOutVec(ctx context.Context, uio IO, ars AddrRangeSeq, toZero int64, opts IOOpts) (int64, error) { var done int64 for !ars.IsEmpty() && done < toZero { @@ -388,7 +397,7 @@ func isASCIIWhitespace(b byte) bool { // // - CopyInt32StringsInVec returns EINVAL if ars.NumBytes() == 0. // -// Preconditions: As for CopyInVec. +// Preconditions: Same as CopyInVec. func CopyInt32StringsInVec(ctx context.Context, uio IO, ars AddrRangeSeq, dsts []int32, opts IOOpts) (int64, error) { if len(dsts) == 0 { return 0, nil @@ -481,28 +490,28 @@ func (s IOSequence) NumBytes() int64 { // DropFirst returns a copy of s with s.Addrs.DropFirst(n). // -// Preconditions: As for AddrRangeSeq.DropFirst. +// Preconditions: Same as AddrRangeSeq.DropFirst. func (s IOSequence) DropFirst(n int) IOSequence { return IOSequence{s.IO, s.Addrs.DropFirst(n), s.Opts} } // DropFirst64 returns a copy of s with s.Addrs.DropFirst64(n). // -// Preconditions: As for AddrRangeSeq.DropFirst64. +// Preconditions: Same as AddrRangeSeq.DropFirst64. func (s IOSequence) DropFirst64(n int64) IOSequence { return IOSequence{s.IO, s.Addrs.DropFirst64(n), s.Opts} } // TakeFirst returns a copy of s with s.Addrs.TakeFirst(n). // -// Preconditions: As for AddrRangeSeq.TakeFirst. +// Preconditions: Same as AddrRangeSeq.TakeFirst. func (s IOSequence) TakeFirst(n int) IOSequence { return IOSequence{s.IO, s.Addrs.TakeFirst(n), s.Opts} } // TakeFirst64 returns a copy of s with s.Addrs.TakeFirst64(n). // -// Preconditions: As for AddrRangeSeq.TakeFirst64. +// Preconditions: Same as AddrRangeSeq.TakeFirst64. func (s IOSequence) TakeFirst64(n int64) IOSequence { return IOSequence{s.IO, s.Addrs.TakeFirst64(n), s.Opts} } @@ -512,7 +521,7 @@ func (s IOSequence) TakeFirst64(n int64) IOSequence { // As with CopyOutVec, if s.NumBytes() < len(src), the copy will be truncated // to s.NumBytes(), and a nil error will be returned. // -// Preconditions: As for CopyOutVec. +// Preconditions: Same as CopyOutVec. func (s IOSequence) CopyOut(ctx context.Context, src []byte) (int, error) { return CopyOutVec(ctx, s.IO, s.Addrs, src, s.Opts) } @@ -522,7 +531,7 @@ func (s IOSequence) CopyOut(ctx context.Context, src []byte) (int, error) { // As with CopyInVec, if s.NumBytes() < len(dst), the copy will be truncated to // s.NumBytes(), and a nil error will be returned. // -// Preconditions: As for CopyInVec. +// Preconditions: Same as CopyInVec. func (s IOSequence) CopyIn(ctx context.Context, dst []byte) (int, error) { return CopyInVec(ctx, s.IO, s.Addrs, dst, s.Opts) } @@ -532,21 +541,21 @@ func (s IOSequence) CopyIn(ctx context.Context, dst []byte) (int, error) { // As with ZeroOutVec, if s.NumBytes() < toZero, the write will be truncated // to s.NumBytes(), and a nil error will be returned. // -// Preconditions: As for ZeroOutVec. +// Preconditions: Same as ZeroOutVec. func (s IOSequence) ZeroOut(ctx context.Context, toZero int64) (int64, error) { return ZeroOutVec(ctx, s.IO, s.Addrs, toZero, s.Opts) } // CopyOutFrom invokes s.CopyOutFrom over s.Addrs. // -// Preconditions: As for IO.CopyOutFrom. +// Preconditions: Same as IO.CopyOutFrom. func (s IOSequence) CopyOutFrom(ctx context.Context, src safemem.Reader) (int64, error) { return s.IO.CopyOutFrom(ctx, s.Addrs, src, s.Opts) } // CopyInTo invokes s.CopyInTo over s.Addrs. // -// Preconditions: As for IO.CopyInTo. +// Preconditions: Same as IO.CopyInTo. func (s IOSequence) CopyInTo(ctx context.Context, dst safemem.Writer) (int64, error) { return s.IO.CopyInTo(ctx, s.Addrs, dst, s.Opts) } diff --git a/runsc/BUILD b/runsc/BUILD index 96f697a5f..33d8554af 100644 --- a/runsc/BUILD +++ b/runsc/BUILD @@ -1,4 +1,4 @@ -load("//tools:defs.bzl", "go_binary", "pkg_deb", "pkg_tar") +load("//tools:defs.bzl", "go_binary") package(licenses = ["notice"]) @@ -17,8 +17,8 @@ go_binary( "//pkg/log", "//pkg/refs", "//pkg/sentry/platform", - "//runsc/boot", "//runsc/cmd", + "//runsc/config", "//runsc/flag", "//runsc/specutils", "@com_github_google_subcommands//:go_default_library", @@ -53,70 +53,14 @@ go_binary( "//pkg/log", "//pkg/refs", "//pkg/sentry/platform", - "//runsc/boot", "//runsc/cmd", + "//runsc/config", "//runsc/flag", "//runsc/specutils", "@com_github_google_subcommands//:go_default_library", ], ) -pkg_tar( - name = "debian-bin", - srcs = [ - ":runsc", - "//shim/v1:gvisor-containerd-shim", - "//shim/v2:containerd-shim-runsc-v1", - ], - mode = "0755", - package_dir = "/usr/bin", -) - -pkg_tar( - name = "debian-data", - extension = "tar.gz", - deps = [ - ":debian-bin", - "//shim:config", - ], -) - -genrule( - name = "deb-version", - # Note that runsc must appear in the srcs parameter and not the tools - # parameter, otherwise it will not be stamped. This is reasonable, as tools - # may be encoded differently in the build graph (cached more aggressively - # because they are assumes to be hermetic). - srcs = [":runsc"], - outs = ["version.txt"], - # Note that the little dance here is necessary because files in the $(SRCS) - # attribute are not executable by default, and we can't touch in place. - cmd = "cp $(location :runsc) $(@D)/runsc && \ - chmod a+x $(@D)/runsc && \ - $(@D)/runsc -version | grep version | sed 's/^[^0-9]*//' > $@ && \ - rm -f $(@D)/runsc", - stamp = 1, -) - -pkg_deb( - name = "runsc-debian", - architecture = "amd64", - data = ":debian-data", - # Note that the description_file will be flatten (all newlines removed), - # and therefore it is kept to a simple one-line description. The expected - # format for debian packages is "short summary\nLonger explanation of - # tool." and this is impossible with the flattening. - description_file = "debian/description", - homepage = "https://gvisor.dev/", - maintainer = "The gVisor Authors <gvisor-dev@googlegroups.com>", - package = "runsc", - postinst = "debian/postinst.sh", - version_file = ":version.txt", - visibility = [ - "//visibility:public", - ], -) - sh_test( name = "version_test", size = "small", diff --git a/runsc/boot/BUILD b/runsc/boot/BUILD index 9f52438c2..2d9517f4a 100644 --- a/runsc/boot/BUILD +++ b/runsc/boot/BUILD @@ -8,7 +8,6 @@ go_library( "compat.go", "compat_amd64.go", "compat_arm64.go", - "config.go", "controller.go", "debug.go", "events.go", @@ -27,10 +26,13 @@ go_library( deps = [ "//pkg/abi", "//pkg/abi/linux", + "//pkg/bpf", + "//pkg/cleanup", "//pkg/context", "//pkg/control/server", "//pkg/cpuid", "//pkg/eventchannel", + "//pkg/fd", "//pkg/fspath", "//pkg/log", "//pkg/memutil", @@ -105,7 +107,9 @@ go_library( "//runsc/boot/filter", "//runsc/boot/platforms", "//runsc/boot/pprof", + "//runsc/config", "//runsc/specutils", + "//runsc/specutils/seccomp", "@com_github_golang_protobuf//proto:go_default_library", "@com_github_opencontainers_runtime_spec//specs-go:go_default_library", "@org_golang_x_sys//unix:go_default_library", @@ -123,6 +127,7 @@ go_test( library = ":boot", deps = [ "//pkg/control/server", + "//pkg/fd", "//pkg/fspath", "//pkg/log", "//pkg/p9", @@ -131,6 +136,7 @@ go_test( "//pkg/sentry/vfs", "//pkg/sync", "//pkg/unet", + "//runsc/config", "//runsc/fsgofer", "@com_github_opencontainers_runtime_spec//specs-go:go_default_library", "@org_golang_x_sys//unix:go_default_library", diff --git a/runsc/boot/controller.go b/runsc/boot/controller.go index 626a3816e..894651519 100644 --- a/runsc/boot/controller.go +++ b/runsc/boot/controller.go @@ -22,6 +22,7 @@ import ( specs "github.com/opencontainers/runtime-spec/specs-go" "gvisor.dev/gvisor/pkg/control/server" + "gvisor.dev/gvisor/pkg/fd" "gvisor.dev/gvisor/pkg/log" "gvisor.dev/gvisor/pkg/sentry/control" "gvisor.dev/gvisor/pkg/sentry/fs" @@ -33,6 +34,7 @@ import ( "gvisor.dev/gvisor/pkg/tcpip/stack" "gvisor.dev/gvisor/pkg/urpc" "gvisor.dev/gvisor/runsc/boot/pprof" + "gvisor.dev/gvisor/runsc/config" "gvisor.dev/gvisor/runsc/specutils" ) @@ -220,7 +222,7 @@ type StartArgs struct { Spec *specs.Spec // Config is the runsc-specific configuration for the sandbox. - Conf *Config + Conf *config.Config // CID is the ID of the container to start. CID string @@ -256,13 +258,20 @@ func (cm *containerManager) Start(args *StartArgs, _ *struct{}) error { // All validation passed, logs the spec for debugging. specutils.LogSpec(args.Spec) - err := cm.l.startContainer(args.Spec, args.Conf, args.CID, args.FilePayload.Files) + fds, err := fd.NewFromFiles(args.FilePayload.Files) if err != nil { + return err + } + defer func() { + for _, fd := range fds { + _ = fd.Close() + } + }() + if err := cm.l.startContainer(args.Spec, args.Conf, args.CID, fds); err != nil { log.Debugf("containerManager.Start failed %q: %+v: %v", args.CID, args, err) return err } log.Debugf("Container %q started", args.CID) - return nil } diff --git a/runsc/boot/filter/config.go b/runsc/boot/filter/config.go index 149eb0b1b..4ed28b5cd 100644 --- a/runsc/boot/filter/config.go +++ b/runsc/boot/filter/config.go @@ -29,7 +29,7 @@ var allowedSyscalls = seccomp.SyscallRules{ syscall.SYS_CLOCK_GETTIME: {}, syscall.SYS_CLONE: []seccomp.Rule{ { - seccomp.AllowValue( + seccomp.EqualTo( syscall.CLONE_VM | syscall.CLONE_FS | syscall.CLONE_FILES | @@ -42,26 +42,26 @@ var allowedSyscalls = seccomp.SyscallRules{ syscall.SYS_DUP: {}, syscall.SYS_DUP3: []seccomp.Rule{ { - seccomp.AllowAny{}, - seccomp.AllowAny{}, - seccomp.AllowValue(syscall.O_CLOEXEC), + seccomp.MatchAny{}, + seccomp.MatchAny{}, + seccomp.EqualTo(syscall.O_CLOEXEC), }, }, syscall.SYS_EPOLL_CREATE1: {}, syscall.SYS_EPOLL_CTL: {}, syscall.SYS_EPOLL_PWAIT: []seccomp.Rule{ { - seccomp.AllowAny{}, - seccomp.AllowAny{}, - seccomp.AllowAny{}, - seccomp.AllowAny{}, - seccomp.AllowValue(0), + seccomp.MatchAny{}, + seccomp.MatchAny{}, + seccomp.MatchAny{}, + seccomp.MatchAny{}, + seccomp.EqualTo(0), }, }, syscall.SYS_EVENTFD2: []seccomp.Rule{ { - seccomp.AllowValue(0), - seccomp.AllowValue(0), + seccomp.EqualTo(0), + seccomp.EqualTo(0), }, }, syscall.SYS_EXIT: {}, @@ -70,16 +70,16 @@ var allowedSyscalls = seccomp.SyscallRules{ syscall.SYS_FCHMOD: {}, syscall.SYS_FCNTL: []seccomp.Rule{ { - seccomp.AllowAny{}, - seccomp.AllowValue(syscall.F_GETFL), + seccomp.MatchAny{}, + seccomp.EqualTo(syscall.F_GETFL), }, { - seccomp.AllowAny{}, - seccomp.AllowValue(syscall.F_SETFL), + seccomp.MatchAny{}, + seccomp.EqualTo(syscall.F_SETFL), }, { - seccomp.AllowAny{}, - seccomp.AllowValue(syscall.F_GETFD), + seccomp.MatchAny{}, + seccomp.EqualTo(syscall.F_GETFD), }, }, syscall.SYS_FSTAT: {}, @@ -87,52 +87,52 @@ var allowedSyscalls = seccomp.SyscallRules{ syscall.SYS_FTRUNCATE: {}, syscall.SYS_FUTEX: []seccomp.Rule{ { - seccomp.AllowAny{}, - seccomp.AllowValue(linux.FUTEX_WAIT | linux.FUTEX_PRIVATE_FLAG), - seccomp.AllowAny{}, - seccomp.AllowAny{}, + seccomp.MatchAny{}, + seccomp.EqualTo(linux.FUTEX_WAIT | linux.FUTEX_PRIVATE_FLAG), + seccomp.MatchAny{}, + seccomp.MatchAny{}, }, { - seccomp.AllowAny{}, - seccomp.AllowValue(linux.FUTEX_WAKE | linux.FUTEX_PRIVATE_FLAG), - seccomp.AllowAny{}, + seccomp.MatchAny{}, + seccomp.EqualTo(linux.FUTEX_WAKE | linux.FUTEX_PRIVATE_FLAG), + seccomp.MatchAny{}, }, // Non-private variants are included for flipcall support. They are otherwise // unncessary, as the sentry will use only private futexes internally. { - seccomp.AllowAny{}, - seccomp.AllowValue(linux.FUTEX_WAIT), - seccomp.AllowAny{}, - seccomp.AllowAny{}, + seccomp.MatchAny{}, + seccomp.EqualTo(linux.FUTEX_WAIT), + seccomp.MatchAny{}, + seccomp.MatchAny{}, }, { - seccomp.AllowAny{}, - seccomp.AllowValue(linux.FUTEX_WAKE), - seccomp.AllowAny{}, + seccomp.MatchAny{}, + seccomp.EqualTo(linux.FUTEX_WAKE), + seccomp.MatchAny{}, }, }, syscall.SYS_GETPID: {}, unix.SYS_GETRANDOM: {}, syscall.SYS_GETSOCKOPT: []seccomp.Rule{ { - seccomp.AllowAny{}, - seccomp.AllowValue(syscall.SOL_SOCKET), - seccomp.AllowValue(syscall.SO_DOMAIN), + seccomp.MatchAny{}, + seccomp.EqualTo(syscall.SOL_SOCKET), + seccomp.EqualTo(syscall.SO_DOMAIN), }, { - seccomp.AllowAny{}, - seccomp.AllowValue(syscall.SOL_SOCKET), - seccomp.AllowValue(syscall.SO_TYPE), + seccomp.MatchAny{}, + seccomp.EqualTo(syscall.SOL_SOCKET), + seccomp.EqualTo(syscall.SO_TYPE), }, { - seccomp.AllowAny{}, - seccomp.AllowValue(syscall.SOL_SOCKET), - seccomp.AllowValue(syscall.SO_ERROR), + seccomp.MatchAny{}, + seccomp.EqualTo(syscall.SOL_SOCKET), + seccomp.EqualTo(syscall.SO_ERROR), }, { - seccomp.AllowAny{}, - seccomp.AllowValue(syscall.SOL_SOCKET), - seccomp.AllowValue(syscall.SO_SNDBUF), + seccomp.MatchAny{}, + seccomp.EqualTo(syscall.SOL_SOCKET), + seccomp.EqualTo(syscall.SO_SNDBUF), }, }, syscall.SYS_GETTID: {}, @@ -141,34 +141,34 @@ var allowedSyscalls = seccomp.SyscallRules{ // setting/getting termios and winsize. syscall.SYS_IOCTL: []seccomp.Rule{ { - seccomp.AllowAny{}, /* fd */ - seccomp.AllowValue(linux.TCGETS), - seccomp.AllowAny{}, /* termios struct */ + seccomp.MatchAny{}, /* fd */ + seccomp.EqualTo(linux.TCGETS), + seccomp.MatchAny{}, /* termios struct */ }, { - seccomp.AllowAny{}, /* fd */ - seccomp.AllowValue(linux.TCSETS), - seccomp.AllowAny{}, /* termios struct */ + seccomp.MatchAny{}, /* fd */ + seccomp.EqualTo(linux.TCSETS), + seccomp.MatchAny{}, /* termios struct */ }, { - seccomp.AllowAny{}, /* fd */ - seccomp.AllowValue(linux.TCSETSF), - seccomp.AllowAny{}, /* termios struct */ + seccomp.MatchAny{}, /* fd */ + seccomp.EqualTo(linux.TCSETSF), + seccomp.MatchAny{}, /* termios struct */ }, { - seccomp.AllowAny{}, /* fd */ - seccomp.AllowValue(linux.TCSETSW), - seccomp.AllowAny{}, /* termios struct */ + seccomp.MatchAny{}, /* fd */ + seccomp.EqualTo(linux.TCSETSW), + seccomp.MatchAny{}, /* termios struct */ }, { - seccomp.AllowAny{}, /* fd */ - seccomp.AllowValue(linux.TIOCSWINSZ), - seccomp.AllowAny{}, /* winsize struct */ + seccomp.MatchAny{}, /* fd */ + seccomp.EqualTo(linux.TIOCSWINSZ), + seccomp.MatchAny{}, /* winsize struct */ }, { - seccomp.AllowAny{}, /* fd */ - seccomp.AllowValue(linux.TIOCGWINSZ), - seccomp.AllowAny{}, /* winsize struct */ + seccomp.MatchAny{}, /* fd */ + seccomp.EqualTo(linux.TIOCGWINSZ), + seccomp.MatchAny{}, /* winsize struct */ }, }, syscall.SYS_LSEEK: {}, @@ -182,46 +182,46 @@ var allowedSyscalls = seccomp.SyscallRules{ // TODO(b/148688965): Remove once this is gone from Go. syscall.SYS_MLOCK: []seccomp.Rule{ { - seccomp.AllowAny{}, - seccomp.AllowValue(4096), + seccomp.MatchAny{}, + seccomp.EqualTo(4096), }, }, syscall.SYS_MMAP: []seccomp.Rule{ { - seccomp.AllowAny{}, - seccomp.AllowAny{}, - seccomp.AllowAny{}, - seccomp.AllowValue(syscall.MAP_SHARED), + seccomp.MatchAny{}, + seccomp.MatchAny{}, + seccomp.MatchAny{}, + seccomp.EqualTo(syscall.MAP_SHARED), }, { - seccomp.AllowAny{}, - seccomp.AllowAny{}, - seccomp.AllowAny{}, - seccomp.AllowValue(syscall.MAP_PRIVATE), + seccomp.MatchAny{}, + seccomp.MatchAny{}, + seccomp.MatchAny{}, + seccomp.EqualTo(syscall.MAP_PRIVATE), }, { - seccomp.AllowAny{}, - seccomp.AllowAny{}, - seccomp.AllowAny{}, - seccomp.AllowValue(syscall.MAP_PRIVATE | syscall.MAP_ANONYMOUS), + seccomp.MatchAny{}, + seccomp.MatchAny{}, + seccomp.MatchAny{}, + seccomp.EqualTo(syscall.MAP_PRIVATE | syscall.MAP_ANONYMOUS), }, { - seccomp.AllowAny{}, - seccomp.AllowAny{}, - seccomp.AllowAny{}, - seccomp.AllowValue(syscall.MAP_PRIVATE | syscall.MAP_ANONYMOUS | syscall.MAP_STACK), + seccomp.MatchAny{}, + seccomp.MatchAny{}, + seccomp.MatchAny{}, + seccomp.EqualTo(syscall.MAP_PRIVATE | syscall.MAP_ANONYMOUS | syscall.MAP_STACK), }, { - seccomp.AllowAny{}, - seccomp.AllowAny{}, - seccomp.AllowAny{}, - seccomp.AllowValue(syscall.MAP_PRIVATE | syscall.MAP_ANONYMOUS | syscall.MAP_NORESERVE), + seccomp.MatchAny{}, + seccomp.MatchAny{}, + seccomp.MatchAny{}, + seccomp.EqualTo(syscall.MAP_PRIVATE | syscall.MAP_ANONYMOUS | syscall.MAP_NORESERVE), }, { - seccomp.AllowAny{}, - seccomp.AllowAny{}, - seccomp.AllowValue(syscall.PROT_WRITE | syscall.PROT_READ), - seccomp.AllowValue(syscall.MAP_PRIVATE | syscall.MAP_ANONYMOUS | syscall.MAP_FIXED), + seccomp.MatchAny{}, + seccomp.MatchAny{}, + seccomp.EqualTo(syscall.PROT_WRITE | syscall.PROT_READ), + seccomp.EqualTo(syscall.MAP_PRIVATE | syscall.MAP_ANONYMOUS | syscall.MAP_FIXED), }, }, syscall.SYS_MPROTECT: {}, @@ -237,32 +237,32 @@ var allowedSyscalls = seccomp.SyscallRules{ syscall.SYS_READ: {}, syscall.SYS_RECVMSG: []seccomp.Rule{ { - seccomp.AllowAny{}, - seccomp.AllowAny{}, - seccomp.AllowValue(syscall.MSG_DONTWAIT | syscall.MSG_TRUNC), + seccomp.MatchAny{}, + seccomp.MatchAny{}, + seccomp.EqualTo(syscall.MSG_DONTWAIT | syscall.MSG_TRUNC), }, { - seccomp.AllowAny{}, - seccomp.AllowAny{}, - seccomp.AllowValue(syscall.MSG_DONTWAIT | syscall.MSG_TRUNC | syscall.MSG_PEEK), + seccomp.MatchAny{}, + seccomp.MatchAny{}, + seccomp.EqualTo(syscall.MSG_DONTWAIT | syscall.MSG_TRUNC | syscall.MSG_PEEK), }, }, syscall.SYS_RECVMMSG: []seccomp.Rule{ { - seccomp.AllowAny{}, - seccomp.AllowAny{}, - seccomp.AllowValue(fdbased.MaxMsgsPerRecv), - seccomp.AllowValue(syscall.MSG_DONTWAIT), - seccomp.AllowValue(0), + seccomp.MatchAny{}, + seccomp.MatchAny{}, + seccomp.EqualTo(fdbased.MaxMsgsPerRecv), + seccomp.EqualTo(syscall.MSG_DONTWAIT), + seccomp.EqualTo(0), }, }, unix.SYS_SENDMMSG: []seccomp.Rule{ { - seccomp.AllowAny{}, - seccomp.AllowAny{}, - seccomp.AllowAny{}, - seccomp.AllowValue(syscall.MSG_DONTWAIT), - seccomp.AllowValue(0), + seccomp.MatchAny{}, + seccomp.MatchAny{}, + seccomp.MatchAny{}, + seccomp.EqualTo(syscall.MSG_DONTWAIT), + seccomp.EqualTo(0), }, }, syscall.SYS_RESTART_SYSCALL: {}, @@ -272,49 +272,49 @@ var allowedSyscalls = seccomp.SyscallRules{ syscall.SYS_SCHED_YIELD: {}, syscall.SYS_SENDMSG: []seccomp.Rule{ { - seccomp.AllowAny{}, - seccomp.AllowAny{}, - seccomp.AllowValue(syscall.MSG_DONTWAIT | syscall.MSG_NOSIGNAL), + seccomp.MatchAny{}, + seccomp.MatchAny{}, + seccomp.EqualTo(syscall.MSG_DONTWAIT | syscall.MSG_NOSIGNAL), }, }, syscall.SYS_SETITIMER: {}, syscall.SYS_SHUTDOWN: []seccomp.Rule{ // Used by fs/host to shutdown host sockets. - {seccomp.AllowAny{}, seccomp.AllowValue(syscall.SHUT_RD)}, - {seccomp.AllowAny{}, seccomp.AllowValue(syscall.SHUT_WR)}, + {seccomp.MatchAny{}, seccomp.EqualTo(syscall.SHUT_RD)}, + {seccomp.MatchAny{}, seccomp.EqualTo(syscall.SHUT_WR)}, // Used by unet to shutdown connections. - {seccomp.AllowAny{}, seccomp.AllowValue(syscall.SHUT_RDWR)}, + {seccomp.MatchAny{}, seccomp.EqualTo(syscall.SHUT_RDWR)}, }, syscall.SYS_SIGALTSTACK: {}, unix.SYS_STATX: {}, syscall.SYS_SYNC_FILE_RANGE: {}, syscall.SYS_TEE: []seccomp.Rule{ { - seccomp.AllowAny{}, - seccomp.AllowAny{}, - seccomp.AllowValue(1), /* len */ - seccomp.AllowValue(unix.SPLICE_F_NONBLOCK), /* flags */ + seccomp.MatchAny{}, + seccomp.MatchAny{}, + seccomp.EqualTo(1), /* len */ + seccomp.EqualTo(unix.SPLICE_F_NONBLOCK), /* flags */ }, }, syscall.SYS_TGKILL: []seccomp.Rule{ { - seccomp.AllowValue(uint64(os.Getpid())), + seccomp.EqualTo(uint64(os.Getpid())), }, }, syscall.SYS_UTIMENSAT: []seccomp.Rule{ { - seccomp.AllowAny{}, - seccomp.AllowValue(0), /* null pathname */ - seccomp.AllowAny{}, - seccomp.AllowValue(0), /* flags */ + seccomp.MatchAny{}, + seccomp.EqualTo(0), /* null pathname */ + seccomp.MatchAny{}, + seccomp.EqualTo(0), /* flags */ }, }, syscall.SYS_WRITE: {}, // For rawfile.NonBlockingWriteIovec. syscall.SYS_WRITEV: []seccomp.Rule{ { - seccomp.AllowAny{}, - seccomp.AllowAny{}, + seccomp.MatchAny{}, + seccomp.MatchAny{}, seccomp.GreaterThan(0), }, }, @@ -325,10 +325,10 @@ func hostInetFilters() seccomp.SyscallRules { return seccomp.SyscallRules{ syscall.SYS_ACCEPT4: []seccomp.Rule{ { - seccomp.AllowAny{}, - seccomp.AllowAny{}, - seccomp.AllowAny{}, - seccomp.AllowValue(syscall.SOCK_NONBLOCK | syscall.SOCK_CLOEXEC), + seccomp.MatchAny{}, + seccomp.MatchAny{}, + seccomp.MatchAny{}, + seccomp.EqualTo(syscall.SOCK_NONBLOCK | syscall.SOCK_CLOEXEC), }, }, syscall.SYS_BIND: {}, @@ -337,84 +337,84 @@ func hostInetFilters() seccomp.SyscallRules { syscall.SYS_GETSOCKNAME: {}, syscall.SYS_GETSOCKOPT: []seccomp.Rule{ { - seccomp.AllowAny{}, - seccomp.AllowValue(syscall.SOL_IP), - seccomp.AllowValue(syscall.IP_TOS), + seccomp.MatchAny{}, + seccomp.EqualTo(syscall.SOL_IP), + seccomp.EqualTo(syscall.IP_TOS), }, { - seccomp.AllowAny{}, - seccomp.AllowValue(syscall.SOL_IP), - seccomp.AllowValue(syscall.IP_RECVTOS), + seccomp.MatchAny{}, + seccomp.EqualTo(syscall.SOL_IP), + seccomp.EqualTo(syscall.IP_RECVTOS), }, { - seccomp.AllowAny{}, - seccomp.AllowValue(syscall.SOL_IPV6), - seccomp.AllowValue(syscall.IPV6_TCLASS), + seccomp.MatchAny{}, + seccomp.EqualTo(syscall.SOL_IPV6), + seccomp.EqualTo(syscall.IPV6_TCLASS), }, { - seccomp.AllowAny{}, - seccomp.AllowValue(syscall.SOL_IPV6), - seccomp.AllowValue(syscall.IPV6_RECVTCLASS), + seccomp.MatchAny{}, + seccomp.EqualTo(syscall.SOL_IPV6), + seccomp.EqualTo(syscall.IPV6_RECVTCLASS), }, { - seccomp.AllowAny{}, - seccomp.AllowValue(syscall.SOL_IPV6), - seccomp.AllowValue(syscall.IPV6_V6ONLY), + seccomp.MatchAny{}, + seccomp.EqualTo(syscall.SOL_IPV6), + seccomp.EqualTo(syscall.IPV6_V6ONLY), }, { - seccomp.AllowAny{}, - seccomp.AllowValue(syscall.SOL_SOCKET), - seccomp.AllowValue(syscall.SO_ERROR), + seccomp.MatchAny{}, + seccomp.EqualTo(syscall.SOL_SOCKET), + seccomp.EqualTo(syscall.SO_ERROR), }, { - seccomp.AllowAny{}, - seccomp.AllowValue(syscall.SOL_SOCKET), - seccomp.AllowValue(syscall.SO_KEEPALIVE), + seccomp.MatchAny{}, + seccomp.EqualTo(syscall.SOL_SOCKET), + seccomp.EqualTo(syscall.SO_KEEPALIVE), }, { - seccomp.AllowAny{}, - seccomp.AllowValue(syscall.SOL_SOCKET), - seccomp.AllowValue(syscall.SO_SNDBUF), + seccomp.MatchAny{}, + seccomp.EqualTo(syscall.SOL_SOCKET), + seccomp.EqualTo(syscall.SO_SNDBUF), }, { - seccomp.AllowAny{}, - seccomp.AllowValue(syscall.SOL_SOCKET), - seccomp.AllowValue(syscall.SO_RCVBUF), + seccomp.MatchAny{}, + seccomp.EqualTo(syscall.SOL_SOCKET), + seccomp.EqualTo(syscall.SO_RCVBUF), }, { - seccomp.AllowAny{}, - seccomp.AllowValue(syscall.SOL_SOCKET), - seccomp.AllowValue(syscall.SO_REUSEADDR), + seccomp.MatchAny{}, + seccomp.EqualTo(syscall.SOL_SOCKET), + seccomp.EqualTo(syscall.SO_REUSEADDR), }, { - seccomp.AllowAny{}, - seccomp.AllowValue(syscall.SOL_SOCKET), - seccomp.AllowValue(syscall.SO_TYPE), + seccomp.MatchAny{}, + seccomp.EqualTo(syscall.SOL_SOCKET), + seccomp.EqualTo(syscall.SO_TYPE), }, { - seccomp.AllowAny{}, - seccomp.AllowValue(syscall.SOL_SOCKET), - seccomp.AllowValue(syscall.SO_LINGER), + seccomp.MatchAny{}, + seccomp.EqualTo(syscall.SOL_SOCKET), + seccomp.EqualTo(syscall.SO_LINGER), }, { - seccomp.AllowAny{}, - seccomp.AllowValue(syscall.SOL_TCP), - seccomp.AllowValue(syscall.TCP_NODELAY), + seccomp.MatchAny{}, + seccomp.EqualTo(syscall.SOL_TCP), + seccomp.EqualTo(syscall.TCP_NODELAY), }, { - seccomp.AllowAny{}, - seccomp.AllowValue(syscall.SOL_TCP), - seccomp.AllowValue(syscall.TCP_INFO), + seccomp.MatchAny{}, + seccomp.EqualTo(syscall.SOL_TCP), + seccomp.EqualTo(syscall.TCP_INFO), }, }, syscall.SYS_IOCTL: []seccomp.Rule{ { - seccomp.AllowAny{}, - seccomp.AllowValue(syscall.TIOCOUTQ), + seccomp.MatchAny{}, + seccomp.EqualTo(syscall.TIOCOUTQ), }, { - seccomp.AllowAny{}, - seccomp.AllowValue(syscall.TIOCINQ), + seccomp.MatchAny{}, + seccomp.EqualTo(syscall.TIOCINQ), }, }, syscall.SYS_LISTEN: {}, @@ -425,103 +425,103 @@ func hostInetFilters() seccomp.SyscallRules { syscall.SYS_SENDTO: {}, syscall.SYS_SETSOCKOPT: []seccomp.Rule{ { - seccomp.AllowAny{}, - seccomp.AllowValue(syscall.SOL_IPV6), - seccomp.AllowValue(syscall.IPV6_V6ONLY), - seccomp.AllowAny{}, - seccomp.AllowValue(4), + seccomp.MatchAny{}, + seccomp.EqualTo(syscall.SOL_IPV6), + seccomp.EqualTo(syscall.IPV6_V6ONLY), + seccomp.MatchAny{}, + seccomp.EqualTo(4), }, { - seccomp.AllowAny{}, - seccomp.AllowValue(syscall.SOL_SOCKET), - seccomp.AllowValue(syscall.SO_SNDBUF), - seccomp.AllowAny{}, - seccomp.AllowValue(4), + seccomp.MatchAny{}, + seccomp.EqualTo(syscall.SOL_SOCKET), + seccomp.EqualTo(syscall.SO_SNDBUF), + seccomp.MatchAny{}, + seccomp.EqualTo(4), }, { - seccomp.AllowAny{}, - seccomp.AllowValue(syscall.SOL_SOCKET), - seccomp.AllowValue(syscall.SO_RCVBUF), - seccomp.AllowAny{}, - seccomp.AllowValue(4), + seccomp.MatchAny{}, + seccomp.EqualTo(syscall.SOL_SOCKET), + seccomp.EqualTo(syscall.SO_RCVBUF), + seccomp.MatchAny{}, + seccomp.EqualTo(4), }, { - seccomp.AllowAny{}, - seccomp.AllowValue(syscall.SOL_SOCKET), - seccomp.AllowValue(syscall.SO_REUSEADDR), - seccomp.AllowAny{}, - seccomp.AllowValue(4), + seccomp.MatchAny{}, + seccomp.EqualTo(syscall.SOL_SOCKET), + seccomp.EqualTo(syscall.SO_REUSEADDR), + seccomp.MatchAny{}, + seccomp.EqualTo(4), }, { - seccomp.AllowAny{}, - seccomp.AllowValue(syscall.SOL_TCP), - seccomp.AllowValue(syscall.TCP_NODELAY), - seccomp.AllowAny{}, - seccomp.AllowValue(4), + seccomp.MatchAny{}, + seccomp.EqualTo(syscall.SOL_TCP), + seccomp.EqualTo(syscall.TCP_NODELAY), + seccomp.MatchAny{}, + seccomp.EqualTo(4), }, { - seccomp.AllowAny{}, - seccomp.AllowValue(syscall.SOL_IP), - seccomp.AllowValue(syscall.IP_TOS), - seccomp.AllowAny{}, - seccomp.AllowValue(4), + seccomp.MatchAny{}, + seccomp.EqualTo(syscall.SOL_IP), + seccomp.EqualTo(syscall.IP_TOS), + seccomp.MatchAny{}, + seccomp.EqualTo(4), }, { - seccomp.AllowAny{}, - seccomp.AllowValue(syscall.SOL_IP), - seccomp.AllowValue(syscall.IP_RECVTOS), - seccomp.AllowAny{}, - seccomp.AllowValue(4), + seccomp.MatchAny{}, + seccomp.EqualTo(syscall.SOL_IP), + seccomp.EqualTo(syscall.IP_RECVTOS), + seccomp.MatchAny{}, + seccomp.EqualTo(4), }, { - seccomp.AllowAny{}, - seccomp.AllowValue(syscall.SOL_IPV6), - seccomp.AllowValue(syscall.IPV6_TCLASS), - seccomp.AllowAny{}, - seccomp.AllowValue(4), + seccomp.MatchAny{}, + seccomp.EqualTo(syscall.SOL_IPV6), + seccomp.EqualTo(syscall.IPV6_TCLASS), + seccomp.MatchAny{}, + seccomp.EqualTo(4), }, { - seccomp.AllowAny{}, - seccomp.AllowValue(syscall.SOL_IPV6), - seccomp.AllowValue(syscall.IPV6_RECVTCLASS), - seccomp.AllowAny{}, - seccomp.AllowValue(4), + seccomp.MatchAny{}, + seccomp.EqualTo(syscall.SOL_IPV6), + seccomp.EqualTo(syscall.IPV6_RECVTCLASS), + seccomp.MatchAny{}, + seccomp.EqualTo(4), }, }, syscall.SYS_SHUTDOWN: []seccomp.Rule{ { - seccomp.AllowAny{}, - seccomp.AllowValue(syscall.SHUT_RD), + seccomp.MatchAny{}, + seccomp.EqualTo(syscall.SHUT_RD), }, { - seccomp.AllowAny{}, - seccomp.AllowValue(syscall.SHUT_WR), + seccomp.MatchAny{}, + seccomp.EqualTo(syscall.SHUT_WR), }, { - seccomp.AllowAny{}, - seccomp.AllowValue(syscall.SHUT_RDWR), + seccomp.MatchAny{}, + seccomp.EqualTo(syscall.SHUT_RDWR), }, }, syscall.SYS_SOCKET: []seccomp.Rule{ { - seccomp.AllowValue(syscall.AF_INET), - seccomp.AllowValue(syscall.SOCK_STREAM | syscall.SOCK_NONBLOCK | syscall.SOCK_CLOEXEC), - seccomp.AllowValue(0), + seccomp.EqualTo(syscall.AF_INET), + seccomp.EqualTo(syscall.SOCK_STREAM | syscall.SOCK_NONBLOCK | syscall.SOCK_CLOEXEC), + seccomp.EqualTo(0), }, { - seccomp.AllowValue(syscall.AF_INET), - seccomp.AllowValue(syscall.SOCK_DGRAM | syscall.SOCK_NONBLOCK | syscall.SOCK_CLOEXEC), - seccomp.AllowValue(0), + seccomp.EqualTo(syscall.AF_INET), + seccomp.EqualTo(syscall.SOCK_DGRAM | syscall.SOCK_NONBLOCK | syscall.SOCK_CLOEXEC), + seccomp.EqualTo(0), }, { - seccomp.AllowValue(syscall.AF_INET6), - seccomp.AllowValue(syscall.SOCK_STREAM | syscall.SOCK_NONBLOCK | syscall.SOCK_CLOEXEC), - seccomp.AllowValue(0), + seccomp.EqualTo(syscall.AF_INET6), + seccomp.EqualTo(syscall.SOCK_STREAM | syscall.SOCK_NONBLOCK | syscall.SOCK_CLOEXEC), + seccomp.EqualTo(0), }, { - seccomp.AllowValue(syscall.AF_INET6), - seccomp.AllowValue(syscall.SOCK_DGRAM | syscall.SOCK_NONBLOCK | syscall.SOCK_CLOEXEC), - seccomp.AllowValue(0), + seccomp.EqualTo(syscall.AF_INET6), + seccomp.EqualTo(syscall.SOCK_DGRAM | syscall.SOCK_NONBLOCK | syscall.SOCK_CLOEXEC), + seccomp.EqualTo(0), }, }, syscall.SYS_WRITEV: {}, @@ -532,20 +532,20 @@ func controlServerFilters(fd int) seccomp.SyscallRules { return seccomp.SyscallRules{ syscall.SYS_ACCEPT: []seccomp.Rule{ { - seccomp.AllowValue(fd), + seccomp.EqualTo(fd), }, }, syscall.SYS_LISTEN: []seccomp.Rule{ { - seccomp.AllowValue(fd), - seccomp.AllowValue(16 /* unet.backlog */), + seccomp.EqualTo(fd), + seccomp.EqualTo(16 /* unet.backlog */), }, }, syscall.SYS_GETSOCKOPT: []seccomp.Rule{ { - seccomp.AllowAny{}, - seccomp.AllowValue(syscall.SOL_SOCKET), - seccomp.AllowValue(syscall.SO_PEERCRED), + seccomp.MatchAny{}, + seccomp.EqualTo(syscall.SOL_SOCKET), + seccomp.EqualTo(syscall.SO_PEERCRED), }, }, } diff --git a/runsc/boot/filter/config_amd64.go b/runsc/boot/filter/config_amd64.go index 5335ff82c..24e13565e 100644 --- a/runsc/boot/filter/config_amd64.go +++ b/runsc/boot/filter/config_amd64.go @@ -25,7 +25,6 @@ import ( func init() { allowedSyscalls[syscall.SYS_ARCH_PRCTL] = append(allowedSyscalls[syscall.SYS_ARCH_PRCTL], - seccomp.Rule{seccomp.AllowValue(linux.ARCH_GET_FS)}, - seccomp.Rule{seccomp.AllowValue(linux.ARCH_SET_FS)}, + seccomp.Rule{seccomp.EqualTo(linux.ARCH_SET_FS)}, ) } diff --git a/runsc/boot/filter/config_profile.go b/runsc/boot/filter/config_profile.go index 194952a7b..7b8669595 100644 --- a/runsc/boot/filter/config_profile.go +++ b/runsc/boot/filter/config_profile.go @@ -25,9 +25,9 @@ func profileFilters() seccomp.SyscallRules { return seccomp.SyscallRules{ syscall.SYS_OPENAT: []seccomp.Rule{ { - seccomp.AllowAny{}, - seccomp.AllowAny{}, - seccomp.AllowValue(syscall.O_RDONLY | syscall.O_LARGEFILE | syscall.O_CLOEXEC), + seccomp.MatchAny{}, + seccomp.MatchAny{}, + seccomp.EqualTo(syscall.O_RDONLY | syscall.O_LARGEFILE | syscall.O_CLOEXEC), }, }, } diff --git a/runsc/boot/fs.go b/runsc/boot/fs.go index 9dd5b0184..ddf288456 100644 --- a/runsc/boot/fs.go +++ b/runsc/boot/fs.go @@ -34,6 +34,7 @@ import ( specs "github.com/opencontainers/runtime-spec/specs-go" "gvisor.dev/gvisor/pkg/abi/linux" "gvisor.dev/gvisor/pkg/context" + "gvisor.dev/gvisor/pkg/fd" "gvisor.dev/gvisor/pkg/log" "gvisor.dev/gvisor/pkg/sentry/fs" "gvisor.dev/gvisor/pkg/sentry/fs/gofer" @@ -48,6 +49,7 @@ import ( "gvisor.dev/gvisor/pkg/sentry/kernel" "gvisor.dev/gvisor/pkg/sentry/kernel/auth" "gvisor.dev/gvisor/pkg/syserror" + "gvisor.dev/gvisor/runsc/config" "gvisor.dev/gvisor/runsc/specutils" ) @@ -66,7 +68,7 @@ const ( // tmpfs has some extra supported options that we must pass through. var tmpfsAllowedData = []string{"mode", "uid", "gid"} -func addOverlay(ctx context.Context, conf *Config, lower *fs.Inode, name string, lowerFlags fs.MountSourceFlags) (*fs.Inode, error) { +func addOverlay(ctx context.Context, conf *config.Config, lower *fs.Inode, name string, lowerFlags fs.MountSourceFlags) (*fs.Inode, error) { // Upper layer uses the same flags as lower, but it must be read-write. upperFlags := lowerFlags upperFlags.ReadOnly = false @@ -156,7 +158,7 @@ func compileMounts(spec *specs.Spec) []specs.Mount { } // p9MountData creates a slice of p9 mount data. -func p9MountData(fd int, fa FileAccessType, vfs2 bool) []string { +func p9MountData(fd int, fa config.FileAccessType, vfs2 bool) []string { opts := []string{ "trans=fd", "rfdno=" + strconv.Itoa(fd), @@ -167,7 +169,7 @@ func p9MountData(fd int, fa FileAccessType, vfs2 bool) []string { // enablement. opts = append(opts, "privateunixsocket=true") } - if fa == FileAccessShared { + if fa == config.FileAccessShared { opts = append(opts, "cache=remote_revalidating") } return opts @@ -252,7 +254,7 @@ func mustFindFilesystem(name string) fs.Filesystem { // addSubmountOverlay overlays the inode over a ramfs tree containing the given // paths. -func addSubmountOverlay(ctx context.Context, inode *fs.Inode, submounts []string) (*fs.Inode, error) { +func addSubmountOverlay(ctx context.Context, inode *fs.Inode, submounts []string, mf fs.MountSourceFlags) (*fs.Inode, error) { // Construct a ramfs tree of mount points. The contents never // change, so this can be fully caching. There's no real // filesystem backing this tree, so we set the filesystem to @@ -262,7 +264,7 @@ func addSubmountOverlay(ctx context.Context, inode *fs.Inode, submounts []string if err != nil { return nil, fmt.Errorf("creating mount tree: %v", err) } - overlayInode, err := fs.NewOverlayRoot(ctx, inode, mountTree, fs.MountSourceFlags{}) + overlayInode, err := fs.NewOverlayRoot(ctx, inode, mountTree, mf) if err != nil { return nil, fmt.Errorf("adding mount overlay: %v", err) } @@ -281,7 +283,7 @@ func subtargets(root string, mnts []specs.Mount) []string { return targets } -func setupContainerFS(ctx context.Context, conf *Config, mntr *containerMounter, procArgs *kernel.CreateProcessArgs) error { +func setupContainerFS(ctx context.Context, conf *config.Config, mntr *containerMounter, procArgs *kernel.CreateProcessArgs) error { if conf.VFS2 { return setupContainerVFS2(ctx, conf, mntr, procArgs) } @@ -319,14 +321,14 @@ func adjustDirentCache(k *kernel.Kernel) error { } type fdDispenser struct { - fds []int + fds []*fd.FD } func (f *fdDispenser) remove() int { if f.empty() { panic("fdDispenser out of fds") } - rv := f.fds[0] + rv := f.fds[0].Release() f.fds = f.fds[1:] return rv } @@ -452,27 +454,27 @@ func (m *mountHint) isSupported() bool { func (m *mountHint) checkCompatible(mount specs.Mount) error { // Remove options that don't affect to mount's behavior. masterOpts := filterUnsupportedOptions(m.mount) - slaveOpts := filterUnsupportedOptions(mount) + replicaOpts := filterUnsupportedOptions(mount) - if len(masterOpts) != len(slaveOpts) { - return fmt.Errorf("mount options in annotations differ from container mount, annotation: %s, mount: %s", masterOpts, slaveOpts) + if len(masterOpts) != len(replicaOpts) { + return fmt.Errorf("mount options in annotations differ from container mount, annotation: %s, mount: %s", masterOpts, replicaOpts) } sort.Strings(masterOpts) - sort.Strings(slaveOpts) + sort.Strings(replicaOpts) for i, opt := range masterOpts { - if opt != slaveOpts[i] { - return fmt.Errorf("mount options in annotations differ from container mount, annotation: %s, mount: %s", masterOpts, slaveOpts) + if opt != replicaOpts[i] { + return fmt.Errorf("mount options in annotations differ from container mount, annotation: %s, mount: %s", masterOpts, replicaOpts) } } return nil } -func (m *mountHint) fileAccessType() FileAccessType { +func (m *mountHint) fileAccessType() config.FileAccessType { if m.share == container { - return FileAccessExclusive + return config.FileAccessExclusive } - return FileAccessShared + return config.FileAccessShared } func filterUnsupportedOptions(mount specs.Mount) []string { @@ -563,7 +565,7 @@ type containerMounter struct { hints *podMountHints } -func newContainerMounter(spec *specs.Spec, goferFDs []int, k *kernel.Kernel, hints *podMountHints) *containerMounter { +func newContainerMounter(spec *specs.Spec, goferFDs []*fd.FD, k *kernel.Kernel, hints *podMountHints) *containerMounter { return &containerMounter{ root: spec.Root, mounts: compileMounts(spec), @@ -576,7 +578,7 @@ func newContainerMounter(spec *specs.Spec, goferFDs []int, k *kernel.Kernel, hin // processHints processes annotations that container hints about how volumes // should be mounted (e.g. a volume shared between containers). It must be // called for the root container only. -func (c *containerMounter) processHints(conf *Config, creds *auth.Credentials) error { +func (c *containerMounter) processHints(conf *config.Config, creds *auth.Credentials) error { if conf.VFS2 { return c.processHintsVFS2(conf, creds) } @@ -600,7 +602,7 @@ func (c *containerMounter) processHints(conf *Config, creds *auth.Credentials) e // setupFS is used to set up the file system for all containers. This is the // main entry point method, with most of the other being internal only. It // returns the mount namespace that is created for the container. -func (c *containerMounter) setupFS(conf *Config, procArgs *kernel.CreateProcessArgs) (*fs.MountNamespace, error) { +func (c *containerMounter) setupFS(conf *config.Config, procArgs *kernel.CreateProcessArgs) (*fs.MountNamespace, error) { log.Infof("Configuring container's file system") // Create context with root credentials to mount the filesystem (the current @@ -626,7 +628,7 @@ func (c *containerMounter) setupFS(conf *Config, procArgs *kernel.CreateProcessA return mns, nil } -func (c *containerMounter) createMountNamespace(ctx context.Context, conf *Config) (*fs.MountNamespace, error) { +func (c *containerMounter) createMountNamespace(ctx context.Context, conf *config.Config) (*fs.MountNamespace, error) { rootInode, err := c.createRootMount(ctx, conf) if err != nil { return nil, fmt.Errorf("creating filesystem for container: %v", err) @@ -638,7 +640,7 @@ func (c *containerMounter) createMountNamespace(ctx context.Context, conf *Confi return mns, nil } -func (c *containerMounter) mountSubmounts(ctx context.Context, conf *Config, mns *fs.MountNamespace) error { +func (c *containerMounter) mountSubmounts(ctx context.Context, conf *config.Config, mns *fs.MountNamespace) error { root := mns.Root() defer root.DecRef(ctx) @@ -674,7 +676,7 @@ func (c *containerMounter) checkDispenser() error { // mountSharedMaster mounts the master of a volume that is shared among // containers in a pod. It returns the root mount's inode. -func (c *containerMounter) mountSharedMaster(ctx context.Context, conf *Config, hint *mountHint) (*fs.Inode, error) { +func (c *containerMounter) mountSharedMaster(ctx context.Context, conf *config.Config, hint *mountHint) (*fs.Inode, error) { // Map mount type to filesystem name, and parse out the options that we are // capable of dealing with. fsName, opts, useOverlay, err := c.getMountNameAndOptions(conf, hint.mount) @@ -714,7 +716,7 @@ func (c *containerMounter) mountSharedMaster(ctx context.Context, conf *Config, } // createRootMount creates the root filesystem. -func (c *containerMounter) createRootMount(ctx context.Context, conf *Config) (*fs.Inode, error) { +func (c *containerMounter) createRootMount(ctx context.Context, conf *config.Config) (*fs.Inode, error) { // First construct the filesystem from the spec.Root. mf := fs.MountSourceFlags{ReadOnly: c.root.Readonly || conf.Overlay} @@ -739,7 +741,7 @@ func (c *containerMounter) createRootMount(ctx context.Context, conf *Config) (* // for submount paths. "/dev" "/sys" "/proc" and "/tmp" are always // mounted even if they are not in the spec. submounts := append(subtargets("/", c.mounts), "/dev", "/sys", "/proc", "/tmp") - rootInode, err = addSubmountOverlay(ctx, rootInode, submounts) + rootInode, err = addSubmountOverlay(ctx, rootInode, submounts, mf) if err != nil { return nil, fmt.Errorf("adding submount overlay: %v", err) } @@ -759,7 +761,7 @@ func (c *containerMounter) createRootMount(ctx context.Context, conf *Config) (* // getMountNameAndOptions retrieves the fsName, opts, and useOverlay values // used for mounts. -func (c *containerMounter) getMountNameAndOptions(conf *Config, m specs.Mount) (string, []string, bool, error) { +func (c *containerMounter) getMountNameAndOptions(conf *config.Config, m specs.Mount) (string, []string, bool, error) { var ( fsName string opts []string @@ -793,19 +795,19 @@ func (c *containerMounter) getMountNameAndOptions(conf *Config, m specs.Mount) ( return fsName, opts, useOverlay, nil } -func (c *containerMounter) getMountAccessType(mount specs.Mount) FileAccessType { +func (c *containerMounter) getMountAccessType(mount specs.Mount) config.FileAccessType { if hint := c.hints.findMount(mount); hint != nil { return hint.fileAccessType() } // Non-root bind mounts are always shared if no hints were provided. - return FileAccessShared + return config.FileAccessShared } // mountSubmount mounts volumes inside the container's root. Because mounts may // be readonly, a lower ramfs overlay is added to create the mount point dir. // Another overlay is added with tmpfs on top if Config.Overlay is true. // 'm.Destination' must be an absolute path with '..' and symlinks resolved. -func (c *containerMounter) mountSubmount(ctx context.Context, conf *Config, mns *fs.MountNamespace, root *fs.Dirent, m specs.Mount) error { +func (c *containerMounter) mountSubmount(ctx context.Context, conf *config.Config, mns *fs.MountNamespace, root *fs.Dirent, m specs.Mount) error { // Map mount type to filesystem name, and parse out the options that we are // capable of dealing with. fsName, opts, useOverlay, err := c.getMountNameAndOptions(conf, m) @@ -849,7 +851,7 @@ func (c *containerMounter) mountSubmount(ctx context.Context, conf *Config, mns submounts := subtargets(m.Destination, c.mounts) if len(submounts) > 0 { log.Infof("Adding submount overlay over %q", m.Destination) - inode, err = addSubmountOverlay(ctx, inode, submounts) + inode, err = addSubmountOverlay(ctx, inode, submounts, mf) if err != nil { return fmt.Errorf("adding submount overlay: %v", err) } @@ -904,7 +906,7 @@ func (c *containerMounter) mountSharedSubmount(ctx context.Context, mns *fs.Moun // addRestoreMount adds a mount to the MountSources map used for restoring a // checkpointed container. -func (c *containerMounter) addRestoreMount(conf *Config, renv *fs.RestoreEnvironment, m specs.Mount) error { +func (c *containerMounter) addRestoreMount(conf *config.Config, renv *fs.RestoreEnvironment, m specs.Mount) error { fsName, opts, useOverlay, err := c.getMountNameAndOptions(conf, m) if err != nil { return err @@ -929,7 +931,7 @@ func (c *containerMounter) addRestoreMount(conf *Config, renv *fs.RestoreEnviron // createRestoreEnvironment builds a fs.RestoreEnvironment called renv by adding // the mounts to the environment. -func (c *containerMounter) createRestoreEnvironment(conf *Config) (*fs.RestoreEnvironment, error) { +func (c *containerMounter) createRestoreEnvironment(conf *config.Config) (*fs.RestoreEnvironment, error) { renv := &fs.RestoreEnvironment{ MountSources: make(map[string][]fs.MountArgs), } @@ -984,7 +986,7 @@ func (c *containerMounter) createRestoreEnvironment(conf *Config) (*fs.RestoreEn // // Note that when there are submounts inside of '/tmp', directories for the // mount points must be present, making '/tmp' not empty anymore. -func (c *containerMounter) mountTmp(ctx context.Context, conf *Config, mns *fs.MountNamespace, root *fs.Dirent) error { +func (c *containerMounter) mountTmp(ctx context.Context, conf *config.Config, mns *fs.MountNamespace, root *fs.Dirent) error { for _, m := range c.mounts { if filepath.Clean(m.Destination) == "/tmp" { log.Debugf("Explict %q mount found, skipping internal tmpfs, mount: %+v", "/tmp", m) diff --git a/runsc/boot/fs_test.go b/runsc/boot/fs_test.go index 912037075..e986231e5 100644 --- a/runsc/boot/fs_test.go +++ b/runsc/boot/fs_test.go @@ -20,6 +20,7 @@ import ( "testing" specs "github.com/opencontainers/runtime-spec/specs-go" + "gvisor.dev/gvisor/runsc/config" ) func TestPodMountHintsHappy(t *testing.T) { @@ -196,7 +197,7 @@ func TestGetMountAccessType(t *testing.T) { for _, tst := range []struct { name string annotations map[string]string - want FileAccessType + want config.FileAccessType }{ { name: "container=exclusive", @@ -205,7 +206,7 @@ func TestGetMountAccessType(t *testing.T) { MountPrefix + "mount1.type": "bind", MountPrefix + "mount1.share": "container", }, - want: FileAccessExclusive, + want: config.FileAccessExclusive, }, { name: "pod=shared", @@ -214,7 +215,7 @@ func TestGetMountAccessType(t *testing.T) { MountPrefix + "mount1.type": "bind", MountPrefix + "mount1.share": "pod", }, - want: FileAccessShared, + want: config.FileAccessShared, }, { name: "shared=shared", @@ -223,7 +224,7 @@ func TestGetMountAccessType(t *testing.T) { MountPrefix + "mount1.type": "bind", MountPrefix + "mount1.share": "shared", }, - want: FileAccessShared, + want: config.FileAccessShared, }, { name: "default=shared", @@ -232,7 +233,7 @@ func TestGetMountAccessType(t *testing.T) { MountPrefix + "mount1.type": "bind", MountPrefix + "mount1.share": "container", }, - want: FileAccessShared, + want: config.FileAccessShared, }, } { t.Run(tst.name, func(t *testing.T) { diff --git a/runsc/boot/loader.go b/runsc/boot/loader.go index 40c6f99fd..4940ea96a 100644 --- a/runsc/boot/loader.go +++ b/runsc/boot/loader.go @@ -27,8 +27,10 @@ import ( specs "github.com/opencontainers/runtime-spec/specs-go" "golang.org/x/sys/unix" "gvisor.dev/gvisor/pkg/abi/linux" + "gvisor.dev/gvisor/pkg/bpf" "gvisor.dev/gvisor/pkg/context" "gvisor.dev/gvisor/pkg/cpuid" + "gvisor.dev/gvisor/pkg/fd" "gvisor.dev/gvisor/pkg/log" "gvisor.dev/gvisor/pkg/memutil" "gvisor.dev/gvisor/pkg/rand" @@ -67,7 +69,9 @@ import ( "gvisor.dev/gvisor/runsc/boot/filter" _ "gvisor.dev/gvisor/runsc/boot/platforms" // register all platforms. "gvisor.dev/gvisor/runsc/boot/pprof" + "gvisor.dev/gvisor/runsc/config" "gvisor.dev/gvisor/runsc/specutils" + "gvisor.dev/gvisor/runsc/specutils/seccomp" // Include supported socket providers. "gvisor.dev/gvisor/pkg/sentry/socket/hostinet" @@ -79,7 +83,7 @@ import ( ) type containerInfo struct { - conf *Config + conf *config.Config // spec is the base configuration for the root container. spec *specs.Spec @@ -88,10 +92,10 @@ type containerInfo struct { procArgs kernel.CreateProcessArgs // stdioFDs contains stdin, stdout, and stderr. - stdioFDs []int + stdioFDs []*fd.FD // goferFDs are the FDs that attach the sandbox to the gofers. - goferFDs []int + goferFDs []*fd.FD } // Loader keeps state needed to start the kernel and run the container.. @@ -165,7 +169,7 @@ type Args struct { // Spec is the sandbox specification. Spec *specs.Spec // Conf is the system configuration. - Conf *Config + Conf *config.Config // ControllerFD is the FD to the URPC controller. The Loader takes ownership // of this FD and may close it at any time. ControllerFD int @@ -355,12 +359,17 @@ func New(args Args) (*Loader, error) { k.SetHostMount(hostMount) } + info := containerInfo{ + conf: args.Conf, + spec: args.Spec, + procArgs: procArgs, + } + // Make host FDs stable between invocations. Host FDs must map to the exact // same number when the sandbox is restored. Otherwise the wrong FD will be // used. - var stdioFDs []int newfd := startingStdioFD - for _, fd := range args.StdioFDs { + for _, stdioFD := range args.StdioFDs { // Check that newfd is unused to avoid clobbering over it. if _, err := unix.FcntlInt(uintptr(newfd), unix.F_GETFD, 0); !errors.Is(err, unix.EBADF) { if err != nil { @@ -369,14 +378,17 @@ func New(args Args) (*Loader, error) { return nil, fmt.Errorf("unable to remap stdios, FD %d is already in use", newfd) } - err := unix.Dup3(fd, newfd, unix.O_CLOEXEC) + err := unix.Dup3(stdioFD, newfd, unix.O_CLOEXEC) if err != nil { - return nil, fmt.Errorf("dup3 of stdioFDs failed: %v", err) + return nil, fmt.Errorf("dup3 of stdios failed: %w", err) } - stdioFDs = append(stdioFDs, newfd) - _ = unix.Close(fd) + info.stdioFDs = append(info.stdioFDs, fd.New(newfd)) + _ = unix.Close(stdioFD) newfd++ } + for _, goferFD := range args.GoferFDs { + info.goferFDs = append(info.goferFDs, fd.New(goferFD)) + } eid := execID{cid: args.ID} l := &Loader{ @@ -385,13 +397,7 @@ func New(args Args) (*Loader, error) { sandboxID: args.ID, processes: map[execID]*execProcess{eid: {}}, mountHints: mountHints, - root: containerInfo{ - conf: args.Conf, - stdioFDs: stdioFDs, - goferFDs: args.GoferFDs, - spec: args.Spec, - procArgs: procArgs, - }, + root: info, } // We don't care about child signals; some platforms can generate a @@ -465,13 +471,18 @@ func (l *Loader) Destroy() { } l.watchdog.Stop() - for i, fd := range l.root.stdioFDs { - _ = unix.Close(fd) - l.root.stdioFDs[i] = -1 + // In the success case, stdioFDs and goferFDs will only contain + // released/closed FDs that ownership has been passed over to host FDs and + // gofer sessions. Close them here in case on failure. + for _, fd := range l.root.stdioFDs { + _ = fd.Close() + } + for _, fd := range l.root.goferFDs { + _ = fd.Close() } } -func createPlatform(conf *Config, deviceFile *os.File) (platform.Platform, error) { +func createPlatform(conf *config.Config, deviceFile *os.File) (platform.Platform, error) { p, err := platform.Lookup(conf.Platform) if err != nil { panic(fmt.Sprintf("invalid platform %v: %v", conf.Platform, err)) @@ -498,13 +509,14 @@ func createMemoryFile() (*pgalloc.MemoryFile, error) { return mf, nil } +// installSeccompFilters installs sandbox seccomp filters with the host. func (l *Loader) installSeccompFilters() error { if l.root.conf.DisableSeccomp { filter.Report("syscall filter is DISABLED. Running in less secure mode.") } else { opts := filter.Options{ Platform: l.k.Platform, - HostNetwork: l.root.conf.Network == NetworkHost, + HostNetwork: l.root.conf.Network == config.NetworkHost, ProfileEnable: l.root.conf.ProfileEnable, ControllerFD: l.ctrl.srv.FD(), } @@ -531,7 +543,7 @@ func (l *Loader) Run() error { } func (l *Loader) run() error { - if l.root.conf.Network == NetworkHost { + if l.root.conf.Network == config.NetworkHost { // Delay host network configuration to this point because network namespace // is configured after the loader is created and before Run() is called. log.Debugf("Configuring host network") @@ -568,6 +580,7 @@ func (l *Loader) run() error { if _, err := l.createContainerProcess(true, l.sandboxID, &l.root, ep); err != nil { return err } + } ep.tg = l.k.GlobalInit() @@ -597,17 +610,6 @@ func (l *Loader) run() error { } }) - // l.stdioFDs are derived from dup() in boot.New() and they are now dup()ed again - // either in createFDTable() during initial start or in descriptor.initAfterLoad() - // during restore, we can release l.stdioFDs now. VFS2 takes ownership of the - // passed FDs, so only close for VFS1. - if !kernel.VFS2Enabled { - for i, fd := range l.root.stdioFDs { - _ = unix.Close(fd) - l.root.stdioFDs[i] = -1 - } - } - log.Infof("Process should have started...") l.watchdog.Start() return l.k.Start() @@ -627,9 +629,9 @@ func (l *Loader) createContainer(cid string) error { } // startContainer starts a child container. It returns the thread group ID of -// the newly created process. Caller owns 'files' and may close them after -// this method returns. -func (l *Loader) startContainer(spec *specs.Spec, conf *Config, cid string, files []*os.File) error { +// the newly created process. Used FDs are either closed or released. It's safe +// for the caller to close any remaining files upon return. +func (l *Loader) startContainer(spec *specs.Spec, conf *config.Config, cid string, files []*fd.FD) error { // Create capabilities. caps, err := specutils.Capabilities(conf.EnableRaw, spec.Process.Capabilities) if err != nil { @@ -680,28 +682,15 @@ func (l *Loader) startContainer(spec *specs.Spec, conf *Config, cid string, file } info := &containerInfo{ - conf: conf, - spec: spec, + conf: conf, + spec: spec, + stdioFDs: files[:3], + goferFDs: files[3:], } info.procArgs, err = createProcessArgs(cid, spec, creds, l.k, pidns) if err != nil { return fmt.Errorf("creating new process: %v", err) } - - // setupContainerFS() dups stdioFDs, so we don't need to dup them here. - for _, f := range files[:3] { - info.stdioFDs = append(info.stdioFDs, int(f.Fd())) - } - - // Can't take ownership away from os.File. dup them to get a new FDs. - for _, f := range files[3:] { - fd, err := unix.Dup(int(f.Fd())) - if err != nil { - return fmt.Errorf("failed to dup file: %v", err) - } - info.goferFDs = append(info.goferFDs, fd) - } - tg, err := l.createContainerProcess(false, cid, info, ep) if err != nil { return err @@ -779,19 +768,44 @@ func (l *Loader) createContainerProcess(root bool, cid string, info *containerIn } } + // Install seccomp filters with the new task if there are any. + if info.conf.OCISeccomp { + if info.spec.Linux != nil && info.spec.Linux.Seccomp != nil { + program, err := seccomp.BuildProgram(info.spec.Linux.Seccomp) + if err != nil { + return nil, fmt.Errorf("building seccomp program: %v", err) + } + + if log.IsLogging(log.Debug) { + out, _ := bpf.DecodeProgram(program) + log.Debugf("Installing OCI seccomp filters\nProgram:\n%s", out) + } + + task := tg.Leader() + // NOTE: It seems Flags are ignored by runc so we ignore them too. + if err := task.AppendSyscallFilter(program, true); err != nil { + return nil, fmt.Errorf("appending seccomp filters: %v", err) + } + } + } else { + if info.spec.Linux != nil && info.spec.Linux.Seccomp != nil { + log.Warningf("Seccomp spec is being ignored") + } + } + return tg, nil } // startGoferMonitor runs a goroutine to monitor gofer's health. It polls on // the gofer FDs looking for disconnects, and destroys the container if a // disconnect occurs in any of the gofer FDs. -func (l *Loader) startGoferMonitor(cid string, goferFDs []int) { +func (l *Loader) startGoferMonitor(cid string, goferFDs []*fd.FD) { go func() { log.Debugf("Monitoring gofer health for container %q", cid) var events []unix.PollFd - for _, fd := range goferFDs { + for _, goferFD := range goferFDs { events = append(events, unix.PollFd{ - Fd: int32(fd), + Fd: int32(goferFD.FD()), Events: unix.POLLHUP | unix.POLLRDHUP, }) } @@ -1017,17 +1031,17 @@ func (l *Loader) WaitExit() kernel.ExitStatus { return l.k.GlobalInit().ExitStatus() } -func newRootNetworkNamespace(conf *Config, clock tcpip.Clock, uniqueID stack.UniqueID) (*inet.Namespace, error) { +func newRootNetworkNamespace(conf *config.Config, clock tcpip.Clock, uniqueID stack.UniqueID) (*inet.Namespace, error) { // Create an empty network stack because the network namespace may be empty at // this point. Netns is configured before Run() is called. Netstack is // configured using a control uRPC message. Host network is configured inside // Run(). switch conf.Network { - case NetworkHost: + case config.NetworkHost: // No network namespacing support for hostinet yet, hence creator is nil. return inet.NewRootNamespace(hostinet.NewStack(), nil), nil - case NetworkNone, NetworkSandbox: + case config.NetworkNone, config.NetworkSandbox: s, err := newEmptySandboxNetworkStack(clock, uniqueID) if err != nil { return nil, err @@ -1060,17 +1074,30 @@ func newEmptySandboxNetworkStack(clock tcpip.Clock, uniqueID stack.UniqueID) (in })} // Enable SACK Recovery. - if err := s.Stack.SetTransportProtocolOption(tcp.ProtocolNumber, tcp.SACKEnabled(true)); err != nil { - return nil, fmt.Errorf("failed to enable SACK: %s", err) + { + opt := tcpip.TCPSACKEnabled(true) + if err := s.Stack.SetTransportProtocolOption(tcp.ProtocolNumber, &opt); err != nil { + return nil, fmt.Errorf("SetTransportProtocolOption(%d, &%T(%t)): %s", tcp.ProtocolNumber, opt, opt, err) + } } // Set default TTLs as required by socket/netstack. - s.Stack.SetNetworkProtocolOption(ipv4.ProtocolNumber, tcpip.DefaultTTLOption(netstack.DefaultTTL)) - s.Stack.SetNetworkProtocolOption(ipv6.ProtocolNumber, tcpip.DefaultTTLOption(netstack.DefaultTTL)) + { + opt := tcpip.DefaultTTLOption(netstack.DefaultTTL) + if err := s.Stack.SetNetworkProtocolOption(ipv4.ProtocolNumber, &opt); err != nil { + return nil, fmt.Errorf("SetNetworkProtocolOption(%d, &%T(%d)): %s", ipv4.ProtocolNumber, opt, opt, err) + } + if err := s.Stack.SetNetworkProtocolOption(ipv6.ProtocolNumber, &opt); err != nil { + return nil, fmt.Errorf("SetNetworkProtocolOption(%d, &%T(%d)): %s", ipv6.ProtocolNumber, opt, opt, err) + } + } // Enable Receive Buffer Auto-Tuning. - if err := s.Stack.SetTransportProtocolOption(tcp.ProtocolNumber, tcpip.ModerateReceiveBufferOption(true)); err != nil { - return nil, fmt.Errorf("SetTransportProtocolOption failed: %s", err) + { + opt := tcpip.TCPModerateReceiveBufferOption(true) + if err := s.Stack.SetTransportProtocolOption(tcp.ProtocolNumber, &opt); err != nil { + return nil, fmt.Errorf("SetTransportProtocolOption(%d, &%T(%t)): %s", tcp.ProtocolNumber, opt, opt, err) + } } return &s, nil @@ -1266,7 +1293,7 @@ func (l *Loader) ttyFromIDLocked(key execID) (*host.TTYFileOperations, *hostvfs2 return ep.tty, ep.ttyVFS2, nil } -func createFDTable(ctx context.Context, console bool, stdioFDs []int) (*kernel.FDTable, *host.TTYFileOperations, *hostvfs2.TTYFileDescription, error) { +func createFDTable(ctx context.Context, console bool, stdioFDs []*fd.FD) (*kernel.FDTable, *host.TTYFileOperations, *hostvfs2.TTYFileDescription, error) { if len(stdioFDs) != 3 { return nil, nil, nil, fmt.Errorf("stdioFDs should contain exactly 3 FDs (stdin, stdout, and stderr), but %d FDs received", len(stdioFDs)) } diff --git a/runsc/boot/loader_test.go b/runsc/boot/loader_test.go index aa3fdf96c..bf9ec5d38 100644 --- a/runsc/boot/loader_test.go +++ b/runsc/boot/loader_test.go @@ -26,6 +26,7 @@ import ( specs "github.com/opencontainers/runtime-spec/specs-go" "golang.org/x/sys/unix" "gvisor.dev/gvisor/pkg/control/server" + "gvisor.dev/gvisor/pkg/fd" "gvisor.dev/gvisor/pkg/fspath" "gvisor.dev/gvisor/pkg/log" "gvisor.dev/gvisor/pkg/p9" @@ -34,6 +35,7 @@ import ( "gvisor.dev/gvisor/pkg/sentry/vfs" "gvisor.dev/gvisor/pkg/sync" "gvisor.dev/gvisor/pkg/unet" + "gvisor.dev/gvisor/runsc/config" "gvisor.dev/gvisor/runsc/fsgofer" ) @@ -43,15 +45,19 @@ func init() { if err := fsgofer.OpenProcSelfFD(); err != nil { panic(err) } + config.RegisterFlags() } -func testConfig() *Config { - return &Config{ - RootDir: "unused_root_dir", - Network: NetworkNone, - DisableSeccomp: true, - Platform: "ptrace", +func testConfig() *config.Config { + conf, err := config.NewFromFlags() + if err != nil { + panic(err) } + // Change test defaults. + conf.RootDir = "unused_root_dir" + conf.Network = config.NetworkNone + conf.DisableSeccomp = true + return conf } // testSpec returns a simple spec that can be used in tests. @@ -439,7 +445,7 @@ func TestCreateMountNamespace(t *testing.T) { } defer cleanup() - mntr := newContainerMounter(&tc.spec, []int{sandEnd}, nil, &podMountHints{}) + mntr := newContainerMounter(&tc.spec, []*fd.FD{fd.New(sandEnd)}, nil, &podMountHints{}) mns, err := mntr.createMountNamespace(ctx, conf) if err != nil { t.Fatalf("failed to create mount namespace: %v", err) @@ -485,9 +491,9 @@ func TestCreateMountNamespaceVFS2(t *testing.T) { } ctx := l.k.SupervisorContext() - mns, err := mntr.setupVFS2(ctx, l.root.conf, &l.root.procArgs) + mns, err := mntr.mountAll(l.root.conf, &l.root.procArgs) if err != nil { - t.Fatalf("failed to setupVFS2: %v", err) + t.Fatalf("mountAll: %v", err) } root := mns.Root() @@ -545,7 +551,7 @@ func TestRestoreEnvironment(t *testing.T) { { Dev: "9pfs-/", Flags: fs.MountSourceFlags{ReadOnly: true}, - DataString: "trans=fd,rfdno=0,wfdno=0,privateunixsocket=true,cache=remote_revalidating", + DataString: "trans=fd,rfdno=0,wfdno=0,privateunixsocket=true", }, }, "tmpfs": { @@ -599,7 +605,7 @@ func TestRestoreEnvironment(t *testing.T) { { Dev: "9pfs-/", Flags: fs.MountSourceFlags{ReadOnly: true}, - DataString: "trans=fd,rfdno=0,wfdno=0,privateunixsocket=true,cache=remote_revalidating", + DataString: "trans=fd,rfdno=0,wfdno=0,privateunixsocket=true", }, { Dev: "9pfs-/dev/fd-foo", @@ -657,7 +663,7 @@ func TestRestoreEnvironment(t *testing.T) { { Dev: "9pfs-/", Flags: fs.MountSourceFlags{ReadOnly: true}, - DataString: "trans=fd,rfdno=0,wfdno=0,privateunixsocket=true,cache=remote_revalidating", + DataString: "trans=fd,rfdno=0,wfdno=0,privateunixsocket=true", }, }, "tmpfs": { @@ -697,7 +703,11 @@ func TestRestoreEnvironment(t *testing.T) { for _, tc := range testCases { t.Run(tc.name, func(t *testing.T) { conf := testConfig() - mntr := newContainerMounter(tc.spec, tc.ioFDs, nil, &podMountHints{}) + var ioFDs []*fd.FD + for _, ioFD := range tc.ioFDs { + ioFDs = append(ioFDs, fd.New(ioFD)) + } + mntr := newContainerMounter(tc.spec, ioFDs, nil, &podMountHints{}) actualRenv, err := mntr.createRestoreEnvironment(conf) if !tc.errorExpected && err != nil { t.Fatalf("could not create restore environment for test:%s", tc.name) diff --git a/runsc/boot/network.go b/runsc/boot/network.go index 4e1fa7665..988573640 100644 --- a/runsc/boot/network.go +++ b/runsc/boot/network.go @@ -33,6 +33,7 @@ import ( "gvisor.dev/gvisor/pkg/tcpip/network/ipv6" "gvisor.dev/gvisor/pkg/tcpip/stack" "gvisor.dev/gvisor/pkg/urpc" + "gvisor.dev/gvisor/runsc/config" ) var ( @@ -78,44 +79,6 @@ type DefaultRoute struct { Name string } -// QueueingDiscipline is used to specify the kind of Queueing Discipline to -// apply for a give FDBasedLink. -type QueueingDiscipline int - -const ( - // QDiscNone disables any queueing for the underlying FD. - QDiscNone QueueingDiscipline = iota - - // QDiscFIFO applies a simple fifo based queue to the underlying - // FD. - QDiscFIFO -) - -// MakeQueueingDiscipline if possible the equivalent QueuingDiscipline for s -// else returns an error. -func MakeQueueingDiscipline(s string) (QueueingDiscipline, error) { - switch s { - case "none": - return QDiscNone, nil - case "fifo": - return QDiscFIFO, nil - default: - return 0, fmt.Errorf("unsupported qdisc specified: %q", s) - } -} - -// String implements fmt.Stringer. -func (q QueueingDiscipline) String() string { - switch q { - case QDiscNone: - return "none" - case QDiscFIFO: - return "fifo" - default: - panic(fmt.Sprintf("Invalid queueing discipline: %d", q)) - } -} - // FDBasedLink configures an fd-based link. type FDBasedLink struct { Name string @@ -127,7 +90,7 @@ type FDBasedLink struct { TXChecksumOffload bool RXChecksumOffload bool LinkAddress net.HardwareAddr - QDisc QueueingDiscipline + QDisc config.QueueingDiscipline // NumChannels controls how many underlying FD's are to be used to // create this endpoint. @@ -247,8 +210,8 @@ func (n *Network) CreateLinksAndRoutes(args *CreateLinksAndRoutesArgs, _ *struct } switch link.QDisc { - case QDiscNone: - case QDiscFIFO: + case config.QDiscNone: + case config.QDiscFIFO: log.Infof("Enabling FIFO QDisc on %q", link.Name) linkEP = fifo.New(linkEP, runtime.GOMAXPROCS(0), 1000) } diff --git a/runsc/boot/strace.go b/runsc/boot/strace.go index fbfd3b07c..c21648a32 100644 --- a/runsc/boot/strace.go +++ b/runsc/boot/strace.go @@ -15,10 +15,13 @@ package boot import ( + "strings" + "gvisor.dev/gvisor/pkg/sentry/strace" + "gvisor.dev/gvisor/runsc/config" ) -func enableStrace(conf *Config) error { +func enableStrace(conf *config.Config) error { // We must initialize even if strace is not enabled. strace.Initialize() @@ -36,5 +39,5 @@ func enableStrace(conf *Config) error { strace.EnableAll(strace.SinkTypeLog) return nil } - return strace.Enable(conf.StraceSyscalls, strace.SinkTypeLog) + return strace.Enable(strings.Split(conf.StraceSyscalls, ","), strace.SinkTypeLog) } diff --git a/runsc/boot/vfs.go b/runsc/boot/vfs.go index 08dce8b6c..e36664938 100644 --- a/runsc/boot/vfs.go +++ b/runsc/boot/vfs.go @@ -16,12 +16,12 @@ package boot import ( "fmt" - "path" "sort" "strings" specs "github.com/opencontainers/runtime-spec/specs-go" "gvisor.dev/gvisor/pkg/abi/linux" + "gvisor.dev/gvisor/pkg/cleanup" "gvisor.dev/gvisor/pkg/context" "gvisor.dev/gvisor/pkg/fspath" "gvisor.dev/gvisor/pkg/log" @@ -42,6 +42,7 @@ import ( "gvisor.dev/gvisor/pkg/sentry/kernel/auth" "gvisor.dev/gvisor/pkg/sentry/vfs" "gvisor.dev/gvisor/pkg/syserror" + "gvisor.dev/gvisor/runsc/config" ) func registerFilesystems(k *kernel.Kernel) error { @@ -133,8 +134,8 @@ func registerFilesystems(k *kernel.Kernel) error { return nil } -func setupContainerVFS2(ctx context.Context, conf *Config, mntr *containerMounter, procArgs *kernel.CreateProcessArgs) error { - mns, err := mntr.setupVFS2(ctx, conf, procArgs) +func setupContainerVFS2(ctx context.Context, conf *config.Config, mntr *containerMounter, procArgs *kernel.CreateProcessArgs) error { + mns, err := mntr.mountAll(conf, procArgs) if err != nil { return fmt.Errorf("failed to setupFS: %w", err) } @@ -149,7 +150,7 @@ func setupContainerVFS2(ctx context.Context, conf *Config, mntr *containerMounte return nil } -func (c *containerMounter) setupVFS2(ctx context.Context, conf *Config, procArgs *kernel.CreateProcessArgs) (*vfs.MountNamespace, error) { +func (c *containerMounter) mountAll(conf *config.Config, procArgs *kernel.CreateProcessArgs) (*vfs.MountNamespace, error) { log.Infof("Configuring container's file system with VFS2") // Create context with root credentials to mount the filesystem (the current @@ -168,35 +169,109 @@ func (c *containerMounter) setupVFS2(ctx context.Context, conf *Config, procArgs } rootProcArgs.MountNamespaceVFS2 = mns + root := mns.Root() + defer root.DecRef(rootCtx) + if root.Mount().ReadOnly() { + // Switch to ReadWrite while we setup submounts. + if err := c.k.VFS().SetMountReadOnly(root.Mount(), false); err != nil { + return nil, fmt.Errorf(`failed to set mount at "/" readwrite: %w`, err) + } + // Restore back to ReadOnly at the end. + defer func() { + if err := c.k.VFS().SetMountReadOnly(root.Mount(), true); err != nil { + panic(fmt.Sprintf(`failed to restore mount at "/" back to readonly: %v`, err)) + } + }() + } + // Mount submounts. if err := c.mountSubmountsVFS2(rootCtx, conf, mns, rootCreds); err != nil { return nil, fmt.Errorf("mounting submounts vfs2: %w", err) } + return mns, nil } -func (c *containerMounter) createMountNamespaceVFS2(ctx context.Context, conf *Config, creds *auth.Credentials) (*vfs.MountNamespace, error) { +// createMountNamespaceVFS2 creates the container's root mount and namespace. +func (c *containerMounter) createMountNamespaceVFS2(ctx context.Context, conf *config.Config, creds *auth.Credentials) (*vfs.MountNamespace, error) { fd := c.fds.remove() - opts := p9MountData(fd, conf.FileAccess, true /* vfs2 */) + data := p9MountData(fd, conf.FileAccess, true /* vfs2 */) if conf.OverlayfsStaleRead { // We can't check for overlayfs here because sandbox is chroot'ed and gofer // can only send mount options for specs.Mounts (specs.Root is missing // Options field). So assume root is always on top of overlayfs. - opts = append(opts, "overlayfs_stale_read") + data = append(data, "overlayfs_stale_read") } log.Infof("Mounting root over 9P, ioFD: %d", fd) - mns, err := c.k.VFS().NewMountNamespace(ctx, creds, "", gofer.Name, &vfs.GetFilesystemOptions{ - Data: strings.Join(opts, ","), - }) + opts := &vfs.MountOptions{ + ReadOnly: c.root.Readonly, + GetFilesystemOptions: vfs.GetFilesystemOptions{ + Data: strings.Join(data, ","), + }, + InternalMount: true, + } + + fsName := gofer.Name + if conf.Overlay && !c.root.Readonly { + log.Infof("Adding overlay on top of root") + var err error + var cleanup func() + opts, cleanup, err = c.configureOverlay(ctx, creds, opts, fsName) + if err != nil { + return nil, fmt.Errorf("mounting root with overlay: %w", err) + } + defer cleanup() + fsName = overlay.Name + } + + mns, err := c.k.VFS().NewMountNamespace(ctx, creds, "", fsName, opts) if err != nil { return nil, fmt.Errorf("setting up mount namespace: %w", err) } return mns, nil } -func (c *containerMounter) mountSubmountsVFS2(ctx context.Context, conf *Config, mns *vfs.MountNamespace, creds *auth.Credentials) error { +// configureOverlay mounts the lower layer using "lowerOpts", mounts the upper +// layer using tmpfs, and return overlay mount options. "cleanup" must be called +// after the options have been used to mount the overlay, to release refs on +// lower and upper mounts. +func (c *containerMounter) configureOverlay(ctx context.Context, creds *auth.Credentials, lowerOpts *vfs.MountOptions, lowerFSName string) (*vfs.MountOptions, func(), error) { + // First copy options from lower layer to upper layer and overlay. Clear + // filesystem specific options. + upperOpts := *lowerOpts + upperOpts.GetFilesystemOptions = vfs.GetFilesystemOptions{} + + overlayOpts := *lowerOpts + overlayOpts.GetFilesystemOptions = vfs.GetFilesystemOptions{} + + // Next mount upper and lower. Upper is a tmpfs mount to keep all + // modifications inside the sandbox. + upper, err := c.k.VFS().MountDisconnected(ctx, creds, "" /* source */, tmpfs.Name, &upperOpts) + if err != nil { + return nil, nil, fmt.Errorf("failed to create upper layer for overlay, opts: %+v: %v", upperOpts, err) + } + cu := cleanup.Make(func() { upper.DecRef(ctx) }) + defer cu.Clean() + + // All writes go to the upper layer, be paranoid and make lower readonly. + lowerOpts.ReadOnly = true + lower, err := c.k.VFS().MountDisconnected(ctx, creds, "" /* source */, lowerFSName, lowerOpts) + if err != nil { + return nil, nil, err + } + cu.Add(func() { lower.DecRef(ctx) }) + + // Configure overlay with both layers. + overlayOpts.GetFilesystemOptions.InternalData = overlay.FilesystemOptions{ + UpperRoot: vfs.MakeVirtualDentry(upper, upper.Root()), + LowerRoots: []vfs.VirtualDentry{vfs.MakeVirtualDentry(lower, lower.Root())}, + } + return &overlayOpts, cu.Release(), nil +} + +func (c *containerMounter) mountSubmountsVFS2(ctx context.Context, conf *config.Config, mns *vfs.MountNamespace, creds *auth.Credentials) error { mounts, err := c.prepareMountsVFS2() if err != nil { return err @@ -205,15 +280,35 @@ func (c *containerMounter) mountSubmountsVFS2(ctx context.Context, conf *Config, for i := range mounts { submount := &mounts[i] log.Debugf("Mounting %q to %q, type: %s, options: %s", submount.Source, submount.Destination, submount.Type, submount.Options) + var ( + mnt *vfs.Mount + err error + ) + if hint := c.hints.findMount(submount.Mount); hint != nil && hint.isSupported() { - if err := c.mountSharedSubmountVFS2(ctx, conf, mns, creds, submount.Mount, hint); err != nil { + mnt, err = c.mountSharedSubmountVFS2(ctx, conf, mns, creds, submount.Mount, hint) + if err != nil { return fmt.Errorf("mount shared mount %q to %q: %v", hint.name, submount.Destination, err) } } else { - if err := c.mountSubmountVFS2(ctx, conf, mns, creds, submount); err != nil { + mnt, err = c.mountSubmountVFS2(ctx, conf, mns, creds, submount) + if err != nil { return fmt.Errorf("mount submount %q: %w", submount.Destination, err) } } + + if mnt != nil && mnt.ReadOnly() { + // Switch to ReadWrite while we setup submounts. + if err := c.k.VFS().SetMountReadOnly(mnt, false); err != nil { + return fmt.Errorf("failed to set mount at %q readwrite: %w", submount.Destination, err) + } + // Restore back to ReadOnly at the end. + defer func() { + if err := c.k.VFS().SetMountReadOnly(mnt, true); err != nil { + panic(fmt.Sprintf("failed to restore mount at %q back to readonly: %v", submount.Destination, err)) + } + }() + } } if err := c.mountTmpVFS2(ctx, conf, creds, mns); err != nil { @@ -256,7 +351,31 @@ func (c *containerMounter) prepareMountsVFS2() ([]mountAndFD, error) { return mounts, nil } -func (c *containerMounter) mountSubmountVFS2(ctx context.Context, conf *Config, mns *vfs.MountNamespace, creds *auth.Credentials, submount *mountAndFD) error { +func (c *containerMounter) mountSubmountVFS2(ctx context.Context, conf *config.Config, mns *vfs.MountNamespace, creds *auth.Credentials, submount *mountAndFD) (*vfs.Mount, error) { + fsName, opts, useOverlay, err := c.getMountNameAndOptionsVFS2(conf, submount) + if err != nil { + return nil, fmt.Errorf("mountOptions failed: %w", err) + } + if len(fsName) == 0 { + // Filesystem is not supported (e.g. cgroup), just skip it. + return nil, nil + } + + if err := c.makeMountPoint(ctx, creds, mns, submount.Destination); err != nil { + return nil, fmt.Errorf("creating mount point %q: %w", submount.Destination, err) + } + + if useOverlay { + log.Infof("Adding overlay on top of mount %q", submount.Destination) + var cleanup func() + opts, cleanup, err = c.configureOverlay(ctx, creds, opts, fsName) + if err != nil { + return nil, fmt.Errorf("mounting volume with overlay at %q: %w", submount.Destination, err) + } + defer cleanup() + fsName = overlay.Name + } + root := mns.Root() defer root.DecRef(ctx) target := &vfs.PathOperation{ @@ -264,29 +383,19 @@ func (c *containerMounter) mountSubmountVFS2(ctx context.Context, conf *Config, Start: root, Path: fspath.Parse(submount.Destination), } - fsName, opts, err := c.getMountNameAndOptionsVFS2(conf, submount) + mnt, err := c.k.VFS().MountAt(ctx, creds, "", target, fsName, opts) if err != nil { - return fmt.Errorf("mountOptions failed: %w", err) - } - if len(fsName) == 0 { - // Filesystem is not supported (e.g. cgroup), just skip it. - return nil - } - - if err := c.makeSyntheticMount(ctx, submount.Destination, root, creds); err != nil { - return err - } - if err := c.k.VFS().MountAt(ctx, creds, "", target, fsName, opts); err != nil { - return fmt.Errorf("failed to mount %q (type: %s): %w, opts: %v", submount.Destination, submount.Type, err, opts) + return nil, fmt.Errorf("failed to mount %q (type: %s): %w, opts: %v", submount.Destination, submount.Type, err, opts) } log.Infof("Mounted %q to %q type: %s, internal-options: %q", submount.Source, submount.Destination, submount.Type, opts.GetFilesystemOptions.Data) - return nil + return mnt, nil } // getMountNameAndOptionsVFS2 retrieves the fsName, opts, and useOverlay values // used for mounts. -func (c *containerMounter) getMountNameAndOptionsVFS2(conf *Config, m *mountAndFD) (string, *vfs.MountOptions, error) { +func (c *containerMounter) getMountNameAndOptionsVFS2(conf *config.Config, m *mountAndFD) (string, *vfs.MountOptions, bool, error) { fsName := m.Type + useOverlay := false var data []string // Find filesystem name and FS specific data field. @@ -301,7 +410,7 @@ func (c *containerMounter) getMountNameAndOptionsVFS2(conf *Config, m *mountAndF var err error data, err = parseAndFilterOptions(m.Options, tmpfsAllowedData...) if err != nil { - return "", nil, err + return "", nil, false, err } case bind: @@ -309,13 +418,16 @@ func (c *containerMounter) getMountNameAndOptionsVFS2(conf *Config, m *mountAndF if m.fd == 0 { // Check that an FD was provided to fails fast. Technically FD=0 is valid, // but unlikely to be correct in this context. - return "", nil, fmt.Errorf("9P mount requires a connection FD") + return "", nil, false, fmt.Errorf("9P mount requires a connection FD") } data = p9MountData(m.fd, c.getMountAccessType(m.Mount), true /* vfs2 */) + // If configured, add overlay to all writable mounts. + useOverlay = conf.Overlay && !mountFlags(m.Options).ReadOnly + default: log.Warningf("ignoring unknown filesystem type %q", m.Type) - return "", nil, nil + return "", nil, false, nil } opts := &vfs.MountOptions{ @@ -340,38 +452,7 @@ func (c *containerMounter) getMountNameAndOptionsVFS2(conf *Config, m *mountAndF } } - if conf.Overlay { - // All writes go to upper, be paranoid and make lower readonly. - opts.ReadOnly = true - } - return fsName, opts, nil -} - -func (c *containerMounter) makeSyntheticMount(ctx context.Context, currentPath string, root vfs.VirtualDentry, creds *auth.Credentials) error { - target := &vfs.PathOperation{ - Root: root, - Start: root, - Path: fspath.Parse(currentPath), - } - _, err := c.k.VFS().StatAt(ctx, creds, target, &vfs.StatOptions{}) - if err == nil { - log.Debugf("Mount point %q already exists", currentPath) - return nil - } - if err != syserror.ENOENT { - return fmt.Errorf("stat failed for %q during mount point creation: %w", currentPath, err) - } - - // Recurse to ensure parent is created and then create the mount point. - if err := c.makeSyntheticMount(ctx, path.Dir(currentPath), root, creds); err != nil { - return err - } - log.Debugf("Creating dir %q for mount point", currentPath) - mkdirOpts := &vfs.MkdirOptions{Mode: 0777, ForSyntheticMountpoint: true} - if err := c.k.VFS().MkdirAt(ctx, creds, target, mkdirOpts); err != nil { - return fmt.Errorf("failed to create directory %q for mount: %w", currentPath, err) - } - return nil + return fsName, opts, useOverlay, nil } // mountTmpVFS2 mounts an internal tmpfs at '/tmp' if it's safe to do so. @@ -383,7 +464,7 @@ func (c *containerMounter) makeSyntheticMount(ctx context.Context, currentPath s // // Note that when there are submounts inside of '/tmp', directories for the // mount points must be present, making '/tmp' not empty anymore. -func (c *containerMounter) mountTmpVFS2(ctx context.Context, conf *Config, creds *auth.Credentials, mns *vfs.MountNamespace) error { +func (c *containerMounter) mountTmpVFS2(ctx context.Context, conf *config.Config, creds *auth.Credentials, mns *vfs.MountNamespace) error { for _, m := range c.mounts { // m.Destination has been cleaned, so it's to use equality here. if m.Destination == "/tmp" { @@ -434,7 +515,8 @@ func (c *containerMounter) mountTmpVFS2(ctx context.Context, conf *Config, creds // another user. This is normally done for /tmp. Options: []string{"mode=01777"}, } - return c.mountSubmountVFS2(ctx, conf, mns, creds, &mountAndFD{Mount: tmpMount}) + _, err := c.mountSubmountVFS2(ctx, conf, mns, creds, &mountAndFD{Mount: tmpMount}) + return err case syserror.ENOTDIR: // Not a dir?! Let it be. @@ -448,7 +530,7 @@ func (c *containerMounter) mountTmpVFS2(ctx context.Context, conf *Config, creds // processHintsVFS2 processes annotations that container hints about how volumes // should be mounted (e.g. a volume shared between containers). It must be // called for the root container only. -func (c *containerMounter) processHintsVFS2(conf *Config, creds *auth.Credentials) error { +func (c *containerMounter) processHintsVFS2(conf *config.Config, creds *auth.Credentials) error { ctx := c.k.SupervisorContext() for _, hint := range c.hints.mounts { // TODO(b/142076984): Only support tmpfs for now. Bind mounts require a @@ -469,51 +551,86 @@ func (c *containerMounter) processHintsVFS2(conf *Config, creds *auth.Credential // mountSharedMasterVFS2 mounts the master of a volume that is shared among // containers in a pod. -func (c *containerMounter) mountSharedMasterVFS2(ctx context.Context, conf *Config, hint *mountHint, creds *auth.Credentials) (*vfs.Mount, error) { +func (c *containerMounter) mountSharedMasterVFS2(ctx context.Context, conf *config.Config, hint *mountHint, creds *auth.Credentials) (*vfs.Mount, error) { // Map mount type to filesystem name, and parse out the options that we are // capable of dealing with. mntFD := &mountAndFD{Mount: hint.mount} - fsName, opts, err := c.getMountNameAndOptionsVFS2(conf, mntFD) + fsName, opts, useOverlay, err := c.getMountNameAndOptionsVFS2(conf, mntFD) if err != nil { return nil, err } if len(fsName) == 0 { return nil, fmt.Errorf("mount type not supported %q", hint.mount.Type) } + + if useOverlay { + log.Infof("Adding overlay on top of shared mount %q", mntFD.Destination) + var cleanup func() + opts, cleanup, err = c.configureOverlay(ctx, creds, opts, fsName) + if err != nil { + return nil, fmt.Errorf("mounting shared volume with overlay at %q: %w", mntFD.Destination, err) + } + defer cleanup() + fsName = overlay.Name + } + return c.k.VFS().MountDisconnected(ctx, creds, "", fsName, opts) } // mountSharedSubmount binds mount to a previously mounted volume that is shared // among containers in the same pod. -func (c *containerMounter) mountSharedSubmountVFS2(ctx context.Context, conf *Config, mns *vfs.MountNamespace, creds *auth.Credentials, mount specs.Mount, source *mountHint) error { +func (c *containerMounter) mountSharedSubmountVFS2(ctx context.Context, conf *config.Config, mns *vfs.MountNamespace, creds *auth.Credentials, mount specs.Mount, source *mountHint) (*vfs.Mount, error) { if err := source.checkCompatible(mount); err != nil { - return err + return nil, err } - _, opts, err := c.getMountNameAndOptionsVFS2(conf, &mountAndFD{Mount: mount}) + // Ignore data and useOverlay because these were already applied to + // the master mount. + _, opts, _, err := c.getMountNameAndOptionsVFS2(conf, &mountAndFD{Mount: mount}) if err != nil { - return err + return nil, err } newMnt, err := c.k.VFS().NewDisconnectedMount(source.vfsMount.Filesystem(), source.vfsMount.Root(), opts) if err != nil { - return err + return nil, err } defer newMnt.DecRef(ctx) root := mns.Root() defer root.DecRef(ctx) - if err := c.makeSyntheticMount(ctx, mount.Destination, root, creds); err != nil { - return err - } - target := &vfs.PathOperation{ Root: root, Start: root, Path: fspath.Parse(mount.Destination), } + + if err := c.makeMountPoint(ctx, creds, mns, mount.Destination); err != nil { + return nil, fmt.Errorf("creating mount point %q: %w", mount.Destination, err) + } + if err := c.k.VFS().ConnectMountAt(ctx, creds, newMnt, target); err != nil { - return err + return nil, err } log.Infof("Mounted %q type shared bind to %q", mount.Destination, source.name) - return nil + return newMnt, nil +} + +func (c *containerMounter) makeMountPoint(ctx context.Context, creds *auth.Credentials, mns *vfs.MountNamespace, dest string) error { + root := mns.Root() + defer root.DecRef(ctx) + target := &vfs.PathOperation{ + Root: root, + Start: root, + Path: fspath.Parse(dest), + } + // First check if mount point exists. When overlay is enabled, gofer doesn't + // allow changes to the FS, making MakeSytheticMountpoint() ineffective + // because MkdirAt fails with EROFS even if file exists. + vd, err := c.k.VFS().GetDentryAt(ctx, creds, target, &vfs.GetDentryOptions{}) + if err == nil { + // File exists, we're done. + vd.DecRef(ctx) + return nil + } + return c.k.VFS().MakeSyntheticMountpoint(ctx, dest, root, creds) } diff --git a/runsc/cmd/BUILD b/runsc/cmd/BUILD index 1b5178dd5..2556f6d9e 100644 --- a/runsc/cmd/BUILD +++ b/runsc/cmd/BUILD @@ -51,6 +51,7 @@ go_library( "//pkg/unet", "//pkg/urpc", "//runsc/boot", + "//runsc/config", "//runsc/console", "//runsc/container", "//runsc/flag", @@ -84,7 +85,7 @@ go_test( "//pkg/sentry/kernel/auth", "//pkg/test/testutil", "//pkg/urpc", - "//runsc/boot", + "//runsc/config", "//runsc/container", "//runsc/specutils", "@com_github_google_go_cmp//cmp:go_default_library", diff --git a/runsc/cmd/boot.go b/runsc/cmd/boot.go index f4f247721..cd419e1aa 100644 --- a/runsc/cmd/boot.go +++ b/runsc/cmd/boot.go @@ -27,6 +27,7 @@ import ( "gvisor.dev/gvisor/pkg/log" "gvisor.dev/gvisor/pkg/sentry/platform" "gvisor.dev/gvisor/runsc/boot" + "gvisor.dev/gvisor/runsc/config" "gvisor.dev/gvisor/runsc/flag" "gvisor.dev/gvisor/runsc/specutils" ) @@ -133,7 +134,7 @@ func (b *Boot) Execute(_ context.Context, f *flag.FlagSet, args ...interface{}) // Ensure that if there is a panic, all goroutine stacks are printed. debug.SetTraceback("system") - conf := args[0].(*boot.Config) + conf := args[0].(*config.Config) if b.attached { // Ensure this process is killed after parent process terminates when @@ -167,7 +168,7 @@ func (b *Boot) Execute(_ context.Context, f *flag.FlagSet, args ...interface{}) // Get the spec from the specFD. specFile := os.NewFile(uintptr(b.specFD), "spec file") defer specFile.Close() - spec, err := specutils.ReadSpecFromFile(b.bundleDir, specFile) + spec, err := specutils.ReadSpecFromFile(b.bundleDir, specFile, conf) if err != nil { Fatalf("reading spec: %v", err) } diff --git a/runsc/cmd/capability_test.go b/runsc/cmd/capability_test.go index a84067112..e13a94486 100644 --- a/runsc/cmd/capability_test.go +++ b/runsc/cmd/capability_test.go @@ -24,7 +24,7 @@ import ( "github.com/syndtr/gocapability/capability" "gvisor.dev/gvisor/pkg/log" "gvisor.dev/gvisor/pkg/test/testutil" - "gvisor.dev/gvisor/runsc/boot" + "gvisor.dev/gvisor/runsc/config" "gvisor.dev/gvisor/runsc/container" "gvisor.dev/gvisor/runsc/specutils" ) @@ -88,7 +88,7 @@ func TestCapabilities(t *testing.T) { conf := testutil.TestConfig(t) // Use --network=host to make sandbox use spec's capabilities. - conf.Network = boot.NetworkHost + conf.Network = config.NetworkHost _, bundleDir, cleanup, err := testutil.SetupContainer(spec, conf) if err != nil { diff --git a/runsc/cmd/checkpoint.go b/runsc/cmd/checkpoint.go index 8a29e521e..8fe0c427a 100644 --- a/runsc/cmd/checkpoint.go +++ b/runsc/cmd/checkpoint.go @@ -22,7 +22,7 @@ import ( "github.com/google/subcommands" "gvisor.dev/gvisor/pkg/log" - "gvisor.dev/gvisor/runsc/boot" + "gvisor.dev/gvisor/runsc/config" "gvisor.dev/gvisor/runsc/container" "gvisor.dev/gvisor/runsc/flag" "gvisor.dev/gvisor/runsc/specutils" @@ -72,7 +72,7 @@ func (c *Checkpoint) Execute(_ context.Context, f *flag.FlagSet, args ...interfa } id := f.Arg(0) - conf := args[0].(*boot.Config) + conf := args[0].(*config.Config) waitStatus := args[1].(*syscall.WaitStatus) cont, err := container.Load(conf.RootDir, id) @@ -118,7 +118,7 @@ func (c *Checkpoint) Execute(_ context.Context, f *flag.FlagSet, args ...interfa Fatalf("setting bundleDir") } - spec, err := specutils.ReadSpec(bundleDir) + spec, err := specutils.ReadSpec(bundleDir, conf) if err != nil { Fatalf("reading spec: %v", err) } diff --git a/runsc/cmd/create.go b/runsc/cmd/create.go index 910e97577..e76f7ba1d 100644 --- a/runsc/cmd/create.go +++ b/runsc/cmd/create.go @@ -18,7 +18,7 @@ import ( "context" "github.com/google/subcommands" - "gvisor.dev/gvisor/runsc/boot" + "gvisor.dev/gvisor/runsc/config" "gvisor.dev/gvisor/runsc/container" "gvisor.dev/gvisor/runsc/flag" "gvisor.dev/gvisor/runsc/specutils" @@ -81,7 +81,7 @@ func (c *Create) Execute(_ context.Context, f *flag.FlagSet, args ...interface{} } id := f.Arg(0) - conf := args[0].(*boot.Config) + conf := args[0].(*config.Config) if conf.Rootless { return Errorf("Rootless mode not supported with %q", c.Name()) @@ -91,7 +91,7 @@ func (c *Create) Execute(_ context.Context, f *flag.FlagSet, args ...interface{} if bundleDir == "" { bundleDir = getwdOrDie() } - spec, err := specutils.ReadSpec(bundleDir) + spec, err := specutils.ReadSpec(bundleDir, conf) if err != nil { return Errorf("reading spec: %v", err) } diff --git a/runsc/cmd/debug.go b/runsc/cmd/debug.go index 742f8c344..132198222 100644 --- a/runsc/cmd/debug.go +++ b/runsc/cmd/debug.go @@ -25,7 +25,7 @@ import ( "github.com/google/subcommands" "gvisor.dev/gvisor/pkg/log" "gvisor.dev/gvisor/pkg/sentry/control" - "gvisor.dev/gvisor/runsc/boot" + "gvisor.dev/gvisor/runsc/config" "gvisor.dev/gvisor/runsc/container" "gvisor.dev/gvisor/runsc/flag" ) @@ -82,7 +82,7 @@ func (d *Debug) SetFlags(f *flag.FlagSet) { // Execute implements subcommands.Command.Execute. func (d *Debug) Execute(_ context.Context, f *flag.FlagSet, args ...interface{}) subcommands.ExitStatus { var c *container.Container - conf := args[0].(*boot.Config) + conf := args[0].(*config.Config) if d.pid == 0 { // No pid, container ID must have been provided. diff --git a/runsc/cmd/delete.go b/runsc/cmd/delete.go index 0e4863f50..4e49deff8 100644 --- a/runsc/cmd/delete.go +++ b/runsc/cmd/delete.go @@ -21,7 +21,7 @@ import ( "github.com/google/subcommands" "gvisor.dev/gvisor/pkg/log" - "gvisor.dev/gvisor/runsc/boot" + "gvisor.dev/gvisor/runsc/config" "gvisor.dev/gvisor/runsc/container" "gvisor.dev/gvisor/runsc/flag" ) @@ -59,14 +59,14 @@ func (d *Delete) Execute(_ context.Context, f *flag.FlagSet, args ...interface{} return subcommands.ExitUsageError } - conf := args[0].(*boot.Config) + conf := args[0].(*config.Config) if err := d.execute(f.Args(), conf); err != nil { Fatalf("%v", err) } return subcommands.ExitSuccess } -func (d *Delete) execute(ids []string, conf *boot.Config) error { +func (d *Delete) execute(ids []string, conf *config.Config) error { for _, id := range ids { c, err := container.Load(conf.RootDir, id) if err != nil { diff --git a/runsc/cmd/delete_test.go b/runsc/cmd/delete_test.go index cb59516a3..e2d994a05 100644 --- a/runsc/cmd/delete_test.go +++ b/runsc/cmd/delete_test.go @@ -18,7 +18,7 @@ import ( "io/ioutil" "testing" - "gvisor.dev/gvisor/runsc/boot" + "gvisor.dev/gvisor/runsc/config" ) func TestNotFound(t *testing.T) { @@ -27,7 +27,7 @@ func TestNotFound(t *testing.T) { if err != nil { t.Fatalf("error creating dir: %v", err) } - conf := &boot.Config{RootDir: dir} + conf := &config.Config{RootDir: dir} d := Delete{} if err := d.execute(ids, conf); err == nil { diff --git a/runsc/cmd/do.go b/runsc/cmd/do.go index 7d1310c96..d1f2e9e6d 100644 --- a/runsc/cmd/do.go +++ b/runsc/cmd/do.go @@ -30,7 +30,7 @@ import ( "github.com/google/subcommands" specs "github.com/opencontainers/runtime-spec/specs-go" "gvisor.dev/gvisor/pkg/log" - "gvisor.dev/gvisor/runsc/boot" + "gvisor.dev/gvisor/runsc/config" "gvisor.dev/gvisor/runsc/container" "gvisor.dev/gvisor/runsc/flag" "gvisor.dev/gvisor/runsc/specutils" @@ -82,7 +82,7 @@ func (c *Do) Execute(_ context.Context, f *flag.FlagSet, args ...interface{}) su return subcommands.ExitUsageError } - conf := args[0].(*boot.Config) + conf := args[0].(*config.Config) waitStatus := args[1].(*syscall.WaitStatus) if conf.Rootless { @@ -125,7 +125,7 @@ func (c *Do) Execute(_ context.Context, f *flag.FlagSet, args ...interface{}) su specutils.LogSpec(spec) cid := fmt.Sprintf("runsc-%06d", rand.Int31n(1000000)) - if conf.Network == boot.NetworkNone { + if conf.Network == config.NetworkNone { netns := specs.LinuxNamespace{ Type: specs.NetworkNamespace, } @@ -135,9 +135,9 @@ func (c *Do) Execute(_ context.Context, f *flag.FlagSet, args ...interface{}) su spec.Linux = &specs.Linux{Namespaces: []specs.LinuxNamespace{netns}} } else if conf.Rootless { - if conf.Network == boot.NetworkSandbox { + if conf.Network == config.NetworkSandbox { c.notifyUser("*** Warning: using host network due to --rootless ***") - conf.Network = boot.NetworkHost + conf.Network = config.NetworkHost } } else { diff --git a/runsc/cmd/events.go b/runsc/cmd/events.go index 51f6a98ed..25fe2cf1c 100644 --- a/runsc/cmd/events.go +++ b/runsc/cmd/events.go @@ -22,7 +22,7 @@ import ( "github.com/google/subcommands" "gvisor.dev/gvisor/pkg/log" - "gvisor.dev/gvisor/runsc/boot" + "gvisor.dev/gvisor/runsc/config" "gvisor.dev/gvisor/runsc/container" "gvisor.dev/gvisor/runsc/flag" ) @@ -72,7 +72,7 @@ func (evs *Events) Execute(ctx context.Context, f *flag.FlagSet, args ...interfa } id := f.Arg(0) - conf := args[0].(*boot.Config) + conf := args[0].(*config.Config) c, err := container.Load(conf.RootDir, id) if err != nil { diff --git a/runsc/cmd/exec.go b/runsc/cmd/exec.go index d9a94903e..775ed4b43 100644 --- a/runsc/cmd/exec.go +++ b/runsc/cmd/exec.go @@ -33,7 +33,7 @@ import ( "gvisor.dev/gvisor/pkg/sentry/control" "gvisor.dev/gvisor/pkg/sentry/kernel/auth" "gvisor.dev/gvisor/pkg/urpc" - "gvisor.dev/gvisor/runsc/boot" + "gvisor.dev/gvisor/runsc/config" "gvisor.dev/gvisor/runsc/console" "gvisor.dev/gvisor/runsc/container" "gvisor.dev/gvisor/runsc/flag" @@ -105,7 +105,7 @@ func (ex *Exec) SetFlags(f *flag.FlagSet) { // Execute implements subcommands.Command.Execute. It starts a process in an // already created container. func (ex *Exec) Execute(_ context.Context, f *flag.FlagSet, args ...interface{}) subcommands.ExitStatus { - conf := args[0].(*boot.Config) + conf := args[0].(*config.Config) e, id, err := ex.parseArgs(f, conf.EnableRaw) if err != nil { Fatalf("parsing process spec: %v", err) @@ -220,7 +220,7 @@ func (ex *Exec) execChildAndWait(waitStatus *syscall.WaitStatus) subcommands.Exi cmd.Stderr = os.Stderr // If the console control socket file is provided, then create a new - // pty master/slave pair and set the TTY on the sandbox process. + // pty master/replica pair and set the TTY on the sandbox process. if ex.consoleSocket != "" { // Create a new TTY pair and send the master on the provided socket. tty, err := console.NewWithSocket(ex.consoleSocket) @@ -229,7 +229,7 @@ func (ex *Exec) execChildAndWait(waitStatus *syscall.WaitStatus) subcommands.Exi } defer tty.Close() - // Set stdio to the new TTY slave. + // Set stdio to the new TTY replica. cmd.Stdin = tty cmd.Stdout = tty cmd.Stderr = tty diff --git a/runsc/cmd/gofer.go b/runsc/cmd/gofer.go index 3966e2d21..371fcc0ae 100644 --- a/runsc/cmd/gofer.go +++ b/runsc/cmd/gofer.go @@ -30,7 +30,7 @@ import ( "gvisor.dev/gvisor/pkg/p9" "gvisor.dev/gvisor/pkg/sync" "gvisor.dev/gvisor/pkg/unet" - "gvisor.dev/gvisor/runsc/boot" + "gvisor.dev/gvisor/runsc/config" "gvisor.dev/gvisor/runsc/flag" "gvisor.dev/gvisor/runsc/fsgofer" "gvisor.dev/gvisor/runsc/fsgofer/filter" @@ -62,9 +62,8 @@ type Gofer struct { applyCaps bool setUpRoot bool - panicOnWrite bool - specFD int - mountsFD int + specFD int + mountsFD int } // Name implements subcommands.Command. @@ -87,7 +86,6 @@ func (g *Gofer) SetFlags(f *flag.FlagSet) { f.StringVar(&g.bundleDir, "bundle", "", "path to the root of the bundle directory, defaults to the current directory") f.Var(&g.ioFDs, "io-fds", "list of FDs to connect 9P servers. They must follow this order: root first, then mounts as defined in the spec") f.BoolVar(&g.applyCaps, "apply-caps", true, "if true, apply capabilities to restrict what the Gofer process can do") - f.BoolVar(&g.panicOnWrite, "panic-on-write", false, "if true, panics on attempts to write to RO mounts. RW mounts are unnaffected") f.BoolVar(&g.setUpRoot, "setup-root", true, "if true, set up an empty root for the process") f.IntVar(&g.specFD, "spec-fd", -1, "required fd with the container spec") f.IntVar(&g.mountsFD, "mounts-fd", -1, "mountsFD is the file descriptor to write list of mounts after they have been resolved (direct paths, no symlinks).") @@ -100,15 +98,15 @@ func (g *Gofer) Execute(_ context.Context, f *flag.FlagSet, args ...interface{}) return subcommands.ExitUsageError } + conf := args[0].(*config.Config) + specFile := os.NewFile(uintptr(g.specFD), "spec file") defer specFile.Close() - spec, err := specutils.ReadSpecFromFile(g.bundleDir, specFile) + spec, err := specutils.ReadSpecFromFile(g.bundleDir, specFile, conf) if err != nil { Fatalf("reading spec: %v", err) } - conf := args[0].(*boot.Config) - if g.setUpRoot { if err := setupRootFS(spec, conf); err != nil { Fatalf("Error setting up root FS: %v", err) @@ -168,8 +166,7 @@ func (g *Gofer) Execute(_ context.Context, f *flag.FlagSet, args ...interface{}) // Start with root mount, then add any other additional mount as needed. ats := make([]p9.Attacher, 0, len(spec.Mounts)+1) ap, err := fsgofer.NewAttachPoint("/", fsgofer.Config{ - ROMount: spec.Root.Readonly || conf.Overlay, - PanicOnWrite: g.panicOnWrite, + ROMount: spec.Root.Readonly || conf.Overlay, }) if err != nil { Fatalf("creating attach point: %v", err) @@ -181,9 +178,8 @@ func (g *Gofer) Execute(_ context.Context, f *flag.FlagSet, args ...interface{}) for _, m := range spec.Mounts { if specutils.Is9PMount(m) { cfg := fsgofer.Config{ - ROMount: isReadonlyMount(m.Options) || conf.Overlay, - PanicOnWrite: g.panicOnWrite, - HostUDS: conf.FSGoferHostUDS, + ROMount: isReadonlyMount(m.Options) || conf.Overlay, + HostUDS: conf.FSGoferHostUDS, } ap, err := fsgofer.NewAttachPoint(m.Destination, cfg) if err != nil { @@ -263,7 +259,7 @@ func isReadonlyMount(opts []string) bool { return false } -func setupRootFS(spec *specs.Spec, conf *boot.Config) error { +func setupRootFS(spec *specs.Spec, conf *config.Config) error { // Convert all shared mounts into slaves to be sure that nothing will be // propagated outside of our namespace. if err := syscall.Mount("", "/", "", syscall.MS_SLAVE|syscall.MS_REC, ""); err != nil { @@ -316,6 +312,7 @@ func setupRootFS(spec *specs.Spec, conf *boot.Config) error { if err != nil { return fmt.Errorf("resolving symlinks to %q: %v", spec.Process.Cwd, err) } + log.Infof("Create working directory %q if needed", spec.Process.Cwd) if err := os.MkdirAll(dst, 0755); err != nil { return fmt.Errorf("creating working directory %q: %v", spec.Process.Cwd, err) } @@ -346,7 +343,7 @@ func setupRootFS(spec *specs.Spec, conf *boot.Config) error { // setupMounts binds mount all mounts specified in the spec in their correct // location inside root. It will resolve relative paths and symlinks. It also // creates directories as needed. -func setupMounts(conf *boot.Config, mounts []specs.Mount, root string) error { +func setupMounts(conf *config.Config, mounts []specs.Mount, root string) error { for _, m := range mounts { if m.Type != "bind" || !specutils.IsSupportedDevMount(m) { continue @@ -385,7 +382,7 @@ func setupMounts(conf *boot.Config, mounts []specs.Mount, root string) error { // Otherwise, it may follow symlinks to locations that would be overwritten // with another mount point and return the wrong location. In short, make sure // setupMounts() has been called before. -func resolveMounts(conf *boot.Config, mounts []specs.Mount, root string) ([]specs.Mount, error) { +func resolveMounts(conf *config.Config, mounts []specs.Mount, root string) ([]specs.Mount, error) { cleanMounts := make([]specs.Mount, 0, len(mounts)) for _, m := range mounts { if m.Type != "bind" || !specutils.IsSupportedDevMount(m) { @@ -467,7 +464,7 @@ func resolveSymlinksImpl(root, base, rel string, followCount uint) (string, erro } // adjustMountOptions adds 'overlayfs_stale_read' if mounting over overlayfs. -func adjustMountOptions(conf *boot.Config, path string, opts []string) ([]string, error) { +func adjustMountOptions(conf *config.Config, path string, opts []string) ([]string, error) { rv := make([]string, len(opts)) copy(rv, opts) diff --git a/runsc/cmd/kill.go b/runsc/cmd/kill.go index 8282ea0e0..04eee99b2 100644 --- a/runsc/cmd/kill.go +++ b/runsc/cmd/kill.go @@ -23,7 +23,7 @@ import ( "github.com/google/subcommands" "golang.org/x/sys/unix" - "gvisor.dev/gvisor/runsc/boot" + "gvisor.dev/gvisor/runsc/config" "gvisor.dev/gvisor/runsc/container" "gvisor.dev/gvisor/runsc/flag" ) @@ -63,7 +63,7 @@ func (k *Kill) Execute(_ context.Context, f *flag.FlagSet, args ...interface{}) } id := f.Arg(0) - conf := args[0].(*boot.Config) + conf := args[0].(*config.Config) if k.pid != 0 && k.all { Fatalf("it is invalid to specify both --all and --pid") diff --git a/runsc/cmd/list.go b/runsc/cmd/list.go index d8d906fe3..f92d6fef9 100644 --- a/runsc/cmd/list.go +++ b/runsc/cmd/list.go @@ -24,7 +24,7 @@ import ( "github.com/google/subcommands" specs "github.com/opencontainers/runtime-spec/specs-go" - "gvisor.dev/gvisor/runsc/boot" + "gvisor.dev/gvisor/runsc/config" "gvisor.dev/gvisor/runsc/container" "gvisor.dev/gvisor/runsc/flag" ) @@ -63,7 +63,7 @@ func (l *List) Execute(_ context.Context, f *flag.FlagSet, args ...interface{}) return subcommands.ExitUsageError } - conf := args[0].(*boot.Config) + conf := args[0].(*config.Config) ids, err := container.List(conf.RootDir) if err != nil { Fatalf("%v", err) diff --git a/runsc/cmd/pause.go b/runsc/cmd/pause.go index 6f95a9837..0eb1402ed 100644 --- a/runsc/cmd/pause.go +++ b/runsc/cmd/pause.go @@ -18,7 +18,7 @@ import ( "context" "github.com/google/subcommands" - "gvisor.dev/gvisor/runsc/boot" + "gvisor.dev/gvisor/runsc/config" "gvisor.dev/gvisor/runsc/container" "gvisor.dev/gvisor/runsc/flag" ) @@ -53,7 +53,7 @@ func (*Pause) Execute(_ context.Context, f *flag.FlagSet, args ...interface{}) s } id := f.Arg(0) - conf := args[0].(*boot.Config) + conf := args[0].(*config.Config) cont, err := container.Load(conf.RootDir, id) if err != nil { diff --git a/runsc/cmd/ps.go b/runsc/cmd/ps.go index 7fb8041af..bc58c928f 100644 --- a/runsc/cmd/ps.go +++ b/runsc/cmd/ps.go @@ -20,7 +20,7 @@ import ( "github.com/google/subcommands" "gvisor.dev/gvisor/pkg/sentry/control" - "gvisor.dev/gvisor/runsc/boot" + "gvisor.dev/gvisor/runsc/config" "gvisor.dev/gvisor/runsc/container" "gvisor.dev/gvisor/runsc/flag" ) @@ -58,7 +58,7 @@ func (ps *PS) Execute(ctx context.Context, f *flag.FlagSet, args ...interface{}) } id := f.Arg(0) - conf := args[0].(*boot.Config) + conf := args[0].(*config.Config) c, err := container.Load(conf.RootDir, id) if err != nil { diff --git a/runsc/cmd/restore.go b/runsc/cmd/restore.go index 72584b326..096ec814c 100644 --- a/runsc/cmd/restore.go +++ b/runsc/cmd/restore.go @@ -20,7 +20,7 @@ import ( "syscall" "github.com/google/subcommands" - "gvisor.dev/gvisor/runsc/boot" + "gvisor.dev/gvisor/runsc/config" "gvisor.dev/gvisor/runsc/container" "gvisor.dev/gvisor/runsc/flag" "gvisor.dev/gvisor/runsc/specutils" @@ -77,7 +77,7 @@ func (r *Restore) Execute(_ context.Context, f *flag.FlagSet, args ...interface{ } id := f.Arg(0) - conf := args[0].(*boot.Config) + conf := args[0].(*config.Config) waitStatus := args[1].(*syscall.WaitStatus) if conf.Rootless { @@ -88,7 +88,7 @@ func (r *Restore) Execute(_ context.Context, f *flag.FlagSet, args ...interface{ if bundleDir == "" { bundleDir = getwdOrDie() } - spec, err := specutils.ReadSpec(bundleDir) + spec, err := specutils.ReadSpec(bundleDir, conf) if err != nil { return Errorf("reading spec: %v", err) } diff --git a/runsc/cmd/resume.go b/runsc/cmd/resume.go index 61a55a554..f24823f99 100644 --- a/runsc/cmd/resume.go +++ b/runsc/cmd/resume.go @@ -18,7 +18,7 @@ import ( "context" "github.com/google/subcommands" - "gvisor.dev/gvisor/runsc/boot" + "gvisor.dev/gvisor/runsc/config" "gvisor.dev/gvisor/runsc/container" "gvisor.dev/gvisor/runsc/flag" ) @@ -54,7 +54,7 @@ func (r *Resume) Execute(_ context.Context, f *flag.FlagSet, args ...interface{} } id := f.Arg(0) - conf := args[0].(*boot.Config) + conf := args[0].(*config.Config) cont, err := container.Load(conf.RootDir, id) if err != nil { diff --git a/runsc/cmd/run.go b/runsc/cmd/run.go index cf41581ad..c48cbe4cd 100644 --- a/runsc/cmd/run.go +++ b/runsc/cmd/run.go @@ -19,7 +19,7 @@ import ( "syscall" "github.com/google/subcommands" - "gvisor.dev/gvisor/runsc/boot" + "gvisor.dev/gvisor/runsc/config" "gvisor.dev/gvisor/runsc/container" "gvisor.dev/gvisor/runsc/flag" "gvisor.dev/gvisor/runsc/specutils" @@ -64,7 +64,7 @@ func (r *Run) Execute(_ context.Context, f *flag.FlagSet, args ...interface{}) s } id := f.Arg(0) - conf := args[0].(*boot.Config) + conf := args[0].(*config.Config) waitStatus := args[1].(*syscall.WaitStatus) if conf.Rootless { @@ -75,7 +75,7 @@ func (r *Run) Execute(_ context.Context, f *flag.FlagSet, args ...interface{}) s if bundleDir == "" { bundleDir = getwdOrDie() } - spec, err := specutils.ReadSpec(bundleDir) + spec, err := specutils.ReadSpec(bundleDir, conf) if err != nil { return Errorf("reading spec: %v", err) } diff --git a/runsc/cmd/start.go b/runsc/cmd/start.go index 0205fd9f7..88991b521 100644 --- a/runsc/cmd/start.go +++ b/runsc/cmd/start.go @@ -18,7 +18,7 @@ import ( "context" "github.com/google/subcommands" - "gvisor.dev/gvisor/runsc/boot" + "gvisor.dev/gvisor/runsc/config" "gvisor.dev/gvisor/runsc/container" "gvisor.dev/gvisor/runsc/flag" ) @@ -52,7 +52,7 @@ func (*Start) Execute(_ context.Context, f *flag.FlagSet, args ...interface{}) s } id := f.Arg(0) - conf := args[0].(*boot.Config) + conf := args[0].(*config.Config) c, err := container.Load(conf.RootDir, id) if err != nil { diff --git a/runsc/cmd/state.go b/runsc/cmd/state.go index cf2413deb..2bd2ab9f8 100644 --- a/runsc/cmd/state.go +++ b/runsc/cmd/state.go @@ -21,7 +21,7 @@ import ( "github.com/google/subcommands" "gvisor.dev/gvisor/pkg/log" - "gvisor.dev/gvisor/runsc/boot" + "gvisor.dev/gvisor/runsc/config" "gvisor.dev/gvisor/runsc/container" "gvisor.dev/gvisor/runsc/flag" ) @@ -55,7 +55,7 @@ func (*State) Execute(_ context.Context, f *flag.FlagSet, args ...interface{}) s } id := f.Arg(0) - conf := args[0].(*boot.Config) + conf := args[0].(*config.Config) c, err := container.Load(conf.RootDir, id) if err != nil { diff --git a/runsc/cmd/wait.go b/runsc/cmd/wait.go index 29c0a15f0..28d0642ed 100644 --- a/runsc/cmd/wait.go +++ b/runsc/cmd/wait.go @@ -21,7 +21,7 @@ import ( "syscall" "github.com/google/subcommands" - "gvisor.dev/gvisor/runsc/boot" + "gvisor.dev/gvisor/runsc/config" "gvisor.dev/gvisor/runsc/container" "gvisor.dev/gvisor/runsc/flag" ) @@ -70,7 +70,7 @@ func (wt *Wait) Execute(_ context.Context, f *flag.FlagSet, args ...interface{}) } id := f.Arg(0) - conf := args[0].(*boot.Config) + conf := args[0].(*config.Config) c, err := container.Load(conf.RootDir, id) if err != nil { diff --git a/runsc/config/BUILD b/runsc/config/BUILD new file mode 100644 index 000000000..b1672bb9d --- /dev/null +++ b/runsc/config/BUILD @@ -0,0 +1,28 @@ +load("//tools:defs.bzl", "go_library", "go_test") + +package(licenses = ["notice"]) + +go_library( + name = "config", + srcs = [ + "config.go", + "flags.go", + ], + visibility = ["//:sandbox"], + deps = [ + "//pkg/refs", + "//pkg/sentry/watchdog", + "//pkg/sync", + "//runsc/flag", + ], +) + +go_test( + name = "config_test", + size = "small", + srcs = [ + "config_test.go", + ], + library = ":config", + deps = ["//runsc/flag"], +) diff --git a/runsc/boot/config.go b/runsc/config/config.go index 80da8b3e6..f30f79f68 100644 --- a/runsc/boot/config.go +++ b/runsc/config/config.go @@ -1,4 +1,4 @@ -// Copyright 2018 The gVisor Authors. +// Copyright 2020 The gVisor Authors. // // Licensed under the Apache License, Version 2.0 (the "License"); // you may not use this file except in compliance with the License. @@ -12,220 +12,112 @@ // See the License for the specific language governing permissions and // limitations under the License. -package boot +// Package config provides basic infrastructure to set configuration settings +// for runsc. The configuration is set by flags to the command line. They can +// also propagate to a different process using the same flags. +package config import ( "fmt" - "strconv" - "strings" "gvisor.dev/gvisor/pkg/refs" "gvisor.dev/gvisor/pkg/sentry/watchdog" ) -// FileAccessType tells how the filesystem is accessed. -type FileAccessType int - -const ( - // FileAccessShared sends IO requests to a Gofer process that validates the - // requests and forwards them to the host. - FileAccessShared FileAccessType = iota - - // FileAccessExclusive is the same as FileAccessShared, but enables - // extra caching for improved performance. It should only be used if - // the sandbox has exclusive access to the filesystem. - FileAccessExclusive -) - -// MakeFileAccessType converts type from string. -func MakeFileAccessType(s string) (FileAccessType, error) { - switch s { - case "shared": - return FileAccessShared, nil - case "exclusive": - return FileAccessExclusive, nil - default: - return 0, fmt.Errorf("invalid file access type %q", s) - } -} - -func (f FileAccessType) String() string { - switch f { - case FileAccessShared: - return "shared" - case FileAccessExclusive: - return "exclusive" - default: - return fmt.Sprintf("unknown(%d)", f) - } -} - -// NetworkType tells which network stack to use. -type NetworkType int - -const ( - // NetworkSandbox uses internal network stack, isolated from the host. - NetworkSandbox NetworkType = iota - - // NetworkHost redirects network related syscalls to the host network. - NetworkHost - - // NetworkNone sets up just loopback using netstack. - NetworkNone -) - -// MakeNetworkType converts type from string. -func MakeNetworkType(s string) (NetworkType, error) { - switch s { - case "sandbox": - return NetworkSandbox, nil - case "host": - return NetworkHost, nil - case "none": - return NetworkNone, nil - default: - return 0, fmt.Errorf("invalid network type %q", s) - } -} - -func (n NetworkType) String() string { - switch n { - case NetworkSandbox: - return "sandbox" - case NetworkHost: - return "host" - case NetworkNone: - return "none" - default: - return fmt.Sprintf("unknown(%d)", n) - } -} - -// MakeWatchdogAction converts type from string. -func MakeWatchdogAction(s string) (watchdog.Action, error) { - switch strings.ToLower(s) { - case "log", "logwarning": - return watchdog.LogWarning, nil - case "panic": - return watchdog.Panic, nil - default: - return 0, fmt.Errorf("invalid watchdog action %q", s) - } -} - -// MakeRefsLeakMode converts type from string. -func MakeRefsLeakMode(s string) (refs.LeakMode, error) { - switch strings.ToLower(s) { - case "disabled": - return refs.NoLeakChecking, nil - case "log-names": - return refs.LeaksLogWarning, nil - case "log-traces": - return refs.LeaksLogTraces, nil - default: - return 0, fmt.Errorf("invalid refs leakmode %q", s) - } -} - -func refsLeakModeToString(mode refs.LeakMode) string { - switch mode { - // If not set, default it to disabled. - case refs.UninitializedLeakChecking, refs.NoLeakChecking: - return "disabled" - case refs.LeaksLogWarning: - return "log-names" - case refs.LeaksLogTraces: - return "log-traces" - default: - panic(fmt.Sprintf("Invalid leakmode: %d", mode)) - } -} - // Config holds configuration that is not part of the runtime spec. +// +// Follow these steps to add a new flag: +// 1. Create a new field in Config. +// 2. Add a field tag with the flag name +// 3. Register a new flag in flags.go, with name and description +// 4. Add any necessary validation into validate() +// 5. If adding an enum, follow the same pattern as FileAccessType +// type Config struct { // RootDir is the runtime root directory. - RootDir string + RootDir string `flag:"root"` // Debug indicates that debug logging should be enabled. - Debug bool + Debug bool `flag:"debug"` // LogFilename is the filename to log to, if not empty. - LogFilename string + LogFilename string `flag:"log"` // LogFormat is the log format. - LogFormat string + LogFormat string `flag:"log-format"` // DebugLog is the path to log debug information to, if not empty. - DebugLog string + DebugLog string `flag:"debug-log"` // PanicLog is the path to log GO's runtime messages, if not empty. - PanicLog string + PanicLog string `flag:"panic-log"` // DebugLogFormat is the log format for debug. - DebugLogFormat string + DebugLogFormat string `flag:"debug-log-format"` // FileAccess indicates how the filesystem is accessed. - FileAccess FileAccessType + FileAccess FileAccessType `flag:"file-access"` // Overlay is whether to wrap the root filesystem in an overlay. - Overlay bool + Overlay bool `flag:"overlay"` // FSGoferHostUDS enables the gofer to mount a host UDS. - FSGoferHostUDS bool + FSGoferHostUDS bool `flag:"fsgofer-host-uds"` // Network indicates what type of network to use. - Network NetworkType + Network NetworkType `flag:"network"` // EnableRaw indicates whether raw sockets should be enabled. Raw // sockets are disabled by stripping CAP_NET_RAW from the list of // capabilities. - EnableRaw bool + EnableRaw bool `flag:"net-raw"` // HardwareGSO indicates that hardware segmentation offload is enabled. - HardwareGSO bool + HardwareGSO bool `flag:"gso"` // SoftwareGSO indicates that software segmentation offload is enabled. - SoftwareGSO bool + SoftwareGSO bool `flag:"software-gso"` // TXChecksumOffload indicates that TX Checksum Offload is enabled. - TXChecksumOffload bool + TXChecksumOffload bool `flag:"tx-checksum-offload"` // RXChecksumOffload indicates that RX Checksum Offload is enabled. - RXChecksumOffload bool + RXChecksumOffload bool `flag:"rx-checksum-offload"` // QDisc indicates the type of queuening discipline to use by default // for non-loopback interfaces. - QDisc QueueingDiscipline + QDisc QueueingDiscipline `flag:"qdisc"` // LogPackets indicates that all network packets should be logged. - LogPackets bool + LogPackets bool `flag:"log-packets"` // Platform is the platform to run on. - Platform string + Platform string `flag:"platform"` // Strace indicates that strace should be enabled. - Strace bool + Strace bool `flag:"strace"` - // StraceSyscalls is the set of syscalls to trace. If StraceEnable is - // true and this list is empty, then all syscalls will be traced. - StraceSyscalls []string + // StraceSyscalls is the set of syscalls to trace (comma-separated values). + // If StraceEnable is true and this string is empty, then all syscalls will + // be traced. + StraceSyscalls string `flag:"strace-syscalls"` // StraceLogSize is the max size of data blobs to display. - StraceLogSize uint + StraceLogSize uint `flag:"strace-log-size"` // DisableSeccomp indicates whether seccomp syscall filters should be // disabled. Pardon the double negation, but default to enabled is important. DisableSeccomp bool // WatchdogAction sets what action the watchdog takes when triggered. - WatchdogAction watchdog.Action + WatchdogAction watchdog.Action `flag:"watchdog-action"` // PanicSignal registers signal handling that panics. Usually set to // SIGUSR2(12) to troubleshoot hangs. -1 disables it. - PanicSignal int + PanicSignal int `flag:"panic-signal"` // ProfileEnable is set to prepare the sandbox to be profiled. - ProfileEnable bool + ProfileEnable bool `flag:"profile"` // RestoreFile is the path to the saved container image RestoreFile string @@ -233,104 +125,215 @@ type Config struct { // NumNetworkChannels controls the number of AF_PACKET sockets that map // to the same underlying network device. This allows netstack to better // scale for high throughput use cases. - NumNetworkChannels int + NumNetworkChannels int `flag:"num-network-channels"` // Rootless allows the sandbox to be started with a user that is not root. // Defense is depth measures are weaker with rootless. Specifically, the // sandbox and Gofer process run as root inside a user namespace with root // mapped to the caller's user. - Rootless bool + Rootless bool `flag:"rootless"` // AlsoLogToStderr allows to send log messages to stderr. - AlsoLogToStderr bool + AlsoLogToStderr bool `flag:"alsologtostderr"` // ReferenceLeakMode sets reference leak check mode - ReferenceLeakMode refs.LeakMode + ReferenceLeak refs.LeakMode `flag:"ref-leak-mode"` // OverlayfsStaleRead instructs the sandbox to assume that the root mount // is on a Linux overlayfs mount, which does not necessarily preserve // coherence between read-only and subsequent writable file descriptors // representing the "same" file. - OverlayfsStaleRead bool + OverlayfsStaleRead bool `flag:"overlayfs-stale-read"` + + // CPUNumFromQuota sets CPU number count to available CPU quota, using + // least integer value greater than or equal to quota. + // + // E.g. 0.2 CPU quota will result in 1, and 1.9 in 2. + CPUNumFromQuota bool `flag:"cpu-num-from-quota"` + + // Enables VFS2. + VFS2 bool `flag:"vfs2"` + + // Enables FUSE usage. + FUSE bool `flag:"fuse"` + + // Allows overriding of flags in OCI annotations. + AllowFlagOverride bool `flag:"allow-flag-override"` + + // Enables seccomp inside the sandbox. + OCISeccomp bool `flag:"oci-seccomp"` // TestOnlyAllowRunAsCurrentUserWithoutChroot should only be used in // tests. It allows runsc to start the sandbox process as the current // user, and without chrooting the sandbox process. This can be // necessary in test environments that have limited capabilities. - TestOnlyAllowRunAsCurrentUserWithoutChroot bool + TestOnlyAllowRunAsCurrentUserWithoutChroot bool `flag:"TESTONLY-unsafe-nonroot"` // TestOnlyTestNameEnv should only be used in tests. It looks up for the // test name in the container environment variables and adds it to the debug // log file name. This is done to help identify the log with the test when // multiple tests are run in parallel, since there is no way to pass // parameters to the runtime from docker. - TestOnlyTestNameEnv string + TestOnlyTestNameEnv string `flag:"TESTONLY-test-name-env"` +} - // CPUNumFromQuota sets CPU number count to available CPU quota, using - // least integer value greater than or equal to quota. - // - // E.g. 0.2 CPU quota will result in 1, and 1.9 in 2. - CPUNumFromQuota bool +func (c *Config) validate() error { + if c.FileAccess == FileAccessShared && c.Overlay { + return fmt.Errorf("overlay flag is incompatible with shared file access") + } + if c.NumNetworkChannels <= 0 { + return fmt.Errorf("num_network_channels must be > 0, got: %d", c.NumNetworkChannels) + } + return nil +} + +// FileAccessType tells how the filesystem is accessed. +type FileAccessType int + +const ( + // FileAccessExclusive is the same as FileAccessShared, but enables + // extra caching for improved performance. It should only be used if + // the sandbox has exclusive access to the filesystem. + FileAccessExclusive FileAccessType = iota - // Enables VFS2 (not plumbled through yet). - VFS2 bool + // FileAccessShared sends IO requests to a Gofer process that validates the + // requests and forwards them to the host. + FileAccessShared +) - // Enables FUSE usage (not plumbled through yet). - FUSE bool +func fileAccessTypePtr(v FileAccessType) *FileAccessType { + return &v } -// ToFlags returns a slice of flags that correspond to the given Config. -func (c *Config) ToFlags() []string { - f := []string{ - "--root=" + c.RootDir, - "--debug=" + strconv.FormatBool(c.Debug), - "--log=" + c.LogFilename, - "--log-format=" + c.LogFormat, - "--debug-log=" + c.DebugLog, - "--panic-log=" + c.PanicLog, - "--debug-log-format=" + c.DebugLogFormat, - "--file-access=" + c.FileAccess.String(), - "--overlay=" + strconv.FormatBool(c.Overlay), - "--fsgofer-host-uds=" + strconv.FormatBool(c.FSGoferHostUDS), - "--network=" + c.Network.String(), - "--log-packets=" + strconv.FormatBool(c.LogPackets), - "--platform=" + c.Platform, - "--strace=" + strconv.FormatBool(c.Strace), - "--strace-syscalls=" + strings.Join(c.StraceSyscalls, ","), - "--strace-log-size=" + strconv.Itoa(int(c.StraceLogSize)), - "--watchdog-action=" + c.WatchdogAction.String(), - "--panic-signal=" + strconv.Itoa(c.PanicSignal), - "--profile=" + strconv.FormatBool(c.ProfileEnable), - "--net-raw=" + strconv.FormatBool(c.EnableRaw), - "--num-network-channels=" + strconv.Itoa(c.NumNetworkChannels), - "--rootless=" + strconv.FormatBool(c.Rootless), - "--alsologtostderr=" + strconv.FormatBool(c.AlsoLogToStderr), - "--ref-leak-mode=" + refsLeakModeToString(c.ReferenceLeakMode), - "--gso=" + strconv.FormatBool(c.HardwareGSO), - "--software-gso=" + strconv.FormatBool(c.SoftwareGSO), - "--rx-checksum-offload=" + strconv.FormatBool(c.RXChecksumOffload), - "--tx-checksum-offload=" + strconv.FormatBool(c.TXChecksumOffload), - "--overlayfs-stale-read=" + strconv.FormatBool(c.OverlayfsStaleRead), - "--qdisc=" + c.QDisc.String(), +// Set implements flag.Value. +func (f *FileAccessType) Set(v string) error { + switch v { + case "shared": + *f = FileAccessShared + case "exclusive": + *f = FileAccessExclusive + default: + return fmt.Errorf("invalid file access type %q", v) } - if c.CPUNumFromQuota { - f = append(f, "--cpu-num-from-quota") + return nil +} + +// Get implements flag.Value. +func (f *FileAccessType) Get() interface{} { + return *f +} + +// String implements flag.Value. +func (f *FileAccessType) String() string { + switch *f { + case FileAccessShared: + return "shared" + case FileAccessExclusive: + return "exclusive" } - // Only include these if set since it is never to be used by users. - if c.TestOnlyAllowRunAsCurrentUserWithoutChroot { - f = append(f, "--TESTONLY-unsafe-nonroot=true") + panic(fmt.Sprintf("Invalid file access type %v", *f)) +} + +// NetworkType tells which network stack to use. +type NetworkType int + +const ( + // NetworkSandbox uses internal network stack, isolated from the host. + NetworkSandbox NetworkType = iota + + // NetworkHost redirects network related syscalls to the host network. + NetworkHost + + // NetworkNone sets up just loopback using netstack. + NetworkNone +) + +func networkTypePtr(v NetworkType) *NetworkType { + return &v +} + +// Set implements flag.Value. +func (n *NetworkType) Set(v string) error { + switch v { + case "sandbox": + *n = NetworkSandbox + case "host": + *n = NetworkHost + case "none": + *n = NetworkNone + default: + return fmt.Errorf("invalid network type %q", v) } - if len(c.TestOnlyTestNameEnv) != 0 { - f = append(f, "--TESTONLY-test-name-env="+c.TestOnlyTestNameEnv) + return nil +} + +// Get implements flag.Value. +func (n *NetworkType) Get() interface{} { + return *n +} + +// String implements flag.Value. +func (n *NetworkType) String() string { + switch *n { + case NetworkSandbox: + return "sandbox" + case NetworkHost: + return "host" + case NetworkNone: + return "none" } + panic(fmt.Sprintf("Invalid network type %v", *n)) +} + +// QueueingDiscipline is used to specify the kind of Queueing Discipline to +// apply for a give FDBasedLink. +type QueueingDiscipline int + +const ( + // QDiscNone disables any queueing for the underlying FD. + QDiscNone QueueingDiscipline = iota + + // QDiscFIFO applies a simple fifo based queue to the underlying FD. + QDiscFIFO +) + +func queueingDisciplinePtr(v QueueingDiscipline) *QueueingDiscipline { + return &v +} - if c.VFS2 { - f = append(f, "--vfs2=true") +// Set implements flag.Value. +func (q *QueueingDiscipline) Set(v string) error { + switch v { + case "none": + *q = QDiscNone + case "fifo": + *q = QDiscFIFO + default: + return fmt.Errorf("invalid qdisc %q", v) } + return nil +} + +// Get implements flag.Value. +func (q *QueueingDiscipline) Get() interface{} { + return *q +} - if c.FUSE { - f = append(f, "--fuse=true") +// String implements flag.Value. +func (q *QueueingDiscipline) String() string { + switch *q { + case QDiscNone: + return "none" + case QDiscFIFO: + return "fifo" } + panic(fmt.Sprintf("Invalid qdisc %v", *q)) +} + +func leakModePtr(v refs.LeakMode) *refs.LeakMode { + return &v +} - return f +func watchdogActionPtr(v watchdog.Action) *watchdog.Action { + return &v } diff --git a/runsc/config/config_test.go b/runsc/config/config_test.go new file mode 100644 index 000000000..fb162b7eb --- /dev/null +++ b/runsc/config/config_test.go @@ -0,0 +1,272 @@ +// Copyright 2020 The gVisor Authors. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package config + +import ( + "strings" + "testing" + + "gvisor.dev/gvisor/runsc/flag" +) + +func init() { + RegisterFlags() +} + +func TestDefault(t *testing.T) { + c, err := NewFromFlags() + if err != nil { + t.Fatal(err) + } + // "--root" is always set to something different than the default. Reset it + // to make it easier to test that default values do not generate flags. + c.RootDir = "" + + // All defaults doesn't require setting flags. + flags := c.ToFlags() + if len(flags) > 0 { + t.Errorf("default flags not set correctly for: %s", flags) + } +} + +func setDefault(name string) { + fl := flag.CommandLine.Lookup(name) + fl.Value.Set(fl.DefValue) +} + +func TestFromFlags(t *testing.T) { + flag.CommandLine.Lookup("root").Value.Set("some-path") + flag.CommandLine.Lookup("debug").Value.Set("true") + flag.CommandLine.Lookup("num-network-channels").Value.Set("123") + flag.CommandLine.Lookup("network").Value.Set("none") + defer func() { + setDefault("root") + setDefault("debug") + setDefault("num-network-channels") + setDefault("network") + }() + + c, err := NewFromFlags() + if err != nil { + t.Fatal(err) + } + if want := "some-path"; c.RootDir != want { + t.Errorf("RootDir=%v, want: %v", c.RootDir, want) + } + if want := true; c.Debug != want { + t.Errorf("Debug=%v, want: %v", c.Debug, want) + } + if want := 123; c.NumNetworkChannels != want { + t.Errorf("NumNetworkChannels=%v, want: %v", c.NumNetworkChannels, want) + } + if want := NetworkNone; c.Network != want { + t.Errorf("Network=%v, want: %v", c.Network, want) + } +} + +func TestToFlags(t *testing.T) { + c, err := NewFromFlags() + if err != nil { + t.Fatal(err) + } + c.RootDir = "some-path" + c.Debug = true + c.NumNetworkChannels = 123 + c.Network = NetworkNone + + flags := c.ToFlags() + if len(flags) != 4 { + t.Errorf("wrong number of flags set, want: 4, got: %d: %s", len(flags), flags) + } + t.Logf("Flags: %s", flags) + fm := map[string]string{} + for _, f := range flags { + kv := strings.Split(f, "=") + fm[kv[0]] = kv[1] + } + for name, want := range map[string]string{ + "--root": "some-path", + "--debug": "true", + "--num-network-channels": "123", + "--network": "none", + } { + if got, ok := fm[name]; ok { + if got != want { + t.Errorf("flag %q, want: %q, got: %q", name, want, got) + } + } else { + t.Errorf("flag %q not set", name) + } + } +} + +// TestInvalidFlags checks that enum flags fail when value is not in enum set. +func TestInvalidFlags(t *testing.T) { + for _, tc := range []struct { + name string + error string + }{ + { + name: "file-access", + error: "invalid file access type", + }, + { + name: "network", + error: "invalid network type", + }, + { + name: "qdisc", + error: "invalid qdisc", + }, + { + name: "watchdog-action", + error: "invalid watchdog action", + }, + { + name: "ref-leak-mode", + error: "invalid ref leak mode", + }, + } { + t.Run(tc.name, func(t *testing.T) { + defer setDefault(tc.name) + if err := flag.CommandLine.Lookup(tc.name).Value.Set("invalid"); err == nil || !strings.Contains(err.Error(), tc.error) { + t.Errorf("flag.Value.Set(invalid) wrong error reported: %v", err) + } + }) + } +} + +func TestValidationFail(t *testing.T) { + for _, tc := range []struct { + name string + flags map[string]string + error string + }{ + { + name: "shared+overlay", + flags: map[string]string{ + "file-access": "shared", + "overlay": "true", + }, + error: "overlay flag is incompatible", + }, + { + name: "network-channels", + flags: map[string]string{ + "num-network-channels": "-1", + }, + error: "num_network_channels must be > 0", + }, + } { + t.Run(tc.name, func(t *testing.T) { + for name, val := range tc.flags { + defer setDefault(name) + if err := flag.CommandLine.Lookup(name).Value.Set(val); err != nil { + t.Errorf("%s=%q: %v", name, val, err) + } + } + if _, err := NewFromFlags(); err == nil || !strings.Contains(err.Error(), tc.error) { + t.Errorf("NewFromFlags() wrong error reported: %v", err) + } + }) + } +} + +func TestOverride(t *testing.T) { + c, err := NewFromFlags() + if err != nil { + t.Fatal(err) + } + c.AllowFlagOverride = true + + t.Run("string", func(t *testing.T) { + c.RootDir = "foobar" + if err := c.Override("root", "bar"); err != nil { + t.Fatalf("Override(root, bar) failed: %v", err) + } + defer setDefault("root") + if c.RootDir != "bar" { + t.Errorf("Override(root, bar) didn't work: %+v", c) + } + }) + + t.Run("bool", func(t *testing.T) { + c.Debug = true + if err := c.Override("debug", "false"); err != nil { + t.Fatalf("Override(debug, false) failed: %v", err) + } + defer setDefault("debug") + if c.Debug { + t.Errorf("Override(debug, false) didn't work: %+v", c) + } + }) + + t.Run("enum", func(t *testing.T) { + c.FileAccess = FileAccessShared + if err := c.Override("file-access", "exclusive"); err != nil { + t.Fatalf("Override(file-access, exclusive) failed: %v", err) + } + defer setDefault("file-access") + if c.FileAccess != FileAccessExclusive { + t.Errorf("Override(file-access, exclusive) didn't work: %+v", c) + } + }) +} + +func TestOverrideDisabled(t *testing.T) { + c, err := NewFromFlags() + if err != nil { + t.Fatal(err) + } + const errMsg = "flag override disabled" + if err := c.Override("root", "path"); err == nil || !strings.Contains(err.Error(), errMsg) { + t.Errorf("Override() wrong error: %v", err) + } +} + +func TestOverrideError(t *testing.T) { + c, err := NewFromFlags() + if err != nil { + t.Fatal(err) + } + c.AllowFlagOverride = true + for _, tc := range []struct { + name string + value string + error string + }{ + { + name: "invalid", + value: "valid", + error: `flag "invalid" not found`, + }, + { + name: "debug", + value: "invalid", + error: "error setting flag debug", + }, + { + name: "file-access", + value: "invalid", + error: "invalid file access type", + }, + } { + t.Run(tc.name, func(t *testing.T) { + if err := c.Override(tc.name, tc.value); err == nil || !strings.Contains(err.Error(), tc.error) { + t.Errorf("Override(%q, %q) wrong error: %v", tc.name, tc.value, err) + } + }) + } +} diff --git a/runsc/config/flags.go b/runsc/config/flags.go new file mode 100644 index 000000000..a5f25cfa2 --- /dev/null +++ b/runsc/config/flags.go @@ -0,0 +1,205 @@ +// Copyright 2020 The gVisor Authors. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package config + +import ( + "fmt" + "os" + "path/filepath" + "reflect" + "strconv" + + "gvisor.dev/gvisor/pkg/refs" + "gvisor.dev/gvisor/pkg/sentry/watchdog" + "gvisor.dev/gvisor/pkg/sync" + "gvisor.dev/gvisor/runsc/flag" +) + +var registration sync.Once + +// This is the set of flags used to populate Config. +func RegisterFlags() { + registration.Do(func() { + // Although these flags are not part of the OCI spec, they are used by + // Docker, and thus should not be changed. + flag.String("root", "", "root directory for storage of container state.") + flag.String("log", "", "file path where internal debug information is written, default is stdout.") + flag.String("log-format", "text", "log format: text (default), json, or json-k8s.") + flag.Bool("debug", false, "enable debug logging.") + + // These flags are unique to runsc, and are used to configure parts of the + // system that are not covered by the runtime spec. + + // Debugging flags. + flag.String("debug-log", "", "additional location for logs. If it ends with '/', log files are created inside the directory with default names. The following variables are available: %TIMESTAMP%, %COMMAND%.") + flag.String("panic-log", "", "file path were panic reports and other Go's runtime messages are written.") + flag.Bool("log-packets", false, "enable network packet logging.") + flag.String("debug-log-format", "text", "log format: text (default), json, or json-k8s.") + flag.Bool("alsologtostderr", false, "send log messages to stderr.") + flag.Bool("allow-flag-override", false, "allow OCI annotations (dev.gvisor.flag.<name>) to override flags for debugging.") + + // Debugging flags: strace related + flag.Bool("strace", false, "enable strace.") + flag.String("strace-syscalls", "", "comma-separated list of syscalls to trace. If --strace is true and this list is empty, then all syscalls will be traced.") + flag.Uint("strace-log-size", 1024, "default size (in bytes) to log data argument blobs.") + + // Flags that control sandbox runtime behavior. + flag.String("platform", "ptrace", "specifies which platform to use: ptrace (default), kvm.") + flag.Var(watchdogActionPtr(watchdog.LogWarning), "watchdog-action", "sets what action the watchdog takes when triggered: log (default), panic.") + flag.Int("panic-signal", -1, "register signal handling that panics. Usually set to SIGUSR2(12) to troubleshoot hangs. -1 disables it.") + flag.Bool("profile", false, "prepares the sandbox to use Golang profiler. Note that enabling profiler loosens the seccomp protection added to the sandbox (DO NOT USE IN PRODUCTION).") + flag.Bool("rootless", false, "it allows the sandbox to be started with a user that is not root. Sandbox and Gofer processes may run with same privileges as current user.") + flag.Var(leakModePtr(refs.NoLeakChecking), "ref-leak-mode", "sets reference leak check mode: disabled (default), log-names, log-traces.") + flag.Bool("cpu-num-from-quota", false, "set cpu number to cpu quota (least integer greater or equal to quota value, but not less than 2)") + flag.Bool("oci-seccomp", false, "Enables loading OCI seccomp filters inside the sandbox.") + + // Flags that control sandbox runtime behavior: FS related. + flag.Var(fileAccessTypePtr(FileAccessExclusive), "file-access", "specifies which filesystem to use for the root mount: exclusive (default), shared. Volume mounts are always shared.") + flag.Bool("overlay", false, "wrap filesystem mounts with writable overlay. All modifications are stored in memory inside the sandbox.") + flag.Bool("overlayfs-stale-read", true, "assume root mount is an overlay filesystem") + flag.Bool("fsgofer-host-uds", false, "allow the gofer to mount Unix Domain Sockets.") + flag.Bool("vfs2", false, "TEST ONLY; use while VFSv2 is landing. This uses the new experimental VFS layer.") + flag.Bool("fuse", false, "TEST ONLY; use while FUSE in VFSv2 is landing. This allows the use of the new experimental FUSE filesystem.") + + // Flags that control sandbox runtime behavior: network related. + flag.Var(networkTypePtr(NetworkSandbox), "network", "specifies which network to use: sandbox (default), host, none. Using network inside the sandbox is more secure because it's isolated from the host network.") + flag.Bool("net-raw", false, "enable raw sockets. When false, raw sockets are disabled by removing CAP_NET_RAW from containers (`runsc exec` will still be able to utilize raw sockets). Raw sockets allow malicious containers to craft packets and potentially attack the network.") + flag.Bool("gso", true, "enable hardware segmentation offload if it is supported by a network device.") + flag.Bool("software-gso", true, "enable software segmentation offload when hardware offload can't be enabled.") + flag.Bool("tx-checksum-offload", false, "enable TX checksum offload.") + flag.Bool("rx-checksum-offload", true, "enable RX checksum offload.") + flag.Var(queueingDisciplinePtr(QDiscFIFO), "qdisc", "specifies which queueing discipline to apply by default to the non loopback nics used by the sandbox.") + flag.Int("num-network-channels", 1, "number of underlying channels(FDs) to use for network link endpoints.") + + // Test flags, not to be used outside tests, ever. + flag.Bool("TESTONLY-unsafe-nonroot", false, "TEST ONLY; do not ever use! This skips many security measures that isolate the host from the sandbox.") + flag.String("TESTONLY-test-name-env", "", "TEST ONLY; do not ever use! Used for automated tests to improve logging.") + }) +} + +// NewFromFlags creates a new Config with values coming from command line flags. +func NewFromFlags() (*Config, error) { + conf := &Config{} + + obj := reflect.ValueOf(conf).Elem() + st := obj.Type() + for i := 0; i < st.NumField(); i++ { + f := st.Field(i) + name, ok := f.Tag.Lookup("flag") + if !ok { + // No flag set for this field. + continue + } + fl := flag.CommandLine.Lookup(name) + if fl == nil { + panic(fmt.Sprintf("Flag %q not found", name)) + } + x := reflect.ValueOf(flag.Get(fl.Value)) + obj.Field(i).Set(x) + } + + if len(conf.RootDir) == 0 { + // If not set, set default root dir to something (hopefully) user-writeable. + conf.RootDir = "/var/run/runsc" + if runtimeDir := os.Getenv("XDG_RUNTIME_DIR"); runtimeDir != "" { + conf.RootDir = filepath.Join(runtimeDir, "runsc") + } + } + + if err := conf.validate(); err != nil { + return nil, err + } + return conf, nil +} + +// ToFlags returns a slice of flags that correspond to the given Config. +func (c *Config) ToFlags() []string { + var rv []string + + obj := reflect.ValueOf(c).Elem() + st := obj.Type() + for i := 0; i < st.NumField(); i++ { + f := st.Field(i) + name, ok := f.Tag.Lookup("flag") + if !ok { + // No flag set for this field. + continue + } + val := getVal(obj.Field(i)) + + flag := flag.CommandLine.Lookup(name) + if flag == nil { + panic(fmt.Sprintf("Flag %q not found", name)) + } + if val == flag.DefValue { + continue + } + rv = append(rv, fmt.Sprintf("--%s=%s", flag.Name, val)) + } + return rv +} + +// Override writes a new value to a flag. +func (c *Config) Override(name string, value string) error { + if !c.AllowFlagOverride { + return fmt.Errorf("flag override disabled, use --allow-flag-override to enable it") + } + + obj := reflect.ValueOf(c).Elem() + st := obj.Type() + for i := 0; i < st.NumField(); i++ { + f := st.Field(i) + fieldName, ok := f.Tag.Lookup("flag") + if !ok || fieldName != name { + // Not a flag field, or flag name doesn't match. + continue + } + fl := flag.CommandLine.Lookup(name) + if fl == nil { + // Flag must exist if there is a field match above. + panic(fmt.Sprintf("Flag %q not found", name)) + } + + // Use flag to convert the string value to the underlying flag type, using + // the same rules as the command-line for consistency. + if err := fl.Value.Set(value); err != nil { + return fmt.Errorf("error setting flag %s=%q: %w", name, value, err) + } + x := reflect.ValueOf(flag.Get(fl.Value)) + obj.Field(i).Set(x) + + // Validates the config again to ensure it's left in a consistent state. + return c.validate() + } + return fmt.Errorf("flag %q not found. Cannot set it to %q", name, value) +} + +func getVal(field reflect.Value) string { + if str, ok := field.Addr().Interface().(fmt.Stringer); ok { + return str.String() + } + switch field.Kind() { + case reflect.Bool: + return strconv.FormatBool(field.Bool()) + case reflect.Int, reflect.Int8, reflect.Int16, reflect.Int32, reflect.Int64: + return strconv.FormatInt(field.Int(), 10) + case reflect.Uint, reflect.Uint8, reflect.Uint16, reflect.Uint32, reflect.Uint64, reflect.Uintptr: + return strconv.FormatUint(field.Uint(), 10) + case reflect.String: + return field.String() + default: + panic("unknown type " + field.Kind().String()) + } +} diff --git a/runsc/console/console.go b/runsc/console/console.go index 64b23639a..dbb88e117 100644 --- a/runsc/console/console.go +++ b/runsc/console/console.go @@ -24,11 +24,11 @@ import ( "golang.org/x/sys/unix" ) -// NewWithSocket creates pty master/slave pair, sends the master FD over the given -// socket, and returns the slave. +// NewWithSocket creates pty master/replica pair, sends the master FD over the given +// socket, and returns the replica. func NewWithSocket(socketPath string) (*os.File, error) { - // Create a new pty master and slave. - ptyMaster, ptySlave, err := pty.Open() + // Create a new pty master and replica. + ptyMaster, ptyReplica, err := pty.Open() if err != nil { return nil, fmt.Errorf("opening pty: %v", err) } @@ -37,18 +37,18 @@ func NewWithSocket(socketPath string) (*os.File, error) { // Get a connection to the socket path. conn, err := net.Dial("unix", socketPath) if err != nil { - ptySlave.Close() + ptyReplica.Close() return nil, fmt.Errorf("dialing socket %q: %v", socketPath, err) } defer conn.Close() uc, ok := conn.(*net.UnixConn) if !ok { - ptySlave.Close() + ptyReplica.Close() return nil, fmt.Errorf("connection is not a UnixConn: %T", conn) } socket, err := uc.File() if err != nil { - ptySlave.Close() + ptyReplica.Close() return nil, fmt.Errorf("getting file for unix socket %v: %v", uc, err) } defer socket.Close() @@ -56,8 +56,8 @@ func NewWithSocket(socketPath string) (*os.File, error) { // Send the master FD over the connection. msg := unix.UnixRights(int(ptyMaster.Fd())) if err := unix.Sendmsg(int(socket.Fd()), []byte("pty-master"), msg, nil, 0); err != nil { - ptySlave.Close() + ptyReplica.Close() return nil, fmt.Errorf("sending console over unix socket %q: %v", socketPath, err) } - return ptySlave, nil + return ptyReplica, nil } diff --git a/runsc/container/BUILD b/runsc/container/BUILD index 9a9ee7e2a..c33755482 100644 --- a/runsc/container/BUILD +++ b/runsc/container/BUILD @@ -23,6 +23,7 @@ go_library( "//pkg/sync", "//runsc/boot", "//runsc/cgroup", + "//runsc/config", "//runsc/sandbox", "//runsc/specutils", "@com_github_cenkalti_backoff//:go_default_library", @@ -65,6 +66,7 @@ go_test( "//pkg/urpc", "//runsc/boot", "//runsc/boot/platforms", + "//runsc/config", "//runsc/specutils", "@com_github_cenkalti_backoff//:go_default_library", "@com_github_kr_pty//:go_default_library", diff --git a/runsc/container/console_test.go b/runsc/container/console_test.go index 995d4e267..4228399b8 100644 --- a/runsc/container/console_test.go +++ b/runsc/container/console_test.go @@ -185,14 +185,14 @@ func TestJobControlSignalExec(t *testing.T) { t.Fatalf("error starting container: %v", err) } - // Create a pty master/slave. The slave will be passed to the exec + // Create a pty master/replica. The replica will be passed to the exec // process. - ptyMaster, ptySlave, err := pty.Open() + ptyMaster, ptyReplica, err := pty.Open() if err != nil { t.Fatalf("error opening pty: %v", err) } defer ptyMaster.Close() - defer ptySlave.Close() + defer ptyReplica.Close() // Exec bash and attach a terminal. Note that occasionally /bin/sh // may be a different shell or have a different configuration (such @@ -203,9 +203,9 @@ func TestJobControlSignalExec(t *testing.T) { // Don't let bash execute from profile or rc files, otherwise // our PID counts get messed up. Argv: []string{"/bin/bash", "--noprofile", "--norc"}, - // Pass the pty slave as FD 0, 1, and 2. + // Pass the pty replica as FD 0, 1, and 2. FilePayload: urpc.FilePayload{ - Files: []*os.File{ptySlave, ptySlave, ptySlave}, + Files: []*os.File{ptyReplica, ptyReplica, ptyReplica}, }, StdioIsPty: true, } diff --git a/runsc/container/container.go b/runsc/container/container.go index 7ad09bf23..63478ba8c 100644 --- a/runsc/container/container.go +++ b/runsc/container/container.go @@ -37,6 +37,7 @@ import ( "gvisor.dev/gvisor/pkg/sentry/sighandling" "gvisor.dev/gvisor/runsc/boot" "gvisor.dev/gvisor/runsc/cgroup" + "gvisor.dev/gvisor/runsc/config" "gvisor.dev/gvisor/runsc/sandbox" "gvisor.dev/gvisor/runsc/specutils" ) @@ -269,7 +270,7 @@ type Args struct { // New creates the container in a new Sandbox process, unless the metadata // indicates that an existing Sandbox should be used. The caller must call // Destroy() on the container. -func New(conf *boot.Config, args Args) (*Container, error) { +func New(conf *config.Config, args Args) (*Container, error) { log.Debugf("Create container %q in root dir: %s", args.ID, conf.RootDir) if err := validateID(args.ID); err != nil { return nil, err @@ -397,7 +398,7 @@ func New(conf *boot.Config, args Args) (*Container, error) { } // Start starts running the containerized process inside the sandbox. -func (c *Container) Start(conf *boot.Config) error { +func (c *Container) Start(conf *config.Config) error { log.Debugf("Start container %q", c.ID) if err := c.Saver.lock(); err != nil { @@ -472,7 +473,7 @@ func (c *Container) Start(conf *boot.Config) error { // Restore takes a container and replaces its kernel and file system // to restore a container from its state file. -func (c *Container) Restore(spec *specs.Spec, conf *boot.Config, restoreFile string) error { +func (c *Container) Restore(spec *specs.Spec, conf *config.Config, restoreFile string) error { log.Debugf("Restore container %q", c.ID) if err := c.Saver.lock(); err != nil { return err @@ -499,7 +500,7 @@ func (c *Container) Restore(spec *specs.Spec, conf *boot.Config, restoreFile str } // Run is a helper that calls Create + Start + Wait. -func Run(conf *boot.Config, args Args) (syscall.WaitStatus, error) { +func Run(conf *config.Config, args Args) (syscall.WaitStatus, error) { log.Debugf("Run container %q in root dir: %s", args.ID, conf.RootDir) c, err := New(conf, args) if err != nil { @@ -861,7 +862,7 @@ func (c *Container) waitForStopped() error { return backoff.Retry(op, b) } -func (c *Container) createGoferProcess(spec *specs.Spec, conf *boot.Config, bundleDir string, attached bool) ([]*os.File, *os.File, error) { +func (c *Container) createGoferProcess(spec *specs.Spec, conf *config.Config, bundleDir string, attached bool) ([]*os.File, *os.File, error) { // Start with the general config flags. args := conf.ToFlags() @@ -901,9 +902,6 @@ func (c *Container) createGoferProcess(spec *specs.Spec, conf *boot.Config, bund } args = append(args, "gofer", "--bundle", bundleDir) - if conf.Overlay { - args = append(args, "--panic-on-write=true") - } // Open the spec file to donate to the sandbox. specFile, err := specutils.OpenSpec(bundleDir) diff --git a/runsc/container/container_test.go b/runsc/container/container_test.go index 5e8247bc8..ff0e60283 100644 --- a/runsc/container/container_test.go +++ b/runsc/container/container_test.go @@ -41,8 +41,9 @@ import ( "gvisor.dev/gvisor/pkg/sentry/kernel/auth" "gvisor.dev/gvisor/pkg/sync" "gvisor.dev/gvisor/pkg/test/testutil" - "gvisor.dev/gvisor/runsc/boot" + "gvisor.dev/gvisor/pkg/urpc" "gvisor.dev/gvisor/runsc/boot/platforms" + "gvisor.dev/gvisor/runsc/config" "gvisor.dev/gvisor/runsc/specutils" ) @@ -250,7 +251,7 @@ func readOutputNum(file string, position int) (int, error) { // run starts the sandbox and waits for it to exit, checking that the // application succeeded. -func run(spec *specs.Spec, conf *boot.Config) error { +func run(spec *specs.Spec, conf *config.Config) error { _, bundleDir, cleanup, err := testutil.SetupContainer(spec, conf) if err != nil { return fmt.Errorf("error setting up container: %v", err) @@ -289,26 +290,24 @@ var ( ) // configs generates different configurations to run tests. -func configs(t *testing.T, opts ...configOption) map[string]*boot.Config { +func configs(t *testing.T, opts ...configOption) map[string]*config.Config { // Always load the default config. - cs := make(map[string]*boot.Config) + cs := make(map[string]*config.Config) + testutil.TestConfig(t) for _, o := range opts { + c := testutil.TestConfig(t) switch o { case overlay: - c := testutil.TestConfig(t) c.Overlay = true cs["overlay"] = c case ptrace: - c := testutil.TestConfig(t) c.Platform = platforms.Ptrace cs["ptrace"] = c case kvm: - c := testutil.TestConfig(t) c.Platform = platforms.KVM cs["kvm"] = c case nonExclusiveFS: - c := testutil.TestConfig(t) - c.FileAccess = boot.FileAccessShared + c.FileAccess = config.FileAccessShared cs["non-exclusive"] = c default: panic(fmt.Sprintf("unknown config option %v", o)) @@ -317,23 +316,13 @@ func configs(t *testing.T, opts ...configOption) map[string]*boot.Config { return cs } -func configsWithVFS2(t *testing.T, opts ...configOption) map[string]*boot.Config { - vfs1 := configs(t, opts...) - - var optsVFS2 []configOption - for _, opt := range opts { - // TODO(gvisor.dev/issue/1487): Enable overlay tests. - if opt != overlay { - optsVFS2 = append(optsVFS2, opt) - } - } - - for key, value := range configs(t, optsVFS2...) { +func configsWithVFS2(t *testing.T, opts ...configOption) map[string]*config.Config { + all := configs(t, opts...) + for key, value := range configs(t, opts...) { value.VFS2 = true - vfs1[key+"VFS2"] = value + all[key+"VFS2"] = value } - - return vfs1 + return all } // TestLifecycle tests the basic Create/Start/Signal/Destroy container lifecycle. @@ -512,7 +501,7 @@ func TestExePath(t *testing.T) { t.Fatalf("error making directory: %v", err) } - for name, conf := range configsWithVFS2(t, overlay) { + for name, conf := range configsWithVFS2(t, all...) { t.Run(name, func(t *testing.T) { for _, test := range []struct { path string @@ -837,7 +826,7 @@ func TestExecProcList(t *testing.T) { // TestKillPid verifies that we can signal individual exec'd processes. func TestKillPid(t *testing.T) { - for name, conf := range configsWithVFS2(t, overlay) { + for name, conf := range configsWithVFS2(t, all...) { t.Run(name, func(t *testing.T) { app, err := testutil.FindFile("test/cmd/test_app/test_app") if err != nil { @@ -1470,7 +1459,7 @@ func TestRunNonRoot(t *testing.T) { // TestMountNewDir checks that runsc will create destination directory if it // doesn't exit. func TestMountNewDir(t *testing.T) { - for name, conf := range configsWithVFS2(t, overlay) { + for name, conf := range configsWithVFS2(t, all...) { t.Run(name, func(t *testing.T) { root, err := ioutil.TempDir(testutil.TmpDir(), "root") if err != nil { @@ -1490,6 +1479,8 @@ func TestMountNewDir(t *testing.T) { Source: srcDir, Type: "bind", }) + // Extra points for creating the mount with a readonly root. + spec.Root.Readonly = true if err := run(spec, conf); err != nil { t.Fatalf("error running sandbox: %v", err) @@ -1499,17 +1490,17 @@ func TestMountNewDir(t *testing.T) { } func TestReadonlyRoot(t *testing.T) { - for name, conf := range configsWithVFS2(t, overlay) { + for name, conf := range configsWithVFS2(t, all...) { t.Run(name, func(t *testing.T) { - spec := testutil.NewSpecWithArgs("/bin/touch", "/foo") + spec := testutil.NewSpecWithArgs("sleep", "100") spec.Root.Readonly = true + _, bundleDir, cleanup, err := testutil.SetupContainer(spec, conf) if err != nil { t.Fatalf("error setting up container: %v", err) } defer cleanup() - // Create, start and wait for the container. args := Args{ ID: testutil.RandomContainerID(), Spec: spec, @@ -1524,12 +1515,82 @@ func TestReadonlyRoot(t *testing.T) { t.Fatalf("error starting container: %v", err) } - ws, err := c.Wait() + // Read mounts to check that root is readonly. + out, ws, err := executeCombinedOutput(c, "/bin/sh", "-c", "mount | grep ' / '") + if err != nil || ws != 0 { + t.Fatalf("exec failed, ws: %v, err: %v", ws, err) + } + t.Logf("root mount: %q", out) + if !strings.Contains(string(out), "(ro)") { + t.Errorf("root not mounted readonly: %q", out) + } + + // Check that file cannot be created. + ws, err = execute(c, "/bin/touch", "/foo") if err != nil { - t.Fatalf("error waiting on container: %v", err) + t.Fatalf("touch file in ro mount: %v", err) } if !ws.Exited() || syscall.Errno(ws.ExitStatus()) != syscall.EPERM { - t.Fatalf("container failed, waitStatus: %v", ws) + t.Fatalf("wrong waitStatus: %v", ws) + } + }) + } +} + +func TestReadonlyMount(t *testing.T) { + for name, conf := range configsWithVFS2(t, all...) { + t.Run(name, func(t *testing.T) { + dir, err := ioutil.TempDir(testutil.TmpDir(), "ro-mount") + if err != nil { + t.Fatalf("ioutil.TempDir() failed: %v", err) + } + spec := testutil.NewSpecWithArgs("sleep", "100") + spec.Mounts = append(spec.Mounts, specs.Mount{ + Destination: dir, + Source: dir, + Type: "bind", + Options: []string{"ro"}, + }) + spec.Root.Readonly = false + + _, bundleDir, cleanup, err := testutil.SetupContainer(spec, conf) + if err != nil { + t.Fatalf("error setting up container: %v", err) + } + defer cleanup() + + args := Args{ + ID: testutil.RandomContainerID(), + Spec: spec, + BundleDir: bundleDir, + } + c, err := New(conf, args) + if err != nil { + t.Fatalf("error creating container: %v", err) + } + defer c.Destroy() + if err := c.Start(conf); err != nil { + t.Fatalf("error starting container: %v", err) + } + + // Read mounts to check that volume is readonly. + cmd := fmt.Sprintf("mount | grep ' %s '", dir) + out, ws, err := executeCombinedOutput(c, "/bin/sh", "-c", cmd) + if err != nil || ws != 0 { + t.Fatalf("exec failed, ws: %v, err: %v", ws, err) + } + t.Logf("mount: %q", out) + if !strings.Contains(string(out), "(ro)") { + t.Errorf("volume not mounted readonly: %q", out) + } + + // Check that file cannot be created. + ws, err = execute(c, "/bin/touch", path.Join(dir, "file")) + if err != nil { + t.Fatalf("touch file in ro mount: %v", err) + } + if !ws.Exited() || syscall.Errno(ws.ExitStatus()) != syscall.EPERM { + t.Fatalf("wrong WaitStatus: %v", ws) } }) } @@ -1616,54 +1677,6 @@ func TestUIDMap(t *testing.T) { } } -func TestReadonlyMount(t *testing.T) { - for name, conf := range configsWithVFS2(t, overlay) { - t.Run(name, func(t *testing.T) { - dir, err := ioutil.TempDir(testutil.TmpDir(), "ro-mount") - spec := testutil.NewSpecWithArgs("/bin/touch", path.Join(dir, "file")) - if err != nil { - t.Fatalf("ioutil.TempDir() failed: %v", err) - } - spec.Mounts = append(spec.Mounts, specs.Mount{ - Destination: dir, - Source: dir, - Type: "bind", - Options: []string{"ro"}, - }) - spec.Root.Readonly = false - - _, bundleDir, cleanup, err := testutil.SetupContainer(spec, conf) - if err != nil { - t.Fatalf("error setting up container: %v", err) - } - defer cleanup() - - // Create, start and wait for the container. - args := Args{ - ID: testutil.RandomContainerID(), - Spec: spec, - BundleDir: bundleDir, - } - c, err := New(conf, args) - if err != nil { - t.Fatalf("error creating container: %v", err) - } - defer c.Destroy() - if err := c.Start(conf); err != nil { - t.Fatalf("error starting container: %v", err) - } - - ws, err := c.Wait() - if err != nil { - t.Fatalf("error waiting on container: %v", err) - } - if !ws.Exited() || syscall.Errno(ws.ExitStatus()) != syscall.EPERM { - t.Fatalf("container failed, waitStatus: %v", ws) - } - }) - } -} - // TestAbbreviatedIDs checks that runsc supports using abbreviated container // IDs in place of full IDs. func TestAbbreviatedIDs(t *testing.T) { @@ -2013,7 +2026,7 @@ func doDestroyStartingTest(t *testing.T, vfs2 bool) { } func TestCreateWorkingDir(t *testing.T) { - for name, conf := range configsWithVFS2(t, overlay) { + for name, conf := range configsWithVFS2(t, all...) { t.Run(name, func(t *testing.T) { tmpDir, err := ioutil.TempDir(testutil.TmpDir(), "cwd-create") if err != nil { @@ -2116,27 +2129,19 @@ func TestMountPropagation(t *testing.T) { // Check that mount didn't propagate to private mount. privFile := filepath.Join(priv, "mnt", "file") - execArgs := &control.ExecArgs{ - Filename: "/usr/bin/test", - Argv: []string{"test", "!", "-f", privFile}, - } - if ws, err := cont.executeSync(execArgs); err != nil || ws != 0 { + if ws, err := execute(cont, "/usr/bin/test", "!", "-f", privFile); err != nil || ws != 0 { t.Fatalf("exec: test ! -f %q, ws: %v, err: %v", privFile, ws, err) } // Check that mount propagated to slave mount. slaveFile := filepath.Join(slave, "mnt", "file") - execArgs = &control.ExecArgs{ - Filename: "/usr/bin/test", - Argv: []string{"test", "-f", slaveFile}, - } - if ws, err := cont.executeSync(execArgs); err != nil || ws != 0 { + if ws, err := execute(cont, "/usr/bin/test", "-f", slaveFile); err != nil || ws != 0 { t.Fatalf("exec: test -f %q, ws: %v, err: %v", privFile, ws, err) } } func TestMountSymlink(t *testing.T) { - for name, conf := range configsWithVFS2(t, overlay) { + for name, conf := range configsWithVFS2(t, all...) { t.Run(name, func(t *testing.T) { dir, err := ioutil.TempDir(testutil.TmpDir(), "mount-symlink") if err != nil { @@ -2196,11 +2201,7 @@ func TestMountSymlink(t *testing.T) { // Check that symlink was resolved and mount was created where the symlink // is pointing to. file := path.Join(target, "file") - execArgs := &control.ExecArgs{ - Filename: "/usr/bin/test", - Argv: []string{"test", "-f", file}, - } - if ws, err := cont.executeSync(execArgs); err != nil || ws != 0 { + if ws, err := execute(cont, "/usr/bin/test", "-f", file); err != nil || ws != 0 { t.Fatalf("exec: test -f %q, ws: %v, err: %v", file, ws, err) } }) @@ -2326,6 +2327,35 @@ func TestTTYField(t *testing.T) { } } +func execute(cont *Container, name string, arg ...string) (syscall.WaitStatus, error) { + args := &control.ExecArgs{ + Filename: name, + Argv: append([]string{name}, arg...), + } + return cont.executeSync(args) +} + +func executeCombinedOutput(cont *Container, name string, arg ...string) ([]byte, syscall.WaitStatus, error) { + r, w, err := os.Pipe() + if err != nil { + return nil, 0, err + } + defer r.Close() + + args := &control.ExecArgs{ + Filename: name, + Argv: append([]string{name}, arg...), + FilePayload: urpc.FilePayload{Files: []*os.File{os.Stdin, w, w}}, + } + ws, err := cont.executeSync(args) + w.Close() + if err != nil { + return nil, 0, err + } + out, err := ioutil.ReadAll(r) + return out, ws, err +} + // executeSync synchronously executes a new process. func (cont *Container) executeSync(args *control.ExecArgs) (syscall.WaitStatus, error) { pid, err := cont.Execute(args) diff --git a/runsc/container/multi_container_test.go b/runsc/container/multi_container_test.go index e189648f4..952215ec1 100644 --- a/runsc/container/multi_container_test.go +++ b/runsc/container/multi_container_test.go @@ -33,6 +33,7 @@ import ( "gvisor.dev/gvisor/pkg/sync" "gvisor.dev/gvisor/pkg/test/testutil" "gvisor.dev/gvisor/runsc/boot" + "gvisor.dev/gvisor/runsc/config" "gvisor.dev/gvisor/runsc/specutils" ) @@ -60,7 +61,7 @@ func createSpecs(cmds ...[]string) ([]*specs.Spec, []string) { return specs, ids } -func startContainers(conf *boot.Config, specs []*specs.Spec, ids []string) ([]*Container, func(), error) { +func startContainers(conf *config.Config, specs []*specs.Spec, ids []string) ([]*Container, func(), error) { if len(conf.RootDir) == 0 { panic("conf.RootDir not set. Call testutil.SetupRootDir() to set.") } @@ -168,7 +169,7 @@ func TestMultiContainerSanity(t *testing.T) { // TestMultiPIDNS checks that it is possible to run 2 dead-simple // containers in the same sandbox with different pidns. func TestMultiPIDNS(t *testing.T) { - for name, conf := range configs(t, all...) { + for name, conf := range configsWithVFS2(t, all...) { t.Run(name, func(t *testing.T) { rootDir, cleanup, err := testutil.SetupRootDir() if err != nil { @@ -213,7 +214,7 @@ func TestMultiPIDNS(t *testing.T) { // TestMultiPIDNSPath checks the pidns path. func TestMultiPIDNSPath(t *testing.T) { - for name, conf := range configs(t, all...) { + for name, conf := range configsWithVFS2(t, all...) { t.Run(name, func(t *testing.T) { rootDir, cleanup, err := testutil.SetupRootDir() if err != nil { @@ -579,7 +580,7 @@ func TestMultiContainerDestroy(t *testing.T) { t.Fatal("error finding test_app:", err) } - for name, conf := range configs(t, all...) { + for name, conf := range configsWithVFS2(t, all...) { t.Run(name, func(t *testing.T) { rootDir, cleanup, err := testutil.SetupRootDir() if err != nil { @@ -1251,8 +1252,7 @@ func TestMultiContainerSharedMountReadonly(t *testing.T) { // Test that shared pod mounts continue to work after container is restarted. func TestMultiContainerSharedMountRestart(t *testing.T) { - //TODO(gvisor.dev/issue/1487): This is failing with VFS2. - for name, conf := range configs(t, all...) { + for name, conf := range configsWithVFS2(t, all...) { t.Run(name, func(t *testing.T) { rootDir, cleanup, err := testutil.SetupRootDir() if err != nil { @@ -1359,7 +1359,7 @@ func TestMultiContainerSharedMountRestart(t *testing.T) { } // Test that unsupported pod mounts options are ignored when matching master and -// slave mounts. +// replica mounts. func TestMultiContainerSharedMountUnsupportedOptions(t *testing.T) { for name, conf := range configsWithVFS2(t, all...) { t.Run(name, func(t *testing.T) { @@ -1517,8 +1517,7 @@ func TestMultiContainerGoferKilled(t *testing.T) { } // Check that container isn't running anymore. - args := &control.ExecArgs{Argv: []string{"/bin/true"}} - if _, err := c.executeSync(args); err == nil { + if _, err := execute(c, "/bin/true"); err == nil { t.Fatalf("Container %q was not stopped after gofer death", c.ID) } @@ -1533,8 +1532,7 @@ func TestMultiContainerGoferKilled(t *testing.T) { if err := waitForProcessList(c, pl); err != nil { t.Errorf("Container %q was affected by another container: %v", c.ID, err) } - args := &control.ExecArgs{Argv: []string{"/bin/true"}} - if _, err := c.executeSync(args); err != nil { + if _, err := execute(c, "/bin/true"); err != nil { t.Fatalf("Container %q was affected by another container: %v", c.ID, err) } } @@ -1556,8 +1554,7 @@ func TestMultiContainerGoferKilled(t *testing.T) { // Check that entire sandbox isn't running anymore. for _, c := range containers { - args := &control.ExecArgs{Argv: []string{"/bin/true"}} - if _, err := c.executeSync(args); err == nil { + if _, err := execute(c, "/bin/true"); err == nil { t.Fatalf("Container %q was not stopped after gofer death", c.ID) } } @@ -1719,12 +1716,11 @@ func TestMultiContainerHomeEnvDir(t *testing.T) { homeDirs[name] = homeFile } - // We will sleep in the root container in order to ensure that - // the root container doesn't terminate before sub containers can be - // created. + // We will sleep in the root container in order to ensure that the root + //container doesn't terminate before sub containers can be created. rootCmd := []string{"/bin/sh", "-c", fmt.Sprintf("printf \"$HOME\" > %s; sleep 1000", homeDirs["root"].Name())} subCmd := []string{"/bin/sh", "-c", fmt.Sprintf("printf \"$HOME\" > %s", homeDirs["sub"].Name())} - execCmd := []string{"/bin/sh", "-c", fmt.Sprintf("printf \"$HOME\" > %s", homeDirs["exec"].Name())} + execCmd := fmt.Sprintf("printf \"$HOME\" > %s", homeDirs["exec"].Name()) // Setup the containers, a root container and sub container. specConfig, ids := createSpecs(rootCmd, subCmd) @@ -1735,9 +1731,8 @@ func TestMultiContainerHomeEnvDir(t *testing.T) { defer cleanup() // Exec into the root container synchronously. - args := &control.ExecArgs{Argv: execCmd} - if _, err := containers[0].executeSync(args); err != nil { - t.Errorf("error executing %+v: %v", args, err) + if _, err := execute(containers[0], "/bin/sh", "-c", execCmd); err != nil { + t.Errorf("error executing %+v: %v", execCmd, err) } // Wait for the subcontainer to finish. diff --git a/runsc/container/shared_volume_test.go b/runsc/container/shared_volume_test.go index bac177a88..cb5bffb89 100644 --- a/runsc/container/shared_volume_test.go +++ b/runsc/container/shared_volume_test.go @@ -25,14 +25,14 @@ import ( "gvisor.dev/gvisor/pkg/sentry/control" "gvisor.dev/gvisor/pkg/sentry/kernel/auth" "gvisor.dev/gvisor/pkg/test/testutil" - "gvisor.dev/gvisor/runsc/boot" + "gvisor.dev/gvisor/runsc/config" ) // TestSharedVolume checks that modifications to a volume mount are propagated // into and out of the sandbox. func TestSharedVolume(t *testing.T) { conf := testutil.TestConfig(t) - conf.FileAccess = boot.FileAccessShared + conf.FileAccess = config.FileAccessShared // Main process just sleeps. We will use "exec" to probe the state of // the filesystem. @@ -168,11 +168,7 @@ func TestSharedVolume(t *testing.T) { func checkFile(c *Container, filename string, want []byte) error { cpy := filename + ".copy" - argsCp := &control.ExecArgs{ - Filename: "/bin/cp", - Argv: []string{"cp", "-f", filename, cpy}, - } - if _, err := c.executeSync(argsCp); err != nil { + if _, err := execute(c, "/bin/cp", "-f", filename, cpy); err != nil { return fmt.Errorf("unexpected error copying file %q to %q: %v", filename, cpy, err) } got, err := ioutil.ReadFile(cpy) @@ -189,7 +185,7 @@ func checkFile(c *Container, filename string, want []byte) error { // is reflected inside. func TestSharedVolumeFile(t *testing.T) { conf := testutil.TestConfig(t) - conf.FileAccess = boot.FileAccessShared + conf.FileAccess = config.FileAccessShared // Main process just sleeps. We will use "exec" to probe the state of // the filesystem. @@ -235,11 +231,7 @@ func TestSharedVolumeFile(t *testing.T) { } // Append to file inside the container and check that content is not lost. - argsAppend := &control.ExecArgs{ - Filename: "/bin/bash", - Argv: []string{"bash", "-c", "echo -n sandbox- >> " + filename}, - } - if _, err := c.executeSync(argsAppend); err != nil { + if _, err := execute(c, "/bin/bash", "-c", "echo -n sandbox- >> "+filename); err != nil { t.Fatalf("unexpected error appending file %q: %v", filename, err) } want = []byte("host-sandbox-") diff --git a/runsc/flag/flag.go b/runsc/flag/flag.go index 0ca4829d7..ba1ff833f 100644 --- a/runsc/flag/flag.go +++ b/runsc/flag/flag.go @@ -21,13 +21,19 @@ import ( type FlagSet = flag.FlagSet var ( - NewFlagSet = flag.NewFlagSet - String = flag.String Bool = flag.Bool - Int = flag.Int - Uint = flag.Uint CommandLine = flag.CommandLine + Int = flag.Int + NewFlagSet = flag.NewFlagSet Parse = flag.Parse + String = flag.String + Uint = flag.Uint + Var = flag.Var ) const ContinueOnError = flag.ContinueOnError + +// Get returns the flag's underlying object. +func Get(v flag.Value) interface{} { + return v.(flag.Getter).Get() +} diff --git a/runsc/fsgofer/BUILD b/runsc/fsgofer/BUILD index 05e3637f7..96c57a426 100644 --- a/runsc/fsgofer/BUILD +++ b/runsc/fsgofer/BUILD @@ -32,5 +32,6 @@ go_test( "//pkg/log", "//pkg/p9", "//pkg/test/testutil", + "@org_golang_x_sys//unix:go_default_library", ], ) diff --git a/runsc/fsgofer/filter/config.go b/runsc/fsgofer/filter/config.go index 88814b83c..0cb9b1cae 100644 --- a/runsc/fsgofer/filter/config.go +++ b/runsc/fsgofer/filter/config.go @@ -29,7 +29,7 @@ var allowedSyscalls = seccomp.SyscallRules{ syscall.SYS_CLOCK_GETTIME: {}, syscall.SYS_CLONE: []seccomp.Rule{ { - seccomp.AllowValue( + seccomp.EqualTo( syscall.CLONE_VM | syscall.CLONE_FS | syscall.CLONE_FILES | @@ -43,46 +43,46 @@ var allowedSyscalls = seccomp.SyscallRules{ syscall.SYS_EPOLL_CTL: {}, syscall.SYS_EPOLL_PWAIT: []seccomp.Rule{ { - seccomp.AllowAny{}, - seccomp.AllowAny{}, - seccomp.AllowAny{}, - seccomp.AllowAny{}, - seccomp.AllowValue(0), + seccomp.MatchAny{}, + seccomp.MatchAny{}, + seccomp.MatchAny{}, + seccomp.MatchAny{}, + seccomp.EqualTo(0), }, }, syscall.SYS_EVENTFD2: []seccomp.Rule{ { - seccomp.AllowValue(0), - seccomp.AllowValue(0), + seccomp.EqualTo(0), + seccomp.EqualTo(0), }, }, syscall.SYS_EXIT: {}, syscall.SYS_EXIT_GROUP: {}, syscall.SYS_FALLOCATE: []seccomp.Rule{ { - seccomp.AllowAny{}, - seccomp.AllowValue(0), + seccomp.MatchAny{}, + seccomp.EqualTo(0), }, }, syscall.SYS_FCHMOD: {}, syscall.SYS_FCHOWNAT: {}, syscall.SYS_FCNTL: []seccomp.Rule{ { - seccomp.AllowAny{}, - seccomp.AllowValue(syscall.F_GETFL), + seccomp.MatchAny{}, + seccomp.EqualTo(syscall.F_GETFL), }, { - seccomp.AllowAny{}, - seccomp.AllowValue(syscall.F_SETFL), + seccomp.MatchAny{}, + seccomp.EqualTo(syscall.F_SETFL), }, { - seccomp.AllowAny{}, - seccomp.AllowValue(syscall.F_GETFD), + seccomp.MatchAny{}, + seccomp.EqualTo(syscall.F_GETFD), }, // Used by flipcall.PacketWindowAllocator.Init(). { - seccomp.AllowAny{}, - seccomp.AllowValue(unix.F_ADD_SEALS), + seccomp.MatchAny{}, + seccomp.EqualTo(unix.F_ADD_SEALS), }, }, syscall.SYS_FSTAT: {}, @@ -91,31 +91,31 @@ var allowedSyscalls = seccomp.SyscallRules{ syscall.SYS_FTRUNCATE: {}, syscall.SYS_FUTEX: { seccomp.Rule{ - seccomp.AllowAny{}, - seccomp.AllowValue(linux.FUTEX_WAIT | linux.FUTEX_PRIVATE_FLAG), - seccomp.AllowAny{}, - seccomp.AllowAny{}, - seccomp.AllowValue(0), + seccomp.MatchAny{}, + seccomp.EqualTo(linux.FUTEX_WAIT | linux.FUTEX_PRIVATE_FLAG), + seccomp.MatchAny{}, + seccomp.MatchAny{}, + seccomp.EqualTo(0), }, seccomp.Rule{ - seccomp.AllowAny{}, - seccomp.AllowValue(linux.FUTEX_WAKE | linux.FUTEX_PRIVATE_FLAG), - seccomp.AllowAny{}, - seccomp.AllowAny{}, - seccomp.AllowValue(0), + seccomp.MatchAny{}, + seccomp.EqualTo(linux.FUTEX_WAKE | linux.FUTEX_PRIVATE_FLAG), + seccomp.MatchAny{}, + seccomp.MatchAny{}, + seccomp.EqualTo(0), }, // Non-private futex used for flipcall. seccomp.Rule{ - seccomp.AllowAny{}, - seccomp.AllowValue(linux.FUTEX_WAIT), - seccomp.AllowAny{}, - seccomp.AllowAny{}, + seccomp.MatchAny{}, + seccomp.EqualTo(linux.FUTEX_WAIT), + seccomp.MatchAny{}, + seccomp.MatchAny{}, }, seccomp.Rule{ - seccomp.AllowAny{}, - seccomp.AllowValue(linux.FUTEX_WAKE), - seccomp.AllowAny{}, - seccomp.AllowAny{}, + seccomp.MatchAny{}, + seccomp.EqualTo(linux.FUTEX_WAKE), + seccomp.MatchAny{}, + seccomp.MatchAny{}, }, }, syscall.SYS_GETDENTS64: {}, @@ -137,28 +137,28 @@ var allowedSyscalls = seccomp.SyscallRules{ // TODO(b/148688965): Remove once this is gone from Go. syscall.SYS_MLOCK: []seccomp.Rule{ { - seccomp.AllowAny{}, - seccomp.AllowValue(4096), + seccomp.MatchAny{}, + seccomp.EqualTo(4096), }, }, syscall.SYS_MMAP: []seccomp.Rule{ { - seccomp.AllowAny{}, - seccomp.AllowAny{}, - seccomp.AllowAny{}, - seccomp.AllowValue(syscall.MAP_SHARED), + seccomp.MatchAny{}, + seccomp.MatchAny{}, + seccomp.MatchAny{}, + seccomp.EqualTo(syscall.MAP_SHARED), }, { - seccomp.AllowAny{}, - seccomp.AllowAny{}, - seccomp.AllowAny{}, - seccomp.AllowValue(syscall.MAP_PRIVATE | syscall.MAP_ANONYMOUS), + seccomp.MatchAny{}, + seccomp.MatchAny{}, + seccomp.MatchAny{}, + seccomp.EqualTo(syscall.MAP_PRIVATE | syscall.MAP_ANONYMOUS), }, { - seccomp.AllowAny{}, - seccomp.AllowAny{}, - seccomp.AllowAny{}, - seccomp.AllowValue(syscall.MAP_PRIVATE | syscall.MAP_ANONYMOUS | syscall.MAP_FIXED), + seccomp.MatchAny{}, + seccomp.MatchAny{}, + seccomp.MatchAny{}, + seccomp.EqualTo(syscall.MAP_PRIVATE | syscall.MAP_ANONYMOUS | syscall.MAP_FIXED), }, }, syscall.SYS_MPROTECT: {}, @@ -172,14 +172,14 @@ var allowedSyscalls = seccomp.SyscallRules{ syscall.SYS_READLINKAT: {}, syscall.SYS_RECVMSG: []seccomp.Rule{ { - seccomp.AllowAny{}, - seccomp.AllowAny{}, - seccomp.AllowValue(syscall.MSG_DONTWAIT | syscall.MSG_TRUNC), + seccomp.MatchAny{}, + seccomp.MatchAny{}, + seccomp.EqualTo(syscall.MSG_DONTWAIT | syscall.MSG_TRUNC), }, { - seccomp.AllowAny{}, - seccomp.AllowAny{}, - seccomp.AllowValue(syscall.MSG_DONTWAIT | syscall.MSG_TRUNC | syscall.MSG_PEEK), + seccomp.MatchAny{}, + seccomp.MatchAny{}, + seccomp.EqualTo(syscall.MSG_DONTWAIT | syscall.MSG_TRUNC | syscall.MSG_PEEK), }, }, syscall.SYS_RENAMEAT: {}, @@ -190,33 +190,33 @@ var allowedSyscalls = seccomp.SyscallRules{ syscall.SYS_SENDMSG: []seccomp.Rule{ // Used by fdchannel.Endpoint.SendFD(). { - seccomp.AllowAny{}, - seccomp.AllowAny{}, - seccomp.AllowValue(0), + seccomp.MatchAny{}, + seccomp.MatchAny{}, + seccomp.EqualTo(0), }, // Used by unet.SocketWriter.WriteVec(). { - seccomp.AllowAny{}, - seccomp.AllowAny{}, - seccomp.AllowValue(syscall.MSG_DONTWAIT | syscall.MSG_NOSIGNAL), + seccomp.MatchAny{}, + seccomp.MatchAny{}, + seccomp.EqualTo(syscall.MSG_DONTWAIT | syscall.MSG_NOSIGNAL), }, }, syscall.SYS_SHUTDOWN: []seccomp.Rule{ - {seccomp.AllowAny{}, seccomp.AllowValue(syscall.SHUT_RDWR)}, + {seccomp.MatchAny{}, seccomp.EqualTo(syscall.SHUT_RDWR)}, }, syscall.SYS_SIGALTSTACK: {}, // Used by fdchannel.NewConnectedSockets(). syscall.SYS_SOCKETPAIR: { { - seccomp.AllowValue(syscall.AF_UNIX), - seccomp.AllowValue(syscall.SOCK_SEQPACKET | syscall.SOCK_CLOEXEC), - seccomp.AllowValue(0), + seccomp.EqualTo(syscall.AF_UNIX), + seccomp.EqualTo(syscall.SOCK_SEQPACKET | syscall.SOCK_CLOEXEC), + seccomp.EqualTo(0), }, }, syscall.SYS_SYMLINKAT: {}, syscall.SYS_TGKILL: []seccomp.Rule{ { - seccomp.AllowValue(uint64(os.Getpid())), + seccomp.EqualTo(uint64(os.Getpid())), }, }, syscall.SYS_UNLINKAT: {}, @@ -227,24 +227,24 @@ var allowedSyscalls = seccomp.SyscallRules{ var udsSyscalls = seccomp.SyscallRules{ syscall.SYS_SOCKET: []seccomp.Rule{ { - seccomp.AllowValue(syscall.AF_UNIX), - seccomp.AllowValue(syscall.SOCK_STREAM), - seccomp.AllowValue(0), + seccomp.EqualTo(syscall.AF_UNIX), + seccomp.EqualTo(syscall.SOCK_STREAM), + seccomp.EqualTo(0), }, { - seccomp.AllowValue(syscall.AF_UNIX), - seccomp.AllowValue(syscall.SOCK_DGRAM), - seccomp.AllowValue(0), + seccomp.EqualTo(syscall.AF_UNIX), + seccomp.EqualTo(syscall.SOCK_DGRAM), + seccomp.EqualTo(0), }, { - seccomp.AllowValue(syscall.AF_UNIX), - seccomp.AllowValue(syscall.SOCK_SEQPACKET), - seccomp.AllowValue(0), + seccomp.EqualTo(syscall.AF_UNIX), + seccomp.EqualTo(syscall.SOCK_SEQPACKET), + seccomp.EqualTo(0), }, }, syscall.SYS_CONNECT: []seccomp.Rule{ { - seccomp.AllowAny{}, + seccomp.MatchAny{}, }, }, } diff --git a/runsc/fsgofer/filter/config_amd64.go b/runsc/fsgofer/filter/config_amd64.go index a4b28cb8b..39f9851a8 100644 --- a/runsc/fsgofer/filter/config_amd64.go +++ b/runsc/fsgofer/filter/config_amd64.go @@ -25,8 +25,7 @@ import ( func init() { allowedSyscalls[syscall.SYS_ARCH_PRCTL] = []seccomp.Rule{ - {seccomp.AllowValue(linux.ARCH_GET_FS)}, - {seccomp.AllowValue(linux.ARCH_SET_FS)}, + {seccomp.EqualTo(linux.ARCH_SET_FS)}, } allowedSyscalls[syscall.SYS_NEWFSTATAT] = []seccomp.Rule{} diff --git a/runsc/fsgofer/fsgofer.go b/runsc/fsgofer/fsgofer.go index c6694c278..0b628c8ce 100644 --- a/runsc/fsgofer/fsgofer.go +++ b/runsc/fsgofer/fsgofer.go @@ -29,7 +29,6 @@ import ( "path/filepath" "runtime" "strconv" - "syscall" "golang.org/x/sys/unix" "gvisor.dev/gvisor/pkg/abi/linux" @@ -45,7 +44,9 @@ const ( // modes to ensure an unopened/closed file fails all mode checks. invalidMode = p9.OpenFlags(math.MaxUint32) - openFlags = syscall.O_NOFOLLOW | syscall.O_CLOEXEC + openFlags = unix.O_NOFOLLOW | unix.O_CLOEXEC + + allowedOpenFlags = unix.O_TRUNC ) // Config sets configuration options for each attach point. @@ -123,7 +124,7 @@ func (a *attachPoint) Attach() (p9.File, error) { } // makeQID returns a unique QID for the given stat buffer. -func (a *attachPoint) makeQID(stat syscall.Stat_t) p9.QID { +func (a *attachPoint) makeQID(stat unix.Stat_t) p9.QID { a.deviceMu.Lock() defer a.deviceMu.Unlock() @@ -154,9 +155,7 @@ func (a *attachPoint) makeQID(stat syscall.Stat_t) p9.QID { // localFile implements p9.File wrapping a local file. The underlying file // is opened during Walk() and stored in 'file' to be used with other // operations. The file is opened as readonly, unless it's a symlink or there is -// no read access, which requires O_PATH. 'file' is dup'ed when Walk(nil) is -// called to clone the file. This reduces the number of walks that need to be -// done by the host file system when files are reused. +// no read access, which requires O_PATH. // // The file may be reopened if the requested mode in Open() is not a subset of // current mode. Consequently, 'file' could have a mode wider than requested and @@ -168,11 +167,30 @@ func (a *attachPoint) makeQID(stat syscall.Stat_t) p9.QID { // performance with 'overlay2' storage driver. overlay2 eagerly copies the // entire file up when it's opened in write mode, and would perform badly when // multiple files are only being opened for read (esp. startup). +// +// File operations must use "at" functions whenever possible: +// * Local operations must use AT_EMPTY_PATH: +// fchownat(fd, "", AT_EMPTY_PATH, ...), instead of chown(fullpath, ...) +// * Creation operations must use (fd + name): +// mkdirat(fd, name, ...), instead of mkdir(fullpath, ...) +// +// Apart from being faster, it also adds another layer of defense against +// symlink attacks (note that O_NOFOLLOW applies only to the last element in +// the path). +// +// The few exceptions where it cannot be done are: utimensat on symlinks, and +// Connect() for the socket address. type localFile struct { + p9.DisallowClientCalls + // attachPoint is the attachPoint that serves this localFile. attachPoint *attachPoint - // hostPath will be safely updated by the Renamed hook. + // hostPath is the full path to the host file. It can be used for logging and + // the few cases where full path is required to operation the host file. In + // all other cases, use "file" directly. + // + // Note: it's safely updated by the Renamed hook. hostPath string // file is opened when localFile is created and it's never nil. It may be @@ -189,7 +207,7 @@ type localFile struct { mode p9.OpenFlags // fileType for this file. It is equivalent to: - // syscall.Stat_t.Mode & syscall.S_IFMT + // unix.Stat_t.Mode & unix.S_IFMT fileType uint32 qid p9.QID @@ -209,7 +227,7 @@ var procSelfFD *fd.FD // OpenProcSelfFD opens the /proc/self/fd directory, which will be used to // reopen file descriptors. func OpenProcSelfFD() error { - d, err := syscall.Open("/proc/self/fd", syscall.O_RDONLY|syscall.O_DIRECTORY, 0) + d, err := unix.Open("/proc/self/fd", unix.O_RDONLY|unix.O_DIRECTORY, 0) if err != nil { return fmt.Errorf("error opening /proc/self/fd: %v", err) } @@ -218,7 +236,7 @@ func OpenProcSelfFD() error { } func reopenProcFd(f *fd.FD, mode int) (*fd.FD, error) { - d, err := syscall.Openat(int(procSelfFD.FD()), strconv.Itoa(f.FD()), mode&^syscall.O_NOFOLLOW, 0) + d, err := unix.Openat(int(procSelfFD.FD()), strconv.Itoa(f.FD()), mode&^unix.O_NOFOLLOW, 0) if err != nil { return nil, err } @@ -227,17 +245,17 @@ func reopenProcFd(f *fd.FD, mode int) (*fd.FD, error) { } func openAnyFileFromParent(parent *localFile, name string) (*fd.FD, string, bool, error) { - path := path.Join(parent.hostPath, name) - f, readable, err := openAnyFile(path, func(mode int) (*fd.FD, error) { + pathDebug := path.Join(parent.hostPath, name) + f, readable, err := openAnyFile(pathDebug, func(mode int) (*fd.FD, error) { return fd.OpenAt(parent.file, name, openFlags|mode, 0) }) - return f, path, readable, err + return f, pathDebug, readable, err } -// openAnyFile attempts to open the file in O_RDONLY and if it fails fallsback +// openAnyFile attempts to open the file in O_RDONLY. If it fails, falls back // to O_PATH. 'path' is used for logging messages only. 'fn' is what does the // actual file open and is customizable by the caller. -func openAnyFile(path string, fn func(mode int) (*fd.FD, error)) (*fd.FD, bool, error) { +func openAnyFile(pathDebug string, fn func(mode int) (*fd.FD, error)) (*fd.FD, bool, error) { // Attempt to open file in the following mode in order: // 1. RDONLY | NONBLOCK: for all files, directories, ro mounts, FIFOs. // Use non-blocking to prevent getting stuck inside open(2) for @@ -248,7 +266,7 @@ func openAnyFile(path string, fn func(mode int) (*fd.FD, error)) (*fd.FD, bool, readable bool }{ { - mode: syscall.O_RDONLY | syscall.O_NONBLOCK, + mode: unix.O_RDONLY | unix.O_NONBLOCK, readable: true, }, { @@ -266,36 +284,36 @@ func openAnyFile(path string, fn func(mode int) (*fd.FD, error)) (*fd.FD, bool, return file, option.readable, nil } switch e := extractErrno(err); e { - case syscall.ENOENT: + case unix.ENOENT: // File doesn't exist, no point in retrying. return nil, false, e } // File failed to open. Try again with next mode, preserving 'err' in case // this was the last attempt. - log.Debugf("Attempt %d to open file failed, mode: %#x, path: %q, err: %v", i, openFlags|option.mode, path, err) + log.Debugf("Attempt %d to open file failed, mode: %#x, path: %q, err: %v", i, openFlags|option.mode, pathDebug, err) } // All attempts to open file have failed, return the last error. - log.Debugf("Failed to open file, path: %q, err: %v", path, err) + log.Debugf("Failed to open file, path: %q, err: %v", pathDebug, err) return nil, false, extractErrno(err) } -func checkSupportedFileType(stat syscall.Stat_t, permitSocket bool) error { - switch stat.Mode & syscall.S_IFMT { - case syscall.S_IFREG, syscall.S_IFDIR, syscall.S_IFLNK: +func checkSupportedFileType(stat unix.Stat_t, permitSocket bool) error { + switch stat.Mode & unix.S_IFMT { + case unix.S_IFREG, unix.S_IFDIR, unix.S_IFLNK: return nil - case syscall.S_IFSOCK: + case unix.S_IFSOCK: if !permitSocket { - return syscall.EPERM + return unix.EPERM } return nil default: - return syscall.EPERM + return unix.EPERM } } -func newLocalFile(a *attachPoint, file *fd.FD, path string, readable bool, stat syscall.Stat_t) (*localFile, error) { +func newLocalFile(a *attachPoint, file *fd.FD, path string, readable bool, stat unix.Stat_t) (*localFile, error) { if err := checkSupportedFileType(stat, a.conf.HostUDS); err != nil { return nil, err } @@ -305,7 +323,7 @@ func newLocalFile(a *attachPoint, file *fd.FD, path string, readable bool, stat hostPath: path, file: file, mode: invalidMode, - fileType: stat.Mode & syscall.S_IFMT, + fileType: stat.Mode & unix.S_IFMT, qid: a.makeQID(stat), controlReadable: readable, }, nil @@ -315,7 +333,7 @@ func newLocalFile(a *attachPoint, file *fd.FD, path string, readable bool, stat // non-blocking. If anything fails, returns nil. It's better to have a file // without host FD, than to fail the operation. func newFDMaybe(file *fd.FD) *fd.FD { - dupFD, err := syscall.Dup(file.FD()) + dupFD, err := unix.Dup(file.FD()) // Technically, the runtime may call the finalizer on file as soon as // FD() returns. runtime.KeepAlive(file) @@ -325,31 +343,23 @@ func newFDMaybe(file *fd.FD) *fd.FD { dup := fd.New(dupFD) // fd is blocking; non-blocking is required. - if err := syscall.SetNonblock(dup.FD(), true); err != nil { + if err := unix.SetNonblock(dup.FD(), true); err != nil { _ = dup.Close() return nil } return dup } -func fstat(fd int) (syscall.Stat_t, error) { - var stat syscall.Stat_t - if err := syscall.Fstat(fd, &stat); err != nil { - return syscall.Stat_t{}, err - } - return stat, nil -} - -func stat(path string) (syscall.Stat_t, error) { - var stat syscall.Stat_t - if err := syscall.Stat(path, &stat); err != nil { - return syscall.Stat_t{}, err +func fstat(fd int) (unix.Stat_t, error) { + var stat unix.Stat_t + if err := unix.Fstat(fd, &stat); err != nil { + return unix.Stat_t{}, err } return stat, nil } func fchown(fd int, uid p9.UID, gid p9.GID) error { - return syscall.Fchownat(fd, "", int(uid), int(gid), linux.AT_EMPTY_PATH|unix.AT_SYMLINK_NOFOLLOW) + return unix.Fchownat(fd, "", int(uid), int(gid), linux.AT_EMPTY_PATH|unix.AT_SYMLINK_NOFOLLOW) } // Open implements p9.File. @@ -357,10 +367,16 @@ func (l *localFile) Open(flags p9.OpenFlags) (*fd.FD, p9.QID, uint32, error) { if l.isOpen() { panic(fmt.Sprintf("attempting to open already opened file: %q", l.hostPath)) } + mode := flags & p9.OpenFlagsModeMask + if mode == p9.WriteOnly || mode == p9.ReadWrite || flags&p9.OpenTruncate != 0 { + if err := l.checkROMount(); err != nil { + return nil, p9.QID{}, 0, err + } + } // Check if control file can be used or if a new open must be created. var newFile *fd.FD - if flags == p9.ReadOnly && l.controlReadable { + if mode == p9.ReadOnly && l.controlReadable && flags.OSFlags()&allowedOpenFlags == 0 { log.Debugf("Open reusing control file, flags: %v, %q", flags, l.hostPath) newFile = l.file } else { @@ -369,15 +385,15 @@ func (l *localFile) Open(flags p9.OpenFlags) (*fd.FD, p9.QID, uint32, error) { // name_to_handle_at and open_by_handle_at aren't supported by overlay2. log.Debugf("Open reopening file, flags: %v, %q", flags, l.hostPath) var err error - // Constrain open flags to the open mode and O_TRUNC. - newFile, err = reopenProcFd(l.file, openFlags|(flags.OSFlags()&(syscall.O_ACCMODE|syscall.O_TRUNC))) + osFlags := flags.OSFlags() & (unix.O_ACCMODE | allowedOpenFlags) + newFile, err = reopenProcFd(l.file, openFlags|osFlags) if err != nil { return nil, p9.QID{}, 0, extractErrno(err) } } var fd *fd.FD - if l.fileType == syscall.S_IFREG { + if l.fileType == unix.S_IFREG { // Donate FD for regular files only. fd = newFDMaybe(newFile) } @@ -389,38 +405,38 @@ func (l *localFile) Open(flags p9.OpenFlags) (*fd.FD, p9.QID, uint32, error) { } l.file = newFile } - l.mode = flags & p9.OpenFlagsModeMask + l.mode = mode return fd, l.qid, 0, nil } // Create implements p9.File. -func (l *localFile) Create(name string, mode p9.OpenFlags, perm p9.FileMode, uid p9.UID, gid p9.GID) (*fd.FD, p9.File, p9.QID, uint32, error) { - conf := l.attachPoint.conf - if conf.ROMount { - if conf.PanicOnWrite { - panic("attempt to write to RO mount") - } - return nil, nil, p9.QID{}, 0, syscall.EBADF +func (l *localFile) Create(name string, p9Flags p9.OpenFlags, perm p9.FileMode, uid p9.UID, gid p9.GID) (*fd.FD, p9.File, p9.QID, uint32, error) { + if err := l.checkROMount(); err != nil { + return nil, nil, p9.QID{}, 0, err } + // Set file creation flags, plus allowed open flags from caller. + osFlags := openFlags | unix.O_CREAT | unix.O_EXCL + osFlags |= p9Flags.OSFlags() & allowedOpenFlags + // 'file' may be used for other operations (e.g. Walk), so read access is // always added to flags. Note that resulting file might have a wider mode // than needed for each particular case. - flags := openFlags | syscall.O_CREAT | syscall.O_EXCL + mode := p9Flags & p9.OpenFlagsModeMask if mode == p9.WriteOnly { - flags |= syscall.O_RDWR + osFlags |= unix.O_RDWR } else { - flags |= mode.OSFlags() + osFlags |= mode.OSFlags() } - child, err := fd.OpenAt(l.file, name, flags, uint32(perm.Permissions())) + child, err := fd.OpenAt(l.file, name, osFlags, uint32(perm.Permissions())) if err != nil { return nil, nil, p9.QID{}, 0, extractErrno(err) } cu := cleanup.Make(func() { _ = child.Close() // Best effort attempt to remove the file in case of failure. - if err := syscall.Unlinkat(l.file.FD(), name); err != nil { + if err := unix.Unlinkat(l.file.FD(), name, 0); err != nil { log.Warningf("error unlinking file %q after failure: %v", path.Join(l.hostPath, name), err) } }) @@ -439,7 +455,7 @@ func (l *localFile) Create(name string, mode p9.OpenFlags, perm p9.FileMode, uid hostPath: path.Join(l.hostPath, name), file: child, mode: mode, - fileType: syscall.S_IFREG, + fileType: unix.S_IFREG, qid: l.attachPoint.makeQID(stat), } @@ -449,15 +465,11 @@ func (l *localFile) Create(name string, mode p9.OpenFlags, perm p9.FileMode, uid // Mkdir implements p9.File. func (l *localFile) Mkdir(name string, perm p9.FileMode, uid p9.UID, gid p9.GID) (p9.QID, error) { - conf := l.attachPoint.conf - if conf.ROMount { - if conf.PanicOnWrite { - panic("attempt to write to RO mount") - } - return p9.QID{}, syscall.EBADF + if err := l.checkROMount(); err != nil { + return p9.QID{}, err } - if err := syscall.Mkdirat(l.file.FD(), name, uint32(perm.Permissions())); err != nil { + if err := unix.Mkdirat(l.file.FD(), name, uint32(perm.Permissions())); err != nil { return p9.QID{}, extractErrno(err) } cu := cleanup.Make(func() { @@ -469,7 +481,7 @@ func (l *localFile) Mkdir(name string, perm p9.FileMode, uid p9.UID, gid p9.GID) defer cu.Clean() // Open directory to change ownership and stat it. - flags := syscall.O_DIRECTORY | syscall.O_RDONLY | openFlags + flags := unix.O_DIRECTORY | unix.O_RDONLY | openFlags f, err := fd.OpenAt(l.file, name, flags, 0) if err != nil { return p9.QID{}, extractErrno(err) @@ -504,20 +516,20 @@ func (l *localFile) WalkGetAttr(names []string) ([]p9.QID, p9.File, p9.AttrMask, return qids, file, mask, attr, nil } -func (l *localFile) walk(names []string) ([]p9.QID, p9.File, syscall.Stat_t, error) { +func (l *localFile) walk(names []string) ([]p9.QID, p9.File, unix.Stat_t, error) { // Duplicate current file if 'names' is empty. if len(names) == 0 { newFile, readable, err := openAnyFile(l.hostPath, func(mode int) (*fd.FD, error) { return reopenProcFd(l.file, openFlags|mode) }) if err != nil { - return nil, nil, syscall.Stat_t{}, extractErrno(err) + return nil, nil, unix.Stat_t{}, extractErrno(err) } stat, err := fstat(newFile.FD()) if err != nil { _ = newFile.Close() - return nil, nil, syscall.Stat_t{}, extractErrno(err) + return nil, nil, unix.Stat_t{}, extractErrno(err) } c := &localFile{ @@ -533,7 +545,7 @@ func (l *localFile) walk(names []string) ([]p9.QID, p9.File, syscall.Stat_t, err } var qids []p9.QID - var lastStat syscall.Stat_t + var lastStat unix.Stat_t last := l for _, name := range names { f, path, readable, err := openAnyFileFromParent(last, name) @@ -541,17 +553,17 @@ func (l *localFile) walk(names []string) ([]p9.QID, p9.File, syscall.Stat_t, err _ = last.Close() } if err != nil { - return nil, nil, syscall.Stat_t{}, extractErrno(err) + return nil, nil, unix.Stat_t{}, extractErrno(err) } lastStat, err = fstat(f.FD()) if err != nil { _ = f.Close() - return nil, nil, syscall.Stat_t{}, extractErrno(err) + return nil, nil, unix.Stat_t{}, extractErrno(err) } c, err := newLocalFile(last.attachPoint, f, path, readable, lastStat) if err != nil { _ = f.Close() - return nil, nil, syscall.Stat_t{}, extractErrno(err) + return nil, nil, unix.Stat_t{}, extractErrno(err) } qids = append(qids, c.qid) @@ -562,8 +574,8 @@ func (l *localFile) walk(names []string) ([]p9.QID, p9.File, syscall.Stat_t, err // StatFS implements p9.File. func (l *localFile) StatFS() (p9.FSStat, error) { - var s syscall.Statfs_t - if err := syscall.Fstatfs(l.file.FD(), &s); err != nil { + var s unix.Statfs_t + if err := unix.Fstatfs(l.file.FD(), &s); err != nil { return p9.FSStat{}, extractErrno(err) } @@ -583,9 +595,9 @@ func (l *localFile) StatFS() (p9.FSStat, error) { // FSync implements p9.File. func (l *localFile) FSync() error { if !l.isOpen() { - return syscall.EBADF + return unix.EBADF } - if err := syscall.Fsync(l.file.FD()); err != nil { + if err := unix.Fsync(l.file.FD()); err != nil { return extractErrno(err) } return nil @@ -601,7 +613,7 @@ func (l *localFile) GetAttr(_ p9.AttrMask) (p9.QID, p9.AttrMask, p9.Attr, error) return l.qid, mask, attr, nil } -func (l *localFile) fillAttr(stat syscall.Stat_t) (p9.AttrMask, p9.Attr) { +func (l *localFile) fillAttr(stat unix.Stat_t) (p9.AttrMask, p9.Attr) { attr := p9.Attr{ Mode: p9.FileMode(stat.Mode), UID: p9.UID(stat.Uid), @@ -637,12 +649,8 @@ func (l *localFile) fillAttr(stat syscall.Stat_t) (p9.AttrMask, p9.Attr) { // cannot be changed atomically and user may see partial changes when // an error happens. func (l *localFile) SetAttr(valid p9.SetAttrMask, attr p9.SetAttr) error { - conf := l.attachPoint.conf - if conf.ROMount { - if conf.PanicOnWrite { - panic("attempt to write to RO mount") - } - return syscall.EBADF + if err := l.checkROMount(); err != nil { + return err } allowed := p9.SetAttrMask{ @@ -665,13 +673,13 @@ func (l *localFile) SetAttr(valid p9.SetAttrMask, attr p9.SetAttr) error { // consistent result that is not attribute dependent. if !valid.IsSubsetOf(allowed) { log.Warningf("SetAttr() failed for %q, mask: %v", l.hostPath, valid) - return syscall.EPERM + return unix.EPERM } // Check if it's possible to use cached file, or if another one needs to be // opened for write. f := l.file - if l.fileType == syscall.S_IFREG && l.mode != p9.WriteOnly && l.mode != p9.ReadWrite { + if l.fileType == unix.S_IFREG && l.mode != p9.WriteOnly && l.mode != p9.ReadWrite { var err error f, err = reopenProcFd(l.file, openFlags|os.O_WRONLY) if err != nil { @@ -692,21 +700,21 @@ func (l *localFile) SetAttr(valid p9.SetAttrMask, attr p9.SetAttr) error { // over another. var err error if valid.Permissions { - if cerr := syscall.Fchmod(f.FD(), uint32(attr.Permissions)); cerr != nil { + if cerr := unix.Fchmod(f.FD(), uint32(attr.Permissions)); cerr != nil { log.Debugf("SetAttr fchmod failed %q, err: %v", l.hostPath, cerr) err = extractErrno(cerr) } } if valid.Size { - if terr := syscall.Ftruncate(f.FD(), int64(attr.Size)); terr != nil { + if terr := unix.Ftruncate(f.FD(), int64(attr.Size)); terr != nil { log.Debugf("SetAttr ftruncate failed %q, err: %v", l.hostPath, terr) err = extractErrno(terr) } } if valid.ATime || valid.MTime { - utimes := [2]syscall.Timespec{ + utimes := [2]unix.Timespec{ {Sec: 0, Nsec: linux.UTIME_OMIT}, {Sec: 0, Nsec: linux.UTIME_OMIT}, } @@ -727,15 +735,15 @@ func (l *localFile) SetAttr(valid p9.SetAttrMask, attr p9.SetAttr) error { } } - if l.fileType == syscall.S_IFLNK { + if l.fileType == unix.S_IFLNK { // utimensat operates different that other syscalls. To operate on a // symlink it *requires* AT_SYMLINK_NOFOLLOW with dirFD and a non-empty // name. - parent, err := syscall.Open(path.Dir(l.hostPath), openFlags|unix.O_PATH, 0) + parent, err := unix.Open(path.Dir(l.hostPath), openFlags|unix.O_PATH, 0) if err != nil { return extractErrno(err) } - defer syscall.Close(parent) + defer unix.Close(parent) if terr := utimensat(parent, path.Base(l.hostPath), utimes, linux.AT_SYMLINK_NOFOLLOW); terr != nil { log.Debugf("SetAttr utimens failed %q, err: %v", l.hostPath, terr) @@ -760,7 +768,7 @@ func (l *localFile) SetAttr(valid p9.SetAttrMask, attr p9.SetAttr) error { if valid.GID { gid = int(attr.GID) } - if oerr := syscall.Fchownat(f.FD(), "", uid, gid, linux.AT_EMPTY_PATH|linux.AT_SYMLINK_NOFOLLOW); oerr != nil { + if oerr := unix.Fchownat(f.FD(), "", uid, gid, linux.AT_EMPTY_PATH|linux.AT_SYMLINK_NOFOLLOW); oerr != nil { log.Debugf("SetAttr fchownat failed %q, err: %v", l.hostPath, oerr) err = extractErrno(oerr) } @@ -770,28 +778,28 @@ func (l *localFile) SetAttr(valid p9.SetAttrMask, attr p9.SetAttr) error { } func (*localFile) GetXattr(string, uint64) (string, error) { - return "", syscall.EOPNOTSUPP + return "", unix.EOPNOTSUPP } func (*localFile) SetXattr(string, string, uint32) error { - return syscall.EOPNOTSUPP + return unix.EOPNOTSUPP } func (*localFile) ListXattr(uint64) (map[string]struct{}, error) { - return nil, syscall.EOPNOTSUPP + return nil, unix.EOPNOTSUPP } func (*localFile) RemoveXattr(string) error { - return syscall.EOPNOTSUPP + return unix.EOPNOTSUPP } // Allocate implements p9.File. func (l *localFile) Allocate(mode p9.AllocateMode, offset, length uint64) error { if !l.isOpen() { - return syscall.EBADF + return unix.EBADF } - if err := syscall.Fallocate(l.file.FD(), mode.ToLinux(), int64(offset), int64(length)); err != nil { + if err := unix.Fallocate(l.file.FD(), mode.ToLinux(), int64(offset), int64(length)); err != nil { return extractErrno(err) } return nil @@ -804,12 +812,8 @@ func (*localFile) Rename(p9.File, string) error { // RenameAt implements p9.File.RenameAt. func (l *localFile) RenameAt(oldName string, directory p9.File, newName string) error { - conf := l.attachPoint.conf - if conf.ROMount { - if conf.PanicOnWrite { - panic("attempt to write to RO mount") - } - return syscall.EBADF + if err := l.checkROMount(); err != nil { + return err } newParent := directory.(*localFile) @@ -822,10 +826,10 @@ func (l *localFile) RenameAt(oldName string, directory p9.File, newName string) // ReadAt implements p9.File. func (l *localFile) ReadAt(p []byte, offset uint64) (int, error) { if l.mode != p9.ReadOnly && l.mode != p9.ReadWrite { - return 0, syscall.EBADF + return 0, unix.EBADF } if !l.isOpen() { - return 0, syscall.EBADF + return 0, unix.EBADF } r, err := l.file.ReadAt(p, int64(offset)) @@ -840,10 +844,10 @@ func (l *localFile) ReadAt(p []byte, offset uint64) (int, error) { // WriteAt implements p9.File. func (l *localFile) WriteAt(p []byte, offset uint64) (int, error) { if l.mode != p9.WriteOnly && l.mode != p9.ReadWrite { - return 0, syscall.EBADF + return 0, unix.EBADF } if !l.isOpen() { - return 0, syscall.EBADF + return 0, unix.EBADF } w, err := l.file.WriteAt(p, int64(offset)) @@ -855,12 +859,8 @@ func (l *localFile) WriteAt(p []byte, offset uint64) (int, error) { // Symlink implements p9.File. func (l *localFile) Symlink(target, newName string, uid p9.UID, gid p9.GID) (p9.QID, error) { - conf := l.attachPoint.conf - if conf.ROMount { - if conf.PanicOnWrite { - panic("attempt to write to RO mount") - } - return p9.QID{}, syscall.EBADF + if err := l.checkROMount(); err != nil { + return p9.QID{}, err } if err := unix.Symlinkat(target, l.file.FD(), newName); err != nil { @@ -868,7 +868,7 @@ func (l *localFile) Symlink(target, newName string, uid p9.UID, gid p9.GID) (p9. } cu := cleanup.Make(func() { // Best effort attempt to remove the symlink in case of failure. - if err := syscall.Unlinkat(l.file.FD(), newName); err != nil { + if err := unix.Unlinkat(l.file.FD(), newName, 0); err != nil { log.Warningf("error unlinking file %q after failure: %v", path.Join(l.hostPath, newName), err) } }) @@ -895,12 +895,8 @@ func (l *localFile) Symlink(target, newName string, uid p9.UID, gid p9.GID) (p9. // Link implements p9.File. func (l *localFile) Link(target p9.File, newName string) error { - conf := l.attachPoint.conf - if conf.ROMount { - if conf.PanicOnWrite { - panic("attempt to write to RO mount") - } - return syscall.EBADF + if err := l.checkROMount(); err != nil { + return err } targetFile := target.(*localFile) @@ -911,49 +907,53 @@ func (l *localFile) Link(target p9.File, newName string) error { } // Mknod implements p9.File. -func (l *localFile) Mknod(name string, mode p9.FileMode, _ uint32, _ uint32, _ p9.UID, _ p9.GID) (p9.QID, error) { - conf := l.attachPoint.conf - if conf.ROMount { - if conf.PanicOnWrite { - panic("attempt to write to RO mount") - } - return p9.QID{}, syscall.EROFS - } - - hostPath := path.Join(l.hostPath, name) - - // Return EEXIST if the file already exists. - if _, err := stat(hostPath); err == nil { - return p9.QID{}, syscall.EEXIST +func (l *localFile) Mknod(name string, mode p9.FileMode, _ uint32, _ uint32, uid p9.UID, gid p9.GID) (p9.QID, error) { + if err := l.checkROMount(); err != nil { + return p9.QID{}, err } // From mknod(2) man page: // "EPERM: [...] if the filesystem containing pathname does not support // the type of node requested." if mode.FileType() != p9.ModeRegular { - return p9.QID{}, syscall.EPERM + return p9.QID{}, unix.EPERM } // Allow Mknod to create regular files. - if err := syscall.Mknod(hostPath, uint32(mode), 0); err != nil { + if err := unix.Mknodat(l.file.FD(), name, uint32(mode), 0); err != nil { return p9.QID{}, err } + cu := cleanup.Make(func() { + // Best effort attempt to remove the file in case of failure. + if err := unix.Unlinkat(l.file.FD(), name, 0); err != nil { + log.Warningf("error unlinking file %q after failure: %v", path.Join(l.hostPath, name), err) + } + }) + defer cu.Clean() - stat, err := stat(hostPath) + // Open file to change ownership and stat it. + child, err := fd.OpenAt(l.file, name, unix.O_PATH|openFlags, 0) if err != nil { return p9.QID{}, extractErrno(err) } + defer child.Close() + + if err := fchown(child.FD(), uid, gid); err != nil { + return p9.QID{}, extractErrno(err) + } + stat, err := fstat(child.FD()) + if err != nil { + return p9.QID{}, extractErrno(err) + } + + cu.Release() return l.attachPoint.makeQID(stat), nil } // UnlinkAt implements p9.File. func (l *localFile) UnlinkAt(name string, flags uint32) error { - conf := l.attachPoint.conf - if conf.ROMount { - if conf.PanicOnWrite { - panic("attempt to write to RO mount") - } - return syscall.EBADF + if err := l.checkROMount(); err != nil { + return err } if err := unix.Unlinkat(l.file.FD(), name, int(flags)); err != nil { @@ -965,10 +965,10 @@ func (l *localFile) UnlinkAt(name string, flags uint32) error { // Readdir implements p9.File. func (l *localFile) Readdir(offset uint64, count uint32) ([]p9.Dirent, error) { if l.mode != p9.ReadOnly && l.mode != p9.ReadWrite { - return nil, syscall.EBADF + return nil, unix.EBADF } if !l.isOpen() { - return nil, syscall.EBADF + return nil, unix.EBADF } // Readdirnames is a cursor over directories, so seek back to 0 to ensure it's @@ -985,7 +985,7 @@ func (l *localFile) Readdir(offset uint64, count uint32) ([]p9.Dirent, error) { // which causes the directory stream to resynchronize with the directory's // current contents). if l.lastDirentOffset != offset || offset == 0 { - if _, err := syscall.Seek(l.file.FD(), 0, 0); err != nil { + if _, err := unix.Seek(l.file.FD(), 0, 0); err != nil { return nil, extractErrno(err) } skip = offset @@ -1018,7 +1018,7 @@ func (l *localFile) readDirent(f int, offset uint64, count uint32, skip uint64) end := offset + uint64(count) for offset < end { - dirSize, err := syscall.ReadDirent(f, direntsBuf) + dirSize, err := unix.ReadDirent(f, direntsBuf) if err != nil { return dirents, err } @@ -1027,7 +1027,7 @@ func (l *localFile) readDirent(f int, offset uint64, count uint32, skip uint64) } names := names[:0] - _, _, names = syscall.ParseDirent(direntsBuf[:dirSize], -1, names) + _, _, names = unix.ParseDirent(direntsBuf[:dirSize], -1, names) // Skip over entries that the caller is not interested in. if skip > 0 { @@ -1072,7 +1072,7 @@ func (l *localFile) Readlink() (string, error) { return string(b[:n]), nil } } - return "", syscall.ENOMEM + return "", unix.ENOMEM } // Flush implements p9.File. @@ -1083,7 +1083,7 @@ func (l *localFile) Flush() error { // Connect implements p9.File. func (l *localFile) Connect(flags p9.ConnectFlags) (*fd.FD, error) { if !l.attachPoint.conf.HostUDS { - return nil, syscall.ECONNREFUSED + return nil, unix.ECONNREFUSED } // TODO(gvisor.dev/issue/1003): Due to different app vs replacement @@ -1091,34 +1091,34 @@ func (l *localFile) Connect(flags p9.ConnectFlags) (*fd.FD, error) { // fit f.path in our sockaddr. We'd need to redirect through a shorter // path in order to actually connect to this socket. if len(l.hostPath) > linux.UnixPathMax { - return nil, syscall.ECONNREFUSED + return nil, unix.ECONNREFUSED } var stype int switch flags { case p9.StreamSocket: - stype = syscall.SOCK_STREAM + stype = unix.SOCK_STREAM case p9.DgramSocket: - stype = syscall.SOCK_DGRAM + stype = unix.SOCK_DGRAM case p9.SeqpacketSocket: - stype = syscall.SOCK_SEQPACKET + stype = unix.SOCK_SEQPACKET default: - return nil, syscall.ENXIO + return nil, unix.ENXIO } - f, err := syscall.Socket(syscall.AF_UNIX, stype, 0) + f, err := unix.Socket(unix.AF_UNIX, stype, 0) if err != nil { return nil, err } - if err := syscall.SetNonblock(f, true); err != nil { - _ = syscall.Close(f) + if err := unix.SetNonblock(f, true); err != nil { + _ = unix.Close(f) return nil, err } - sa := syscall.SockaddrUnix{Name: l.hostPath} - if err := syscall.Connect(f, &sa); err != nil { - _ = syscall.Close(f) + sa := unix.SockaddrUnix{Name: l.hostPath} + if err := unix.Connect(f, &sa); err != nil { + _ = unix.Close(f) return nil, err } @@ -1143,7 +1143,7 @@ func (l *localFile) Renamed(newDir p9.File, newName string) { } // extractErrno tries to determine the errno. -func extractErrno(err error) syscall.Errno { +func extractErrno(err error) unix.Errno { if err == nil { // This should never happen. The likely result will be that // some user gets the frustrating "error: SUCCESS" message. @@ -1153,18 +1153,18 @@ func extractErrno(err error) syscall.Errno { switch err { case os.ErrNotExist: - return syscall.ENOENT + return unix.ENOENT case os.ErrExist: - return syscall.EEXIST + return unix.EEXIST case os.ErrPermission: - return syscall.EACCES + return unix.EACCES case os.ErrInvalid: - return syscall.EINVAL + return unix.EINVAL } // See if it's an errno or a common wrapped error. switch e := err.(type) { - case syscall.Errno: + case unix.Errno: return e case *os.PathError: return extractErrno(e.Err) @@ -1176,5 +1176,12 @@ func extractErrno(err error) syscall.Errno { // Fall back to EIO. log.Debugf("Unknown error: %v, defaulting to EIO", err) - return syscall.EIO + return unix.EIO +} + +func (l *localFile) checkROMount() error { + if conf := l.attachPoint.conf; conf.ROMount { + return unix.EROFS + } + return nil } diff --git a/runsc/fsgofer/fsgofer_amd64_unsafe.go b/runsc/fsgofer/fsgofer_amd64_unsafe.go index 5d4aab597..c46958185 100644 --- a/runsc/fsgofer/fsgofer_amd64_unsafe.go +++ b/runsc/fsgofer/fsgofer_amd64_unsafe.go @@ -17,25 +17,25 @@ package fsgofer import ( - "syscall" "unsafe" + "golang.org/x/sys/unix" "gvisor.dev/gvisor/pkg/abi/linux" "gvisor.dev/gvisor/pkg/syserr" ) -func statAt(dirFd int, name string) (syscall.Stat_t, error) { - nameBytes, err := syscall.BytePtrFromString(name) +func statAt(dirFd int, name string) (unix.Stat_t, error) { + nameBytes, err := unix.BytePtrFromString(name) if err != nil { - return syscall.Stat_t{}, err + return unix.Stat_t{}, err } namePtr := unsafe.Pointer(nameBytes) - var stat syscall.Stat_t + var stat unix.Stat_t statPtr := unsafe.Pointer(&stat) - if _, _, errno := syscall.Syscall6( - syscall.SYS_NEWFSTATAT, + if _, _, errno := unix.Syscall6( + unix.SYS_NEWFSTATAT, uintptr(dirFd), uintptr(namePtr), uintptr(statPtr), @@ -43,7 +43,7 @@ func statAt(dirFd int, name string) (syscall.Stat_t, error) { 0, 0); errno != 0 { - return syscall.Stat_t{}, syserr.FromHost(errno).ToError() + return unix.Stat_t{}, syserr.FromHost(errno).ToError() } return stat, nil } diff --git a/runsc/fsgofer/fsgofer_arm64_unsafe.go b/runsc/fsgofer/fsgofer_arm64_unsafe.go index 8041fd352..491460718 100644 --- a/runsc/fsgofer/fsgofer_arm64_unsafe.go +++ b/runsc/fsgofer/fsgofer_arm64_unsafe.go @@ -17,25 +17,25 @@ package fsgofer import ( - "syscall" "unsafe" + "golang.org/x/sys/unix" "gvisor.dev/gvisor/pkg/abi/linux" "gvisor.dev/gvisor/pkg/syserr" ) -func statAt(dirFd int, name string) (syscall.Stat_t, error) { - nameBytes, err := syscall.BytePtrFromString(name) +func statAt(dirFd int, name string) (unix.Stat_t, error) { + nameBytes, err := unix.BytePtrFromString(name) if err != nil { - return syscall.Stat_t{}, err + return unix.Stat_t{}, err } namePtr := unsafe.Pointer(nameBytes) - var stat syscall.Stat_t + var stat unix.Stat_t statPtr := unsafe.Pointer(&stat) - if _, _, errno := syscall.Syscall6( - syscall.SYS_FSTATAT, + if _, _, errno := unix.Syscall6( + unix.SYS_FSTATAT, uintptr(dirFd), uintptr(namePtr), uintptr(statPtr), @@ -43,7 +43,7 @@ func statAt(dirFd int, name string) (syscall.Stat_t, error) { 0, 0); errno != 0 { - return syscall.Stat_t{}, syserr.FromHost(errno).ToError() + return unix.Stat_t{}, syserr.FromHost(errno).ToError() } return stat, nil } diff --git a/runsc/fsgofer/fsgofer_test.go b/runsc/fsgofer/fsgofer_test.go index 94f167417..a84206686 100644 --- a/runsc/fsgofer/fsgofer_test.go +++ b/runsc/fsgofer/fsgofer_test.go @@ -21,9 +21,9 @@ import ( "os" "path" "path/filepath" - "syscall" "testing" + "golang.org/x/sys/unix" "gvisor.dev/gvisor/pkg/log" "gvisor.dev/gvisor/pkg/p9" "gvisor.dev/gvisor/pkg/test/testutil" @@ -32,7 +32,7 @@ import ( var allOpenFlags = []p9.OpenFlags{p9.ReadOnly, p9.WriteOnly, p9.ReadWrite} var ( - allTypes = []uint32{syscall.S_IFREG, syscall.S_IFDIR, syscall.S_IFLNK} + allTypes = []uint32{unix.S_IFREG, unix.S_IFDIR, unix.S_IFLNK} // allConfs is set in init(). allConfs []Config @@ -52,8 +52,8 @@ func init() { } } -func configTestName(config *Config) string { - if config.ROMount { +func configTestName(conf *Config) string { + if conf.ROMount { return "ROMount" } return "RWMount" @@ -83,7 +83,7 @@ func testReadWrite(f p9.File, flags p9.OpenFlags, content []byte) error { } want = append(want, b...) } else { - if e, ok := err.(syscall.Errno); !ok || e != syscall.EBADF { + if e, ok := err.(unix.Errno); !ok || e != unix.EBADF { return fmt.Errorf("WriteAt() should have failed, got: %d, want: EBADFD", err) } } @@ -101,7 +101,7 @@ func testReadWrite(f p9.File, flags p9.OpenFlags, content []byte) error { return fmt.Errorf("ReadAt() wrong data, got: %s, want: %s", string(rBuf), want) } } else { - if e, ok := err.(syscall.Errno); !ok || e != syscall.EBADF { + if e, ok := err.(unix.Errno); !ok || e != unix.EBADF { return fmt.Errorf("ReadAt() should have failed, got: %d, want: EBADFD", err) } } @@ -121,11 +121,11 @@ func (s state) String() string { func typeName(fileType uint32) string { switch fileType { - case syscall.S_IFREG: + case unix.S_IFREG: return "file" - case syscall.S_IFDIR: + case unix.S_IFDIR: return "directory" - case syscall.S_IFLNK: + case unix.S_IFLNK: return "symlink" default: panic(fmt.Sprintf("invalid file type for test: %d", fileType)) @@ -195,19 +195,19 @@ func setup(fileType uint32) (string, string, error) { var name string switch fileType { - case syscall.S_IFREG: + case unix.S_IFREG: name = "file" _, f, _, _, err := root.Create(name, p9.ReadWrite, 0777, p9.UID(os.Getuid()), p9.GID(os.Getgid())) if err != nil { return "", "", fmt.Errorf("createFile(root, %q) failed, err: %v", "test", err) } defer f.Close() - case syscall.S_IFDIR: + case unix.S_IFDIR: name = "dir" if _, err := root.Mkdir(name, 0777, p9.UID(os.Getuid()), p9.GID(os.Getgid())); err != nil { return "", "", fmt.Errorf("root.MkDir(%q) failed, err: %v", name, err) } - case syscall.S_IFLNK: + case unix.S_IFLNK: name = "symlink" if _, err := root.Symlink("/some/target", name, p9.UID(os.Getuid()), p9.GID(os.Getgid())); err != nil { return "", "", fmt.Errorf("root.Symlink(%q) failed, err: %v", name, err) @@ -227,7 +227,7 @@ func createFile(dir *localFile, name string) (*localFile, error) { } func TestReadWrite(t *testing.T) { - runCustom(t, []uint32{syscall.S_IFDIR}, rwConfs, func(t *testing.T, s state) { + runCustom(t, []uint32{unix.S_IFDIR}, rwConfs, func(t *testing.T, s state) { child, err := createFile(s.file, "test") if err != nil { t.Fatalf("%v: createFile() failed, err: %v", s, err) @@ -246,9 +246,13 @@ func TestReadWrite(t *testing.T) { if err != nil { t.Fatalf("%v: Walk(%s) failed, err: %v", s, "test", err) } - if _, _, _, err := l.Open(flags); err != nil { + fd, _, _, err := l.Open(flags) + if err != nil { t.Fatalf("%v: Open(%v) failed, err: %v", s, flags, err) } + if fd != nil { + defer fd.Close() + } if err := testReadWrite(l, flags, want); err != nil { t.Fatalf("%v: testReadWrite(%v) failed: %v", s, flags, err) } @@ -257,14 +261,14 @@ func TestReadWrite(t *testing.T) { } func TestCreate(t *testing.T) { - runCustom(t, []uint32{syscall.S_IFDIR}, rwConfs, func(t *testing.T, s state) { + runCustom(t, []uint32{unix.S_IFDIR}, rwConfs, func(t *testing.T, s state) { for i, flags := range allOpenFlags { _, l, _, _, err := s.file.Create(fmt.Sprintf("test-%d", i), flags, 0777, p9.UID(os.Getuid()), p9.GID(os.Getgid())) if err != nil { t.Fatalf("%v, %v: WriteAt() failed, err: %v", s, flags, err) } - if err := testReadWrite(l, flags, []byte{}); err != nil { + if err := testReadWrite(l, flags, nil); err != nil { t.Fatalf("%v: testReadWrite(%v) failed: %v", s, flags, err) } } @@ -274,7 +278,7 @@ func TestCreate(t *testing.T) { // TestReadWriteDup tests that a file opened in any mode can be dup'ed and // reopened in any other mode. func TestReadWriteDup(t *testing.T) { - runCustom(t, []uint32{syscall.S_IFDIR}, rwConfs, func(t *testing.T, s state) { + runCustom(t, []uint32{unix.S_IFDIR}, rwConfs, func(t *testing.T, s state) { child, err := createFile(s.file, "test") if err != nil { t.Fatalf("%v: createFile() failed, err: %v", s, err) @@ -304,9 +308,13 @@ func TestReadWriteDup(t *testing.T) { t.Fatalf("%v: Walk(<empty>) failed: %v", s, err) } defer dup.Close() - if _, _, _, err := dup.Open(dupFlags); err != nil { + fd, _, _, err := dup.Open(dupFlags) + if err != nil { t.Fatalf("%v: Open(%v) failed: %v", s, flags, err) } + if fd != nil { + defer fd.Close() + } if err := testReadWrite(dup, dupFlags, want); err != nil { t.Fatalf("%v: testReadWrite(%v) failed: %v", s, dupFlags, err) } @@ -316,19 +324,19 @@ func TestReadWriteDup(t *testing.T) { } func TestUnopened(t *testing.T) { - runCustom(t, []uint32{syscall.S_IFREG}, allConfs, func(t *testing.T, s state) { + runCustom(t, []uint32{unix.S_IFREG}, allConfs, func(t *testing.T, s state) { b := []byte("foobar") - if _, err := s.file.WriteAt(b, 0); err != syscall.EBADF { - t.Errorf("%v: WriteAt() should have failed, got: %v, expected: syscall.EBADF", s, err) + if _, err := s.file.WriteAt(b, 0); err != unix.EBADF { + t.Errorf("%v: WriteAt() should have failed, got: %v, expected: unix.EBADF", s, err) } - if _, err := s.file.ReadAt(b, 0); err != syscall.EBADF { - t.Errorf("%v: ReadAt() should have failed, got: %v, expected: syscall.EBADF", s, err) + if _, err := s.file.ReadAt(b, 0); err != unix.EBADF { + t.Errorf("%v: ReadAt() should have failed, got: %v, expected: unix.EBADF", s, err) } - if _, err := s.file.Readdir(0, 100); err != syscall.EBADF { - t.Errorf("%v: Readdir() should have failed, got: %v, expected: syscall.EBADF", s, err) + if _, err := s.file.Readdir(0, 100); err != unix.EBADF { + t.Errorf("%v: Readdir() should have failed, got: %v, expected: unix.EBADF", s, err) } - if err := s.file.FSync(); err != syscall.EBADF { - t.Errorf("%v: FSync() should have failed, got: %v, expected: syscall.EBADF", s, err) + if err := s.file.FSync(); err != unix.EBADF { + t.Errorf("%v: FSync() should have failed, got: %v, expected: unix.EBADF", s, err) } }) } @@ -338,7 +346,7 @@ func TestUnopened(t *testing.T) { // was open with O_PATH, but Open() was not checking for it and allowing the // control file to be reused. func TestOpenOPath(t *testing.T) { - runCustom(t, []uint32{syscall.S_IFREG}, rwConfs, func(t *testing.T, s state) { + runCustom(t, []uint32{unix.S_IFREG}, rwConfs, func(t *testing.T, s state) { // Fist remove all permissions on the file. if err := s.file.SetAttr(p9.SetAttrMask{Permissions: true}, p9.SetAttr{Permissions: p9.FileMode(0)}); err != nil { t.Fatalf("SetAttr(): %v", err) @@ -353,7 +361,7 @@ func TestOpenOPath(t *testing.T) { if newFile.(*localFile).controlReadable { t.Fatalf("control file didn't open with O_PATH: %+v", newFile) } - if _, _, _, err := newFile.Open(p9.ReadOnly); err != syscall.EACCES { + if _, _, _, err := newFile.Open(p9.ReadOnly); err != unix.EACCES { t.Fatalf("Open() should have failed, got: %v, wanted: EACCES", err) } }) @@ -375,7 +383,7 @@ func TestSetAttrPerm(t *testing.T) { valid := p9.SetAttrMask{Permissions: true} attr := p9.SetAttr{Permissions: 0777} got, err := SetGetAttr(s.file, valid, attr) - if s.fileType == syscall.S_IFLNK { + if s.fileType == unix.S_IFLNK { if err == nil { t.Fatalf("%v: SetGetAttr(valid, %v) should have failed", s, attr.Permissions) } @@ -396,7 +404,7 @@ func TestSetAttrSize(t *testing.T) { valid := p9.SetAttrMask{Size: true} attr := p9.SetAttr{Size: size} got, err := SetGetAttr(s.file, valid, attr) - if s.fileType == syscall.S_IFLNK || s.fileType == syscall.S_IFDIR { + if s.fileType == unix.S_IFLNK || s.fileType == unix.S_IFDIR { if err == nil { t.Fatalf("%v: SetGetAttr(valid, %v) should have failed", s, attr.Permissions) } @@ -478,9 +486,9 @@ func TestLink(t *testing.T) { } err = dir.Link(s.file, linkFile) - if s.fileType == syscall.S_IFDIR { - if err != syscall.EPERM { - t.Errorf("%v: Link(target, %s) should have failed, got: %v, expected: syscall.EPERM", s, linkFile, err) + if s.fileType == unix.S_IFDIR { + if err != unix.EPERM { + t.Errorf("%v: Link(target, %s) should have failed, got: %v, expected: unix.EPERM", s, linkFile, err) } return } @@ -491,54 +499,64 @@ func TestLink(t *testing.T) { } func TestROMountChecks(t *testing.T) { + const want = unix.EROFS + uid := p9.UID(os.Getuid()) + gid := p9.GID(os.Getgid()) + runCustom(t, allTypes, roConfs, func(t *testing.T, s state) { - if _, _, _, _, err := s.file.Create("some_file", p9.ReadWrite, 0777, p9.UID(os.Getuid()), p9.GID(os.Getgid())); err != syscall.EBADF { - t.Errorf("%v: Create() should have failed, got: %v, expected: syscall.EBADF", s, err) + if s.fileType != unix.S_IFLNK { + if _, _, _, err := s.file.Open(p9.WriteOnly); err != want { + t.Errorf("Open() should have failed, got: %v, expected: %v", err, want) + } + if _, _, _, err := s.file.Open(p9.ReadWrite); err != want { + t.Errorf("Open() should have failed, got: %v, expected: %v", err, want) + } + if _, _, _, err := s.file.Open(p9.ReadOnly | p9.OpenTruncate); err != want { + t.Errorf("Open() should have failed, got: %v, expected: %v", err, want) + } + f, _, _, err := s.file.Open(p9.ReadOnly) + if err != nil { + t.Errorf("Open() failed: %v", err) + } + if f != nil { + _ = f.Close() + } } - if _, err := s.file.Mkdir("some_dir", 0777, p9.UID(os.Getuid()), p9.GID(os.Getgid())); err != syscall.EBADF { - t.Errorf("%v: MkDir() should have failed, got: %v, expected: syscall.EBADF", s, err) + + if _, _, _, _, err := s.file.Create("some_file", p9.ReadWrite, 0777, uid, gid); err != want { + t.Errorf("Create() should have failed, got: %v, expected: %v", err, want) } - if err := s.file.RenameAt("some_file", s.file, "other_file"); err != syscall.EBADF { - t.Errorf("%v: Rename() should have failed, got: %v, expected: syscall.EBADF", s, err) + if _, err := s.file.Mkdir("some_dir", 0777, uid, gid); err != want { + t.Errorf("MkDir() should have failed, got: %v, expected: %v", err, want) } - if _, err := s.file.Symlink("some_place", "some_symlink", p9.UID(os.Getuid()), p9.GID(os.Getgid())); err != syscall.EBADF { - t.Errorf("%v: Symlink() should have failed, got: %v, expected: syscall.EBADF", s, err) + if err := s.file.RenameAt("some_file", s.file, "other_file"); err != want { + t.Errorf("Rename() should have failed, got: %v, expected: %v", err, want) } - if err := s.file.UnlinkAt("some_file", 0); err != syscall.EBADF { - t.Errorf("%v: UnlinkAt() should have failed, got: %v, expected: syscall.EBADF", s, err) + if _, err := s.file.Symlink("some_place", "some_symlink", uid, gid); err != want { + t.Errorf("Symlink() should have failed, got: %v, expected: %v", err, want) } - if err := s.file.Link(s.file, "some_link"); err != syscall.EBADF { - t.Errorf("%v: Link() should have failed, got: %v, expected: syscall.EBADF", s, err) + if err := s.file.UnlinkAt("some_file", 0); err != want { + t.Errorf("UnlinkAt() should have failed, got: %v, expected: %v", err, want) } - - valid := p9.SetAttrMask{Size: true} - attr := p9.SetAttr{Size: 0} - if err := s.file.SetAttr(valid, attr); err != syscall.EBADF { - t.Errorf("%v: SetAttr() should have failed, got: %v, expected: syscall.EBADF", s, err) + if err := s.file.Link(s.file, "some_link"); err != want { + t.Errorf("Link() should have failed, got: %v, expected: %v", err, want) + } + if _, err := s.file.Mknod("some-nod", 0777, 1, 2, uid, gid); err != want { + t.Errorf("Mknod() should have failed, got: %v, expected: %v", err, want) } - }) -} - -func TestROMountPanics(t *testing.T) { - conf := Config{ROMount: true, PanicOnWrite: true} - runCustom(t, allTypes, []Config{conf}, func(t *testing.T, s state) { - assertPanic(t, func() { s.file.Create("some_file", p9.ReadWrite, 0777, p9.UID(os.Getuid()), p9.GID(os.Getgid())) }) - assertPanic(t, func() { s.file.Mkdir("some_dir", 0777, p9.UID(os.Getuid()), p9.GID(os.Getgid())) }) - assertPanic(t, func() { s.file.RenameAt("some_file", s.file, "other_file") }) - assertPanic(t, func() { s.file.Symlink("some_place", "some_symlink", p9.UID(os.Getuid()), p9.GID(os.Getgid())) }) - assertPanic(t, func() { s.file.UnlinkAt("some_file", 0) }) - assertPanic(t, func() { s.file.Link(s.file, "some_link") }) valid := p9.SetAttrMask{Size: true} attr := p9.SetAttr{Size: 0} - assertPanic(t, func() { s.file.SetAttr(valid, attr) }) + if err := s.file.SetAttr(valid, attr); err != want { + t.Errorf("SetAttr() should have failed, got: %v, expected: %v", err, want) + } }) } func TestWalkNotFound(t *testing.T) { - runCustom(t, []uint32{syscall.S_IFDIR}, allConfs, func(t *testing.T, s state) { - if _, _, err := s.file.Walk([]string{"nobody-here"}); err != syscall.ENOENT { - t.Errorf("%v: Walk(%q) should have failed, got: %v, expected: syscall.ENOENT", s, "nobody-here", err) + runCustom(t, []uint32{unix.S_IFDIR}, allConfs, func(t *testing.T, s state) { + if _, _, err := s.file.Walk([]string{"nobody-here"}); err != unix.ENOENT { + t.Errorf("%v: Walk(%q) should have failed, got: %v, expected: unix.ENOENT", s, "nobody-here", err) } }) } @@ -557,7 +575,7 @@ func TestWalkDup(t *testing.T) { } func TestReaddir(t *testing.T) { - runCustom(t, []uint32{syscall.S_IFDIR}, rwConfs, func(t *testing.T, s state) { + runCustom(t, []uint32{unix.S_IFDIR}, rwConfs, func(t *testing.T, s state) { name := "dir" if _, err := s.file.Mkdir(name, 0777, p9.UID(os.Getuid()), p9.GID(os.Getgid())); err != nil { t.Fatalf("%v: MkDir(%s) failed, err: %v", s, name, err) @@ -682,7 +700,7 @@ func TestAttachInvalidType(t *testing.T) { defer os.RemoveAll(dir) fifo := filepath.Join(dir, "fifo") - if err := syscall.Mkfifo(fifo, 0755); err != nil { + if err := unix.Mkfifo(fifo, 0755); err != nil { t.Fatalf("Mkfifo(%q): %v", fifo, err) } @@ -741,3 +759,63 @@ func TestDoubleAttachError(t *testing.T) { t.Fatalf("Attach should have failed, got %v want non-nil", err) } } + +func TestTruncate(t *testing.T) { + runCustom(t, []uint32{unix.S_IFDIR}, rwConfs, func(t *testing.T, s state) { + child, err := createFile(s.file, "test") + if err != nil { + t.Fatalf("createFile() failed: %v", err) + } + defer child.Close() + want := []byte("foobar") + w, err := child.WriteAt(want, 0) + if err != nil { + t.Fatalf("Write() failed: %v", err) + } + if w != len(want) { + t.Fatalf("Write() was partial, got: %d, expected: %d", w, len(want)) + } + + _, l, err := s.file.Walk([]string{"test"}) + if err != nil { + t.Fatalf("Walk(%s) failed: %v", "test", err) + } + if _, _, _, err := l.Open(p9.ReadOnly | p9.OpenTruncate); err != nil { + t.Fatalf("Open() failed: %v", err) + } + _, mask, attr, err := l.GetAttr(p9.AttrMask{Size: true}) + if err != nil { + t.Fatalf("GetAttr() failed: %v", err) + } + if !mask.Size { + t.Fatalf("GetAttr() didn't return size: %+v", mask) + } + if attr.Size != 0 { + t.Fatalf("truncate didn't work, want: 0, got: %d", attr.Size) + } + }) +} + +func TestMknod(t *testing.T) { + runCustom(t, []uint32{unix.S_IFDIR}, rwConfs, func(t *testing.T, s state) { + _, err := s.file.Mknod("test", p9.ModeRegular|0777, 1, 2, p9.UID(os.Getuid()), p9.GID(os.Getgid())) + if err != nil { + t.Fatalf("Mknod() failed: %v", err) + } + + _, f, err := s.file.Walk([]string{"test"}) + if err != nil { + t.Fatalf("Walk() failed: %v", err) + } + fd, _, _, err := f.Open(p9.ReadWrite) + if err != nil { + t.Fatalf("Open() failed: %v", err) + } + if fd != nil { + defer fd.Close() + } + if err := testReadWrite(f, p9.ReadWrite, nil); err != nil { + t.Fatalf("testReadWrite() failed: %v", err) + } + }) +} diff --git a/runsc/fsgofer/fsgofer_unsafe.go b/runsc/fsgofer/fsgofer_unsafe.go index 542b54365..f11fea40d 100644 --- a/runsc/fsgofer/fsgofer_unsafe.go +++ b/runsc/fsgofer/fsgofer_unsafe.go @@ -15,18 +15,18 @@ package fsgofer import ( - "syscall" "unsafe" + "golang.org/x/sys/unix" "gvisor.dev/gvisor/pkg/syserr" ) -func utimensat(dirFd int, name string, times [2]syscall.Timespec, flags int) error { +func utimensat(dirFd int, name string, times [2]unix.Timespec, flags int) error { // utimensat(2) doesn't accept empty name, instead name must be nil to make it // operate directly on 'dirFd' unlike other *at syscalls. var namePtr unsafe.Pointer if name != "" { - nameBytes, err := syscall.BytePtrFromString(name) + nameBytes, err := unix.BytePtrFromString(name) if err != nil { return err } @@ -35,8 +35,8 @@ func utimensat(dirFd int, name string, times [2]syscall.Timespec, flags int) err timesPtr := unsafe.Pointer(×[0]) - if _, _, errno := syscall.Syscall6( - syscall.SYS_UTIMENSAT, + if _, _, errno := unix.Syscall6( + unix.SYS_UTIMENSAT, uintptr(dirFd), uintptr(namePtr), uintptr(timesPtr), @@ -52,7 +52,7 @@ func utimensat(dirFd int, name string, times [2]syscall.Timespec, flags int) err func renameat(oldDirFD int, oldName string, newDirFD int, newName string) error { var oldNamePtr unsafe.Pointer if oldName != "" { - nameBytes, err := syscall.BytePtrFromString(oldName) + nameBytes, err := unix.BytePtrFromString(oldName) if err != nil { return err } @@ -60,15 +60,15 @@ func renameat(oldDirFD int, oldName string, newDirFD int, newName string) error } var newNamePtr unsafe.Pointer if newName != "" { - nameBytes, err := syscall.BytePtrFromString(newName) + nameBytes, err := unix.BytePtrFromString(newName) if err != nil { return err } newNamePtr = unsafe.Pointer(nameBytes) } - if _, _, errno := syscall.Syscall6( - syscall.SYS_RENAMEAT, + if _, _, errno := unix.Syscall6( + unix.SYS_RENAMEAT, uintptr(oldDirFD), uintptr(oldNamePtr), uintptr(newDirFD), diff --git a/runsc/main.go b/runsc/main.go index 69cb505fa..ed244c4ba 100644 --- a/runsc/main.go +++ b/runsc/main.go @@ -23,8 +23,6 @@ import ( "io/ioutil" "os" "os/signal" - "path/filepath" - "strings" "syscall" "time" @@ -32,8 +30,8 @@ import ( "gvisor.dev/gvisor/pkg/log" "gvisor.dev/gvisor/pkg/refs" "gvisor.dev/gvisor/pkg/sentry/platform" - "gvisor.dev/gvisor/runsc/boot" "gvisor.dev/gvisor/runsc/cmd" + "gvisor.dev/gvisor/runsc/config" "gvisor.dev/gvisor/runsc/flag" "gvisor.dev/gvisor/runsc/specutils" ) @@ -41,58 +39,17 @@ import ( var ( // Although these flags are not part of the OCI spec, they are used by // Docker, and thus should not be changed. - rootDir = flag.String("root", "", "root directory for storage of container state.") - logFilename = flag.String("log", "", "file path where internal debug information is written, default is stdout.") - logFormat = flag.String("log-format", "text", "log format: text (default), json, or json-k8s.") - debug = flag.Bool("debug", false, "enable debug logging.") - showVersion = flag.Bool("version", false, "show version and exit.") // TODO(gvisor.dev/issue/193): support systemd cgroups systemdCgroup = flag.Bool("systemd-cgroup", false, "Use systemd for cgroups. NOT SUPPORTED.") + showVersion = flag.Bool("version", false, "show version and exit.") // These flags are unique to runsc, and are used to configure parts of the // system that are not covered by the runtime spec. // Debugging flags. - debugLog = flag.String("debug-log", "", "additional location for logs. If it ends with '/', log files are created inside the directory with default names. The following variables are available: %TIMESTAMP%, %COMMAND%.") - panicLog = flag.String("panic-log", "", "file path were panic reports and other Go's runtime messages are written.") - logPackets = flag.Bool("log-packets", false, "enable network packet logging.") - logFD = flag.Int("log-fd", -1, "file descriptor to log to. If set, the 'log' flag is ignored.") - debugLogFD = flag.Int("debug-log-fd", -1, "file descriptor to write debug logs to. If set, the 'debug-log-dir' flag is ignored.") - panicLogFD = flag.Int("panic-log-fd", -1, "file descriptor to write Go's runtime messages.") - debugLogFormat = flag.String("debug-log-format", "text", "log format: text (default), json, or json-k8s.") - alsoLogToStderr = flag.Bool("alsologtostderr", false, "send log messages to stderr.") - - // Debugging flags: strace related - strace = flag.Bool("strace", false, "enable strace.") - straceSyscalls = flag.String("strace-syscalls", "", "comma-separated list of syscalls to trace. If --strace is true and this list is empty, then all syscalls will be traced.") - straceLogSize = flag.Uint("strace-log-size", 1024, "default size (in bytes) to log data argument blobs.") - - // Flags that control sandbox runtime behavior. - platformName = flag.String("platform", "ptrace", "specifies which platform to use: ptrace (default), kvm.") - network = flag.String("network", "sandbox", "specifies which network to use: sandbox (default), host, none. Using network inside the sandbox is more secure because it's isolated from the host network.") - hardwareGSO = flag.Bool("gso", true, "enable hardware segmentation offload if it is supported by a network device.") - softwareGSO = flag.Bool("software-gso", true, "enable software segmentation offload when hardware offload can't be enabled.") - txChecksumOffload = flag.Bool("tx-checksum-offload", false, "enable TX checksum offload.") - rxChecksumOffload = flag.Bool("rx-checksum-offload", true, "enable RX checksum offload.") - qDisc = flag.String("qdisc", "fifo", "specifies which queueing discipline to apply by default to the non loopback nics used by the sandbox.") - fileAccess = flag.String("file-access", "exclusive", "specifies which filesystem to use for the root mount: exclusive (default), shared. Volume mounts are always shared.") - fsGoferHostUDS = flag.Bool("fsgofer-host-uds", false, "allow the gofer to mount Unix Domain Sockets.") - overlay = flag.Bool("overlay", false, "wrap filesystem mounts with writable overlay. All modifications are stored in memory inside the sandbox.") - overlayfsStaleRead = flag.Bool("overlayfs-stale-read", true, "assume root mount is an overlay filesystem") - watchdogAction = flag.String("watchdog-action", "log", "sets what action the watchdog takes when triggered: log (default), panic.") - panicSignal = flag.Int("panic-signal", -1, "register signal handling that panics. Usually set to SIGUSR2(12) to troubleshoot hangs. -1 disables it.") - profile = flag.Bool("profile", false, "prepares the sandbox to use Golang profiler. Note that enabling profiler loosens the seccomp protection added to the sandbox (DO NOT USE IN PRODUCTION).") - netRaw = flag.Bool("net-raw", false, "enable raw sockets. When false, raw sockets are disabled by removing CAP_NET_RAW from containers (`runsc exec` will still be able to utilize raw sockets). Raw sockets allow malicious containers to craft packets and potentially attack the network.") - numNetworkChannels = flag.Int("num-network-channels", 1, "number of underlying channels(FDs) to use for network link endpoints.") - rootless = flag.Bool("rootless", false, "it allows the sandbox to be started with a user that is not root. Sandbox and Gofer processes may run with same privileges as current user.") - referenceLeakMode = flag.String("ref-leak-mode", "disabled", "sets reference leak check mode: disabled (default), log-names, log-traces.") - cpuNumFromQuota = flag.Bool("cpu-num-from-quota", false, "set cpu number to cpu quota (least integer greater or equal to quota value, but not less than 2)") - vfs2Enabled = flag.Bool("vfs2", false, "TEST ONLY; use while VFSv2 is landing. This uses the new experimental VFS layer.") - fuseEnabled = flag.Bool("fuse", false, "TEST ONLY; use while FUSE in VFSv2 is landing. This allows the use of the new experimental FUSE filesystem.") - - // Test flags, not to be used outside tests, ever. - testOnlyAllowRunAsCurrentUserWithoutChroot = flag.Bool("TESTONLY-unsafe-nonroot", false, "TEST ONLY; do not ever use! This skips many security measures that isolate the host from the sandbox.") - testOnlyTestNameEnv = flag.String("TESTONLY-test-name-env", "", "TEST ONLY; do not ever use! Used for automated tests to improve logging.") + logFD = flag.Int("log-fd", -1, "file descriptor to log to. If set, the 'log' flag is ignored.") + debugLogFD = flag.Int("debug-log-fd", -1, "file descriptor to write debug logs to. If set, the 'debug-log-dir' flag is ignored.") + panicLogFD = flag.Int("panic-log-fd", -1, "file descriptor to write Go's runtime messages.") ) func main() { @@ -136,6 +93,8 @@ func main() { subcommands.Register(new(cmd.Gofer), internalGroup) subcommands.Register(new(cmd.Statefile), internalGroup) + config.RegisterFlags() + // All subcommands must be registered before flag parsing. flag.Parse() @@ -147,6 +106,12 @@ func main() { os.Exit(0) } + // Create a new Config from the flags. + conf, err := config.NewFromFlags() + if err != nil { + cmd.Fatalf(err.Error()) + } + // TODO(gvisor.dev/issue/193): support systemd cgroups if *systemdCgroup { fmt.Fprintln(os.Stderr, "systemd cgroup flag passed, but systemd cgroups not supported. See gvisor.dev/issue/193") @@ -157,103 +122,28 @@ func main() { if *logFD > -1 { errorLogger = os.NewFile(uintptr(*logFD), "error log file") - } else if *logFilename != "" { + } else if conf.LogFilename != "" { // We must set O_APPEND and not O_TRUNC because Docker passes // the same log file for all commands (and also parses these // log files), so we can't destroy them on each command. var err error - errorLogger, err = os.OpenFile(*logFilename, os.O_WRONLY|os.O_CREATE|os.O_APPEND, 0644) + errorLogger, err = os.OpenFile(conf.LogFilename, os.O_WRONLY|os.O_CREATE|os.O_APPEND, 0644) if err != nil { - cmd.Fatalf("error opening log file %q: %v", *logFilename, err) + cmd.Fatalf("error opening log file %q: %v", conf.LogFilename, err) } } cmd.ErrorLogger = errorLogger - platformType := *platformName - if _, err := platform.Lookup(platformType); err != nil { - cmd.Fatalf("%v", err) - } - - fsAccess, err := boot.MakeFileAccessType(*fileAccess) - if err != nil { - cmd.Fatalf("%v", err) - } - - if fsAccess == boot.FileAccessShared && *overlay { - cmd.Fatalf("overlay flag is incompatible with shared file access") - } - - netType, err := boot.MakeNetworkType(*network) - if err != nil { + if _, err := platform.Lookup(conf.Platform); err != nil { cmd.Fatalf("%v", err) } - wa, err := boot.MakeWatchdogAction(*watchdogAction) - if err != nil { - cmd.Fatalf("%v", err) - } - - if *numNetworkChannels <= 0 { - cmd.Fatalf("num_network_channels must be > 0, got: %d", *numNetworkChannels) - } - - refsLeakMode, err := boot.MakeRefsLeakMode(*referenceLeakMode) - if err != nil { - cmd.Fatalf("%v", err) - } - - queueingDiscipline, err := boot.MakeQueueingDiscipline(*qDisc) - if err != nil { - cmd.Fatalf("%s", err) - } - // Sets the reference leak check mode. Also set it in config below to // propagate it to child processes. - refs.SetLeakMode(refsLeakMode) - - // Create a new Config from the flags. - conf := &boot.Config{ - RootDir: *rootDir, - Debug: *debug, - LogFilename: *logFilename, - LogFormat: *logFormat, - DebugLog: *debugLog, - PanicLog: *panicLog, - DebugLogFormat: *debugLogFormat, - FileAccess: fsAccess, - FSGoferHostUDS: *fsGoferHostUDS, - Overlay: *overlay, - Network: netType, - HardwareGSO: *hardwareGSO, - SoftwareGSO: *softwareGSO, - TXChecksumOffload: *txChecksumOffload, - RXChecksumOffload: *rxChecksumOffload, - LogPackets: *logPackets, - Platform: platformType, - Strace: *strace, - StraceLogSize: *straceLogSize, - WatchdogAction: wa, - PanicSignal: *panicSignal, - ProfileEnable: *profile, - EnableRaw: *netRaw, - NumNetworkChannels: *numNetworkChannels, - Rootless: *rootless, - AlsoLogToStderr: *alsoLogToStderr, - ReferenceLeakMode: refsLeakMode, - OverlayfsStaleRead: *overlayfsStaleRead, - CPUNumFromQuota: *cpuNumFromQuota, - VFS2: *vfs2Enabled, - FUSE: *fuseEnabled, - QDisc: queueingDiscipline, - TestOnlyAllowRunAsCurrentUserWithoutChroot: *testOnlyAllowRunAsCurrentUserWithoutChroot, - TestOnlyTestNameEnv: *testOnlyTestNameEnv, - } - if len(*straceSyscalls) != 0 { - conf.StraceSyscalls = strings.Split(*straceSyscalls, ",") - } + refs.SetLeakMode(conf.ReferenceLeak) // Set up logging. - if *debug { + if conf.Debug { log.SetLevel(log.Debug) } @@ -275,14 +165,14 @@ func main() { if *debugLogFD > -1 { f := os.NewFile(uintptr(*debugLogFD), "debug log file") - e = newEmitter(*debugLogFormat, f) + e = newEmitter(conf.DebugLogFormat, f) - } else if *debugLog != "" { - f, err := specutils.DebugLogFile(*debugLog, subcommand, "" /* name */) + } else if conf.DebugLog != "" { + f, err := specutils.DebugLogFile(conf.DebugLog, subcommand, "" /* name */) if err != nil { - cmd.Fatalf("error opening debug log file in %q: %v", *debugLog, err) + cmd.Fatalf("error opening debug log file in %q: %v", conf.DebugLog, err) } - e = newEmitter(*debugLogFormat, f) + e = newEmitter(conf.DebugLogFormat, f) } else { // Stderr is reserved for the application, just discard the logs if no debug @@ -308,8 +198,8 @@ func main() { if err := syscall.Dup3(fd, int(os.Stderr.Fd()), 0); err != nil { cmd.Fatalf("error dup'ing fd %d to stderr: %v", fd, err) } - } else if *alsoLogToStderr { - e = &log.MultiEmitter{e, newEmitter(*debugLogFormat, os.Stderr)} + } else if conf.AlsoLogToStderr { + e = &log.MultiEmitter{e, newEmitter(conf.DebugLogFormat, os.Stderr)} } log.SetTarget(e) @@ -328,7 +218,7 @@ func main() { log.Infof("\t\tVFS2 enabled: %v", conf.VFS2) log.Infof("***************************") - if *testOnlyAllowRunAsCurrentUserWithoutChroot { + if conf.TestOnlyAllowRunAsCurrentUserWithoutChroot { // SIGTERM is sent to all processes if a test exceeds its // timeout and this case is handled by syscall_test_runner. log.Warningf("Block the TERM signal. This is only safe in tests!") @@ -364,11 +254,3 @@ func newEmitter(format string, logFile io.Writer) log.Emitter { cmd.Fatalf("invalid log format %q, must be 'text', 'json', or 'json-k8s'", format) panic("unreachable") } - -func init() { - // Set default root dir to something (hopefully) user-writeable. - *rootDir = "/var/run/runsc" - if runtimeDir := os.Getenv("XDG_RUNTIME_DIR"); runtimeDir != "" { - *rootDir = filepath.Join(runtimeDir, "runsc") - } -} diff --git a/runsc/sandbox/BUILD b/runsc/sandbox/BUILD index 2b9d4549d..f0a551a1e 100644 --- a/runsc/sandbox/BUILD +++ b/runsc/sandbox/BUILD @@ -26,6 +26,7 @@ go_library( "//runsc/boot", "//runsc/boot/platforms", "//runsc/cgroup", + "//runsc/config", "//runsc/console", "//runsc/specutils", "@com_github_cenkalti_backoff//:go_default_library", diff --git a/runsc/sandbox/network.go b/runsc/sandbox/network.go index 817a923ad..0b9f39466 100644 --- a/runsc/sandbox/network.go +++ b/runsc/sandbox/network.go @@ -31,6 +31,7 @@ import ( "gvisor.dev/gvisor/pkg/tcpip/stack" "gvisor.dev/gvisor/pkg/urpc" "gvisor.dev/gvisor/runsc/boot" + "gvisor.dev/gvisor/runsc/config" "gvisor.dev/gvisor/runsc/specutils" ) @@ -49,26 +50,26 @@ import ( // // Run the following container to test it: // docker run -di --runtime=runsc -p 8080:80 -v $PWD:/usr/local/apache2/htdocs/ httpd:2.4 -func setupNetwork(conn *urpc.Client, pid int, spec *specs.Spec, conf *boot.Config) error { +func setupNetwork(conn *urpc.Client, pid int, spec *specs.Spec, conf *config.Config) error { log.Infof("Setting up network") switch conf.Network { - case boot.NetworkNone: + case config.NetworkNone: log.Infof("Network is disabled, create loopback interface only") if err := createDefaultLoopbackInterface(conn); err != nil { return fmt.Errorf("creating default loopback interface: %v", err) } - case boot.NetworkSandbox: + case config.NetworkSandbox: // Build the path to the net namespace of the sandbox process. // This is what we will copy. nsPath := filepath.Join("/proc", strconv.Itoa(pid), "ns/net") if err := createInterfacesAndRoutesFromNS(conn, nsPath, conf.HardwareGSO, conf.SoftwareGSO, conf.TXChecksumOffload, conf.RXChecksumOffload, conf.NumNetworkChannels, conf.QDisc); err != nil { return fmt.Errorf("creating interfaces from net namespace %q: %v", nsPath, err) } - case boot.NetworkHost: + case config.NetworkHost: // Nothing to do here. default: - return fmt.Errorf("invalid network type: %d", conf.Network) + return fmt.Errorf("invalid network type: %v", conf.Network) } return nil } @@ -115,7 +116,7 @@ func isRootNS() (bool, error) { // createInterfacesAndRoutesFromNS scrapes the interface and routes from the // net namespace with the given path, creates them in the sandbox, and removes // them from the host. -func createInterfacesAndRoutesFromNS(conn *urpc.Client, nsPath string, hardwareGSO bool, softwareGSO bool, txChecksumOffload bool, rxChecksumOffload bool, numNetworkChannels int, qDisc boot.QueueingDiscipline) error { +func createInterfacesAndRoutesFromNS(conn *urpc.Client, nsPath string, hardwareGSO bool, softwareGSO bool, txChecksumOffload bool, rxChecksumOffload bool, numNetworkChannels int, qDisc config.QueueingDiscipline) error { // Join the network namespace that we will be copying. restore, err := joinNetNS(nsPath) if err != nil { diff --git a/runsc/sandbox/sandbox.go b/runsc/sandbox/sandbox.go index 36bb0c9c9..a8f4f64a5 100644 --- a/runsc/sandbox/sandbox.go +++ b/runsc/sandbox/sandbox.go @@ -41,6 +41,7 @@ import ( "gvisor.dev/gvisor/runsc/boot" "gvisor.dev/gvisor/runsc/boot/platforms" "gvisor.dev/gvisor/runsc/cgroup" + "gvisor.dev/gvisor/runsc/config" "gvisor.dev/gvisor/runsc/console" "gvisor.dev/gvisor/runsc/specutils" ) @@ -116,7 +117,7 @@ type Args struct { // New creates the sandbox process. The caller must call Destroy() on the // sandbox. -func New(conf *boot.Config, args *Args) (*Sandbox, error) { +func New(conf *config.Config, args *Args) (*Sandbox, error) { s := &Sandbox{ID: args.ID, Cgroup: args.Cgroup} // The Cleanup object cleans up partially created sandboxes when an error // occurs. Any errors occurring during cleanup itself are ignored. @@ -180,7 +181,7 @@ func (s *Sandbox) CreateContainer(cid string) error { } // StartRoot starts running the root container process inside the sandbox. -func (s *Sandbox) StartRoot(spec *specs.Spec, conf *boot.Config) error { +func (s *Sandbox) StartRoot(spec *specs.Spec, conf *config.Config) error { log.Debugf("Start root sandbox %q, PID: %d", s.ID, s.Pid) conn, err := s.sandboxConnect() if err != nil { @@ -203,7 +204,7 @@ func (s *Sandbox) StartRoot(spec *specs.Spec, conf *boot.Config) error { } // StartContainer starts running a non-root container inside the sandbox. -func (s *Sandbox) StartContainer(spec *specs.Spec, conf *boot.Config, cid string, goferFiles []*os.File) error { +func (s *Sandbox) StartContainer(spec *specs.Spec, conf *config.Config, cid string, goferFiles []*os.File) error { for _, f := range goferFiles { defer f.Close() } @@ -232,7 +233,7 @@ func (s *Sandbox) StartContainer(spec *specs.Spec, conf *boot.Config, cid string } // Restore sends the restore call for a container in the sandbox. -func (s *Sandbox) Restore(cid string, spec *specs.Spec, conf *boot.Config, filename string) error { +func (s *Sandbox) Restore(cid string, spec *specs.Spec, conf *config.Config, filename string) error { log.Debugf("Restore sandbox %q", s.ID) rf, err := os.Open(filename) @@ -344,7 +345,7 @@ func (s *Sandbox) connError(err error) error { // createSandboxProcess starts the sandbox as a subprocess by running the "boot" // command, passing in the bundle dir. -func (s *Sandbox) createSandboxProcess(conf *boot.Config, args *Args, startSyncFile *os.File) error { +func (s *Sandbox) createSandboxProcess(conf *config.Config, args *Args, startSyncFile *os.File) error { // nextFD is used to get unused FDs that we can pass to the sandbox. It // starts at 3 because 0, 1, and 2 are taken by stdin/out/err. nextFD := 3 @@ -477,10 +478,10 @@ func (s *Sandbox) createSandboxProcess(conf *boot.Config, args *Args, startSyncF cmd.Stderr = nil // If the console control socket file is provided, then create a new - // pty master/slave pair and set the TTY on the sandbox process. + // pty master/replica pair and set the TTY on the sandbox process. if args.Spec.Process.Terminal && args.ConsoleSocket != "" { // console.NewWithSocket will send the master on the given - // socket, and return the slave. + // socket, and return the replica. tty, err := console.NewWithSocket(args.ConsoleSocket) if err != nil { return fmt.Errorf("setting up console with socket %q: %v", args.ConsoleSocket, err) @@ -555,10 +556,10 @@ func (s *Sandbox) createSandboxProcess(conf *boot.Config, args *Args, startSyncF // Joins the network namespace if network is enabled. the sandbox talks // directly to the host network, which may have been configured in the // namespace. - if ns, ok := specutils.GetNS(specs.NetworkNamespace, args.Spec); ok && conf.Network != boot.NetworkNone { + if ns, ok := specutils.GetNS(specs.NetworkNamespace, args.Spec); ok && conf.Network != config.NetworkNone { log.Infof("Sandbox will be started in the container's network namespace: %+v", ns) nss = append(nss, ns) - } else if conf.Network == boot.NetworkHost { + } else if conf.Network == config.NetworkHost { log.Infof("Sandbox will be started in the host network namespace") } else { log.Infof("Sandbox will be started in new network namespace") @@ -568,7 +569,7 @@ func (s *Sandbox) createSandboxProcess(conf *boot.Config, args *Args, startSyncF // User namespace depends on the network type. Host network requires to run // inside the user namespace specified in the spec or the current namespace // if none is configured. - if conf.Network == boot.NetworkHost { + if conf.Network == config.NetworkHost { if userns, ok := specutils.GetNS(specs.UserNamespace, args.Spec); ok { log.Infof("Sandbox will be started in container's user namespace: %+v", userns) nss = append(nss, userns) @@ -1179,7 +1180,7 @@ func deviceFileForPlatform(name string) (*os.File, error) { // checkBinaryPermissions verifies that the required binary bits are set on // the runsc executable. -func checkBinaryPermissions(conf *boot.Config) error { +func checkBinaryPermissions(conf *config.Config) error { // All platforms need the other exe bit neededBits := os.FileMode(0001) if conf.Platform == platforms.Ptrace { diff --git a/runsc/specutils/BUILD b/runsc/specutils/BUILD index 43851a22f..679d8bc8e 100644 --- a/runsc/specutils/BUILD +++ b/runsc/specutils/BUILD @@ -16,6 +16,7 @@ go_library( "//pkg/bits", "//pkg/log", "//pkg/sentry/kernel/auth", + "//runsc/config", "@com_github_cenkalti_backoff//:go_default_library", "@com_github_mohae_deepcopy//:go_default_library", "@com_github_opencontainers_runtime_spec//specs-go:go_default_library", diff --git a/runsc/specutils/seccomp/BUILD b/runsc/specutils/seccomp/BUILD new file mode 100644 index 000000000..3520f2d6d --- /dev/null +++ b/runsc/specutils/seccomp/BUILD @@ -0,0 +1,34 @@ +load("//tools:defs.bzl", "go_library", "go_test") + +package(licenses = ["notice"]) + +go_library( + name = "seccomp", + srcs = [ + "audit_amd64.go", + "audit_arm64.go", + "seccomp.go", + ], + visibility = ["//:sandbox"], + deps = [ + "//pkg/abi/linux", + "//pkg/bpf", + "//pkg/log", + "//pkg/seccomp", + "//pkg/sentry/kernel", + "//pkg/sentry/syscalls/linux", + "@com_github_opencontainers_runtime_spec//specs-go:go_default_library", + ], +) + +go_test( + name = "seccomp_test", + size = "small", + srcs = ["seccomp_test.go"], + library = ":seccomp", + deps = [ + "//pkg/binary", + "//pkg/bpf", + "@com_github_opencontainers_runtime_spec//specs-go:go_default_library", + ], +) diff --git a/tools/nogo/data/data.go b/runsc/specutils/seccomp/audit_amd64.go index eb84d0d27..417cf4a7a 100644 --- a/tools/nogo/data/data.go +++ b/runsc/specutils/seccomp/audit_amd64.go @@ -1,4 +1,4 @@ -// Copyright 2019 The gVisor Authors. +// Copyright 2020 The gVisor Authors. // // Licensed under the Apache License, Version 2.0 (the "License"); // you may not use this file except in compliance with the License. @@ -12,10 +12,14 @@ // See the License for the specific language governing permissions and // limitations under the License. -// Package data contains shared data for nogo analysis. -// -// This is used to break a dependency cycle. -package data +// +build amd64 + +package seccomp + +import ( + "gvisor.dev/gvisor/pkg/abi/linux" +) -// Objdump is the dumped binary under analysis. -var Objdump string +const ( + nativeArchAuditNo = linux.AUDIT_ARCH_X86_64 +) diff --git a/runsc/specutils/seccomp/audit_arm64.go b/runsc/specutils/seccomp/audit_arm64.go new file mode 100644 index 000000000..b727ceff2 --- /dev/null +++ b/runsc/specutils/seccomp/audit_arm64.go @@ -0,0 +1,25 @@ +// Copyright 2020 The gVisor Authors. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +// +build arm64 + +package seccomp + +import ( + "gvisor.dev/gvisor/pkg/abi/linux" +) + +const ( + nativeArchAuditNo = linux.AUDIT_ARCH_AARCH64 +) diff --git a/runsc/specutils/seccomp/seccomp.go b/runsc/specutils/seccomp/seccomp.go new file mode 100644 index 000000000..5932f7a41 --- /dev/null +++ b/runsc/specutils/seccomp/seccomp.go @@ -0,0 +1,229 @@ +// Copyright 2020 The gVisor Authors. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +// Package seccomp implements some features of libseccomp in order to support +// OCI. +package seccomp + +import ( + "fmt" + "syscall" + + specs "github.com/opencontainers/runtime-spec/specs-go" + "gvisor.dev/gvisor/pkg/abi/linux" + "gvisor.dev/gvisor/pkg/bpf" + "gvisor.dev/gvisor/pkg/log" + "gvisor.dev/gvisor/pkg/seccomp" + "gvisor.dev/gvisor/pkg/sentry/kernel" + slinux "gvisor.dev/gvisor/pkg/sentry/syscalls/linux" +) + +var ( + killThreadAction = linux.SECCOMP_RET_KILL_THREAD + trapAction = linux.SECCOMP_RET_TRAP + // runc always returns EPERM as the errorcode for SECCOMP_RET_ERRNO + errnoAction = linux.SECCOMP_RET_ERRNO.WithReturnCode(uint16(syscall.EPERM)) + // runc always returns EPERM as the errorcode for SECCOMP_RET_TRACE + traceAction = linux.SECCOMP_RET_TRACE.WithReturnCode(uint16(syscall.EPERM)) + allowAction = linux.SECCOMP_RET_ALLOW +) + +// BuildProgram generates a bpf program based on the given OCI seccomp +// config. +func BuildProgram(s *specs.LinuxSeccomp) (bpf.Program, error) { + defaultAction, err := convertAction(s.DefaultAction) + if err != nil { + return bpf.Program{}, fmt.Errorf("secomp default action: %w", err) + } + ruleset, err := convertRules(s) + if err != nil { + return bpf.Program{}, fmt.Errorf("invalid seccomp rules: %w", err) + } + + instrs, err := seccomp.BuildProgram(ruleset, defaultAction, killThreadAction) + if err != nil { + return bpf.Program{}, fmt.Errorf("building seccomp program: %w", err) + } + + program, err := bpf.Compile(instrs) + if err != nil { + return bpf.Program{}, fmt.Errorf("compiling seccomp program: %w", err) + } + + return program, nil +} + +// lookupSyscallNo gets the syscall number for the syscall with the given name +// for the given architecture. +func lookupSyscallNo(arch uint32, name string) (uint32, error) { + var table *kernel.SyscallTable + switch arch { + case linux.AUDIT_ARCH_X86_64: + table = slinux.AMD64 + case linux.AUDIT_ARCH_AARCH64: + table = slinux.ARM64 + } + if table == nil { + return 0, fmt.Errorf("unsupported architecture: %d", arch) + } + n, err := table.LookupNo(name) + if err != nil { + return 0, err + } + return uint32(n), nil +} + +// convertAction converts a LinuxSeccompAction to BPFAction +func convertAction(act specs.LinuxSeccompAction) (linux.BPFAction, error) { + // TODO(gvisor.dev/issue/3124): Update specs package to include ActLog and ActKillProcess. + switch act { + case specs.ActKill: + return killThreadAction, nil + case specs.ActTrap: + return trapAction, nil + case specs.ActErrno: + return errnoAction, nil + case specs.ActTrace: + return traceAction, nil + case specs.ActAllow: + return allowAction, nil + default: + return 0, fmt.Errorf("invalid action: %v", act) + } +} + +// convertRules converts OCI linux seccomp rules into RuleSets that can be used by +// the seccomp package to build a seccomp program. +func convertRules(s *specs.LinuxSeccomp) ([]seccomp.RuleSet, error) { + // NOTE: Architectures are only really relevant when calling 32bit syscalls + // on a 64bit system. Since we don't support that in gVisor anyway, we + // ignore Architectures and only test against the native architecture. + + ruleset := []seccomp.RuleSet{} + + for _, syscall := range s.Syscalls { + sysRules := seccomp.NewSyscallRules() + + action, err := convertAction(syscall.Action) + if err != nil { + return nil, err + } + + // Args + rules, err := convertArgs(syscall.Args) + if err != nil { + return nil, err + } + + for _, name := range syscall.Names { + syscallNo, err := lookupSyscallNo(nativeArchAuditNo, name) + if err != nil { + // If there is an error looking up the syscall number, assume it is + // not supported on this architecture and ignore it. This is, for + // better or worse, what runc does. + log.Warningf("OCI seccomp: ignoring syscall %q", name) + continue + } + + for _, rule := range rules { + sysRules.AddRule(uintptr(syscallNo), rule) + } + } + + ruleset = append(ruleset, seccomp.RuleSet{ + Rules: sysRules, + Action: action, + }) + } + + return ruleset, nil +} + +// convertArgs converts an OCI seccomp argument rule to a list of seccomp.Rule. +func convertArgs(args []specs.LinuxSeccompArg) ([]seccomp.Rule, error) { + argCounts := make([]uint, 6) + + for _, arg := range args { + if arg.Index > 6 { + return nil, fmt.Errorf("invalid index: %d", arg.Index) + } + + argCounts[arg.Index]++ + } + + // NOTE: If multiple rules apply to the same argument (same index) the + // action is triggered if any one of the rules matches (OR). If not, then + // all rules much match in order to trigger the action (AND). This appears to + // be some kind of legacy behavior of runc that nevertheless needs to be + // supported to maintain compatibility. + + hasMultipleArgs := false + for _, count := range argCounts { + if count > 1 { + hasMultipleArgs = true + break + } + } + + if hasMultipleArgs { + rules := []seccomp.Rule{} + + // Old runc behavior - do this for compatibility. + // Add rules as ORs by adding separate Rules. + for _, arg := range args { + rule := seccomp.Rule{nil, nil, nil, nil, nil, nil} + + if err := convertRule(arg, &rule); err != nil { + return nil, err + } + + rules = append(rules, rule) + } + + return rules, nil + } + + // Add rules as ANDs by adding to the same Rule. + rule := seccomp.Rule{nil, nil, nil, nil, nil, nil} + for _, arg := range args { + if err := convertRule(arg, &rule); err != nil { + return nil, err + } + } + + return []seccomp.Rule{rule}, nil +} + +// convertRule converts and adds the arg to a rule. +func convertRule(arg specs.LinuxSeccompArg, rule *seccomp.Rule) error { + switch arg.Op { + case specs.OpEqualTo: + rule[arg.Index] = seccomp.EqualTo(arg.Value) + case specs.OpNotEqual: + rule[arg.Index] = seccomp.NotEqual(arg.Value) + case specs.OpGreaterThan: + rule[arg.Index] = seccomp.GreaterThan(arg.Value) + case specs.OpGreaterEqual: + rule[arg.Index] = seccomp.GreaterThanOrEqual(arg.Value) + case specs.OpLessThan: + rule[arg.Index] = seccomp.LessThan(arg.Value) + case specs.OpLessEqual: + rule[arg.Index] = seccomp.LessThanOrEqual(arg.Value) + case specs.OpMaskedEqual: + rule[arg.Index] = seccomp.MaskedEqual(uintptr(arg.Value), uintptr(arg.ValueTwo)) + default: + return fmt.Errorf("unsupported operand: %q", arg.Op) + } + return nil +} diff --git a/runsc/specutils/seccomp/seccomp_test.go b/runsc/specutils/seccomp/seccomp_test.go new file mode 100644 index 000000000..2079cd2e9 --- /dev/null +++ b/runsc/specutils/seccomp/seccomp_test.go @@ -0,0 +1,414 @@ +// Copyright 2020 The gVisor Authors. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package seccomp + +import ( + "fmt" + "syscall" + "testing" + + specs "github.com/opencontainers/runtime-spec/specs-go" + "gvisor.dev/gvisor/pkg/binary" + "gvisor.dev/gvisor/pkg/bpf" +) + +type seccompData struct { + nr uint32 + arch uint32 + instructionPointer uint64 + args [6]uint64 +} + +// asInput converts a seccompData to a bpf.Input. +func asInput(d seccompData) bpf.Input { + return bpf.InputBytes{binary.Marshal(nil, binary.LittleEndian, d), binary.LittleEndian} +} + +// testInput creates an Input struct with given seccomp input values. +func testInput(arch uint32, syscallName string, args *[6]uint64) bpf.Input { + syscallNo, err := lookupSyscallNo(arch, syscallName) + if err != nil { + // Assume tests set valid syscall names. + panic(err) + } + + if args == nil { + argArray := [6]uint64{0, 0, 0, 0, 0, 0} + args = &argArray + } + + data := seccompData{ + nr: syscallNo, + arch: arch, + args: *args, + } + + return asInput(data) +} + +// testCase holds a seccomp test case. +type testCase struct { + name string + config specs.LinuxSeccomp + input bpf.Input + expected uint32 +} + +var ( + // seccompTests is a list of speccomp test cases. + seccompTests = []testCase{ + { + name: "default_allow", + config: specs.LinuxSeccomp{ + DefaultAction: specs.ActAllow, + }, + input: testInput(nativeArchAuditNo, "read", nil), + expected: uint32(allowAction), + }, + { + name: "default_deny", + config: specs.LinuxSeccomp{ + DefaultAction: specs.ActErrno, + }, + input: testInput(nativeArchAuditNo, "read", nil), + expected: uint32(errnoAction), + }, + { + name: "deny_arch", + config: specs.LinuxSeccomp{ + DefaultAction: specs.ActAllow, + Syscalls: []specs.LinuxSyscall{ + { + Names: []string{ + "getcwd", + }, + Action: specs.ActErrno, + }, + }, + }, + // Syscall matches but the arch is AUDIT_ARCH_X86 so the return + // value is the bad arch action. + input: asInput(seccompData{nr: 183, arch: 0x40000003}), // + expected: uint32(killThreadAction), + }, + { + name: "match_name_errno", + config: specs.LinuxSeccomp{ + DefaultAction: specs.ActAllow, + Syscalls: []specs.LinuxSyscall{ + { + Names: []string{ + "getcwd", + "chmod", + }, + Action: specs.ActErrno, + }, + { + Names: []string{ + "write", + }, + Action: specs.ActTrace, + }, + }, + }, + input: testInput(nativeArchAuditNo, "getcwd", nil), + expected: uint32(errnoAction), + }, + { + name: "match_name_trace", + config: specs.LinuxSeccomp{ + DefaultAction: specs.ActAllow, + Syscalls: []specs.LinuxSyscall{ + { + Names: []string{ + "getcwd", + "chmod", + }, + Action: specs.ActErrno, + }, + { + Names: []string{ + "write", + }, + Action: specs.ActTrace, + }, + }, + }, + input: testInput(nativeArchAuditNo, "write", nil), + expected: uint32(traceAction), + }, + { + name: "no_match_name_allow", + config: specs.LinuxSeccomp{ + DefaultAction: specs.ActAllow, + Syscalls: []specs.LinuxSyscall{ + { + Names: []string{ + "getcwd", + "chmod", + }, + Action: specs.ActErrno, + }, + { + Names: []string{ + "write", + }, + Action: specs.ActTrace, + }, + }, + }, + input: testInput(nativeArchAuditNo, "open", nil), + expected: uint32(allowAction), + }, + { + name: "simple_match_args", + config: specs.LinuxSeccomp{ + DefaultAction: specs.ActAllow, + Syscalls: []specs.LinuxSyscall{ + { + Names: []string{ + "clone", + }, + Args: []specs.LinuxSeccompArg{ + { + Index: 0, + Value: syscall.CLONE_FS, + Op: specs.OpEqualTo, + }, + }, + Action: specs.ActErrno, + }, + }, + }, + input: testInput(nativeArchAuditNo, "clone", &[6]uint64{syscall.CLONE_FS}), + expected: uint32(errnoAction), + }, + { + name: "match_args_or", + config: specs.LinuxSeccomp{ + DefaultAction: specs.ActAllow, + Syscalls: []specs.LinuxSyscall{ + { + Names: []string{ + "clone", + }, + Args: []specs.LinuxSeccompArg{ + { + Index: 0, + Value: syscall.CLONE_FS, + Op: specs.OpEqualTo, + }, + { + Index: 0, + Value: syscall.CLONE_VM, + Op: specs.OpEqualTo, + }, + }, + Action: specs.ActErrno, + }, + }, + }, + input: testInput(nativeArchAuditNo, "clone", &[6]uint64{syscall.CLONE_FS}), + expected: uint32(errnoAction), + }, + { + name: "match_args_and", + config: specs.LinuxSeccomp{ + DefaultAction: specs.ActAllow, + Syscalls: []specs.LinuxSyscall{ + { + Names: []string{ + "getsockopt", + }, + Args: []specs.LinuxSeccompArg{ + { + Index: 1, + Value: syscall.SOL_SOCKET, + Op: specs.OpEqualTo, + }, + { + Index: 2, + Value: syscall.SO_PEERCRED, + Op: specs.OpEqualTo, + }, + }, + Action: specs.ActErrno, + }, + }, + }, + input: testInput(nativeArchAuditNo, "getsockopt", &[6]uint64{0, syscall.SOL_SOCKET, syscall.SO_PEERCRED}), + expected: uint32(errnoAction), + }, + { + name: "no_match_args_and", + config: specs.LinuxSeccomp{ + DefaultAction: specs.ActAllow, + Syscalls: []specs.LinuxSyscall{ + { + Names: []string{ + "getsockopt", + }, + Args: []specs.LinuxSeccompArg{ + { + Index: 1, + Value: syscall.SOL_SOCKET, + Op: specs.OpEqualTo, + }, + { + Index: 2, + Value: syscall.SO_PEERCRED, + Op: specs.OpEqualTo, + }, + }, + Action: specs.ActErrno, + }, + }, + }, + input: testInput(nativeArchAuditNo, "getsockopt", &[6]uint64{0, syscall.SOL_SOCKET}), + expected: uint32(allowAction), + }, + { + name: "Simple args (no match)", + config: specs.LinuxSeccomp{ + DefaultAction: specs.ActAllow, + Syscalls: []specs.LinuxSyscall{ + { + Names: []string{ + "clone", + }, + Args: []specs.LinuxSeccompArg{ + { + Index: 0, + Value: syscall.CLONE_FS, + Op: specs.OpEqualTo, + }, + }, + Action: specs.ActErrno, + }, + }, + }, + input: testInput(nativeArchAuditNo, "clone", &[6]uint64{syscall.CLONE_VM}), + expected: uint32(allowAction), + }, + { + name: "OpMaskedEqual (match)", + config: specs.LinuxSeccomp{ + DefaultAction: specs.ActAllow, + Syscalls: []specs.LinuxSyscall{ + { + Names: []string{ + "clone", + }, + Args: []specs.LinuxSeccompArg{ + { + Index: 0, + Value: syscall.CLONE_FS, + ValueTwo: syscall.CLONE_FS, + Op: specs.OpMaskedEqual, + }, + }, + Action: specs.ActErrno, + }, + }, + }, + input: testInput(nativeArchAuditNo, "clone", &[6]uint64{syscall.CLONE_FS | syscall.CLONE_VM}), + expected: uint32(errnoAction), + }, + { + name: "OpMaskedEqual (no match)", + config: specs.LinuxSeccomp{ + DefaultAction: specs.ActAllow, + Syscalls: []specs.LinuxSyscall{ + { + Names: []string{ + "clone", + }, + Args: []specs.LinuxSeccompArg{ + { + Index: 0, + Value: syscall.CLONE_FS | syscall.CLONE_VM, + ValueTwo: syscall.CLONE_FS | syscall.CLONE_VM, + Op: specs.OpMaskedEqual, + }, + }, + Action: specs.ActErrno, + }, + }, + }, + input: testInput(nativeArchAuditNo, "clone", &[6]uint64{syscall.CLONE_FS}), + expected: uint32(allowAction), + }, + { + name: "OpMaskedEqual (clone)", + config: specs.LinuxSeccomp{ + DefaultAction: specs.ActErrno, + Syscalls: []specs.LinuxSyscall{ + { + Names: []string{ + "clone", + }, + // This comes from the Docker default seccomp + // profile for clone. + Args: []specs.LinuxSeccompArg{ + { + Index: 0, + Value: 0x7e020000, + ValueTwo: 0x0, + Op: specs.OpMaskedEqual, + }, + }, + Action: specs.ActAllow, + }, + }, + }, + input: testInput(nativeArchAuditNo, "clone", &[6]uint64{0x50f00}), + expected: uint32(allowAction), + }, + } +) + +// TestRunscSeccomp generates seccomp programs from OCI config and executes +// them using runsc's library, comparing against expected results. +func TestRunscSeccomp(t *testing.T) { + for _, tc := range seccompTests { + t.Run(tc.name, func(t *testing.T) { + runscProgram, err := BuildProgram(&tc.config) + if err != nil { + t.Fatalf("generating runsc BPF: %v", err) + } + + if err := checkProgram(runscProgram, tc.input, tc.expected); err != nil { + t.Fatalf("running runsc BPF: %v", err) + } + }) + } +} + +// checkProgram runs the given program over the given input and checks the +// result against the expected output. +func checkProgram(p bpf.Program, in bpf.Input, expected uint32) error { + result, err := bpf.Exec(p, in) + if err != nil { + return err + } + + if result != expected { + // Include a decoded version of the program in output for debugging purposes. + decoded, _ := bpf.DecodeProgram(p) + return fmt.Errorf("Unexpected result: got: %d, expected: %d\nBPF Program\n%s", result, expected, decoded) + } + + return nil +} diff --git a/runsc/specutils/specutils.go b/runsc/specutils/specutils.go index 5015c3a84..0392e3e83 100644 --- a/runsc/specutils/specutils.go +++ b/runsc/specutils/specutils.go @@ -35,6 +35,7 @@ import ( "gvisor.dev/gvisor/pkg/bits" "gvisor.dev/gvisor/pkg/log" "gvisor.dev/gvisor/pkg/sentry/kernel/auth" + "gvisor.dev/gvisor/runsc/config" ) // ExePath must point to runsc binary, which is normally the same binary. It's @@ -110,11 +111,6 @@ func ValidateSpec(spec *specs.Spec) error { log.Warningf("noNewPrivileges ignored. PR_SET_NO_NEW_PRIVS is assumed to always be set.") } - // TODO(gvisor.dev/issue/510): Apply seccomp to application inside sandbox. - if spec.Linux != nil && spec.Linux.Seccomp != nil { - log.Warningf("Seccomp spec is being ignored") - } - if spec.Linux != nil && spec.Linux.RootfsPropagation != "" { if err := validateRootfsPropagation(spec.Linux.RootfsPropagation); err != nil { return err @@ -161,18 +157,18 @@ func OpenSpec(bundleDir string) (*os.File, error) { // ReadSpec reads an OCI runtime spec from the given bundle directory. // ReadSpec also normalizes all potential relative paths into absolute // path, e.g. spec.Root.Path, mount.Source. -func ReadSpec(bundleDir string) (*specs.Spec, error) { +func ReadSpec(bundleDir string, conf *config.Config) (*specs.Spec, error) { specFile, err := OpenSpec(bundleDir) if err != nil { return nil, fmt.Errorf("error opening spec file %q: %v", filepath.Join(bundleDir, "config.json"), err) } defer specFile.Close() - return ReadSpecFromFile(bundleDir, specFile) + return ReadSpecFromFile(bundleDir, specFile, conf) } // ReadSpecFromFile reads an OCI runtime spec from the given File, and // normalizes all relative paths into absolute by prepending the bundle dir. -func ReadSpecFromFile(bundleDir string, specFile *os.File) (*specs.Spec, error) { +func ReadSpecFromFile(bundleDir string, specFile *os.File, conf *config.Config) (*specs.Spec, error) { if _, err := specFile.Seek(0, os.SEEK_SET); err != nil { return nil, fmt.Errorf("error seeking to beginning of file %q: %v", specFile.Name(), err) } @@ -195,6 +191,20 @@ func ReadSpecFromFile(bundleDir string, specFile *os.File) (*specs.Spec, error) m.Source = absPath(bundleDir, m.Source) } } + + // Override flags using annotation to allow customization per sandbox + // instance. + for annotation, val := range spec.Annotations { + const flagPrefix = "dev.gvisor.flag." + if strings.HasPrefix(annotation, flagPrefix) { + name := annotation[len(flagPrefix):] + log.Infof("Overriding flag: %s=%q", name, val) + if err := conf.Override(name, val); err != nil { + return nil, err + } + } + } + return &spec, nil } diff --git a/scripts/packetdrill_tests.sh b/scripts/packetdrill_tests.sh index 1a8181ac8..cdb98c834 100755 --- a/scripts/packetdrill_tests.sh +++ b/scripts/packetdrill_tests.sh @@ -16,8 +16,8 @@ source $(dirname $0)/common.sh -make load-packetdrill +QUERY_RESULT=$(query 'attr(tags, manual, tests(//test/packetdrill/...))') install_runsc_for_test runsc-d -QUERY_RESULT=$(query "attr(tags, manual, tests(//test/packetdrill/...))") +make load-packetdrill test_runsc $QUERY_RESULT diff --git a/scripts/packetimpact_tests.sh b/scripts/packetimpact_tests.sh index 77fb84bc3..4878b72f4 100755 --- a/scripts/packetimpact_tests.sh +++ b/scripts/packetimpact_tests.sh @@ -16,8 +16,9 @@ source $(dirname $0)/common.sh -make load-packetimpact + +QUERY_RESULT=$(query 'attr(tags, packetimpact, tests(//test/packetimpact/...))') install_runsc_for_test runsc-d -QUERY_RESULT=$(query "attr(tags, packetimpact, tests(//test/packetimpact/...))") +make load-packetimpact test_runsc $QUERY_RESULT diff --git a/shim/BUILD b/shim/BUILD index e581618b2..8d29c459b 100644 --- a/shim/BUILD +++ b/shim/BUILD @@ -10,6 +10,6 @@ pkg_tar( mode = "0644", package_dir = "/etc/containerd", visibility = [ - "//runsc:__pkg__", + "//visibility:public", ], ) diff --git a/test/benchmarks/database/redis_test.go b/test/benchmarks/database/redis_test.go index 394fce820..6671a4969 100644 --- a/test/benchmarks/database/redis_test.go +++ b/test/benchmarks/database/redis_test.go @@ -84,12 +84,12 @@ func BenchmarkRedis(b *testing.B) { ip, err := serverMachine.IPAddress() if err != nil { - b.Fatal("failed to get IP from server: %v", err) + b.Fatalf("failed to get IP from server: %v", err) } serverPort, err := server.FindPort(ctx, port) if err != nil { - b.Fatal("failed to get IP from server: %v", err) + b.Fatalf("failed to get IP from server: %v", err) } if err = harness.WaitUntilServing(ctx, clientMachine, ip, serverPort); err != nil { diff --git a/test/benchmarks/fs/bazel_test.go b/test/benchmarks/fs/bazel_test.go index f4236ba37..ef1b8e4ea 100644 --- a/test/benchmarks/fs/bazel_test.go +++ b/test/benchmarks/fs/bazel_test.go @@ -62,7 +62,7 @@ func runBuildBenchmark(b *testing.B, image, workdir, target string) { container := machine.GetContainer(ctx, b) defer container.CleanUp(ctx) - // Start a container and sleep by an order of b.N. + // Start a container and sleep. if err := container.Spawn(ctx, dockerutil.RunOpts{ Image: image, }, "sleep", fmt.Sprintf("%d", 1000000)); err != nil { @@ -70,12 +70,13 @@ func runBuildBenchmark(b *testing.B, image, workdir, target string) { } // If we are running on a tmpfs, copy to /tmp which is a tmpfs. + prefix := "" if bm.tmpfs { if out, err := container.Exec(ctx, dockerutil.ExecOpts{}, "cp", "-r", workdir, "/tmp/."); err != nil { - b.Fatal("failed to copy directory: %v %s", err, out) + b.Fatalf("failed to copy directory: %v (%s)", err, out) } - workdir = "/tmp" + workdir + prefix = "/tmp" } // Restart profiles after the copy. @@ -94,7 +95,7 @@ func runBuildBenchmark(b *testing.B, image, workdir, target string) { b.StartTimer() got, err := container.Exec(ctx, dockerutil.ExecOpts{ - WorkDir: workdir, + WorkDir: prefix + workdir, }, "bazel", "build", "-c", "opt", target) if err != nil { b.Fatalf("build failed with: %v", err) @@ -107,7 +108,7 @@ func runBuildBenchmark(b *testing.B, image, workdir, target string) { } // Clean bazel in case we use b.N. _, err = container.Exec(ctx, dockerutil.ExecOpts{ - WorkDir: workdir, + WorkDir: prefix + workdir, }, "bazel", "clean") if err != nil { b.Fatalf("build failed with: %v", err) diff --git a/test/benchmarks/network/BUILD b/test/benchmarks/network/BUILD index bd3f6245c..472b5c387 100644 --- a/test/benchmarks/network/BUILD +++ b/test/benchmarks/network/BUILD @@ -5,8 +5,15 @@ package(licenses = ["notice"]) go_library( name = "network", testonly = 1, - srcs = ["network.go"], - deps = ["//test/benchmarks/harness"], + srcs = [ + "network.go", + "static_server.go", + ], + deps = [ + "//pkg/test/dockerutil", + "//test/benchmarks/harness", + "//test/benchmarks/tools", + ], ) go_test( diff --git a/test/benchmarks/network/httpd_test.go b/test/benchmarks/network/httpd_test.go index 336e04c91..369ab326e 100644 --- a/test/benchmarks/network/httpd_test.go +++ b/test/benchmarks/network/httpd_test.go @@ -14,22 +14,19 @@ package network import ( - "context" "fmt" "testing" "gvisor.dev/gvisor/pkg/test/dockerutil" - "gvisor.dev/gvisor/test/benchmarks/harness" "gvisor.dev/gvisor/test/benchmarks/tools" ) // see Dockerfile '//images/benchmarks/httpd'. -var docs = map[string]string{ +var httpdDocs = map[string]string{ "notfound": "notfound", "1Kb": "latin1k.txt", "10Kb": "latin10k.txt", "100Kb": "latin100k.txt", - "1000Kb": "latin1000k.txt", "1Mb": "latin1024k.txt", "10Mb": "latin10240k.txt", } @@ -37,30 +34,17 @@ var docs = map[string]string{ // BenchmarkHttpdConcurrency iterates the concurrency argument and tests // how well the runtime under test handles requests in parallel. func BenchmarkHttpdConcurrency(b *testing.B) { - // Grab a machine for the client and server. - clientMachine, err := h.GetMachine() - if err != nil { - b.Fatalf("failed to get client: %v", err) - } - defer clientMachine.CleanUp() - - serverMachine, err := h.GetMachine() - if err != nil { - b.Fatalf("failed to get server: %v", err) - } - defer serverMachine.CleanUp() - // The test iterates over client concurrency, so set other parameters. concurrency := []int{1, 25, 50, 100, 1000} for _, c := range concurrency { b.Run(fmt.Sprintf("%d", c), func(b *testing.B) { hey := &tools.Hey{ - Requests: 10000, + Requests: c * b.N, Concurrency: c, - Doc: docs["10Kb"], + Doc: httpdDocs["10Kb"], } - runHttpd(b, clientMachine, serverMachine, hey, false /* reverse */) + runHttpd(b, hey, false /* reverse */) }) } } @@ -77,57 +61,30 @@ func BenchmarkReverseHttpdDocSize(b *testing.B) { benchmarkHttpdDocSize(b, true /* reverse */) } +// benchmarkHttpdDocSize iterates through all doc sizes, running subbenchmarks +// for each size. func benchmarkHttpdDocSize(b *testing.B, reverse bool) { b.Helper() - - clientMachine, err := h.GetMachine() - if err != nil { - b.Fatalf("failed to get machine: %v", err) - } - defer clientMachine.CleanUp() - - serverMachine, err := h.GetMachine() - if err != nil { - b.Fatalf("failed to get machine: %v", err) - } - defer serverMachine.CleanUp() - - for name, filename := range docs { + for name, filename := range httpdDocs { concurrency := []int{1, 25, 50, 100, 1000} for _, c := range concurrency { b.Run(fmt.Sprintf("%s_%d", name, c), func(b *testing.B) { hey := &tools.Hey{ - Requests: 10000, + Requests: c * b.N, Concurrency: c, Doc: filename, } - runHttpd(b, clientMachine, serverMachine, hey, reverse) + runHttpd(b, hey, reverse) }) } } } -// runHttpd runs a single test run. -func runHttpd(b *testing.B, clientMachine, serverMachine harness.Machine, hey *tools.Hey, reverse bool) { - b.Helper() - - // Grab a container from the server. - ctx := context.Background() - var server *dockerutil.Container - if reverse { - server = serverMachine.GetNativeContainer(ctx, b) - } else { - server = serverMachine.GetContainer(ctx, b) - } - - defer server.CleanUp(ctx) - - // Copy the docs to /tmp and serve from there. - cmd := "mkdir -p /tmp/html; cp -r /local/* /tmp/html/.; apache2 -X" +// runHttpd configures the static serving methods to run httpd. +func runHttpd(b *testing.B, hey *tools.Hey, reverse bool) { + // httpd runs on port 80. port := 80 - - // Start the server. - if err := server.Spawn(ctx, dockerutil.RunOpts{ + httpdRunOpts := dockerutil.RunOpts{ Image: "benchmarks/httpd", Ports: []int{port}, Env: []string{ @@ -138,44 +95,7 @@ func runHttpd(b *testing.B, clientMachine, serverMachine harness.Machine, hey *t "APACHE_LOG_DIR=/tmp", "APACHE_PID_FILE=/tmp/apache.pid", }, - }, "sh", "-c", cmd); err != nil { - b.Fatalf("failed to start server: %v", err) - } - - ip, err := serverMachine.IPAddress() - if err != nil { - b.Fatalf("failed to find server ip: %v", err) - } - - servingPort, err := server.FindPort(ctx, port) - if err != nil { - b.Fatalf("failed to find server port %d: %v", port, err) - } - - // Check the server is serving. - harness.WaitUntilServing(ctx, clientMachine, ip, servingPort) - - var client *dockerutil.Container - // Grab a client. - if reverse { - client = clientMachine.GetContainer(ctx, b) - } else { - client = clientMachine.GetNativeContainer(ctx, b) - } - defer client.CleanUp(ctx) - - b.ResetTimer() - server.RestartProfiles() - for i := 0; i < b.N; i++ { - out, err := client.Run(ctx, dockerutil.RunOpts{ - Image: "benchmarks/hey", - }, hey.MakeCmd(ip, servingPort)...) - if err != nil { - b.Fatalf("run failed with: %v", err) - } - - b.StopTimer() - hey.Report(b, out) - b.StartTimer() } + httpdCmd := []string{"sh", "-c", "mkdir -p /tmp/html; cp -r /local/* /tmp/html/.; apache2 -X"} + runStaticServer(b, httpdRunOpts, httpdCmd, port, hey, reverse) } diff --git a/test/benchmarks/network/nginx_test.go b/test/benchmarks/network/nginx_test.go index 2bf1a3624..036fd666f 100644 --- a/test/benchmarks/network/nginx_test.go +++ b/test/benchmarks/network/nginx_test.go @@ -14,91 +14,80 @@ package network import ( - "context" "fmt" "testing" "gvisor.dev/gvisor/pkg/test/dockerutil" - "gvisor.dev/gvisor/test/benchmarks/harness" "gvisor.dev/gvisor/test/benchmarks/tools" ) +// see Dockerfile '//images/benchmarks/nginx'. +var nginxDocs = map[string]string{ + "notfound": "notfound", + "1Kb": "latin1k.txt", + "10Kb": "latin10k.txt", + "100Kb": "latin100k.txt", + "1Mb": "latin1024k.txt", + "10Mb": "latin10240k.txt", +} + // BenchmarkNginxConcurrency iterates the concurrency argument and tests // how well the runtime under test handles requests in parallel. -// TODO(gvisor.dev/issue/3536): Update with different doc sizes like Httpd. func BenchmarkNginxConcurrency(b *testing.B) { - // Grab a machine for the client and server. - clientMachine, err := h.GetMachine() - if err != nil { - b.Fatalf("failed to get client: %v", err) - } - defer clientMachine.CleanUp() - - serverMachine, err := h.GetMachine() - if err != nil { - b.Fatalf("failed to get server: %v", err) - } - defer serverMachine.CleanUp() - - concurrency := []int{1, 5, 10, 25} + concurrency := []int{1, 25, 100, 1000} for _, c := range concurrency { b.Run(fmt.Sprintf("%d", c), func(b *testing.B) { hey := &tools.Hey{ - Requests: 10000, + Requests: c * b.N, Concurrency: c, + Doc: nginxDocs["10kb"], // see Dockerfile '//images/benchmarks/nginx' and httpd_test. } - runNginx(b, clientMachine, serverMachine, hey) + runNginx(b, hey, false /* reverse */) }) } } -// runHttpd runs a single test run. -func runNginx(b *testing.B, clientMachine, serverMachine harness.Machine, hey *tools.Hey) { - b.Helper() - - // Grab a container from the server. - ctx := context.Background() - server := serverMachine.GetContainer(ctx, b) - defer server.CleanUp(ctx) +// BenchmarkNginxDocSize iterates over different sized payloads, testing how +// well the runtime handles sending different payload sizes. +func BenchmarkNginxDocSize(b *testing.B) { + benchmarkHttpdDocSize(b, false /* reverse */) +} - port := 80 - // Start the server. - if err := server.Spawn(ctx, - dockerutil.RunOpts{ - Image: "benchmarks/nginx", - Ports: []int{port}, - }); err != nil { - b.Fatalf("server failed to start: %v", err) - } +// BenchmarkReverseNginxDocSize iterates over different sized payloads, testing +// how well the runtime handles receiving different payload sizes. +func BenchmarkReverseNginxDocSize(b *testing.B) { + benchmarkHttpdDocSize(b, true /* reverse */) +} - ip, err := serverMachine.IPAddress() - if err != nil { - b.Fatalf("failed to find server ip: %v", err) +// benchmarkNginxDocSize iterates through all doc sizes, running subbenchmarks +// for each size. +func benchmarkNginxDocSize(b *testing.B, reverse bool) { + b.Helper() + for name, filename := range nginxDocs { + concurrency := []int{1, 25, 50, 100, 1000} + for _, c := range concurrency { + b.Run(fmt.Sprintf("%s_%d", name, c), func(b *testing.B) { + hey := &tools.Hey{ + Requests: c * b.N, + Concurrency: c, + Doc: filename, + } + runNginx(b, hey, reverse) + }) + } } +} - servingPort, err := server.FindPort(ctx, port) - if err != nil { - b.Fatalf("failed to find server port %d: %v", port, err) +// runNginx configures the static serving methods to run httpd. +func runNginx(b *testing.B, hey *tools.Hey, reverse bool) { + // nginx runs on port 80. + port := 80 + nginxRunOpts := dockerutil.RunOpts{ + Image: "benchmarks/nginx", + Ports: []int{port}, } - // Check the server is serving. - harness.WaitUntilServing(ctx, clientMachine, ip, servingPort) - - // Grab a client. - client := clientMachine.GetNativeContainer(ctx, b) - defer client.CleanUp(ctx) - - b.ResetTimer() - server.RestartProfiles() - for i := 0; i < b.N; i++ { - out, err := client.Run(ctx, dockerutil.RunOpts{ - Image: "benchmarks/hey", - }, hey.MakeCmd(ip, servingPort)...) - if err != nil { - b.Fatalf("run failed with: %v", err) - } - b.StopTimer() - hey.Report(b, out) - b.StartTimer() - } + // Command copies nginxDocs to tmpfs serving directory and runs nginx. + nginxCmd := []string{"sh", "-c", "mkdir -p /tmp/html && cp -a /local/* /tmp/html && nginx"} + runStaticServer(b, nginxRunOpts, nginxCmd, port, hey, reverse) } diff --git a/test/benchmarks/network/node_test.go b/test/benchmarks/network/node_test.go index 52eb794c4..0f4a205b6 100644 --- a/test/benchmarks/network/node_test.go +++ b/test/benchmarks/network/node_test.go @@ -48,14 +48,14 @@ func runNode(b *testing.B, hey *tools.Hey) { // The machine to hold Redis and the Node Server. serverMachine, err := h.GetMachine() if err != nil { - b.Fatal("failed to get machine with: %v", err) + b.Fatalf("failed to get machine with: %v", err) } defer serverMachine.CleanUp() // The machine to run 'hey'. clientMachine, err := h.GetMachine() if err != nil { - b.Fatal("failed to get machine with: %v", err) + b.Fatalf("failed to get machine with: %v", err) } defer clientMachine.CleanUp() diff --git a/test/benchmarks/network/ruby_test.go b/test/benchmarks/network/ruby_test.go index 5e0b2b724..67f63f76a 100644 --- a/test/benchmarks/network/ruby_test.go +++ b/test/benchmarks/network/ruby_test.go @@ -47,14 +47,14 @@ func runRuby(b *testing.B, hey *tools.Hey) { // The machine to hold Redis and the Ruby Server. serverMachine, err := h.GetMachine() if err != nil { - b.Fatal("failed to get machine with: %v", err) + b.Fatalf("failed to get machine with: %v", err) } defer serverMachine.CleanUp() // The machine to run 'hey'. clientMachine, err := h.GetMachine() if err != nil { - b.Fatal("failed to get machine with: %v", err) + b.Fatalf("failed to get machine with: %v", err) } defer clientMachine.CleanUp() ctx := context.Background() diff --git a/test/benchmarks/network/static_server.go b/test/benchmarks/network/static_server.go new file mode 100644 index 000000000..3ef62a71f --- /dev/null +++ b/test/benchmarks/network/static_server.go @@ -0,0 +1,88 @@ +// Copyright 2020 The gVisor Authors. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package network + +import ( + "context" + "testing" + + "gvisor.dev/gvisor/pkg/test/dockerutil" + "gvisor.dev/gvisor/test/benchmarks/harness" + "gvisor.dev/gvisor/test/benchmarks/tools" +) + +// runStaticServer runs static serving workloads (httpd, nginx). +func runStaticServer(b *testing.B, serverOpts dockerutil.RunOpts, serverCmd []string, port int, hey *tools.Hey, reverse bool) { + b.Helper() + ctx := context.Background() + + // Get two machines: a client and server. + clientMachine, err := h.GetMachine() + if err != nil { + b.Fatalf("failed to get machine: %v", err) + } + defer clientMachine.CleanUp() + + serverMachine, err := h.GetMachine() + if err != nil { + b.Fatalf("failed to get machine: %v", err) + } + defer serverMachine.CleanUp() + + // Make the containers. 'reverse=true' specifies that the client should use the + // runtime under test. + var client, server *dockerutil.Container + if reverse { + client = clientMachine.GetContainer(ctx, b) + server = serverMachine.GetNativeContainer(ctx, b) + } else { + client = clientMachine.GetNativeContainer(ctx, b) + server = serverMachine.GetContainer(ctx, b) + } + defer client.CleanUp(ctx) + defer server.CleanUp(ctx) + + // Start the server. + if err := server.Spawn(ctx, serverOpts, serverCmd...); err != nil { + b.Fatalf("failed to start server: %v", err) + } + + // Get its IP. + ip, err := serverMachine.IPAddress() + if err != nil { + b.Fatalf("failed to find server ip: %v", err) + } + + // Get the published port. + servingPort, err := server.FindPort(ctx, port) + if err != nil { + b.Fatalf("failed to find server port %d: %v", port, err) + } + + // Make sure the server is serving. + harness.WaitUntilServing(ctx, clientMachine, ip, servingPort) + b.ResetTimer() + server.RestartProfiles() + out, err := client.Run(ctx, dockerutil.RunOpts{ + Image: "benchmarks/hey", + }, hey.MakeCmd(ip, servingPort)...) + if err != nil { + b.Fatalf("run failed with: %v", err) + } + + b.StopTimer() + hey.Report(b, out) + b.StartTimer() +} diff --git a/test/benchmarks/tcp/tcp_proxy.go b/test/benchmarks/tcp/tcp_proxy.go index 4b7ca7a14..6cabfb451 100644 --- a/test/benchmarks/tcp/tcp_proxy.go +++ b/test/benchmarks/tcp/tcp_proxy.go @@ -228,19 +228,26 @@ func newNetstackImpl(mode string) (impl, error) { }) // Set protocol options. - if err := s.SetTransportProtocolOption(tcp.ProtocolNumber, tcp.SACKEnabled(*sack)); err != nil { - return nil, fmt.Errorf("SetTransportProtocolOption for SACKEnabled failed: %s", err) + { + opt := tcpip.TCPSACKEnabled(*sack) + if err := s.SetTransportProtocolOption(tcp.ProtocolNumber, &opt); err != nil { + return nil, fmt.Errorf("SetTransportProtocolOption(%d, &%T(%t)): %s", tcp.ProtocolNumber, opt, opt, err) + } } // Enable Receive Buffer Auto-Tuning. - if err := s.SetTransportProtocolOption(tcp.ProtocolNumber, tcpip.ModerateReceiveBufferOption(*moderateRecvBuf)); err != nil { - return nil, fmt.Errorf("SetTransportProtocolOption failed: %s", err) + { + opt := tcpip.TCPModerateReceiveBufferOption(*moderateRecvBuf) + if err := s.SetTransportProtocolOption(tcp.ProtocolNumber, &opt); err != nil { + return nil, fmt.Errorf("SetTransportProtocolOption(%d, &%T(%t)): %s", tcp.ProtocolNumber, opt, opt, err) + } } // Set Congestion Control to cubic if requested. if *cubic { - if err := s.SetTransportProtocolOption(tcp.ProtocolNumber, tcpip.CongestionControlOption("cubic")); err != nil { - return nil, fmt.Errorf("SetTransportProtocolOption for CongestionControlOption(cubic) failed: %s", err) + opt := tcpip.CongestionControlOption("cubic") + if err := s.SetTransportProtocolOption(tcp.ProtocolNumber, &opt); err != nil { + return nil, fmt.Errorf("SetTransportProtocolOption(%d, &%T(%s)): %s", tcp.ProtocolNumber, opt, opt, err) } } diff --git a/test/e2e/integration_test.go b/test/e2e/integration_test.go index 809244bab..8425abecb 100644 --- a/test/e2e/integration_test.go +++ b/test/e2e/integration_test.go @@ -64,9 +64,10 @@ func TestLifeCycle(t *testing.T) { defer d.CleanUp(ctx) // Start the container. + port := 80 if err := d.Create(ctx, dockerutil.RunOpts{ Image: "basic/nginx", - Ports: []int{80}, + Ports: []int{port}, }); err != nil { t.Fatalf("docker create failed: %v", err) } @@ -74,16 +75,15 @@ func TestLifeCycle(t *testing.T) { t.Fatalf("docker start failed: %v", err) } - // Test that container is working. - port, err := d.FindPort(ctx, 80) + ip, err := d.FindIP(ctx, false) if err != nil { - t.Fatalf("docker.FindPort(80) failed: %v", err) + t.Fatalf("docker.FindIP failed: %v", err) } - if err := testutil.WaitForHTTP(port, defaultWait); err != nil { + if err := testutil.WaitForHTTP(ip.String(), port, defaultWait); err != nil { t.Fatalf("WaitForHTTP() timeout: %v", err) } client := http.Client{Timeout: defaultWait} - if err := httpRequestSucceeds(client, "localhost", port); err != nil { + if err := httpRequestSucceeds(client, ip.String(), port); err != nil { t.Errorf("http request failed: %v", err) } @@ -105,27 +105,28 @@ func TestPauseResume(t *testing.T) { defer d.CleanUp(ctx) // Start the container. + port := 8080 if err := d.Spawn(ctx, dockerutil.RunOpts{ Image: "basic/python", - Ports: []int{8080}, // See Dockerfile. + Ports: []int{port}, // See Dockerfile. }); err != nil { t.Fatalf("docker run failed: %v", err) } - // Find where port 8080 is mapped to. - port, err := d.FindPort(ctx, 8080) + // Find container IP address. + ip, err := d.FindIP(ctx, false) if err != nil { - t.Fatalf("docker.FindPort(8080) failed: %v", err) + t.Fatalf("docker.FindIP failed: %v", err) } // Wait until it's up and running. - if err := testutil.WaitForHTTP(port, defaultWait); err != nil { + if err := testutil.WaitForHTTP(ip.String(), port, defaultWait); err != nil { t.Fatalf("WaitForHTTP() timeout: %v", err) } // Check that container is working. client := http.Client{Timeout: defaultWait} - if err := httpRequestSucceeds(client, "localhost", port); err != nil { + if err := httpRequestSucceeds(client, ip.String(), port); err != nil { t.Error("http request failed:", err) } @@ -135,7 +136,7 @@ func TestPauseResume(t *testing.T) { // Check if container is paused. client = http.Client{Timeout: 10 * time.Millisecond} // Don't wait a minute. - switch _, err := client.Get(fmt.Sprintf("http://localhost:%d", port)); v := err.(type) { + switch _, err := client.Get(fmt.Sprintf("http://%s:%d", ip.String(), port)); v := err.(type) { case nil: t.Errorf("http req expected to fail but it succeeded") case net.Error: @@ -151,13 +152,13 @@ func TestPauseResume(t *testing.T) { } // Wait until it's up and running. - if err := testutil.WaitForHTTP(port, defaultWait); err != nil { + if err := testutil.WaitForHTTP(ip.String(), port, defaultWait); err != nil { t.Fatalf("WaitForHTTP() timeout: %v", err) } // Check if container is working again. client = http.Client{Timeout: defaultWait} - if err := httpRequestSucceeds(client, "localhost", port); err != nil { + if err := httpRequestSucceeds(client, ip.String(), port); err != nil { t.Error("http request failed:", err) } } @@ -179,9 +180,10 @@ func TestCheckpointRestore(t *testing.T) { defer d.CleanUp(ctx) // Start the container. + port := 8080 if err := d.Spawn(ctx, dockerutil.RunOpts{ Image: "basic/python", - Ports: []int{8080}, // See Dockerfile. + Ports: []int{port}, // See Dockerfile. }); err != nil { t.Fatalf("docker run failed: %v", err) } @@ -199,20 +201,20 @@ func TestCheckpointRestore(t *testing.T) { t.Fatalf("docker restore failed: %v", err) } - // Find where port 8080 is mapped to. - port, err := d.FindPort(ctx, 8080) + // Find container IP address. + ip, err := d.FindIP(ctx, false) if err != nil { - t.Fatalf("docker.FindPort(8080) failed: %v", err) + t.Fatalf("docker.FindIP failed: %v", err) } // Wait until it's up and running. - if err := testutil.WaitForHTTP(port, defaultWait); err != nil { + if err := testutil.WaitForHTTP(ip.String(), port, defaultWait); err != nil { t.Fatalf("WaitForHTTP() timeout: %v", err) } // Check if container is working again. client := http.Client{Timeout: defaultWait} - if err := httpRequestSucceeds(client, "localhost", port); err != nil { + if err := httpRequestSucceeds(client, ip.String(), port); err != nil { t.Error("http request failed:", err) } } diff --git a/test/fuse/BUILD b/test/fuse/BUILD index 56157c96b..8e31fdd41 100644 --- a/test/fuse/BUILD +++ b/test/fuse/BUILD @@ -5,5 +5,69 @@ package(licenses = ["notice"]) syscall_test( fuse = "True", test = "//test/fuse/linux:stat_test", - vfs2 = "True", +) + +syscall_test( + fuse = "True", + test = "//test/fuse/linux:open_test", +) + +syscall_test( + fuse = "True", + test = "//test/fuse/linux:release_test", +) + +syscall_test( + fuse = "True", + test = "//test/fuse/linux:mknod_test", +) + +syscall_test( + fuse = "True", + test = "//test/fuse/linux:symlink_test", +) + +syscall_test( + fuse = "True", + test = "//test/fuse/linux:readlink_test", +) + +syscall_test( + fuse = "True", + test = "//test/fuse/linux:mkdir_test", +) + +syscall_test( + fuse = "True", + test = "//test/fuse/linux:read_test", +) + +syscall_test( + fuse = "True", + test = "//test/fuse/linux:write_test", +) + +syscall_test( + fuse = "True", + test = "//test/fuse/linux:rmdir_test", +) + +syscall_test( + fuse = "True", + test = "//test/fuse/linux:readdir_test", +) + +syscall_test( + fuse = "True", + test = "//test/fuse/linux:create_test", +) + +syscall_test( + fuse = "True", + test = "//test/fuse/linux:unlink_test", +) + +syscall_test( + fuse = "True", + test = "//test/fuse/linux:setstat_test", ) diff --git a/test/fuse/README.md b/test/fuse/README.md index 734c3a4e3..65add57e2 100644 --- a/test/fuse/README.md +++ b/test/fuse/README.md @@ -1,65 +1,114 @@ # gVisor FUSE Test Suite -This is an integration test suite for fuse(4) filesystem. It runs under both -gVisor and Linux, and ensures compatibility between the two. This test suite is -based on system calls test. +This is an integration test suite for fuse(4) filesystem. It runs under gVisor +sandbox container with VFS2 and FUSE function enabled. -This document describes the framework of fuse integration test and the -guidelines that should be followed when adding new fuse tests. +This document describes the framework of FUSE integration test, how to use it, +and the guidelines that should be followed when adding new testing features. ## Integration Test Framework -Please refer to the figure below. `>` is entering the function, `<` is leaving -the function, and `=` indicates sequentially entering and leaving. +By inheriting the `FuseTest` class defined in `linux/fuse_base.h`, every test +fixture can run in an environment with `mount_point_` mounted by a fake FUSE +server. It creates a `socketpair(2)` to send and receive control commands and +data between the client and the server. Because the FUSE server runs in the +background thread, gTest cannot catch its assertion failure immediately. Thus, +`TearDown()` function sends command to the FUSE server to check if all gTest +assertion in the server are successful and all requests and preset responses are +consumed. + +## Communication Diagram + +Diagram below describes how a testing thread communicates with the FUSE server +to achieve integration test. + +For the following diagram, `>` means entering the function, `<` is leaving the +function, and `=` indicates sequentially entering and leaving. Not necessarily +follow exactly the below diagram due to the nature of a multi-threaded system, +however, it is still helpful to know when the client waits for the server to +complete a command and when the server awaits the next instruction. ``` - | Client (Test Main Process) | Server (FUSE Daemon) - | | - | >TEST_F() | - | >SetUp() | - | =MountFuse() | - | >SetUpFuseServer() | - | [create communication pipes] | - | =fork() | =fork() - | >WaitCompleted() | - | [wait for MarkDone()] | - | | =ConsumeFuseInit() - | | =MarkDone() - | <WaitCompleted() | - | <SetUpFuseServer() | - | <SetUp() | - | >SetExpected() | - | [construct expected reaction] | - | | >FuseLoop() - | | >ReceiveExpected() - | | [wait data from pipe] - | [write data to pipe] | - | [wait for MarkDone()] | - | | [save data to memory] - | | =MarkDone() - | <SetExpected() | - | | <ReceiveExpected() - | | >read() - | | [wait for fs operation] - | >[Do fs operation] | - | [wait for fs response] | - | | <read() - | | =CompareRequest() - | | =write() [write fs response] - | <[Do fs operation] | - | =[Test fs operation result] | - | =[wait for MarkDone()] | - | | =MarkDone() - | >TearDown() | - | =UnmountFuse() | - | <TearDown() | - | <TEST_F() | +| Client (Testing Thread) | Server (FUSE Server Thread) +| | +| >TEST_F() | +| >SetUp() | +| =MountFuse() | +| >SetUpFuseServer() | +| [create communication socket]| +| =fork() | =fork() +| [wait server complete] | +| | =ServerConsumeFuseInit() +| | =ServerCompleteWith() +| <SetUpFuseServer() | +| <SetUp() | +| [testing main] | +| | >ServerFuseLoop() +| | [poll on socket and fd] +| >SetServerResponse() | +| [write data to socket] | +| [wait server complete] | +| | [socket event occurs] +| | >ServerHandleCommand() +| | >ServerReceiveResponse() +| | [read data from socket] +| | [save data to memory] +| | <ServerReceiveResponse() +| | =ServerCompleteWith() +| <SetServerResponse() | +| | <ServerHandleCommand() +| >[Do fs operation] | +| [wait for fs response] | +| | [fd event occurs] +| | >ServerProcessFuseRequest() +| | =[read fs request] +| | =[save fs request to memory] +| | =[write fs response] +| <[Do fs operation] | +| | <ServerProcessFuseRequest() +| | +| =[Test fs operation result] | +| | +| >GetServerActualRequest() | +| [write data to socket] | +| [wait data from server] | +| | [socket event occurs] +| | >ServerHandleCommand() +| | >ServerSendReceivedRequest() +| | [write data to socket] +| [read data from socket] | +| [wait server complete] | +| | <ServerSendReceivedRequest() +| | =ServerCompleteWith() +| <GetServerActualRequest() | +| | <ServerHandleCommand() +| | +| =[Test actual request] | +| | +| >TearDown() | +| ... | +| >GetServerNumUnsentResponses() | +| [write data to socket] | +| [wait server complete] | +| | [socket event arrive] +| | >ServerHandleCommand() +| | >ServerSendData() +| | [write data to socket] +| | <ServerSendData() +| | =ServerCompleteWith() +| [read data from socket] | +| [test if all succeeded] | +| <GetServerNumUnsentResponses() | +| | <ServerHandleCommand() +| =UnmountFuse() | +| <TearDown() | +| <TEST_F() | ``` ## Running the tests -Based on syscall tests, fuse tests can run in different environments. To enable -fuse testing environment, the test targets should be appended with `_fuse`. +Based on syscall tests, FUSE tests generate targets only with vfs2 and fuse +enabled. The corresponding targets end in `_fuse`. For example, to run fuse test in `stat_test.cc`: @@ -77,17 +126,16 @@ $ bazel test --test_tag_filters=fuse //test/fuse/... 1. Add test targets in `BUILD` and `linux/BUILD`. 2. Inherit your test from `FuseTest` base class. It allows you to: - - Run a fake FUSE server in background during each test setup. - - Create pipes for communication and provide utility functions. - - Stop FUSE server after test completes. -3. Customize your comparison function for request assessment in FUSE server. -4. Add the mapping of the size of structs if you are working on new FUSE - opcode. - - Please update `FuseTest::GetPayloadSize()` for each new FUSE opcode. -5. Build the expected request-response pair of your FUSE operation. -6. Call `SetExpected()` function to inject the expected reaction. -7. Check the response and/or errors. -8. Finally call `WaitCompleted()` to ensure the FUSE server acts correctly. + - Fork a fake FUSE server in background during each test setup. + - Create a pair of sockets for communication and provide utility + functions. + - Stop FUSE server and check if error occurs in it after test completes. +3. Build the expected opcode-response pairs of your FUSE operation. +4. Call `SetServerResponse()` to preset the next expected opcode and response. +5. Do real filesystem operations (FUSE is mounted at `mount_point_`). +6. Check FUSE response and/or errors. +7. Retrieve FUSE request by `GetServerActualRequest()`. +8. Check if the request is as expected. A few customized matchers used in syscalls test are encouraged to test the outcome of filesystem operations. Such as: @@ -101,3 +149,40 @@ SyscallFailsWithErrno(...) Please refer to [test/syscalls/README.md](../syscalls/README.md) for further details. + +## Writing a new FuseTestCmd + +A `FuseTestCmd` is a control protocol used in the communication between the +testing thread and the FUSE server. Such commands are sent from the testing +thread to the FUSE server to set up, control, or inspect the behavior of the +FUSE server in response to a sequence of FUSE requests. + +The lifecycle of a command contains following steps: + +1. The testing thread sends a `FuseTestCmd` via socket and waits for + completion. +2. The FUSE server receives the command and does corresponding action. +3. (Optional) The testing thread reads data from socket. +4. The FUSE server sends a success indicator via socket after processing. +5. The testing thread gets the success signal and continues testing. + +The success indicator, i.e. `WaitServerComplete()`, is crucial at the end of +each `FuseTestCmd` sent from the testing thread. Because we don't want to begin +filesystem operation if the requests have not been completely set up. Also, to +test FUSE interactions in a sequential manner, concurrent requests are not +supported now. + +To add a new `FuseTestCmd`, one must comply with following format: + +1. Add a new `FuseTestCmd` enum class item defined in `linux/fuse_base.h` +2. Add a `SetServerXXX()` or `GetServerXXX()` public function in `FuseTest`. + This is how the testing thread will call to send control message. Define how + many bytes you want to send along with the command and what you will expect + to receive. Finally it should block and wait for a success indicator from + the FUSE server. +3. Add a handler logic in the switch condition of `ServerHandleCommand()`. Use + `ServerSendData()` or declare a new private function such as + `ServerReceiveXXX()` or `ServerSendXXX()`. It is mandatory to set it private + since only the FUSE server (forked from `FuseTest` base class) can call it. + This is the server part of the specific `FuseTestCmd` and the format of the + data should be consistent with what the client expects in the previous step. diff --git a/test/fuse/linux/BUILD b/test/fuse/linux/BUILD index 4871bb531..7673252ec 100644 --- a/test/fuse/linux/BUILD +++ b/test/fuse/linux/BUILD @@ -11,7 +11,134 @@ cc_binary( srcs = ["stat_test.cc"], deps = [ gtest, + ":fuse_fd_util", + "//test/util:cleanup", + "//test/util:fs_util", + "//test/util:fuse_util", + "//test/util:test_main", + "//test/util:test_util", + ], +) + +cc_binary( + name = "open_test", + testonly = 1, + srcs = ["open_test.cc"], + deps = [ + gtest, ":fuse_base", + "//test/util:fuse_util", + "//test/util:test_main", + "//test/util:test_util", + ], +) + +cc_binary( + name = "release_test", + testonly = 1, + srcs = ["release_test.cc"], + deps = [ + gtest, + ":fuse_base", + "//test/util:fuse_util", + "//test/util:test_main", + "//test/util:test_util", + ], +) + +cc_binary( + name = "mknod_test", + testonly = 1, + srcs = ["mknod_test.cc"], + deps = [ + gtest, + ":fuse_base", + "//test/util:fuse_util", + "//test/util:temp_umask", + "//test/util:test_main", + "//test/util:test_util", + ], +) + +cc_binary( + name = "symlink_test", + testonly = 1, + srcs = ["symlink_test.cc"], + deps = [ + gtest, + ":fuse_base", + "//test/util:fuse_util", + "//test/util:test_main", + "//test/util:test_util", + ], +) + +cc_binary( + name = "readlink_test", + testonly = 1, + srcs = ["readlink_test.cc"], + deps = [ + gtest, + ":fuse_base", + "//test/util:fuse_util", + "//test/util:test_main", + "//test/util:test_util", + ], +) + +cc_binary( + name = "mkdir_test", + testonly = 1, + srcs = ["mkdir_test.cc"], + deps = [ + gtest, + ":fuse_base", + "//test/util:fuse_util", + "//test/util:temp_umask", + "//test/util:test_main", + "//test/util:test_util", + ], +) + +cc_binary( + name = "setstat_test", + testonly = 1, + srcs = ["setstat_test.cc"], + deps = [ + gtest, + ":fuse_fd_util", + "//test/util:cleanup", + "//test/util:fs_util", + "//test/util:fuse_util", + "//test/util:temp_umask", + "//test/util:test_main", + "//test/util:test_util", + ], +) + +cc_binary( + name = "rmdir_test", + testonly = 1, + srcs = ["rmdir_test.cc"], + deps = [ + gtest, + ":fuse_base", + "//test/util:fs_util", + "//test/util:fuse_util", + "//test/util:test_main", + "//test/util:test_util", + ], +) + +cc_binary( + name = "readdir_test", + testonly = 1, + srcs = ["readdir_test.cc"], + deps = [ + gtest, + ":fuse_base", + "//test/util:fs_util", + "//test/util:fuse_util", "//test/util:test_main", "//test/util:test_util", ], @@ -24,9 +151,80 @@ cc_library( hdrs = ["fuse_base.h"], deps = [ gtest, + "//test/util:fuse_util", "//test/util:posix_error", "//test/util:temp_path", "//test/util:test_util", "@com_google_absl//absl/strings:str_format", ], ) + +cc_library( + name = "fuse_fd_util", + testonly = 1, + srcs = ["fuse_fd_util.cc"], + hdrs = ["fuse_fd_util.h"], + deps = [ + gtest, + ":fuse_base", + "//test/util:cleanup", + "//test/util:file_descriptor", + "//test/util:fuse_util", + "//test/util:posix_error", + ], +) + +cc_binary( + name = "read_test", + testonly = 1, + srcs = ["read_test.cc"], + deps = [ + gtest, + ":fuse_base", + "//test/util:fuse_util", + "//test/util:test_main", + "//test/util:test_util", + ], +) + +cc_binary( + name = "write_test", + testonly = 1, + srcs = ["write_test.cc"], + deps = [ + gtest, + ":fuse_base", + "//test/util:fuse_util", + "//test/util:test_main", + "//test/util:test_util", + ], +) + +cc_binary( + name = "create_test", + testonly = 1, + srcs = ["create_test.cc"], + deps = [ + gtest, + ":fuse_base", + "//test/util:fs_util", + "//test/util:fuse_util", + "//test/util:temp_umask", + "//test/util:test_main", + "//test/util:test_util", + ], +) + +cc_binary( + name = "unlink_test", + testonly = 1, + srcs = ["unlink_test.cc"], + deps = [ + gtest, + ":fuse_base", + "//test/util:fuse_util", + "//test/util:temp_umask", + "//test/util:test_main", + "//test/util:test_util", + ], +) diff --git a/test/fuse/linux/create_test.cc b/test/fuse/linux/create_test.cc new file mode 100644 index 000000000..9a0219a58 --- /dev/null +++ b/test/fuse/linux/create_test.cc @@ -0,0 +1,128 @@ +// Copyright 2020 The gVisor Authors. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include <errno.h> +#include <fcntl.h> +#include <linux/fuse.h> +#include <sys/stat.h> +#include <sys/statfs.h> +#include <sys/types.h> +#include <unistd.h> + +#include <string> + +#include "gtest/gtest.h" +#include "test/fuse/linux/fuse_base.h" +#include "test/util/fs_util.h" +#include "test/util/fuse_util.h" +#include "test/util/temp_umask.h" +#include "test/util/test_util.h" + +namespace gvisor { +namespace testing { + +namespace { + +class CreateTest : public FuseTest { + protected: + const std::string test_file_name_ = "test_file"; + const mode_t mode = S_IFREG | S_IRWXU | S_IRWXG | S_IRWXO; +}; + +TEST_F(CreateTest, CreateFile) { + const std::string test_file_path = + JoinPath(mount_point_.path().c_str(), test_file_name_); + + // Ensure the file doesn't exist. + struct fuse_out_header out_header = { + .len = sizeof(struct fuse_out_header), + .error = -ENOENT, + }; + auto iov_out = FuseGenerateIovecs(out_header); + SetServerResponse(FUSE_LOOKUP, iov_out); + + // creat(2) is equal to open(2) with open_flags O_CREAT | O_WRONLY | O_TRUNC. + const mode_t new_mask = S_IWGRP | S_IWOTH; + const int open_flags = O_CREAT | O_WRONLY | O_TRUNC; + out_header.error = 0; + out_header.len = sizeof(struct fuse_out_header) + + sizeof(struct fuse_entry_out) + sizeof(struct fuse_open_out); + struct fuse_entry_out entry_payload = DefaultEntryOut(mode & ~new_mask, 2); + struct fuse_open_out out_payload = { + .fh = 1, + .open_flags = open_flags, + }; + iov_out = FuseGenerateIovecs(out_header, entry_payload, out_payload); + SetServerResponse(FUSE_CREATE, iov_out); + + // kernfs generates a successive FUSE_OPEN after the file is created. Linux's + // fuse kernel module will not send this FUSE_OPEN after creat(2). + out_header.len = + sizeof(struct fuse_out_header) + sizeof(struct fuse_open_out); + iov_out = FuseGenerateIovecs(out_header, out_payload); + SetServerResponse(FUSE_OPEN, iov_out); + + int fd; + TempUmask mask(new_mask); + EXPECT_THAT(fd = creat(test_file_path.c_str(), mode), SyscallSucceeds()); + EXPECT_THAT(fcntl(fd, F_GETFL), + SyscallSucceedsWithValue(open_flags & O_ACCMODE)); + + struct fuse_in_header in_header; + struct fuse_create_in in_payload; + std::vector<char> name(test_file_name_.size() + 1); + auto iov_in = FuseGenerateIovecs(in_header, in_payload, name); + + // Skip the request of FUSE_LOOKUP. + SkipServerActualRequest(); + + // Get the first FUSE_CREATE. + GetServerActualRequest(iov_in); + EXPECT_EQ(in_header.len, sizeof(in_header) + sizeof(in_payload) + + test_file_name_.size() + 1); + EXPECT_EQ(in_header.opcode, FUSE_CREATE); + EXPECT_EQ(in_payload.flags, open_flags); + EXPECT_EQ(in_payload.mode, mode & ~new_mask); + EXPECT_EQ(in_payload.umask, new_mask); + EXPECT_EQ(std::string(name.data()), test_file_name_); + + // Get the successive FUSE_OPEN. + struct fuse_open_in in_payload_open; + iov_in = FuseGenerateIovecs(in_header, in_payload_open); + GetServerActualRequest(iov_in); + EXPECT_EQ(in_header.len, sizeof(in_header) + sizeof(in_payload_open)); + EXPECT_EQ(in_header.opcode, FUSE_OPEN); + EXPECT_EQ(in_payload_open.flags, open_flags & O_ACCMODE); + + EXPECT_THAT(close(fd), SyscallSucceeds()); + // Skip the FUSE_RELEASE. + SkipServerActualRequest(); +} + +TEST_F(CreateTest, CreateFileAlreadyExists) { + const std::string test_file_path = + JoinPath(mount_point_.path().c_str(), test_file_name_); + + const int open_flags = O_CREAT | O_EXCL; + + SetServerInodeLookup(test_file_name_); + + EXPECT_THAT(open(test_file_path.c_str(), mode, open_flags), + SyscallFailsWithErrno(EEXIST)); +} + +} // namespace + +} // namespace testing +} // namespace gvisor diff --git a/test/fuse/linux/fuse_base.cc b/test/fuse/linux/fuse_base.cc index 9c3124472..5b45804e1 100644 --- a/test/fuse/linux/fuse_base.cc +++ b/test/fuse/linux/fuse_base.cc @@ -16,17 +16,17 @@ #include <fcntl.h> #include <linux/fuse.h> -#include <string.h> +#include <poll.h> #include <sys/mount.h> +#include <sys/socket.h> #include <sys/stat.h> #include <sys/types.h> #include <sys/uio.h> #include <unistd.h> -#include <iostream> - #include "gtest/gtest.h" #include "absl/strings/str_format.h" +#include "test/util/fuse_util.h" #include "test/util/posix_error.h" #include "test/util/temp_path.h" #include "test/util/test_util.h" @@ -39,49 +39,123 @@ void FuseTest::SetUp() { SetUpFuseServer(); } -void FuseTest::TearDown() { UnmountFuse(); } - -// Since CompareRequest is running in background thread, gTest assertions and -// expectations won't directly reflect the test result. However, the FUSE -// background server still connects to the same standard I/O as testing main -// thread. So EXPECT_XX can still be used to show different results. To -// ensure failed testing result is observable, return false and the result -// will be sent to test main thread via pipe. -bool FuseTest::CompareRequest(void* expected_mem, size_t expected_len, - void* real_mem, size_t real_len) { - if (expected_len != real_len) return false; - return memcmp(expected_mem, real_mem, expected_len) == 0; +void FuseTest::TearDown() { + EXPECT_EQ(GetServerNumUnconsumedRequests(), 0); + EXPECT_EQ(GetServerNumUnsentResponses(), 0); + UnmountFuse(); } -// SetExpected is called by the testing main thread to set expected request- -// response pair of a single FUSE operation. -void FuseTest::SetExpected(struct iovec* iov_in, int iov_in_cnt, - struct iovec* iov_out, int iov_out_cnt) { - EXPECT_THAT(RetryEINTR(writev)(set_expected_[1], iov_in, iov_in_cnt), - SyscallSucceedsWithValue(::testing::Gt(0))); - WaitCompleted(); +// Sends 3 parts of data to the FUSE server: +// 1. The `kSetResponse` command +// 2. The expected opcode +// 3. The fake FUSE response +// Then waits for the FUSE server to notify its completion. +void FuseTest::SetServerResponse(uint32_t opcode, + std::vector<struct iovec>& iovecs) { + uint32_t cmd = static_cast<uint32_t>(FuseTestCmd::kSetResponse); + EXPECT_THAT(RetryEINTR(write)(sock_[0], &cmd, sizeof(cmd)), + SyscallSucceedsWithValue(sizeof(cmd))); + + EXPECT_THAT(RetryEINTR(write)(sock_[0], &opcode, sizeof(opcode)), + SyscallSucceedsWithValue(sizeof(opcode))); - EXPECT_THAT(RetryEINTR(writev)(set_expected_[1], iov_out, iov_out_cnt), - SyscallSucceedsWithValue(::testing::Gt(0))); - WaitCompleted(); + EXPECT_THAT(RetryEINTR(writev)(sock_[0], iovecs.data(), iovecs.size()), + SyscallSucceeds()); + + WaitServerComplete(); } -// WaitCompleted waits for the FUSE server to finish its job and check if it +// Waits for the FUSE server to finish its blocking job and check if it // completes without errors. -void FuseTest::WaitCompleted() { - char success; - EXPECT_THAT(RetryEINTR(read)(done_[0], &success, sizeof(success)), - SyscallSucceedsWithValue(1)); +void FuseTest::WaitServerComplete() { + uint32_t success; + EXPECT_THAT(RetryEINTR(read)(sock_[0], &success, sizeof(success)), + SyscallSucceedsWithValue(sizeof(success))); + ASSERT_EQ(success, 1); +} + +// Sends the `kGetRequest` command to the FUSE server, then reads the next +// request into iovec struct. The order of calling this function should be +// the same as the one of SetServerResponse(). +void FuseTest::GetServerActualRequest(std::vector<struct iovec>& iovecs) { + uint32_t cmd = static_cast<uint32_t>(FuseTestCmd::kGetRequest); + EXPECT_THAT(RetryEINTR(write)(sock_[0], &cmd, sizeof(cmd)), + SyscallSucceedsWithValue(sizeof(cmd))); + + EXPECT_THAT(RetryEINTR(readv)(sock_[0], iovecs.data(), iovecs.size()), + SyscallSucceeds()); + + WaitServerComplete(); +} + +// Sends a FuseTestCmd command to the FUSE server, reads from the socket, and +// returns the corresponding data. +uint32_t FuseTest::GetServerData(uint32_t cmd) { + uint32_t data; + EXPECT_THAT(RetryEINTR(write)(sock_[0], &cmd, sizeof(cmd)), + SyscallSucceedsWithValue(sizeof(cmd))); + + EXPECT_THAT(RetryEINTR(read)(sock_[0], &data, sizeof(data)), + SyscallSucceedsWithValue(sizeof(data))); + + WaitServerComplete(); + return data; +} + +uint32_t FuseTest::GetServerNumUnconsumedRequests() { + return GetServerData( + static_cast<uint32_t>(FuseTestCmd::kGetNumUnconsumedRequests)); +} + +uint32_t FuseTest::GetServerNumUnsentResponses() { + return GetServerData( + static_cast<uint32_t>(FuseTestCmd::kGetNumUnsentResponses)); +} + +uint32_t FuseTest::GetServerTotalReceivedBytes() { + return GetServerData( + static_cast<uint32_t>(FuseTestCmd::kGetTotalReceivedBytes)); +} + +// Sends the `kSkipRequest` command to the FUSE server, which would skip +// current stored request data. +void FuseTest::SkipServerActualRequest() { + uint32_t cmd = static_cast<uint32_t>(FuseTestCmd::kSkipRequest); + EXPECT_THAT(RetryEINTR(write)(sock_[0], &cmd, sizeof(cmd)), + SyscallSucceedsWithValue(sizeof(cmd))); + + WaitServerComplete(); } -void FuseTest::MountFuse() { +// Sends the `kSetInodeLookup` command, expected mode, and the path of the +// inode to create under the mount point. +void FuseTest::SetServerInodeLookup(const std::string& path, mode_t mode, + uint64_t size) { + uint32_t cmd = static_cast<uint32_t>(FuseTestCmd::kSetInodeLookup); + EXPECT_THAT(RetryEINTR(write)(sock_[0], &cmd, sizeof(cmd)), + SyscallSucceedsWithValue(sizeof(cmd))); + + EXPECT_THAT(RetryEINTR(write)(sock_[0], &mode, sizeof(mode)), + SyscallSucceedsWithValue(sizeof(mode))); + + EXPECT_THAT(RetryEINTR(write)(sock_[0], &size, sizeof(size)), + SyscallSucceedsWithValue(sizeof(size))); + + // Pad 1 byte for null-terminate c-string. + EXPECT_THAT(RetryEINTR(write)(sock_[0], path.c_str(), path.size() + 1), + SyscallSucceedsWithValue(path.size() + 1)); + + WaitServerComplete(); +} + +void FuseTest::MountFuse(const char* mountOpts) { EXPECT_THAT(dev_fd_ = open("/dev/fuse", O_RDWR), SyscallSucceeds()); - std::string mount_opts = absl::StrFormat("fd=%d,%s", dev_fd_, kMountOpts); + std::string mount_opts = absl::StrFormat("fd=%d,%s", dev_fd_, mountOpts); mount_point_ = ASSERT_NO_ERRNO_AND_VALUE(TempPath::CreateDir()); EXPECT_THAT(mount("fuse", mount_point_.path().c_str(), "fuse", MS_NODEV | MS_NOSUID, mount_opts.c_str()), - SyscallSucceedsWithValue(0)); + SyscallSucceeds()); } void FuseTest::UnmountFuse() { @@ -89,13 +163,13 @@ void FuseTest::UnmountFuse() { // TODO(gvisor.dev/issue/3330): ensure the process is terminated successfully. } -// ConsumeFuseInit consumes the first FUSE request and returns the -// corresponding PosixError. -PosixError FuseTest::ConsumeFuseInit() { +// Consumes the first FUSE request and returns the corresponding PosixError. +PosixError FuseTest::ServerConsumeFuseInit( + const struct fuse_init_out* out_payload) { + std::vector<char> buf(FUSE_MIN_READ_BUFFER); RETURN_ERROR_IF_SYSCALL_FAIL( - RetryEINTR(read)(dev_fd_, buf_.data(), buf_.size())); + RetryEINTR(read)(dev_fd_, buf.data(), buf.size())); - struct iovec iov_out[2]; struct fuse_out_header out_header = { .len = sizeof(struct fuse_out_header) + sizeof(struct fuse_init_out), .error = 0, @@ -103,72 +177,74 @@ PosixError FuseTest::ConsumeFuseInit() { }; // Returns a fake fuse_init_out with 7.0 version to avoid ECONNREFUSED // error in the initialization of FUSE connection. - struct fuse_init_out out_payload = { - .major = 7, - }; - iov_out[0].iov_len = sizeof(out_header); - iov_out[0].iov_base = &out_header; - iov_out[1].iov_len = sizeof(out_payload); - iov_out[1].iov_base = &out_payload; + auto iov_out = FuseGenerateIovecs( + out_header, *const_cast<struct fuse_init_out*>(out_payload)); - RETURN_ERROR_IF_SYSCALL_FAIL(RetryEINTR(writev)(dev_fd_, iov_out, 2)); + RETURN_ERROR_IF_SYSCALL_FAIL( + RetryEINTR(writev)(dev_fd_, iov_out.data(), iov_out.size())); return NoError(); } -// ReceiveExpected reads 1 pair of expected fuse request-response `iovec`s -// from pipe and save them into member variables of this testing instance. -void FuseTest::ReceiveExpected() { - // Set expected fuse_in request. - EXPECT_THAT(len_in_ = RetryEINTR(read)(set_expected_[0], mem_in_.data(), - mem_in_.size()), - SyscallSucceedsWithValue(::testing::Gt(0))); - MarkDone(len_in_ > 0); +// Reads 1 expected opcode and a fake response from socket and save them into +// the serial buffer of this testing instance. +void FuseTest::ServerReceiveResponse() { + ssize_t len; + uint32_t opcode; + std::vector<char> buf(FUSE_MIN_READ_BUFFER); + EXPECT_THAT(RetryEINTR(read)(sock_[1], &opcode, sizeof(opcode)), + SyscallSucceedsWithValue(sizeof(opcode))); - // Set expected fuse_out response. - EXPECT_THAT(len_out_ = RetryEINTR(read)(set_expected_[0], mem_out_.data(), - mem_out_.size()), - SyscallSucceedsWithValue(::testing::Gt(0))); - MarkDone(len_out_ > 0); + EXPECT_THAT(len = RetryEINTR(read)(sock_[1], buf.data(), buf.size()), + SyscallSucceeds()); + + responses_.AddMemBlock(opcode, buf.data(), len); } -// MarkDone writes 1 byte of success indicator through pipe. -void FuseTest::MarkDone(bool success) { - char data = success ? 1 : 0; - EXPECT_THAT(RetryEINTR(write)(done_[1], &data, sizeof(data)), - SyscallSucceedsWithValue(1)); +// Writes 1 byte of success indicator through socket. +void FuseTest::ServerCompleteWith(bool success) { + uint32_t data = success ? 1 : 0; + ServerSendData(data); } -// FuseLoop is the implementation of the fake FUSE server. Read from /dev/fuse, -// compare the request by CompareRequest (use derived function if specified), -// and write the expected response to /dev/fuse. -void FuseTest::FuseLoop() { - bool success = true; - ssize_t len = 0; +// ServerFuseLoop is the implementation of the fake FUSE server. Monitors 2 +// file descriptors: /dev/fuse and sock_[1]. Events from /dev/fuse are FUSE +// requests and events from sock_[1] are FUSE testing commands, leading by +// a FuseTestCmd data to indicate the command. +void FuseTest::ServerFuseLoop() { + const int nfds = 2; + struct pollfd fds[nfds] = { + { + .fd = dev_fd_, + .events = POLL_IN | POLLHUP | POLLERR | POLLNVAL, + }, + { + .fd = sock_[1], + .events = POLL_IN | POLLHUP | POLLERR | POLLNVAL, + }, + }; + while (true) { - ReceiveExpected(); + ASSERT_THAT(poll(fds, nfds, -1), SyscallSucceeds()); - EXPECT_THAT(len = RetryEINTR(read)(dev_fd_, buf_.data(), buf_.size()), - SyscallSucceedsWithValue(len_in_)); - if (len != len_in_) success = false; + for (int fd_idx = 0; fd_idx < nfds; ++fd_idx) { + if (fds[fd_idx].revents == 0) continue; - if (!CompareRequest(buf_.data(), len_in_, mem_in_.data(), len_in_)) { - std::cerr << "the FUSE request is not expected" << std::endl; - success = false; + ASSERT_EQ(fds[fd_idx].revents, POLL_IN); + if (fds[fd_idx].fd == sock_[1]) { + ServerHandleCommand(); + } else if (fds[fd_idx].fd == dev_fd_) { + ServerProcessFuseRequest(); + } } - - EXPECT_THAT(len = RetryEINTR(write)(dev_fd_, mem_out_.data(), len_out_), - SyscallSucceedsWithValue(len_out_)); - if (len != len_out_) success = false; - MarkDone(success); } } -// SetUpFuseServer creates 2 pipes. First is for testing client to send the -// expected request-response pair, and the other acts as a checkpoint for the -// FUSE server to notify the client that it can proceed. -void FuseTest::SetUpFuseServer() { - ASSERT_THAT(pipe(set_expected_), SyscallSucceedsWithValue(0)); - ASSERT_THAT(pipe(done_), SyscallSucceedsWithValue(0)); +// SetUpFuseServer creates 1 socketpair and fork the process. The parent thread +// becomes testing thread and the child thread becomes the FUSE server running +// in background. These 2 threads are connected via socketpair. sock_[0] is +// opened in testing thread and sock_[1] is opened in the FUSE server. +void FuseTest::SetUpFuseServer(const struct fuse_init_out* payload) { + ASSERT_THAT(socketpair(AF_UNIX, SOCK_STREAM, 0, sock_), SyscallSucceeds()); switch (fork()) { case -1: @@ -177,31 +253,194 @@ void FuseTest::SetUpFuseServer() { case 0: break; default: - ASSERT_THAT(close(set_expected_[0]), SyscallSucceedsWithValue(0)); - ASSERT_THAT(close(done_[1]), SyscallSucceedsWithValue(0)); - WaitCompleted(); + ASSERT_THAT(close(sock_[1]), SyscallSucceeds()); + WaitServerComplete(); return; } - ASSERT_THAT(close(set_expected_[1]), SyscallSucceedsWithValue(0)); - ASSERT_THAT(close(done_[0]), SyscallSucceedsWithValue(0)); - - MarkDone(ConsumeFuseInit().ok()); - - FuseLoop(); + // Begin child thread, i.e. the FUSE server. + ASSERT_THAT(close(sock_[0]), SyscallSucceeds()); + ServerCompleteWith(ServerConsumeFuseInit(payload).ok()); + ServerFuseLoop(); _exit(0); } -// GetPayloadSize is a helper function to get the number of bytes of a -// specific FUSE operation struct. -size_t FuseTest::GetPayloadSize(uint32_t opcode, bool in) { - switch (opcode) { - case FUSE_INIT: - return in ? sizeof(struct fuse_init_in) : sizeof(struct fuse_init_out); +void FuseTest::ServerSendData(uint32_t data) { + EXPECT_THAT(RetryEINTR(write)(sock_[1], &data, sizeof(data)), + SyscallSucceedsWithValue(sizeof(data))); +} + +// Reads FuseTestCmd sent from testing thread and routes to correct handler. +// Since each command should be a blocking operation, a `ServerCompleteWith()` +// is required after the switch keyword. +void FuseTest::ServerHandleCommand() { + uint32_t cmd; + EXPECT_THAT(RetryEINTR(read)(sock_[1], &cmd, sizeof(cmd)), + SyscallSucceedsWithValue(sizeof(cmd))); + + switch (static_cast<FuseTestCmd>(cmd)) { + case FuseTestCmd::kSetResponse: + ServerReceiveResponse(); + break; + case FuseTestCmd::kSetInodeLookup: + ServerReceiveInodeLookup(); + break; + case FuseTestCmd::kGetRequest: + ServerSendReceivedRequest(); + break; + case FuseTestCmd::kGetTotalReceivedBytes: + ServerSendData(static_cast<uint32_t>(requests_.UsedBytes())); + break; + case FuseTestCmd::kGetNumUnconsumedRequests: + ServerSendData(static_cast<uint32_t>(requests_.RemainingBlocks())); + break; + case FuseTestCmd::kGetNumUnsentResponses: + ServerSendData(static_cast<uint32_t>(responses_.RemainingBlocks())); + break; + case FuseTestCmd::kSkipRequest: + ServerSkipReceivedRequest(); + break; default: + FAIL() << "Unknown FuseTestCmd " << cmd; break; } - return 0; + + ServerCompleteWith(!HasFailure()); +} + +// Reads the expected file mode and the path of one file. Crafts a basic +// `fuse_entry_out` memory block and inserts into a map for future use. +// The FUSE server will always return this response if a FUSE_LOOKUP +// request with this specific path comes in. +void FuseTest::ServerReceiveInodeLookup() { + mode_t mode; + uint64_t size; + std::vector<char> buf(FUSE_MIN_READ_BUFFER); + + EXPECT_THAT(RetryEINTR(read)(sock_[1], &mode, sizeof(mode)), + SyscallSucceedsWithValue(sizeof(mode))); + + EXPECT_THAT(RetryEINTR(read)(sock_[1], &size, sizeof(size)), + SyscallSucceedsWithValue(sizeof(size))); + + EXPECT_THAT(RetryEINTR(read)(sock_[1], buf.data(), buf.size()), + SyscallSucceeds()); + + std::string path(buf.data()); + + uint32_t out_len = + sizeof(struct fuse_out_header) + sizeof(struct fuse_entry_out); + struct fuse_out_header out_header = { + .len = out_len, + .error = 0, + }; + struct fuse_entry_out out_payload = DefaultEntryOut(mode, nodeid_); + // Since this is only used in test, nodeid_ is simply increased by 1 to + // comply with the unqiueness of different path. + ++nodeid_; + + // Set the size. + out_payload.attr.size = size; + + memcpy(buf.data(), &out_header, sizeof(out_header)); + memcpy(buf.data() + sizeof(out_header), &out_payload, sizeof(out_payload)); + lookups_.AddMemBlock(FUSE_LOOKUP, buf.data(), out_len); + lookup_map_[path] = lookups_.Next(); +} + +// Sends the received request pointed by current cursor and advances cursor. +void FuseTest::ServerSendReceivedRequest() { + if (requests_.End()) { + FAIL() << "No more received request."; + return; + } + auto mem_block = requests_.Next(); + EXPECT_THAT( + RetryEINTR(write)(sock_[1], requests_.DataAtOffset(mem_block.offset), + mem_block.len), + SyscallSucceedsWithValue(mem_block.len)); +} + +// Skip the request pointed by current cursor. +void FuseTest::ServerSkipReceivedRequest() { + if (requests_.End()) { + FAIL() << "No more received request."; + return; + } + requests_.Next(); +} + +// Handles FUSE request. Reads request from /dev/fuse, checks if it has the +// same opcode as expected, and responds with the saved fake FUSE response. +// The FUSE request is copied to the serial buffer and can be retrieved one- +// by-one by calling GetServerActualRequest from testing thread. +void FuseTest::ServerProcessFuseRequest() { + ssize_t len; + std::vector<char> buf(FUSE_MIN_READ_BUFFER); + + // Read FUSE request. + EXPECT_THAT(len = RetryEINTR(read)(dev_fd_, buf.data(), buf.size()), + SyscallSucceeds()); + fuse_in_header* in_header = reinterpret_cast<fuse_in_header*>(buf.data()); + + // Check if this is a preset FUSE_LOOKUP path. + if (in_header->opcode == FUSE_LOOKUP) { + std::string path(buf.data() + sizeof(struct fuse_in_header)); + auto it = lookup_map_.find(path); + if (it != lookup_map_.end()) { + // Matches a preset path. Reply with fake data and skip saving the + // request. + ServerRespondFuseSuccess(lookups_, it->second, in_header->unique); + return; + } + } + + requests_.AddMemBlock(in_header->opcode, buf.data(), len); + + if (in_header->opcode == FUSE_RELEASE || in_header->opcode == FUSE_RELEASEDIR) + return; + // Check if there is a corresponding response. + if (responses_.End()) { + GTEST_NONFATAL_FAILURE_("No more FUSE response is expected"); + ServerRespondFuseError(in_header->unique); + return; + } + auto mem_block = responses_.Next(); + if (in_header->opcode != mem_block.opcode) { + std::string message = absl::StrFormat("Expect opcode %d but got %d", + mem_block.opcode, in_header->opcode); + GTEST_NONFATAL_FAILURE_(message.c_str()); + // We won't get correct response if opcode is not expected. Send error + // response here to avoid wrong parsing by VFS. + ServerRespondFuseError(in_header->unique); + return; + } + + // Write FUSE response. + ServerRespondFuseSuccess(responses_, mem_block, in_header->unique); +} + +void FuseTest::ServerRespondFuseSuccess(FuseMemBuffer& mem_buf, + const FuseMemBlock& block, + uint64_t unique) { + fuse_out_header* out_header = + reinterpret_cast<fuse_out_header*>(mem_buf.DataAtOffset(block.offset)); + + // Patch `unique` in fuse_out_header to avoid EINVAL caused by responding + // with an unknown `unique`. + out_header->unique = unique; + EXPECT_THAT(RetryEINTR(write)(dev_fd_, out_header, block.len), + SyscallSucceedsWithValue(block.len)); +} + +void FuseTest::ServerRespondFuseError(uint64_t unique) { + fuse_out_header out_header = { + .len = sizeof(struct fuse_out_header), + .error = ENOSYS, + .unique = unique, + }; + EXPECT_THAT(RetryEINTR(write)(dev_fd_, &out_header, sizeof(out_header)), + SyscallSucceedsWithValue(sizeof(out_header))); } } // namespace testing diff --git a/test/fuse/linux/fuse_base.h b/test/fuse/linux/fuse_base.h index 3a2f255a9..6ad296ca2 100644 --- a/test/fuse/linux/fuse_base.h +++ b/test/fuse/linux/fuse_base.h @@ -16,8 +16,12 @@ #define GVISOR_TEST_FUSE_FUSE_BASE_H_ #include <linux/fuse.h> +#include <string.h> +#include <sys/stat.h> #include <sys/uio.h> +#include <iostream> +#include <unordered_map> #include <vector> #include "gtest/gtest.h" @@ -29,68 +33,216 @@ namespace testing { constexpr char kMountOpts[] = "rootmode=755,user_id=0,group_id=0"; -class FuseTest : public ::testing::Test { +constexpr struct fuse_init_out kDefaultFUSEInitOutPayload = {.major = 7}; + +// Internal commands used to communicate between testing thread and the FUSE +// server. See test/fuse/README.md for further detail. +enum class FuseTestCmd { + kSetResponse = 0, + kSetInodeLookup, + kGetRequest, + kGetNumUnconsumedRequests, + kGetNumUnsentResponses, + kGetTotalReceivedBytes, + kSkipRequest, +}; + +// Holds the information of a memory block in a serial buffer. +struct FuseMemBlock { + uint32_t opcode; + size_t offset; + size_t len; +}; + +// A wrapper of a simple serial buffer that can be used with read(2) and +// write(2). Contains a cursor to indicate accessing. This class is not thread- +// safe and can only be used in single-thread version. +class FuseMemBuffer { public: - FuseTest() { - buf_.resize(FUSE_MIN_READ_BUFFER); - mem_in_.resize(FUSE_MIN_READ_BUFFER); - mem_out_.resize(FUSE_MIN_READ_BUFFER); + FuseMemBuffer() : cursor_(0) { + // To read from /dev/fuse, a buffer needs at least FUSE_MIN_READ_BUFFER + // bytes to avoid EINVAL. FuseMemBuffer holds memory that can accommodate + // a sequence of FUSE request/response, so it is initiated with double + // minimal requirement. + mem_.resize(FUSE_MIN_READ_BUFFER * 2); } - void SetUp() override; - void TearDown() override; - // CompareRequest is used by the FUSE server and should be implemented to - // compare different FUSE operations. It compares the actual FUSE input - // request with the expected one set by `SetExpected()`. - virtual bool CompareRequest(void* expected_mem, size_t expected_len, - void* real_mem, size_t real_len); + // Returns whether there is no memory block. + bool Empty() { return blocks_.empty(); } + + // Returns if there is no more remaining memory blocks. + bool End() { return cursor_ == blocks_.size(); } + + // Returns how many bytes that have been received. + size_t UsedBytes() { + return Empty() ? 0 : blocks_.back().offset + blocks_.back().len; + } + + // Returns the available bytes remains in the serial buffer. + size_t AvailBytes() { return mem_.size() - UsedBytes(); } + + // Appends a memory block information that starts at the tail of the serial + // buffer. /dev/fuse requires at least FUSE_MIN_READ_BUFFER bytes to read, or + // it will issue EINVAL. If it is not enough, just double the buffer length. + void AddMemBlock(uint32_t opcode, void* data, size_t len) { + if (AvailBytes() < FUSE_MIN_READ_BUFFER) { + mem_.resize(mem_.size() << 1); + } + size_t offset = UsedBytes(); + memcpy(mem_.data() + offset, data, len); + blocks_.push_back(FuseMemBlock{opcode, offset, len}); + } + + // Returns the memory address at a specific offset. Used with read(2) or + // write(2). + char* DataAtOffset(size_t offset) { return mem_.data() + offset; } + + // Returns current memory block pointed by the cursor and increase by 1. + FuseMemBlock Next() { + if (End()) { + std::cerr << "Buffer is already exhausted." << std::endl; + return FuseMemBlock{}; + } + return blocks_[cursor_++]; + } + + // Returns the number of the blocks that has not been requested. + size_t RemainingBlocks() { return blocks_.size() - cursor_; } + + private: + size_t cursor_; + std::vector<FuseMemBlock> blocks_; + std::vector<char> mem_; +}; - // SetExpected is called by the testing main thread. Writes a request- - // response pair into FUSE server's member variables via pipe. - void SetExpected(struct iovec* iov_in, int iov_in_cnt, struct iovec* iov_out, - int iov_out_cnt); +// FuseTest base class is useful in FUSE integration test. Inherit this class +// to automatically set up a fake FUSE server and use the member functions +// to manipulate with it. Refer to test/fuse/README.md for detailed explanation. +class FuseTest : public ::testing::Test { + public: + // nodeid_ is the ID of a fake inode. We starts from 2 since 1 is occupied by + // the mount point. + FuseTest() : nodeid_(2) {} + void SetUp() override; + void TearDown() override; - // WaitCompleted waits for FUSE server to complete its processing. It - // complains if the FUSE server responds failure during tests. - void WaitCompleted(); + // Called by the testing thread to set up a fake response for an expected + // opcode via socket. This can be used multiple times to define a sequence of + // expected FUSE reactions. + void SetServerResponse(uint32_t opcode, std::vector<struct iovec>& iovecs); + + // Called by the testing thread to install a fake path under the mount point. + // e.g. a file under /mnt/dir/file and moint point is /mnt, then it will look + // up "dir/file" in this case. + // + // It sets a fixed response to the FUSE_LOOKUP requests issued with this + // path, pretending there is an inode and avoid ENOENT when testing. If mode + // is not given, it creates a regular file with mode 0600. + void SetServerInodeLookup(const std::string& path, + mode_t mode = S_IFREG | S_IRUSR | S_IWUSR, + uint64_t size = 512); + + // Called by the testing thread to ask the FUSE server for its next received + // FUSE request. Be sure to use the corresponding struct of iovec to receive + // data from server. + void GetServerActualRequest(std::vector<struct iovec>& iovecs); + + // Called by the testing thread to query the number of unconsumed requests in + // the requests_ serial buffer of the FUSE server. TearDown() ensures all + // FUSE requests received by the FUSE server were consumed by the testing + // thread. + uint32_t GetServerNumUnconsumedRequests(); + + // Called by the testing thread to query the number of unsent responses in + // the responses_ serial buffer of the FUSE server. TearDown() ensures all + // preset FUSE responses were sent out by the FUSE server. + uint32_t GetServerNumUnsentResponses(); + + // Called by the testing thread to ask the FUSE server for its total received + // bytes from /dev/fuse. + uint32_t GetServerTotalReceivedBytes(); + + // Called by the testing thread to ask the FUSE server to skip stored + // request data. + void SkipServerActualRequest(); protected: TempPath mount_point_; - private: - void MountFuse(); + // Opens /dev/fuse and inherit the file descriptor for the FUSE server. + void MountFuse(const char* mountOpts = kMountOpts); + + // Creates a socketpair for communication and forks FUSE server. + void SetUpFuseServer( + const struct fuse_init_out* payload = &kDefaultFUSEInitOutPayload); + + // Unmounts the mountpoint of the FUSE server. void UnmountFuse(); - // ConsumeFuseInit is only used during FUSE server setup. - PosixError ConsumeFuseInit(); + private: + // Sends a FuseTestCmd and gets a uint32_t data from the FUSE server. + inline uint32_t GetServerData(uint32_t cmd); + + // Waits for FUSE server to complete its processing. Complains if the FUSE + // server responds any failure during tests. + void WaitServerComplete(); - // ReceiveExpected is the FUSE server side's corresponding code of - // `SetExpected()`. Save the request-response pair into its memory. - void ReceiveExpected(); + // The FUSE server stays here and waits next command or FUSE request until it + // is terminated. + void ServerFuseLoop(); - // MarkDone is used by the FUSE server to tell testing main if it's OK to - // proceed next command. - void MarkDone(bool success); + // Used by the FUSE server to tell testing thread if it is OK to proceed next + // command. Will be issued after processing each FuseTestCmd. + void ServerCompleteWith(bool success); - // FuseLoop is where the FUSE server stay until it is terminated. - void FuseLoop(); + // Consumes the first FUSE request when mounting FUSE. Replies with a + // response with empty payload. + PosixError ServerConsumeFuseInit(const struct fuse_init_out* payload); - // SetUpFuseServer creates 2 pipes for communication and forks FUSE server. - void SetUpFuseServer(); + // A command switch that dispatch different FuseTestCmd to its handler. + void ServerHandleCommand(); - // GetPayloadSize is a helper function to get the number of bytes of a - // specific FUSE operation struct. - size_t GetPayloadSize(uint32_t opcode, bool in); + // The FUSE server side's corresponding code of `SetServerResponse()`. + // Handles `kSetResponse` command. Saves the fake response into its output + // memory queue. + void ServerReceiveResponse(); + + // The FUSE server side's corresponding code of `SetServerInodeLookup()`. + // Handles `kSetInodeLookup` command. Receives an expected file mode and + // file path under the mount point. + void ServerReceiveInodeLookup(); + + // The FUSE server side's corresponding code of `GetServerActualRequest()`. + // Handles `kGetRequest` command. Sends the next received request pointed by + // the cursor. + void ServerSendReceivedRequest(); + + // Sends a uint32_t data via socket. + inline void ServerSendData(uint32_t data); + + // The FUSE server side's corresponding code of `SkipServerActualRequest()`. + // Handles `kSkipRequest` command. Skip the request pointed by current cursor. + void ServerSkipReceivedRequest(); + + // Handles FUSE request sent to /dev/fuse by its saved responses. + void ServerProcessFuseRequest(); + + // Responds to FUSE request with a saved data. + void ServerRespondFuseSuccess(FuseMemBuffer& mem_buf, + const FuseMemBlock& block, uint64_t unique); + + // Responds an error header to /dev/fuse when bad thing happens. + void ServerRespondFuseError(uint64_t unique); int dev_fd_; - int set_expected_[2]; - int done_[2]; - - std::vector<char> buf_; - std::vector<char> mem_in_; - std::vector<char> mem_out_; - ssize_t len_in_; - ssize_t len_out_; + int sock_[2]; + + uint64_t nodeid_; + std::unordered_map<std::string, FuseMemBlock> lookup_map_; + + FuseMemBuffer requests_; + FuseMemBuffer responses_; + FuseMemBuffer lookups_; }; } // namespace testing diff --git a/test/fuse/linux/fuse_fd_util.cc b/test/fuse/linux/fuse_fd_util.cc new file mode 100644 index 000000000..30d1157bb --- /dev/null +++ b/test/fuse/linux/fuse_fd_util.cc @@ -0,0 +1,61 @@ +// Copyright 2020 The gVisor Authors. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "test/fuse/linux/fuse_fd_util.h" + +#include <fcntl.h> +#include <linux/fuse.h> +#include <sys/stat.h> +#include <sys/types.h> +#include <sys/uio.h> + +#include <string> +#include <vector> + +#include "test/util/cleanup.h" +#include "test/util/file_descriptor.h" +#include "test/util/fuse_util.h" +#include "test/util/posix_error.h" + +namespace gvisor { +namespace testing { + +PosixErrorOr<FileDescriptor> FuseFdTest::OpenPath(const std::string &path, + uint32_t flags, uint64_t fh) { + struct fuse_out_header out_header = { + .len = sizeof(struct fuse_out_header) + sizeof(struct fuse_open_out), + }; + struct fuse_open_out out_payload = { + .fh = fh, + .open_flags = flags, + }; + auto iov_out = FuseGenerateIovecs(out_header, out_payload); + SetServerResponse(FUSE_OPEN, iov_out); + + auto res = Open(path.c_str(), flags); + if (res.ok()) { + SkipServerActualRequest(); + } + return res; +} + +Cleanup FuseFdTest::CloseFD(FileDescriptor &fd) { + return Cleanup([&] { + close(fd.release()); + SkipServerActualRequest(); + }); +} + +} // namespace testing +} // namespace gvisor diff --git a/test/fuse/linux/fuse_fd_util.h b/test/fuse/linux/fuse_fd_util.h new file mode 100644 index 000000000..066185c94 --- /dev/null +++ b/test/fuse/linux/fuse_fd_util.h @@ -0,0 +1,48 @@ +// Copyright 2020 The gVisor Authors. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#ifndef GVISOR_TEST_FUSE_FUSE_FD_UTIL_H_ +#define GVISOR_TEST_FUSE_FUSE_FD_UTIL_H_ + +#include <fcntl.h> +#include <sys/stat.h> +#include <sys/types.h> + +#include <string> + +#include "test/fuse/linux/fuse_base.h" +#include "test/util/cleanup.h" +#include "test/util/file_descriptor.h" +#include "test/util/posix_error.h" + +namespace gvisor { +namespace testing { + +class FuseFdTest : public FuseTest { + public: + // Sets the FUSE server to respond to a FUSE_OPEN with corresponding flags and + // fh. Then does a real file system open on the absolute path to get an fd. + PosixErrorOr<FileDescriptor> OpenPath(const std::string &path, + uint32_t flags = O_RDONLY, + uint64_t fh = 1); + + // Returns a cleanup object that closes the fd when it is destroyed. After + // the close is done, tells the FUSE server to skip this FUSE_RELEASE. + Cleanup CloseFD(FileDescriptor &fd); +}; + +} // namespace testing +} // namespace gvisor + +#endif // GVISOR_TEST_FUSE_FUSE_FD_UTIL_H_ diff --git a/test/fuse/linux/mkdir_test.cc b/test/fuse/linux/mkdir_test.cc new file mode 100644 index 000000000..9647cb93f --- /dev/null +++ b/test/fuse/linux/mkdir_test.cc @@ -0,0 +1,88 @@ +// Copyright 2020 The gVisor Authors. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include <errno.h> +#include <fcntl.h> +#include <linux/fuse.h> +#include <sys/stat.h> +#include <sys/statfs.h> +#include <sys/types.h> +#include <unistd.h> + +#include <string> +#include <vector> + +#include "gtest/gtest.h" +#include "test/fuse/linux/fuse_base.h" +#include "test/util/fuse_util.h" +#include "test/util/temp_umask.h" +#include "test/util/test_util.h" + +namespace gvisor { +namespace testing { + +namespace { + +class MkdirTest : public FuseTest { + protected: + const std::string test_dir_ = "test_dir"; + const mode_t perms_ = S_IRWXU | S_IRWXG | S_IRWXO; +}; + +TEST_F(MkdirTest, CreateDir) { + const std::string test_dir_path_ = + JoinPath(mount_point_.path().c_str(), test_dir_); + const mode_t new_umask = 0077; + + struct fuse_out_header out_header = { + .len = sizeof(struct fuse_out_header) + sizeof(struct fuse_entry_out), + }; + struct fuse_entry_out out_payload = DefaultEntryOut(S_IFDIR | perms_, 5); + auto iov_out = FuseGenerateIovecs(out_header, out_payload); + SetServerResponse(FUSE_MKDIR, iov_out); + TempUmask mask(new_umask); + ASSERT_THAT(mkdir(test_dir_path_.c_str(), 0777), SyscallSucceeds()); + + struct fuse_in_header in_header; + struct fuse_mkdir_in in_payload; + std::vector<char> actual_dir(test_dir_.length() + 1); + auto iov_in = FuseGenerateIovecs(in_header, in_payload, actual_dir); + GetServerActualRequest(iov_in); + + EXPECT_EQ(in_header.len, + sizeof(in_header) + sizeof(in_payload) + test_dir_.length() + 1); + EXPECT_EQ(in_header.opcode, FUSE_MKDIR); + EXPECT_EQ(in_payload.mode & 0777, perms_ & ~new_umask); + EXPECT_EQ(in_payload.umask, new_umask); + EXPECT_EQ(std::string(actual_dir.data()), test_dir_); +} + +TEST_F(MkdirTest, FileTypeError) { + const std::string test_dir_path_ = + JoinPath(mount_point_.path().c_str(), test_dir_); + + struct fuse_out_header out_header = { + .len = sizeof(struct fuse_out_header) + sizeof(struct fuse_entry_out), + }; + struct fuse_entry_out out_payload = DefaultEntryOut(S_IFREG | perms_, 5); + auto iov_out = FuseGenerateIovecs(out_header, out_payload); + SetServerResponse(FUSE_MKDIR, iov_out); + ASSERT_THAT(mkdir(test_dir_path_.c_str(), 0777), SyscallFailsWithErrno(EIO)); + SkipServerActualRequest(); +} + +} // namespace + +} // namespace testing +} // namespace gvisor diff --git a/test/fuse/linux/mknod_test.cc b/test/fuse/linux/mknod_test.cc new file mode 100644 index 000000000..74c74d76b --- /dev/null +++ b/test/fuse/linux/mknod_test.cc @@ -0,0 +1,107 @@ +// Copyright 2020 The gVisor Authors. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include <errno.h> +#include <fcntl.h> +#include <linux/fuse.h> +#include <sys/stat.h> +#include <sys/statfs.h> +#include <sys/types.h> +#include <unistd.h> + +#include <string> +#include <vector> + +#include "gtest/gtest.h" +#include "test/fuse/linux/fuse_base.h" +#include "test/util/fuse_util.h" +#include "test/util/temp_umask.h" +#include "test/util/test_util.h" + +namespace gvisor { +namespace testing { + +namespace { + +class MknodTest : public FuseTest { + protected: + const std::string test_file_ = "test_file"; + const mode_t perms_ = S_IRWXU | S_IRWXG | S_IRWXO; +}; + +TEST_F(MknodTest, RegularFile) { + const std::string test_file_path = + JoinPath(mount_point_.path().c_str(), test_file_); + const mode_t new_umask = 0077; + + struct fuse_out_header out_header = { + .len = sizeof(struct fuse_out_header) + sizeof(struct fuse_entry_out), + }; + struct fuse_entry_out out_payload = DefaultEntryOut(S_IFREG | perms_, 5); + auto iov_out = FuseGenerateIovecs(out_header, out_payload); + SetServerResponse(FUSE_MKNOD, iov_out); + TempUmask mask(new_umask); + ASSERT_THAT(mknod(test_file_path.c_str(), perms_, 0), SyscallSucceeds()); + + struct fuse_in_header in_header; + struct fuse_mknod_in in_payload; + std::vector<char> actual_file(test_file_.length() + 1); + auto iov_in = FuseGenerateIovecs(in_header, in_payload, actual_file); + GetServerActualRequest(iov_in); + + EXPECT_EQ(in_header.len, + sizeof(in_header) + sizeof(in_payload) + test_file_.length() + 1); + EXPECT_EQ(in_header.opcode, FUSE_MKNOD); + EXPECT_EQ(in_payload.mode & 0777, perms_ & ~new_umask); + EXPECT_EQ(in_payload.umask, new_umask); + EXPECT_EQ(in_payload.rdev, 0); + EXPECT_EQ(std::string(actual_file.data()), test_file_); +} + +TEST_F(MknodTest, FileTypeError) { + const std::string test_file_path = + JoinPath(mount_point_.path().c_str(), test_file_); + + struct fuse_out_header out_header = { + .len = sizeof(struct fuse_out_header) + sizeof(struct fuse_entry_out), + }; + // server return directory instead of regular file should cause an error. + struct fuse_entry_out out_payload = DefaultEntryOut(S_IFDIR | perms_, 5); + auto iov_out = FuseGenerateIovecs(out_header, out_payload); + SetServerResponse(FUSE_MKNOD, iov_out); + ASSERT_THAT(mknod(test_file_path.c_str(), perms_, 0), + SyscallFailsWithErrno(EIO)); + SkipServerActualRequest(); +} + +TEST_F(MknodTest, NodeIDError) { + const std::string test_file_path = + JoinPath(mount_point_.path().c_str(), test_file_); + + struct fuse_out_header out_header = { + .len = sizeof(struct fuse_out_header) + sizeof(struct fuse_entry_out), + }; + struct fuse_entry_out out_payload = + DefaultEntryOut(S_IFREG | perms_, FUSE_ROOT_ID); + auto iov_out = FuseGenerateIovecs(out_header, out_payload); + SetServerResponse(FUSE_MKNOD, iov_out); + ASSERT_THAT(mknod(test_file_path.c_str(), perms_, 0), + SyscallFailsWithErrno(EIO)); + SkipServerActualRequest(); +} + +} // namespace + +} // namespace testing +} // namespace gvisor diff --git a/test/fuse/linux/open_test.cc b/test/fuse/linux/open_test.cc new file mode 100644 index 000000000..4b0c4a805 --- /dev/null +++ b/test/fuse/linux/open_test.cc @@ -0,0 +1,128 @@ +// Copyright 2020 The gVisor Authors. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include <errno.h> +#include <fcntl.h> +#include <linux/fuse.h> +#include <sys/stat.h> +#include <sys/statfs.h> +#include <sys/types.h> +#include <unistd.h> + +#include <string> + +#include "gtest/gtest.h" +#include "test/fuse/linux/fuse_base.h" +#include "test/util/fuse_util.h" +#include "test/util/test_util.h" + +namespace gvisor { +namespace testing { + +namespace { + +class OpenTest : public FuseTest { + // OpenTest doesn't care the release request when close a fd, + // so doesn't check leftover requests when tearing down. + void TearDown() { UnmountFuse(); } + + protected: + const std::string test_file_ = "test_file"; + const mode_t regular_file_ = S_IFREG | S_IRWXU | S_IRWXG | S_IRWXO; + + struct fuse_open_out out_payload_ = { + .fh = 1, + .open_flags = O_RDWR, + }; +}; + +TEST_F(OpenTest, RegularFile) { + const std::string test_file_path = + JoinPath(mount_point_.path().c_str(), test_file_); + SetServerInodeLookup(test_file_, regular_file_); + + struct fuse_out_header out_header = { + .len = sizeof(struct fuse_out_header) + sizeof(struct fuse_open_out), + }; + auto iov_out = FuseGenerateIovecs(out_header, out_payload_); + SetServerResponse(FUSE_OPEN, iov_out); + FileDescriptor fd = + ASSERT_NO_ERRNO_AND_VALUE(Open(test_file_path.c_str(), O_RDWR)); + + struct fuse_in_header in_header; + struct fuse_open_in in_payload; + auto iov_in = FuseGenerateIovecs(in_header, in_payload); + GetServerActualRequest(iov_in); + + EXPECT_EQ(in_header.len, sizeof(in_header) + sizeof(in_payload)); + EXPECT_EQ(in_header.opcode, FUSE_OPEN); + EXPECT_EQ(in_payload.flags, O_RDWR); + EXPECT_THAT(fcntl(fd.get(), F_GETFL), SyscallSucceedsWithValue(O_RDWR)); +} + +TEST_F(OpenTest, SetNoOpen) { + const std::string test_file_path = + JoinPath(mount_point_.path().c_str(), test_file_); + SetServerInodeLookup(test_file_, regular_file_); + + // ENOSYS indicates open is not implemented. + struct fuse_out_header out_header = { + .len = sizeof(struct fuse_out_header) + sizeof(struct fuse_open_out), + .error = -ENOSYS, + }; + auto iov_out = FuseGenerateIovecs(out_header, out_payload_); + SetServerResponse(FUSE_OPEN, iov_out); + ASSERT_NO_ERRNO_AND_VALUE(Open(test_file_path.c_str(), O_RDWR)); + SkipServerActualRequest(); + + // check open doesn't send new request. + uint32_t recieved_before = GetServerTotalReceivedBytes(); + ASSERT_NO_ERRNO_AND_VALUE(Open(test_file_path.c_str(), O_RDWR)); + EXPECT_EQ(GetServerTotalReceivedBytes(), recieved_before); +} + +TEST_F(OpenTest, OpenFail) { + struct fuse_out_header out_header = { + .len = sizeof(struct fuse_out_header) + sizeof(struct fuse_open_out), + .error = -ENOENT, + }; + + auto iov_out = FuseGenerateIovecs(out_header, out_payload_); + SetServerResponse(FUSE_OPENDIR, iov_out); + ASSERT_THAT(open(mount_point_.path().c_str(), O_RDWR), + SyscallFailsWithErrno(ENOENT)); + + struct fuse_in_header in_header; + struct fuse_open_in in_payload; + auto iov_in = FuseGenerateIovecs(in_header, in_payload); + GetServerActualRequest(iov_in); + + EXPECT_EQ(in_header.len, sizeof(in_header) + sizeof(in_payload)); + EXPECT_EQ(in_header.opcode, FUSE_OPENDIR); + EXPECT_EQ(in_payload.flags, O_RDWR); +} + +TEST_F(OpenTest, DirectoryFlagOnRegularFile) { + const std::string test_file_path = + JoinPath(mount_point_.path().c_str(), test_file_); + + SetServerInodeLookup(test_file_, regular_file_); + ASSERT_THAT(open(test_file_path.c_str(), O_RDWR | O_DIRECTORY), + SyscallFailsWithErrno(ENOTDIR)); +} + +} // namespace + +} // namespace testing +} // namespace gvisor diff --git a/test/fuse/linux/read_test.cc b/test/fuse/linux/read_test.cc new file mode 100644 index 000000000..88fc299d8 --- /dev/null +++ b/test/fuse/linux/read_test.cc @@ -0,0 +1,390 @@ +// Copyright 2020 The gVisor Authors. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include <errno.h> +#include <fcntl.h> +#include <linux/fuse.h> +#include <sys/stat.h> +#include <sys/statfs.h> +#include <sys/types.h> +#include <unistd.h> + +#include <string> +#include <vector> + +#include "gtest/gtest.h" +#include "test/fuse/linux/fuse_base.h" +#include "test/util/fuse_util.h" +#include "test/util/test_util.h" + +namespace gvisor { +namespace testing { + +namespace { + +class ReadTest : public FuseTest { + void SetUp() override { + FuseTest::SetUp(); + test_file_path_ = JoinPath(mount_point_.path().c_str(), test_file_); + } + + // TearDown overrides the parent's function + // to skip checking the unconsumed release request at the end. + void TearDown() override { UnmountFuse(); } + + protected: + const std::string test_file_ = "test_file"; + const mode_t test_file_mode_ = S_IFREG | S_IRWXU | S_IRWXG | S_IRWXO; + const uint64_t test_fh_ = 1; + const uint32_t open_flag_ = O_RDWR; + + std::string test_file_path_; + + PosixErrorOr<FileDescriptor> OpenTestFile(const std::string &path, + uint64_t size = 512) { + SetServerInodeLookup(test_file_, test_file_mode_, size); + + struct fuse_out_header out_header_open = { + .len = sizeof(struct fuse_out_header) + sizeof(struct fuse_open_out), + }; + struct fuse_open_out out_payload_open = { + .fh = test_fh_, + .open_flags = open_flag_, + }; + auto iov_out_open = FuseGenerateIovecs(out_header_open, out_payload_open); + SetServerResponse(FUSE_OPEN, iov_out_open); + + auto res = Open(path.c_str(), open_flag_); + if (res.ok()) { + SkipServerActualRequest(); + } + return res; + } +}; + +class ReadTestSmallMaxRead : public ReadTest { + void SetUp() override { + MountFuse(mountOpts); + SetUpFuseServer(); + test_file_path_ = JoinPath(mount_point_.path().c_str(), test_file_); + } + + protected: + constexpr static char mountOpts[] = + "rootmode=755,user_id=0,group_id=0,max_read=4096"; + // 4096 is hard-coded as the max_read in mount options. + const int size_fragment = 4096; +}; + +TEST_F(ReadTest, ReadWhole) { + auto fd = ASSERT_NO_ERRNO_AND_VALUE(OpenTestFile(test_file_path_)); + + // Prepare for the read. + const int n_read = 5; + std::vector<char> data(n_read); + RandomizeBuffer(data.data(), data.size()); + struct fuse_out_header out_header_read = { + .len = + static_cast<uint32_t>(sizeof(struct fuse_out_header) + data.size()), + }; + auto iov_out_read = FuseGenerateIovecs(out_header_read, data); + SetServerResponse(FUSE_READ, iov_out_read); + + // Read the whole "file". + std::vector<char> buf(n_read); + EXPECT_THAT(read(fd.get(), buf.data(), n_read), + SyscallSucceedsWithValue(n_read)); + + // Check the read request. + struct fuse_in_header in_header_read; + struct fuse_read_in in_payload_read; + auto iov_in = FuseGenerateIovecs(in_header_read, in_payload_read); + GetServerActualRequest(iov_in); + + EXPECT_EQ(in_payload_read.fh, test_fh_); + EXPECT_EQ(in_header_read.len, + sizeof(in_header_read) + sizeof(in_payload_read)); + EXPECT_EQ(in_header_read.opcode, FUSE_READ); + EXPECT_EQ(in_payload_read.offset, 0); + EXPECT_EQ(buf, data); +} + +TEST_F(ReadTest, ReadPartial) { + auto fd = ASSERT_NO_ERRNO_AND_VALUE(OpenTestFile(test_file_path_)); + + // Prepare for the read. + const int n_data = 10; + std::vector<char> data(n_data); + RandomizeBuffer(data.data(), data.size()); + // Note: due to read ahead, current read implementation will treat any + // response that is longer than requested as correct (i.e. not reach the EOF). + // Therefore, the test below should make sure the size to read does not exceed + // n_data. + struct fuse_out_header out_header_read = { + .len = + static_cast<uint32_t>(sizeof(struct fuse_out_header) + data.size()), + }; + auto iov_out_read = FuseGenerateIovecs(out_header_read, data); + struct fuse_in_header in_header_read; + struct fuse_read_in in_payload_read; + auto iov_in = FuseGenerateIovecs(in_header_read, in_payload_read); + + std::vector<char> buf(n_data); + + // Read 1 bytes. + SetServerResponse(FUSE_READ, iov_out_read); + EXPECT_THAT(read(fd.get(), buf.data(), 1), SyscallSucceedsWithValue(1)); + + // Check the 1-byte read request. + GetServerActualRequest(iov_in); + EXPECT_EQ(in_payload_read.fh, test_fh_); + EXPECT_EQ(in_header_read.len, + sizeof(in_header_read) + sizeof(in_payload_read)); + EXPECT_EQ(in_header_read.opcode, FUSE_READ); + EXPECT_EQ(in_payload_read.offset, 0); + + // Read 3 bytes. + SetServerResponse(FUSE_READ, iov_out_read); + EXPECT_THAT(read(fd.get(), buf.data(), 3), SyscallSucceedsWithValue(3)); + + // Check the 3-byte read request. + GetServerActualRequest(iov_in); + EXPECT_EQ(in_payload_read.fh, test_fh_); + EXPECT_EQ(in_payload_read.offset, 1); + + // Read 5 bytes. + SetServerResponse(FUSE_READ, iov_out_read); + EXPECT_THAT(read(fd.get(), buf.data(), 5), SyscallSucceedsWithValue(5)); + + // Check the 5-byte read request. + GetServerActualRequest(iov_in); + EXPECT_EQ(in_payload_read.fh, test_fh_); + EXPECT_EQ(in_payload_read.offset, 4); +} + +TEST_F(ReadTest, PRead) { + const int file_size = 512; + auto fd = ASSERT_NO_ERRNO_AND_VALUE(OpenTestFile(test_file_path_, file_size)); + + // Prepare for the read. + const int n_read = 5; + std::vector<char> data(n_read); + RandomizeBuffer(data.data(), data.size()); + struct fuse_out_header out_header_read = { + .len = + static_cast<uint32_t>(sizeof(struct fuse_out_header) + data.size()), + }; + auto iov_out_read = FuseGenerateIovecs(out_header_read, data); + SetServerResponse(FUSE_READ, iov_out_read); + + // Read some bytes. + std::vector<char> buf(n_read); + const int offset_read = file_size >> 1; + EXPECT_THAT(pread(fd.get(), buf.data(), n_read, offset_read), + SyscallSucceedsWithValue(n_read)); + + // Check the read request. + struct fuse_in_header in_header_read; + struct fuse_read_in in_payload_read; + auto iov_in = FuseGenerateIovecs(in_header_read, in_payload_read); + GetServerActualRequest(iov_in); + + EXPECT_EQ(in_payload_read.fh, test_fh_); + EXPECT_EQ(in_header_read.len, + sizeof(in_header_read) + sizeof(in_payload_read)); + EXPECT_EQ(in_header_read.opcode, FUSE_READ); + EXPECT_EQ(in_payload_read.offset, offset_read); + EXPECT_EQ(buf, data); +} + +TEST_F(ReadTest, ReadZero) { + auto fd = ASSERT_NO_ERRNO_AND_VALUE(OpenTestFile(test_file_path_)); + + // Issue the read. + std::vector<char> buf; + EXPECT_THAT(read(fd.get(), buf.data(), 0), SyscallSucceedsWithValue(0)); +} + +TEST_F(ReadTest, ReadShort) { + auto fd = ASSERT_NO_ERRNO_AND_VALUE(OpenTestFile(test_file_path_)); + + // Prepare for the short read. + const int n_read = 5; + std::vector<char> data(n_read >> 1); + RandomizeBuffer(data.data(), data.size()); + struct fuse_out_header out_header_read = { + .len = + static_cast<uint32_t>(sizeof(struct fuse_out_header) + data.size()), + }; + auto iov_out_read = FuseGenerateIovecs(out_header_read, data); + SetServerResponse(FUSE_READ, iov_out_read); + + // Read the whole "file". + std::vector<char> buf(n_read); + EXPECT_THAT(read(fd.get(), buf.data(), n_read), + SyscallSucceedsWithValue(data.size())); + + // Check the read request. + struct fuse_in_header in_header_read; + struct fuse_read_in in_payload_read; + auto iov_in = FuseGenerateIovecs(in_header_read, in_payload_read); + GetServerActualRequest(iov_in); + + EXPECT_EQ(in_payload_read.fh, test_fh_); + EXPECT_EQ(in_header_read.len, + sizeof(in_header_read) + sizeof(in_payload_read)); + EXPECT_EQ(in_header_read.opcode, FUSE_READ); + EXPECT_EQ(in_payload_read.offset, 0); + std::vector<char> short_buf(buf.begin(), buf.begin() + data.size()); + EXPECT_EQ(short_buf, data); +} + +TEST_F(ReadTest, ReadShortEOF) { + auto fd = ASSERT_NO_ERRNO_AND_VALUE(OpenTestFile(test_file_path_)); + + // Prepare for the short read. + struct fuse_out_header out_header_read = { + .len = static_cast<uint32_t>(sizeof(struct fuse_out_header)), + }; + auto iov_out_read = FuseGenerateIovecs(out_header_read); + SetServerResponse(FUSE_READ, iov_out_read); + + // Read the whole "file". + const int n_read = 10; + std::vector<char> buf(n_read); + EXPECT_THAT(read(fd.get(), buf.data(), n_read), SyscallSucceedsWithValue(0)); + + // Check the read request. + struct fuse_in_header in_header_read; + struct fuse_read_in in_payload_read; + auto iov_in = FuseGenerateIovecs(in_header_read, in_payload_read); + GetServerActualRequest(iov_in); + + EXPECT_EQ(in_payload_read.fh, test_fh_); + EXPECT_EQ(in_header_read.len, + sizeof(in_header_read) + sizeof(in_payload_read)); + EXPECT_EQ(in_header_read.opcode, FUSE_READ); + EXPECT_EQ(in_payload_read.offset, 0); +} + +TEST_F(ReadTestSmallMaxRead, ReadSmallMaxRead) { + const int n_fragment = 10; + const int n_read = size_fragment * n_fragment; + + auto fd = ASSERT_NO_ERRNO_AND_VALUE(OpenTestFile(test_file_path_, n_read)); + + // Prepare for the read. + std::vector<char> data(size_fragment); + RandomizeBuffer(data.data(), data.size()); + struct fuse_out_header out_header_read = { + .len = + static_cast<uint32_t>(sizeof(struct fuse_out_header) + data.size()), + }; + auto iov_out_read = FuseGenerateIovecs(out_header_read, data); + + for (int i = 0; i < n_fragment; ++i) { + SetServerResponse(FUSE_READ, iov_out_read); + } + + // Read the whole "file". + std::vector<char> buf(n_read); + EXPECT_THAT(read(fd.get(), buf.data(), n_read), + SyscallSucceedsWithValue(n_read)); + + ASSERT_EQ(GetServerNumUnsentResponses(), 0); + ASSERT_EQ(GetServerNumUnconsumedRequests(), n_fragment); + + // Check each read segment. + struct fuse_in_header in_header_read; + struct fuse_read_in in_payload_read; + auto iov_in = FuseGenerateIovecs(in_header_read, in_payload_read); + + for (int i = 0; i < n_fragment; ++i) { + GetServerActualRequest(iov_in); + EXPECT_EQ(in_payload_read.fh, test_fh_); + EXPECT_EQ(in_header_read.len, + sizeof(in_header_read) + sizeof(in_payload_read)); + EXPECT_EQ(in_header_read.opcode, FUSE_READ); + EXPECT_EQ(in_payload_read.offset, i * size_fragment); + EXPECT_EQ(in_payload_read.size, size_fragment); + + auto it = buf.begin() + i * size_fragment; + EXPECT_EQ(std::vector<char>(it, it + size_fragment), data); + } +} + +TEST_F(ReadTestSmallMaxRead, ReadSmallMaxReadShort) { + const int n_fragment = 10; + const int n_read = size_fragment * n_fragment; + + auto fd = ASSERT_NO_ERRNO_AND_VALUE(OpenTestFile(test_file_path_, n_read)); + + // Prepare for the read. + std::vector<char> data(size_fragment); + RandomizeBuffer(data.data(), data.size()); + struct fuse_out_header out_header_read = { + .len = + static_cast<uint32_t>(sizeof(struct fuse_out_header) + data.size()), + }; + auto iov_out_read = FuseGenerateIovecs(out_header_read, data); + + for (int i = 0; i < n_fragment - 1; ++i) { + SetServerResponse(FUSE_READ, iov_out_read); + } + + // The last fragment is a short read. + std::vector<char> half_data(data.begin(), data.begin() + (data.size() >> 1)); + struct fuse_out_header out_header_read_short = { + .len = static_cast<uint32_t>(sizeof(struct fuse_out_header) + + half_data.size()), + }; + auto iov_out_read_short = + FuseGenerateIovecs(out_header_read_short, half_data); + SetServerResponse(FUSE_READ, iov_out_read_short); + + // Read the whole "file". + std::vector<char> buf(n_read); + EXPECT_THAT(read(fd.get(), buf.data(), n_read), + SyscallSucceedsWithValue(n_read - (data.size() >> 1))); + + ASSERT_EQ(GetServerNumUnsentResponses(), 0); + ASSERT_EQ(GetServerNumUnconsumedRequests(), n_fragment); + + // Check each read segment. + struct fuse_in_header in_header_read; + struct fuse_read_in in_payload_read; + auto iov_in = FuseGenerateIovecs(in_header_read, in_payload_read); + + for (int i = 0; i < n_fragment; ++i) { + GetServerActualRequest(iov_in); + EXPECT_EQ(in_payload_read.fh, test_fh_); + EXPECT_EQ(in_header_read.len, + sizeof(in_header_read) + sizeof(in_payload_read)); + EXPECT_EQ(in_header_read.opcode, FUSE_READ); + EXPECT_EQ(in_payload_read.offset, i * size_fragment); + EXPECT_EQ(in_payload_read.size, size_fragment); + + auto it = buf.begin() + i * size_fragment; + if (i != n_fragment - 1) { + EXPECT_EQ(std::vector<char>(it, it + data.size()), data); + } else { + EXPECT_EQ(std::vector<char>(it, it + half_data.size()), half_data); + } + } +} + +} // namespace + +} // namespace testing +} // namespace gvisor diff --git a/test/fuse/linux/readdir_test.cc b/test/fuse/linux/readdir_test.cc new file mode 100644 index 000000000..2afb4b062 --- /dev/null +++ b/test/fuse/linux/readdir_test.cc @@ -0,0 +1,193 @@ +// Copyright 2020 The gVisor Authors. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include <dirent.h> +#include <errno.h> +#include <fcntl.h> +#include <linux/fuse.h> +#include <linux/unistd.h> +#include <sys/stat.h> +#include <sys/statfs.h> +#include <sys/types.h> +#include <unistd.h> + +#include <string> + +#include "gtest/gtest.h" +#include "test/fuse/linux/fuse_base.h" +#include "test/util/fuse_util.h" +#include "test/util/test_util.h" + +#define FUSE_NAME_OFFSET offsetof(struct fuse_dirent, name) +#define FUSE_DIRENT_ALIGN(x) \ + (((x) + sizeof(uint64_t) - 1) & ~(sizeof(uint64_t) - 1)) +#define FUSE_DIRENT_SIZE(d) FUSE_DIRENT_ALIGN(FUSE_NAME_OFFSET + (d)->namelen) + +namespace gvisor { +namespace testing { + +namespace { + +class ReaddirTest : public FuseTest { + public: + void fill_fuse_dirent(char *buf, const char *name, uint64_t ino) { + size_t namelen = strlen(name); + size_t entlen = FUSE_NAME_OFFSET + namelen; + size_t entlen_padded = FUSE_DIRENT_ALIGN(entlen); + struct fuse_dirent *dirent; + + dirent = reinterpret_cast<struct fuse_dirent *>(buf); + dirent->ino = ino; + dirent->namelen = namelen; + memcpy(dirent->name, name, namelen); + memset(dirent->name + namelen, 0, entlen_padded - entlen); + } + + protected: + const std::string test_dir_name_ = "test_dir"; +}; + +TEST_F(ReaddirTest, SingleEntry) { + const std::string test_dir_path = + JoinPath(mount_point_.path().c_str(), test_dir_name_); + + const uint64_t ino_dir = 1024; + // We need to make sure the test dir is a directory that can be found. + mode_t expected_mode = + S_IFDIR | S_IRWXU | S_IRGRP | S_IXGRP | S_IROTH | S_IXOTH; + struct fuse_attr dir_attr = { + .ino = ino_dir, + .size = 512, + .blocks = 4, + .mode = expected_mode, + .blksize = 4096, + }; + + // We need to make sure the test dir is a directory that can be found. + struct fuse_out_header lookup_header = { + .len = sizeof(struct fuse_out_header) + sizeof(struct fuse_entry_out), + }; + struct fuse_entry_out lookup_payload = { + .nodeid = 1, + .entry_valid = true, + .attr_valid = true, + .attr = dir_attr, + }; + + struct fuse_out_header open_header = { + .len = sizeof(struct fuse_out_header) + sizeof(struct fuse_open_out), + }; + struct fuse_open_out open_payload = { + .fh = 1, + }; + auto iov_out = FuseGenerateIovecs(lookup_header, lookup_payload); + SetServerResponse(FUSE_LOOKUP, iov_out); + + iov_out = FuseGenerateIovecs(open_header, open_payload); + SetServerResponse(FUSE_OPENDIR, iov_out); + + FileDescriptor fd = + ASSERT_NO_ERRNO_AND_VALUE(Open(test_dir_path.c_str(), O_RDONLY)); + + // The open command makes two syscalls. Lookup the dir file and open. + // We don't need to inspect those headers in this test. + SkipServerActualRequest(); // LOOKUP. + SkipServerActualRequest(); // OPENDIR. + + // Readdir test code. + std::string dot = "."; + std::string dot_dot = ".."; + std::string test_file = "testFile"; + + // Figure out how many dirents to send over and allocate them appropriately. + // Each dirent has a dynamic name and a static metadata part. The dirent size + // is aligned to being a multiple of 8. + size_t dot_file_dirent_size = + FUSE_DIRENT_ALIGN(dot.length() + FUSE_NAME_OFFSET); + size_t dot_dot_file_dirent_size = + FUSE_DIRENT_ALIGN(dot_dot.length() + FUSE_NAME_OFFSET); + size_t test_file_dirent_size = + FUSE_DIRENT_ALIGN(test_file.length() + FUSE_NAME_OFFSET); + + // Create an appropriately sized payload. + size_t readdir_payload_size = + test_file_dirent_size + dot_file_dirent_size + dot_dot_file_dirent_size; + std::vector<char> readdir_payload_vec(readdir_payload_size); + char *readdir_payload = readdir_payload_vec.data(); + + // Use fake ino for other directories. + fill_fuse_dirent(readdir_payload, dot.c_str(), ino_dir - 2); + fill_fuse_dirent(readdir_payload + dot_file_dirent_size, dot_dot.c_str(), + ino_dir - 1); + fill_fuse_dirent( + readdir_payload + dot_file_dirent_size + dot_dot_file_dirent_size, + test_file.c_str(), ino_dir); + + struct fuse_out_header readdir_header = { + .len = uint32_t(sizeof(struct fuse_out_header) + readdir_payload_size), + }; + struct fuse_out_header readdir_header_break = { + .len = uint32_t(sizeof(struct fuse_out_header)), + }; + + iov_out = FuseGenerateIovecs(readdir_header, readdir_payload_vec); + SetServerResponse(FUSE_READDIR, iov_out); + + iov_out = FuseGenerateIovecs(readdir_header_break); + SetServerResponse(FUSE_READDIR, iov_out); + + std::vector<char> buf(4090, 0); + int nread, off = 0, i = 0; + EXPECT_THAT( + nread = syscall(__NR_getdents64, fd.get(), buf.data(), buf.size()), + SyscallSucceeds()); + for (; off < nread;) { + struct dirent64 *ent = (struct dirent64 *)(buf.data() + off); + off += ent->d_reclen; + switch (i++) { + case 0: + EXPECT_EQ(std::string(ent->d_name), dot); + break; + case 1: + EXPECT_EQ(std::string(ent->d_name), dot_dot); + break; + case 2: + EXPECT_EQ(std::string(ent->d_name), test_file); + break; + } + } + + EXPECT_THAT( + nread = syscall(__NR_getdents64, fd.get(), buf.data(), buf.size()), + SyscallSucceedsWithValue(0)); + + SkipServerActualRequest(); // READDIR. + SkipServerActualRequest(); // READDIR with no data. + + // Clean up. + fd.reset(-1); + + struct fuse_in_header in_header; + struct fuse_release_in in_payload; + + auto iov_in = FuseGenerateIovecs(in_header, in_payload); + GetServerActualRequest(iov_in); + EXPECT_EQ(in_header.len, sizeof(in_header) + sizeof(in_payload)); + EXPECT_EQ(in_header.opcode, FUSE_RELEASEDIR); +} + +} // namespace + +} // namespace testing +} // namespace gvisor diff --git a/test/fuse/linux/readlink_test.cc b/test/fuse/linux/readlink_test.cc new file mode 100644 index 000000000..2cba8fc23 --- /dev/null +++ b/test/fuse/linux/readlink_test.cc @@ -0,0 +1,85 @@ +// Copyright 2020 The gVisor Authors. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include <errno.h> +#include <fcntl.h> +#include <linux/fuse.h> +#include <sys/stat.h> +#include <sys/statfs.h> +#include <sys/types.h> +#include <unistd.h> + +#include <string> +#include <vector> + +#include "gtest/gtest.h" +#include "test/fuse/linux/fuse_base.h" +#include "test/util/fuse_util.h" +#include "test/util/test_util.h" + +namespace gvisor { +namespace testing { + +namespace { + +class ReadlinkTest : public FuseTest { + protected: + const std::string test_file_ = "test_file_"; + const mode_t perms_ = S_IRWXU | S_IRWXG | S_IRWXO; +}; + +TEST_F(ReadlinkTest, ReadSymLink) { + const std::string symlink_path = + JoinPath(mount_point_.path().c_str(), test_file_); + SetServerInodeLookup(test_file_, S_IFLNK | perms_); + + struct fuse_out_header out_header = { + .len = static_cast<uint32_t>(sizeof(struct fuse_out_header)) + + static_cast<uint32_t>(test_file_.length()) + 1, + }; + std::string link = test_file_; + auto iov_out = FuseGenerateIovecs(out_header, link); + SetServerResponse(FUSE_READLINK, iov_out); + const std::string actual_link = + ASSERT_NO_ERRNO_AND_VALUE(ReadLink(symlink_path)); + + struct fuse_in_header in_header; + auto iov_in = FuseGenerateIovecs(in_header); + GetServerActualRequest(iov_in); + + EXPECT_EQ(in_header.len, sizeof(in_header)); + EXPECT_EQ(in_header.opcode, FUSE_READLINK); + EXPECT_EQ(0, memcmp(actual_link.c_str(), link.data(), link.size())); + + // next readlink should have link cached, so shouldn't have new request to + // server. + uint32_t recieved_before = GetServerTotalReceivedBytes(); + ASSERT_NO_ERRNO(ReadLink(symlink_path)); + EXPECT_EQ(GetServerTotalReceivedBytes(), recieved_before); +} + +TEST_F(ReadlinkTest, NotSymlink) { + const std::string test_file_path = + JoinPath(mount_point_.path().c_str(), test_file_); + SetServerInodeLookup(test_file_, S_IFREG | perms_); + + std::vector<char> buf(PATH_MAX + 1); + ASSERT_THAT(readlink(test_file_path.c_str(), buf.data(), PATH_MAX), + SyscallFailsWithErrno(EINVAL)); +} + +} // namespace + +} // namespace testing +} // namespace gvisor diff --git a/test/fuse/linux/release_test.cc b/test/fuse/linux/release_test.cc new file mode 100644 index 000000000..b5adb0870 --- /dev/null +++ b/test/fuse/linux/release_test.cc @@ -0,0 +1,74 @@ +// Copyright 2020 The gVisor Authors. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include <errno.h> +#include <fcntl.h> +#include <linux/fuse.h> +#include <sys/mount.h> +#include <sys/stat.h> +#include <sys/statfs.h> +#include <sys/types.h> +#include <unistd.h> + +#include <string> +#include <vector> + +#include "gtest/gtest.h" +#include "test/fuse/linux/fuse_base.h" +#include "test/util/fuse_util.h" +#include "test/util/test_util.h" + +namespace gvisor { +namespace testing { + +namespace { + +class ReleaseTest : public FuseTest { + protected: + const std::string test_file_ = "test_file"; +}; + +TEST_F(ReleaseTest, RegularFile) { + const std::string test_file_path = + JoinPath(mount_point_.path().c_str(), test_file_); + SetServerInodeLookup(test_file_, S_IFREG | S_IRWXU | S_IRWXG | S_IRWXO); + + struct fuse_out_header out_header = { + .len = sizeof(struct fuse_out_header) + sizeof(struct fuse_open_out), + }; + struct fuse_open_out out_payload = { + .fh = 1, + .open_flags = O_RDWR, + }; + auto iov_out = FuseGenerateIovecs(out_header, out_payload); + SetServerResponse(FUSE_OPEN, iov_out); + FileDescriptor fd = ASSERT_NO_ERRNO_AND_VALUE(Open(test_file_path, O_RDWR)); + SkipServerActualRequest(); + ASSERT_THAT(close(fd.release()), SyscallSucceeds()); + + struct fuse_in_header in_header; + struct fuse_release_in in_payload; + auto iov_in = FuseGenerateIovecs(in_header, in_payload); + GetServerActualRequest(iov_in); + + EXPECT_EQ(in_header.len, sizeof(in_header) + sizeof(in_payload)); + EXPECT_EQ(in_header.opcode, FUSE_RELEASE); + EXPECT_EQ(in_payload.flags, O_RDWR); + EXPECT_EQ(in_payload.fh, 1); +} + +} // namespace + +} // namespace testing +} // namespace gvisor diff --git a/test/fuse/linux/rmdir_test.cc b/test/fuse/linux/rmdir_test.cc new file mode 100644 index 000000000..e3200e446 --- /dev/null +++ b/test/fuse/linux/rmdir_test.cc @@ -0,0 +1,100 @@ +// Copyright 2020 The gVisor Authors. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include <errno.h> +#include <fcntl.h> +#include <linux/fuse.h> +#include <sys/stat.h> +#include <sys/statfs.h> +#include <sys/types.h> +#include <sys/uio.h> +#include <unistd.h> + +#include <string> +#include <vector> + +#include "gtest/gtest.h" +#include "test/fuse/linux/fuse_base.h" +#include "test/util/fs_util.h" +#include "test/util/fuse_util.h" +#include "test/util/test_util.h" + +namespace gvisor { +namespace testing { + +namespace { + +class RmDirTest : public FuseTest { + protected: + const std::string test_dir_name_ = "test_dir"; + const std::string test_subdir_ = "test_subdir"; + const mode_t test_dir_mode_ = S_IFDIR | S_IRWXU | S_IRWXG | S_IRWXO; +}; + +TEST_F(RmDirTest, NormalRmDir) { + const std::string test_dir_path_ = + JoinPath(mount_point_.path().c_str(), test_dir_name_); + + SetServerInodeLookup(test_dir_name_, test_dir_mode_); + + // RmDir code. + struct fuse_out_header rmdir_header = { + .len = sizeof(struct fuse_out_header), + }; + + auto iov_out = FuseGenerateIovecs(rmdir_header); + SetServerResponse(FUSE_RMDIR, iov_out); + + ASSERT_THAT(rmdir(test_dir_path_.c_str()), SyscallSucceeds()); + + struct fuse_in_header in_header; + std::vector<char> actual_dirname(test_dir_name_.length() + 1); + auto iov_in = FuseGenerateIovecs(in_header, actual_dirname); + GetServerActualRequest(iov_in); + + EXPECT_EQ(in_header.len, sizeof(in_header) + test_dir_name_.length() + 1); + EXPECT_EQ(in_header.opcode, FUSE_RMDIR); + EXPECT_EQ(std::string(actual_dirname.data()), test_dir_name_); +} + +TEST_F(RmDirTest, NormalRmDirSubdir) { + SetServerInodeLookup(test_subdir_, S_IFDIR | S_IRWXU | S_IRWXG | S_IRWXO); + const std::string test_dir_path_ = + JoinPath(mount_point_.path().c_str(), test_subdir_, test_dir_name_); + SetServerInodeLookup(test_dir_name_, test_dir_mode_); + + // RmDir code. + struct fuse_out_header rmdir_header = { + .len = sizeof(struct fuse_out_header), + }; + + auto iov_out = FuseGenerateIovecs(rmdir_header); + SetServerResponse(FUSE_RMDIR, iov_out); + + ASSERT_THAT(rmdir(test_dir_path_.c_str()), SyscallSucceeds()); + + struct fuse_in_header in_header; + std::vector<char> actual_dirname(test_dir_name_.length() + 1); + auto iov_in = FuseGenerateIovecs(in_header, actual_dirname); + GetServerActualRequest(iov_in); + + EXPECT_EQ(in_header.len, sizeof(in_header) + test_dir_name_.length() + 1); + EXPECT_EQ(in_header.opcode, FUSE_RMDIR); + EXPECT_EQ(std::string(actual_dirname.data()), test_dir_name_); +} + +} // namespace + +} // namespace testing +} // namespace gvisor diff --git a/test/fuse/linux/setstat_test.cc b/test/fuse/linux/setstat_test.cc new file mode 100644 index 000000000..68301c775 --- /dev/null +++ b/test/fuse/linux/setstat_test.cc @@ -0,0 +1,338 @@ +// Copyright 2020 The gVisor Authors. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include <errno.h> +#include <fcntl.h> +#include <linux/fuse.h> +#include <sys/stat.h> +#include <sys/statfs.h> +#include <sys/types.h> +#include <sys/uio.h> +#include <unistd.h> +#include <utime.h> + +#include <string> +#include <vector> + +#include "gtest/gtest.h" +#include "test/fuse/linux/fuse_fd_util.h" +#include "test/util/cleanup.h" +#include "test/util/fs_util.h" +#include "test/util/fuse_util.h" +#include "test/util/test_util.h" + +namespace gvisor { +namespace testing { + +namespace { + +class SetStatTest : public FuseFdTest { + public: + void SetUp() override { + FuseFdTest::SetUp(); + test_dir_path_ = JoinPath(mount_point_.path(), test_dir_); + test_file_path_ = JoinPath(mount_point_.path(), test_file_); + } + + protected: + const uint64_t fh = 23; + const std::string test_dir_ = "testdir"; + const std::string test_file_ = "testfile"; + const mode_t test_dir_mode_ = S_IFDIR | S_IRUSR | S_IWUSR | S_IXUSR; + const mode_t test_file_mode_ = S_IFREG | S_IRUSR | S_IWUSR | S_IXUSR; + + std::string test_dir_path_; + std::string test_file_path_; +}; + +TEST_F(SetStatTest, ChmodDir) { + // Set up fixture. + SetServerInodeLookup(test_dir_, test_dir_mode_); + struct fuse_out_header out_header = { + .len = sizeof(struct fuse_out_header) + sizeof(struct fuse_attr_out), + .error = 0, + }; + mode_t set_mode = S_IRGRP | S_IWGRP | S_IXGRP; + struct fuse_attr_out out_payload = { + .attr = DefaultFuseAttr(set_mode, 2), + }; + auto iov_out = FuseGenerateIovecs(out_header, out_payload); + SetServerResponse(FUSE_SETATTR, iov_out); + + // Make syscall. + EXPECT_THAT(chmod(test_dir_path_.c_str(), set_mode), SyscallSucceeds()); + + // Check FUSE request. + struct fuse_in_header in_header; + struct fuse_setattr_in in_payload; + auto iov_in = FuseGenerateIovecs(in_header, in_payload); + + GetServerActualRequest(iov_in); + EXPECT_EQ(in_header.len, sizeof(in_header) + sizeof(in_payload)); + EXPECT_EQ(in_header.opcode, FUSE_SETATTR); + EXPECT_EQ(in_header.uid, 0); + EXPECT_EQ(in_header.gid, 0); + EXPECT_EQ(in_payload.valid, FATTR_MODE); + EXPECT_EQ(in_payload.mode, S_IFDIR | set_mode); +} + +TEST_F(SetStatTest, ChownDir) { + // Set up fixture. + SetServerInodeLookup(test_dir_, test_dir_mode_); + struct fuse_out_header out_header = { + .len = sizeof(struct fuse_out_header) + sizeof(struct fuse_attr_out), + .error = 0, + }; + struct fuse_attr_out out_payload = { + .attr = DefaultFuseAttr(test_dir_mode_, 2), + }; + auto iov_out = FuseGenerateIovecs(out_header, out_payload); + SetServerResponse(FUSE_SETATTR, iov_out); + + // Make syscall. + EXPECT_THAT(chown(test_dir_path_.c_str(), 1025, 1025), SyscallSucceeds()); + + // Check FUSE request. + struct fuse_in_header in_header; + struct fuse_setattr_in in_payload; + auto iov_in = FuseGenerateIovecs(in_header, in_payload); + + GetServerActualRequest(iov_in); + EXPECT_EQ(in_header.len, sizeof(in_header) + sizeof(in_payload)); + EXPECT_EQ(in_header.opcode, FUSE_SETATTR); + EXPECT_EQ(in_header.uid, 0); + EXPECT_EQ(in_header.gid, 0); + EXPECT_EQ(in_payload.valid, FATTR_UID | FATTR_GID); + EXPECT_EQ(in_payload.uid, 1025); + EXPECT_EQ(in_payload.gid, 1025); +} + +TEST_F(SetStatTest, TruncateFile) { + // Set up fixture. + SetServerInodeLookup(test_file_, test_file_mode_); + struct fuse_out_header out_header = { + .len = sizeof(struct fuse_out_header) + sizeof(struct fuse_attr_out), + .error = 0, + }; + struct fuse_attr_out out_payload = { + .attr = DefaultFuseAttr(S_IFREG | S_IRUSR | S_IWUSR, 2), + }; + auto iov_out = FuseGenerateIovecs(out_header, out_payload); + SetServerResponse(FUSE_SETATTR, iov_out); + + // Make syscall. + EXPECT_THAT(truncate(test_file_path_.c_str(), 321), SyscallSucceeds()); + + // Check FUSE request. + struct fuse_in_header in_header; + struct fuse_setattr_in in_payload; + auto iov_in = FuseGenerateIovecs(in_header, in_payload); + + GetServerActualRequest(iov_in); + EXPECT_EQ(in_header.len, sizeof(in_header) + sizeof(in_payload)); + EXPECT_EQ(in_header.opcode, FUSE_SETATTR); + EXPECT_EQ(in_header.uid, 0); + EXPECT_EQ(in_header.gid, 0); + EXPECT_EQ(in_payload.valid, FATTR_SIZE); + EXPECT_EQ(in_payload.size, 321); +} + +TEST_F(SetStatTest, UtimeFile) { + // Set up fixture. + SetServerInodeLookup(test_file_, test_file_mode_); + struct fuse_out_header out_header = { + .len = sizeof(struct fuse_out_header) + sizeof(struct fuse_attr_out), + .error = 0, + }; + struct fuse_attr_out out_payload = { + .attr = DefaultFuseAttr(S_IFREG | S_IRUSR | S_IWUSR, 2), + }; + auto iov_out = FuseGenerateIovecs(out_header, out_payload); + SetServerResponse(FUSE_SETATTR, iov_out); + + // Make syscall. + time_t expected_atime = 1597159766, expected_mtime = 1597159765; + struct utimbuf times = { + .actime = expected_atime, + .modtime = expected_mtime, + }; + EXPECT_THAT(utime(test_file_path_.c_str(), ×), SyscallSucceeds()); + + // Check FUSE request. + struct fuse_in_header in_header; + struct fuse_setattr_in in_payload; + auto iov_in = FuseGenerateIovecs(in_header, in_payload); + + GetServerActualRequest(iov_in); + EXPECT_EQ(in_header.len, sizeof(in_header) + sizeof(in_payload)); + EXPECT_EQ(in_header.opcode, FUSE_SETATTR); + EXPECT_EQ(in_header.uid, 0); + EXPECT_EQ(in_header.gid, 0); + EXPECT_EQ(in_payload.valid, FATTR_ATIME | FATTR_MTIME); + EXPECT_EQ(in_payload.atime, expected_atime); + EXPECT_EQ(in_payload.mtime, expected_mtime); +} + +TEST_F(SetStatTest, UtimesFile) { + // Set up fixture. + SetServerInodeLookup(test_file_, test_file_mode_); + struct fuse_out_header out_header = { + .len = sizeof(struct fuse_out_header) + sizeof(struct fuse_attr_out), + .error = 0, + }; + struct fuse_attr_out out_payload = { + .attr = DefaultFuseAttr(test_file_mode_, 2), + }; + auto iov_out = FuseGenerateIovecs(out_header, out_payload); + SetServerResponse(FUSE_SETATTR, iov_out); + + // Make syscall. + struct timeval expected_times[2] = { + { + .tv_sec = 1597159766, + .tv_usec = 234945, + }, + { + .tv_sec = 1597159765, + .tv_usec = 232341, + }, + }; + EXPECT_THAT(utimes(test_file_path_.c_str(), expected_times), + SyscallSucceeds()); + + // Check FUSE request. + struct fuse_in_header in_header; + struct fuse_setattr_in in_payload; + auto iov_in = FuseGenerateIovecs(in_header, in_payload); + + GetServerActualRequest(iov_in); + EXPECT_EQ(in_header.len, sizeof(in_header) + sizeof(in_payload)); + EXPECT_EQ(in_header.opcode, FUSE_SETATTR); + EXPECT_EQ(in_header.uid, 0); + EXPECT_EQ(in_header.gid, 0); + EXPECT_EQ(in_payload.valid, FATTR_ATIME | FATTR_MTIME); + EXPECT_EQ(in_payload.atime, expected_times[0].tv_sec); + EXPECT_EQ(in_payload.atimensec, expected_times[0].tv_usec * 1000); + EXPECT_EQ(in_payload.mtime, expected_times[1].tv_sec); + EXPECT_EQ(in_payload.mtimensec, expected_times[1].tv_usec * 1000); +} + +TEST_F(SetStatTest, FtruncateFile) { + // Set up fixture. + SetServerInodeLookup(test_file_, test_file_mode_); + auto fd = ASSERT_NO_ERRNO_AND_VALUE(OpenPath(test_file_path_, O_RDWR, fh)); + auto close_fd = CloseFD(fd); + + struct fuse_out_header out_header = { + .len = sizeof(struct fuse_out_header) + sizeof(struct fuse_attr_out), + .error = 0, + }; + struct fuse_attr_out out_payload = { + .attr = DefaultFuseAttr(test_file_mode_, 2), + }; + auto iov_out = FuseGenerateIovecs(out_header, out_payload); + SetServerResponse(FUSE_SETATTR, iov_out); + + // Make syscall. + EXPECT_THAT(ftruncate(fd.get(), 321), SyscallSucceeds()); + + // Check FUSE request. + struct fuse_in_header in_header; + struct fuse_setattr_in in_payload; + auto iov_in = FuseGenerateIovecs(in_header, in_payload); + + GetServerActualRequest(iov_in); + EXPECT_EQ(in_header.len, sizeof(in_header) + sizeof(in_payload)); + EXPECT_EQ(in_header.opcode, FUSE_SETATTR); + EXPECT_EQ(in_header.uid, 0); + EXPECT_EQ(in_header.gid, 0); + EXPECT_EQ(in_payload.valid, FATTR_SIZE | FATTR_FH); + EXPECT_EQ(in_payload.fh, fh); + EXPECT_EQ(in_payload.size, 321); +} + +TEST_F(SetStatTest, FchmodFile) { + // Set up fixture. + SetServerInodeLookup(test_file_, test_file_mode_); + auto fd = ASSERT_NO_ERRNO_AND_VALUE(OpenPath(test_file_path_, O_RDWR, fh)); + auto close_fd = CloseFD(fd); + + struct fuse_out_header out_header = { + .len = sizeof(struct fuse_out_header) + sizeof(struct fuse_attr_out), + .error = 0, + }; + mode_t set_mode = S_IROTH | S_IWOTH | S_IXOTH; + struct fuse_attr_out out_payload = { + .attr = DefaultFuseAttr(set_mode, 2), + }; + auto iov_out = FuseGenerateIovecs(out_header, out_payload); + SetServerResponse(FUSE_SETATTR, iov_out); + + // Make syscall. + EXPECT_THAT(fchmod(fd.get(), set_mode), SyscallSucceeds()); + + // Check FUSE request. + struct fuse_in_header in_header; + struct fuse_setattr_in in_payload; + auto iov_in = FuseGenerateIovecs(in_header, in_payload); + + GetServerActualRequest(iov_in); + EXPECT_EQ(in_header.len, sizeof(in_header) + sizeof(in_payload)); + EXPECT_EQ(in_header.opcode, FUSE_SETATTR); + EXPECT_EQ(in_header.uid, 0); + EXPECT_EQ(in_header.gid, 0); + EXPECT_EQ(in_payload.valid, FATTR_MODE | FATTR_FH); + EXPECT_EQ(in_payload.fh, fh); + EXPECT_EQ(in_payload.mode, S_IFREG | set_mode); +} + +TEST_F(SetStatTest, FchownFile) { + // Set up fixture. + SetServerInodeLookup(test_file_, test_file_mode_); + auto fd = ASSERT_NO_ERRNO_AND_VALUE(OpenPath(test_file_path_, O_RDWR, fh)); + auto close_fd = CloseFD(fd); + + struct fuse_out_header out_header = { + .len = sizeof(struct fuse_out_header) + sizeof(struct fuse_attr_out), + .error = 0, + }; + struct fuse_attr_out out_payload = { + .attr = DefaultFuseAttr(S_IFREG | S_IRUSR | S_IWUSR | S_IXUSR, 2), + }; + auto iov_out = FuseGenerateIovecs(out_header, out_payload); + SetServerResponse(FUSE_SETATTR, iov_out); + + // Make syscall. + EXPECT_THAT(fchown(fd.get(), 1025, 1025), SyscallSucceeds()); + + // Check FUSE request. + struct fuse_in_header in_header; + struct fuse_setattr_in in_payload; + auto iov_in = FuseGenerateIovecs(in_header, in_payload); + + GetServerActualRequest(iov_in); + EXPECT_EQ(in_header.len, sizeof(in_header) + sizeof(in_payload)); + EXPECT_EQ(in_header.opcode, FUSE_SETATTR); + EXPECT_EQ(in_header.uid, 0); + EXPECT_EQ(in_header.gid, 0); + EXPECT_EQ(in_payload.valid, FATTR_UID | FATTR_GID | FATTR_FH); + EXPECT_EQ(in_payload.fh, fh); + EXPECT_EQ(in_payload.uid, 1025); + EXPECT_EQ(in_payload.gid, 1025); +} + +} // namespace + +} // namespace testing +} // namespace gvisor diff --git a/test/fuse/linux/stat_test.cc b/test/fuse/linux/stat_test.cc index 172e09867..6f032cac1 100644 --- a/test/fuse/linux/stat_test.cc +++ b/test/fuse/linux/stat_test.cc @@ -18,12 +18,16 @@ #include <sys/stat.h> #include <sys/statfs.h> #include <sys/types.h> +#include <sys/uio.h> #include <unistd.h> #include <vector> #include "gtest/gtest.h" -#include "test/fuse/linux/fuse_base.h" +#include "test/fuse/linux/fuse_fd_util.h" +#include "test/util/cleanup.h" +#include "test/util/fs_util.h" +#include "test/util/fuse_util.h" #include "test/util/test_util.h" namespace gvisor { @@ -31,89 +35,45 @@ namespace testing { namespace { -class StatTest : public FuseTest { +class StatTest : public FuseFdTest { public: - bool CompareRequest(void* expected_mem, size_t expected_len, void* real_mem, - size_t real_len) override { - if (expected_len != real_len) return false; - struct fuse_in_header* real_header = - reinterpret_cast<fuse_in_header*>(real_mem); - - if (real_header->opcode != FUSE_GETATTR) { - std::cerr << "expect header opcode " << FUSE_GETATTR << " but got " - << real_header->opcode << std::endl; - return false; - } - return true; + void SetUp() override { + FuseFdTest::SetUp(); + test_file_path_ = JoinPath(mount_point_.path(), test_file_); } + protected: bool StatsAreEqual(struct stat expected, struct stat actual) { - // device number will be dynamically allocated by kernel, we cannot know - // in advance + // Device number will be dynamically allocated by kernel, we cannot know in + // advance. actual.st_dev = expected.st_dev; return memcmp(&expected, &actual, sizeof(struct stat)) == 0; } + + const std::string test_file_ = "testfile"; + const mode_t expected_mode = S_IFREG | S_IRUSR | S_IWUSR; + const uint64_t fh = 23; + + std::string test_file_path_; }; TEST_F(StatTest, StatNormal) { - struct iovec iov_in[2]; - struct iovec iov_out[2]; - - struct fuse_in_header in_header = { - .len = sizeof(struct fuse_in_header) + sizeof(struct fuse_getattr_in), - .opcode = FUSE_GETATTR, - .unique = 4, - .nodeid = 1, - .uid = 0, - .gid = 0, - .pid = 4, - .padding = 0, - }; - struct fuse_getattr_in in_payload = {0}; - iov_in[0].iov_len = sizeof(in_header); - iov_in[0].iov_base = &in_header; - iov_in[1].iov_len = sizeof(in_payload); - iov_in[1].iov_base = &in_payload; - - mode_t expected_mode = S_IRWXU | S_IRGRP | S_IXGRP | S_IROTH | S_IXOTH; - struct timespec atime = {.tv_sec = 1595436289, .tv_nsec = 134150844}; - struct timespec mtime = {.tv_sec = 1595436290, .tv_nsec = 134150845}; - struct timespec ctime = {.tv_sec = 1595436291, .tv_nsec = 134150846}; + // Set up fixture. + struct fuse_attr attr = DefaultFuseAttr(expected_mode, 1); struct fuse_out_header out_header = { .len = sizeof(struct fuse_out_header) + sizeof(struct fuse_attr_out), - .error = 0, - .unique = 4, - }; - struct fuse_attr attr = { - .ino = 1, - .size = 512, - .blocks = 4, - .atime = static_cast<uint64_t>(atime.tv_sec), - .mtime = static_cast<uint64_t>(mtime.tv_sec), - .ctime = static_cast<uint64_t>(ctime.tv_sec), - .atimensec = static_cast<uint32_t>(atime.tv_nsec), - .mtimensec = static_cast<uint32_t>(mtime.tv_nsec), - .ctimensec = static_cast<uint32_t>(ctime.tv_nsec), - .mode = expected_mode, - .nlink = 2, - .uid = 1234, - .gid = 4321, - .rdev = 12, - .blksize = 4096, }; struct fuse_attr_out out_payload = { .attr = attr, }; - iov_out[0].iov_len = sizeof(out_header); - iov_out[0].iov_base = &out_header; - iov_out[1].iov_len = sizeof(out_payload); - iov_out[1].iov_base = &out_payload; - - SetExpected(iov_in, 2, iov_out, 2); + auto iov_out = FuseGenerateIovecs(out_header, out_payload); + SetServerResponse(FUSE_GETATTR, iov_out); + // Make syscall. struct stat stat_buf; EXPECT_THAT(stat(mount_point_.path().c_str(), &stat_buf), SyscallSucceeds()); + // Check filesystem operation result. struct stat expected_stat = { .st_ino = attr.ino, .st_nlink = attr.nlink, @@ -124,43 +84,133 @@ TEST_F(StatTest, StatNormal) { .st_size = static_cast<off_t>(attr.size), .st_blksize = attr.blksize, .st_blocks = static_cast<blkcnt_t>(attr.blocks), - .st_atim = atime, - .st_mtim = mtime, - .st_ctim = ctime, + .st_atim = (struct timespec){.tv_sec = static_cast<int>(attr.atime), + .tv_nsec = attr.atimensec}, + .st_mtim = (struct timespec){.tv_sec = static_cast<int>(attr.mtime), + .tv_nsec = attr.mtimensec}, + .st_ctim = (struct timespec){.tv_sec = static_cast<int>(attr.ctime), + .tv_nsec = attr.ctimensec}, }; EXPECT_TRUE(StatsAreEqual(stat_buf, expected_stat)); - WaitCompleted(); -} -TEST_F(StatTest, StatNotFound) { - struct iovec iov_in[2]; - struct iovec iov_out[2]; + // Check FUSE request. + struct fuse_in_header in_header; + struct fuse_getattr_in in_payload; + auto iov_in = FuseGenerateIovecs(in_header, in_payload); - struct fuse_in_header in_header = { - .len = sizeof(struct fuse_in_header) + sizeof(struct fuse_getattr_in), - .opcode = FUSE_GETATTR, - .unique = 4, - }; - struct fuse_getattr_in in_payload = {0}; - iov_in[0].iov_len = sizeof(in_header); - iov_in[0].iov_base = &in_header; - iov_in[1].iov_len = sizeof(in_payload); - iov_in[1].iov_base = &in_payload; + GetServerActualRequest(iov_in); + EXPECT_EQ(in_header.opcode, FUSE_GETATTR); + EXPECT_EQ(in_payload.getattr_flags, 0); + EXPECT_EQ(in_payload.fh, 0); +} +TEST_F(StatTest, StatNotFound) { + // Set up fixture. struct fuse_out_header out_header = { .len = sizeof(struct fuse_out_header), .error = -ENOENT, - .unique = 4, }; - iov_out[0].iov_len = sizeof(out_header); - iov_out[0].iov_base = &out_header; - - SetExpected(iov_in, 2, iov_out, 1); + auto iov_out = FuseGenerateIovecs(out_header); + SetServerResponse(FUSE_GETATTR, iov_out); + // Make syscall. struct stat stat_buf; EXPECT_THAT(stat(mount_point_.path().c_str(), &stat_buf), SyscallFailsWithErrno(ENOENT)); - WaitCompleted(); + + // Check FUSE request. + struct fuse_in_header in_header; + struct fuse_getattr_in in_payload; + auto iov_in = FuseGenerateIovecs(in_header, in_payload); + + GetServerActualRequest(iov_in); + EXPECT_EQ(in_header.opcode, FUSE_GETATTR); + EXPECT_EQ(in_payload.getattr_flags, 0); + EXPECT_EQ(in_payload.fh, 0); +} + +TEST_F(StatTest, FstatNormal) { + // Set up fixture. + SetServerInodeLookup(test_file_); + auto fd = ASSERT_NO_ERRNO_AND_VALUE(OpenPath(test_file_path_, O_RDONLY, fh)); + auto close_fd = CloseFD(fd); + + struct fuse_attr attr = DefaultFuseAttr(expected_mode, 2); + struct fuse_out_header out_header = { + .len = sizeof(struct fuse_out_header) + sizeof(struct fuse_attr_out), + }; + struct fuse_attr_out out_payload = { + .attr = attr, + }; + auto iov_out = FuseGenerateIovecs(out_header, out_payload); + SetServerResponse(FUSE_GETATTR, iov_out); + + // Make syscall. + struct stat stat_buf; + EXPECT_THAT(fstat(fd.get(), &stat_buf), SyscallSucceeds()); + + // Check filesystem operation result. + struct stat expected_stat = { + .st_ino = attr.ino, + .st_nlink = attr.nlink, + .st_mode = expected_mode, + .st_uid = attr.uid, + .st_gid = attr.gid, + .st_rdev = attr.rdev, + .st_size = static_cast<off_t>(attr.size), + .st_blksize = attr.blksize, + .st_blocks = static_cast<blkcnt_t>(attr.blocks), + .st_atim = (struct timespec){.tv_sec = static_cast<int>(attr.atime), + .tv_nsec = attr.atimensec}, + .st_mtim = (struct timespec){.tv_sec = static_cast<int>(attr.mtime), + .tv_nsec = attr.mtimensec}, + .st_ctim = (struct timespec){.tv_sec = static_cast<int>(attr.ctime), + .tv_nsec = attr.ctimensec}, + }; + EXPECT_TRUE(StatsAreEqual(stat_buf, expected_stat)); + + // Check FUSE request. + struct fuse_in_header in_header; + struct fuse_getattr_in in_payload; + auto iov_in = FuseGenerateIovecs(in_header, in_payload); + + GetServerActualRequest(iov_in); + EXPECT_EQ(in_header.opcode, FUSE_GETATTR); + EXPECT_EQ(in_payload.getattr_flags, 0); + EXPECT_EQ(in_payload.fh, 0); +} + +TEST_F(StatTest, StatByFileHandle) { + // Set up fixture. + SetServerInodeLookup(test_file_, expected_mode, 0); + auto fd = ASSERT_NO_ERRNO_AND_VALUE(OpenPath(test_file_path_, O_RDONLY, fh)); + auto close_fd = CloseFD(fd); + + struct fuse_attr attr = DefaultFuseAttr(expected_mode, 2, 0); + struct fuse_out_header out_header = { + .len = sizeof(struct fuse_out_header) + sizeof(struct fuse_attr_out), + }; + struct fuse_attr_out out_payload = { + .attr = attr, + }; + auto iov_out = FuseGenerateIovecs(out_header, out_payload); + SetServerResponse(FUSE_GETATTR, iov_out); + + // Make syscall. + std::vector<char> buf(1); + // Since this is an empty file, it won't issue FUSE_READ. But a FUSE_GETATTR + // will be issued before read completes. + EXPECT_THAT(read(fd.get(), buf.data(), buf.size()), SyscallSucceeds()); + + // Check FUSE request. + struct fuse_in_header in_header; + struct fuse_getattr_in in_payload; + auto iov_in = FuseGenerateIovecs(in_header, in_payload); + + GetServerActualRequest(iov_in); + EXPECT_EQ(in_header.opcode, FUSE_GETATTR); + EXPECT_EQ(in_payload.getattr_flags, FUSE_GETATTR_FH); + EXPECT_EQ(in_payload.fh, fh); } } // namespace diff --git a/test/fuse/linux/symlink_test.cc b/test/fuse/linux/symlink_test.cc new file mode 100644 index 000000000..2c3a52987 --- /dev/null +++ b/test/fuse/linux/symlink_test.cc @@ -0,0 +1,88 @@ +// Copyright 2020 The gVisor Authors. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include <errno.h> +#include <fcntl.h> +#include <linux/fuse.h> +#include <sys/stat.h> +#include <sys/statfs.h> +#include <sys/types.h> +#include <unistd.h> + +#include <string> +#include <vector> + +#include "gtest/gtest.h" +#include "test/fuse/linux/fuse_base.h" +#include "test/util/fuse_util.h" +#include "test/util/test_util.h" + +namespace gvisor { +namespace testing { + +namespace { + +class SymlinkTest : public FuseTest { + protected: + const std::string target_file_ = "target_file_"; + const std::string symlink_ = "symlink_"; + const mode_t perms_ = S_IRWXU | S_IRWXG | S_IRWXO; +}; + +TEST_F(SymlinkTest, CreateSymLink) { + const std::string symlink_path = + JoinPath(mount_point_.path().c_str(), symlink_); + + struct fuse_out_header out_header = { + .len = sizeof(struct fuse_out_header) + sizeof(struct fuse_entry_out), + }; + struct fuse_entry_out out_payload = DefaultEntryOut(S_IFLNK | perms_, 5); + auto iov_out = FuseGenerateIovecs(out_header, out_payload); + SetServerResponse(FUSE_SYMLINK, iov_out); + ASSERT_THAT(symlink(target_file_.c_str(), symlink_path.c_str()), + SyscallSucceeds()); + + struct fuse_in_header in_header; + std::vector<char> actual_target_file(target_file_.length() + 1); + std::vector<char> actual_symlink(symlink_.length() + 1); + auto iov_in = + FuseGenerateIovecs(in_header, actual_symlink, actual_target_file); + GetServerActualRequest(iov_in); + + EXPECT_EQ(in_header.len, + sizeof(in_header) + symlink_.length() + target_file_.length() + 2); + EXPECT_EQ(in_header.opcode, FUSE_SYMLINK); + EXPECT_EQ(std::string(actual_target_file.data()), target_file_); + EXPECT_EQ(std::string(actual_symlink.data()), symlink_); +} + +TEST_F(SymlinkTest, FileTypeError) { + const std::string symlink_path = + JoinPath(mount_point_.path().c_str(), symlink_); + + struct fuse_out_header out_header = { + .len = sizeof(struct fuse_out_header) + sizeof(struct fuse_entry_out), + }; + struct fuse_entry_out out_payload = DefaultEntryOut(S_IFREG | perms_, 5); + auto iov_out = FuseGenerateIovecs(out_header, out_payload); + SetServerResponse(FUSE_SYMLINK, iov_out); + ASSERT_THAT(symlink(target_file_.c_str(), symlink_path.c_str()), + SyscallFailsWithErrno(EIO)); + SkipServerActualRequest(); +} + +} // namespace + +} // namespace testing +} // namespace gvisor diff --git a/test/fuse/linux/unlink_test.cc b/test/fuse/linux/unlink_test.cc new file mode 100644 index 000000000..13efbf7c7 --- /dev/null +++ b/test/fuse/linux/unlink_test.cc @@ -0,0 +1,107 @@ +// Copyright 2020 The gVisor Authors. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include <errno.h> +#include <fcntl.h> +#include <linux/fuse.h> +#include <sys/mount.h> +#include <sys/stat.h> +#include <sys/statfs.h> +#include <sys/types.h> +#include <unistd.h> + +#include <string> +#include <vector> + +#include "gtest/gtest.h" +#include "test/fuse/linux/fuse_base.h" +#include "test/util/fuse_util.h" +#include "test/util/test_util.h" + +namespace gvisor { +namespace testing { + +namespace { + +class UnlinkTest : public FuseTest { + protected: + const std::string test_file_ = "test_file"; + const std::string test_subdir_ = "test_subdir"; +}; + +TEST_F(UnlinkTest, RegularFile) { + const std::string test_file_path = + JoinPath(mount_point_.path().c_str(), test_file_); + SetServerInodeLookup(test_file_, S_IFREG | S_IRWXU | S_IRWXG | S_IRWXO); + + struct fuse_out_header out_header = { + .len = sizeof(struct fuse_out_header), + }; + auto iov_out = FuseGenerateIovecs(out_header); + SetServerResponse(FUSE_UNLINK, iov_out); + + ASSERT_THAT(unlink(test_file_path.c_str()), SyscallSucceeds()); + struct fuse_in_header in_header; + std::vector<char> unlinked_file(test_file_.length() + 1); + auto iov_in = FuseGenerateIovecs(in_header, unlinked_file); + GetServerActualRequest(iov_in); + + EXPECT_EQ(in_header.len, sizeof(in_header) + test_file_.length() + 1); + EXPECT_EQ(in_header.opcode, FUSE_UNLINK); + EXPECT_EQ(std::string(unlinked_file.data()), test_file_); +} + +TEST_F(UnlinkTest, RegularFileSubDir) { + SetServerInodeLookup(test_subdir_, S_IFDIR | S_IRWXU | S_IRWXG | S_IRWXO); + const std::string test_file_path = + JoinPath(mount_point_.path().c_str(), test_subdir_, test_file_); + SetServerInodeLookup(test_file_, S_IFREG | S_IRWXU | S_IRWXG | S_IRWXO); + + struct fuse_out_header out_header = { + .len = sizeof(struct fuse_out_header), + }; + auto iov_out = FuseGenerateIovecs(out_header); + SetServerResponse(FUSE_UNLINK, iov_out); + + ASSERT_THAT(unlink(test_file_path.c_str()), SyscallSucceeds()); + struct fuse_in_header in_header; + std::vector<char> unlinked_file(test_file_.length() + 1); + auto iov_in = FuseGenerateIovecs(in_header, unlinked_file); + GetServerActualRequest(iov_in); + + EXPECT_EQ(in_header.len, sizeof(in_header) + test_file_.length() + 1); + EXPECT_EQ(in_header.opcode, FUSE_UNLINK); + EXPECT_EQ(std::string(unlinked_file.data()), test_file_); +} + +TEST_F(UnlinkTest, NoFile) { + const std::string test_file_path = + JoinPath(mount_point_.path().c_str(), test_file_); + SetServerInodeLookup(test_file_, S_IFREG | S_IRWXU | S_IRWXG | S_IRWXO); + + struct fuse_out_header out_header = { + .len = sizeof(struct fuse_out_header), + .error = -ENOENT, + }; + auto iov_out = FuseGenerateIovecs(out_header); + SetServerResponse(FUSE_UNLINK, iov_out); + + ASSERT_THAT(unlink(test_file_path.c_str()), SyscallFailsWithErrno(ENOENT)); + SkipServerActualRequest(); +} + +} // namespace + +} // namespace testing +} // namespace gvisor diff --git a/test/fuse/linux/write_test.cc b/test/fuse/linux/write_test.cc new file mode 100644 index 000000000..1a62beb96 --- /dev/null +++ b/test/fuse/linux/write_test.cc @@ -0,0 +1,303 @@ +// Copyright 2020 The gVisor Authors. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include <errno.h> +#include <fcntl.h> +#include <linux/fuse.h> +#include <sys/stat.h> +#include <sys/statfs.h> +#include <sys/types.h> +#include <unistd.h> + +#include <string> +#include <vector> + +#include "gtest/gtest.h" +#include "test/fuse/linux/fuse_base.h" +#include "test/util/fuse_util.h" +#include "test/util/test_util.h" + +namespace gvisor { +namespace testing { + +namespace { + +class WriteTest : public FuseTest { + void SetUp() override { + FuseTest::SetUp(); + test_file_path_ = JoinPath(mount_point_.path().c_str(), test_file_); + } + + // TearDown overrides the parent's function + // to skip checking the unconsumed release request at the end. + void TearDown() override { UnmountFuse(); } + + protected: + const std::string test_file_ = "test_file"; + const mode_t test_file_mode_ = S_IFREG | S_IRWXU | S_IRWXG | S_IRWXO; + const uint64_t test_fh_ = 1; + const uint32_t open_flag_ = O_RDWR; + + std::string test_file_path_; + + PosixErrorOr<FileDescriptor> OpenTestFile(const std::string &path, + uint64_t size = 512) { + SetServerInodeLookup(test_file_, test_file_mode_, size); + + struct fuse_out_header out_header_open = { + .len = sizeof(struct fuse_out_header) + sizeof(struct fuse_open_out), + }; + struct fuse_open_out out_payload_open = { + .fh = test_fh_, + .open_flags = open_flag_, + }; + auto iov_out_open = FuseGenerateIovecs(out_header_open, out_payload_open); + SetServerResponse(FUSE_OPEN, iov_out_open); + + auto res = Open(path.c_str(), open_flag_); + if (res.ok()) { + SkipServerActualRequest(); + } + return res; + } +}; + +class WriteTestSmallMaxWrite : public WriteTest { + void SetUp() override { + MountFuse(); + SetUpFuseServer(&fuse_init_payload); + test_file_path_ = JoinPath(mount_point_.path().c_str(), test_file_); + } + + protected: + const static uint32_t max_write_ = 4096; + constexpr static struct fuse_init_out fuse_init_payload = { + .major = 7, + .max_write = max_write_, + }; + + const uint32_t size_fragment = max_write_; +}; + +TEST_F(WriteTest, WriteNormal) { + auto fd = ASSERT_NO_ERRNO_AND_VALUE(OpenTestFile(test_file_path_)); + + // Prepare for the write. + const int n_write = 10; + struct fuse_out_header out_header_write = { + .len = sizeof(struct fuse_out_header) + sizeof(struct fuse_write_out), + }; + struct fuse_write_out out_payload_write = { + .size = n_write, + }; + auto iov_out_write = FuseGenerateIovecs(out_header_write, out_payload_write); + SetServerResponse(FUSE_WRITE, iov_out_write); + + // Issue the write. + std::vector<char> buf(n_write); + RandomizeBuffer(buf.data(), buf.size()); + EXPECT_THAT(write(fd.get(), buf.data(), n_write), + SyscallSucceedsWithValue(n_write)); + + // Check the write request. + struct fuse_in_header in_header_write; + struct fuse_write_in in_payload_write; + std::vector<char> payload_buf(n_write); + auto iov_in_write = + FuseGenerateIovecs(in_header_write, in_payload_write, payload_buf); + GetServerActualRequest(iov_in_write); + + EXPECT_EQ(in_payload_write.fh, test_fh_); + EXPECT_EQ(in_header_write.len, + sizeof(in_header_write) + sizeof(in_payload_write)); + EXPECT_EQ(in_header_write.opcode, FUSE_WRITE); + EXPECT_EQ(in_payload_write.offset, 0); + EXPECT_EQ(in_payload_write.size, n_write); + EXPECT_EQ(buf, payload_buf); +} + +TEST_F(WriteTest, WriteShort) { + auto fd = ASSERT_NO_ERRNO_AND_VALUE(OpenTestFile(test_file_path_)); + + // Prepare for the write. + const int n_write = 10, n_written = 5; + struct fuse_out_header out_header_write = { + .len = sizeof(struct fuse_out_header) + sizeof(struct fuse_write_out), + }; + struct fuse_write_out out_payload_write = { + .size = n_written, + }; + auto iov_out_write = FuseGenerateIovecs(out_header_write, out_payload_write); + SetServerResponse(FUSE_WRITE, iov_out_write); + + // Issue the write. + std::vector<char> buf(n_write); + RandomizeBuffer(buf.data(), buf.size()); + EXPECT_THAT(write(fd.get(), buf.data(), n_write), + SyscallSucceedsWithValue(n_written)); + + // Check the write request. + struct fuse_in_header in_header_write; + struct fuse_write_in in_payload_write; + std::vector<char> payload_buf(n_write); + auto iov_in_write = + FuseGenerateIovecs(in_header_write, in_payload_write, payload_buf); + GetServerActualRequest(iov_in_write); + + EXPECT_EQ(in_payload_write.fh, test_fh_); + EXPECT_EQ(in_header_write.len, + sizeof(in_header_write) + sizeof(in_payload_write)); + EXPECT_EQ(in_header_write.opcode, FUSE_WRITE); + EXPECT_EQ(in_payload_write.offset, 0); + EXPECT_EQ(in_payload_write.size, n_write); + EXPECT_EQ(buf, payload_buf); +} + +TEST_F(WriteTest, WriteShortZero) { + auto fd = ASSERT_NO_ERRNO_AND_VALUE(OpenTestFile(test_file_path_)); + + // Prepare for the write. + const int n_write = 10; + struct fuse_out_header out_header_write = { + .len = sizeof(struct fuse_out_header) + sizeof(struct fuse_write_out), + }; + struct fuse_write_out out_payload_write = { + .size = 0, + }; + auto iov_out_write = FuseGenerateIovecs(out_header_write, out_payload_write); + SetServerResponse(FUSE_WRITE, iov_out_write); + + // Issue the write. + std::vector<char> buf(n_write); + RandomizeBuffer(buf.data(), buf.size()); + EXPECT_THAT(write(fd.get(), buf.data(), n_write), SyscallFailsWithErrno(EIO)); + + // Check the write request. + struct fuse_in_header in_header_write; + struct fuse_write_in in_payload_write; + std::vector<char> payload_buf(n_write); + auto iov_in_write = + FuseGenerateIovecs(in_header_write, in_payload_write, payload_buf); + GetServerActualRequest(iov_in_write); + + EXPECT_EQ(in_payload_write.fh, test_fh_); + EXPECT_EQ(in_header_write.len, + sizeof(in_header_write) + sizeof(in_payload_write)); + EXPECT_EQ(in_header_write.opcode, FUSE_WRITE); + EXPECT_EQ(in_payload_write.offset, 0); + EXPECT_EQ(in_payload_write.size, n_write); + EXPECT_EQ(buf, payload_buf); +} + +TEST_F(WriteTest, WriteZero) { + auto fd = ASSERT_NO_ERRNO_AND_VALUE(OpenTestFile(test_file_path_)); + + // Issue the write. + std::vector<char> buf(0); + EXPECT_THAT(write(fd.get(), buf.data(), 0), SyscallSucceedsWithValue(0)); +} + +TEST_F(WriteTest, PWrite) { + const int file_size = 512; + auto fd = ASSERT_NO_ERRNO_AND_VALUE(OpenTestFile(test_file_path_, file_size)); + + // Prepare for the write. + const int n_write = 10; + struct fuse_out_header out_header_write = { + .len = sizeof(struct fuse_out_header) + sizeof(struct fuse_write_out), + }; + struct fuse_write_out out_payload_write = { + .size = n_write, + }; + auto iov_out_write = FuseGenerateIovecs(out_header_write, out_payload_write); + SetServerResponse(FUSE_WRITE, iov_out_write); + + // Issue the write. + std::vector<char> buf(n_write); + RandomizeBuffer(buf.data(), buf.size()); + const int offset_write = file_size >> 1; + EXPECT_THAT(pwrite(fd.get(), buf.data(), n_write, offset_write), + SyscallSucceedsWithValue(n_write)); + + // Check the write request. + struct fuse_in_header in_header_write; + struct fuse_write_in in_payload_write; + std::vector<char> payload_buf(n_write); + auto iov_in_write = + FuseGenerateIovecs(in_header_write, in_payload_write, payload_buf); + GetServerActualRequest(iov_in_write); + + EXPECT_EQ(in_payload_write.fh, test_fh_); + EXPECT_EQ(in_header_write.len, + sizeof(in_header_write) + sizeof(in_payload_write)); + EXPECT_EQ(in_header_write.opcode, FUSE_WRITE); + EXPECT_EQ(in_payload_write.offset, offset_write); + EXPECT_EQ(in_payload_write.size, n_write); + EXPECT_EQ(buf, payload_buf); +} + +TEST_F(WriteTestSmallMaxWrite, WriteSmallMaxWrie) { + const int n_fragment = 10; + const int n_write = size_fragment * n_fragment; + + auto fd = ASSERT_NO_ERRNO_AND_VALUE(OpenTestFile(test_file_path_, n_write)); + + // Prepare for the write. + struct fuse_out_header out_header_write = { + .len = sizeof(struct fuse_out_header) + sizeof(struct fuse_write_out), + }; + struct fuse_write_out out_payload_write = { + .size = size_fragment, + }; + auto iov_out_write = FuseGenerateIovecs(out_header_write, out_payload_write); + + for (int i = 0; i < n_fragment; ++i) { + SetServerResponse(FUSE_WRITE, iov_out_write); + } + + // Issue the write. + std::vector<char> buf(n_write); + RandomizeBuffer(buf.data(), buf.size()); + EXPECT_THAT(write(fd.get(), buf.data(), n_write), + SyscallSucceedsWithValue(n_write)); + + ASSERT_EQ(GetServerNumUnsentResponses(), 0); + ASSERT_EQ(GetServerNumUnconsumedRequests(), n_fragment); + + // Check the write request. + struct fuse_in_header in_header_write; + struct fuse_write_in in_payload_write; + std::vector<char> payload_buf(size_fragment); + auto iov_in_write = + FuseGenerateIovecs(in_header_write, in_payload_write, payload_buf); + + for (int i = 0; i < n_fragment; ++i) { + GetServerActualRequest(iov_in_write); + + EXPECT_EQ(in_payload_write.fh, test_fh_); + EXPECT_EQ(in_header_write.len, + sizeof(in_header_write) + sizeof(in_payload_write)); + EXPECT_EQ(in_header_write.opcode, FUSE_WRITE); + EXPECT_EQ(in_payload_write.offset, i * size_fragment); + EXPECT_EQ(in_payload_write.size, size_fragment); + + auto it = buf.begin() + i * size_fragment; + EXPECT_EQ(std::vector<char>(it, it + size_fragment), payload_buf); + } +} + +} // namespace + +} // namespace testing +} // namespace gvisor diff --git a/test/image/image_test.go b/test/image/image_test.go index ac6186688..968e62f63 100644 --- a/test/image/image_test.go +++ b/test/image/image_test.go @@ -63,8 +63,8 @@ func TestHelloWorld(t *testing.T) { } } -func runHTTPRequest(port int) error { - url := fmt.Sprintf("http://localhost:%d/not-found", port) +func runHTTPRequest(ip string, port int) error { + url := fmt.Sprintf("http://%s:%d/not-found", ip, port) resp, err := http.Get(url) if err != nil { return fmt.Errorf("error reaching http server: %v", err) @@ -73,7 +73,7 @@ func runHTTPRequest(port int) error { return fmt.Errorf("Wrong response code, got: %d, want: %d", resp.StatusCode, want) } - url = fmt.Sprintf("http://localhost:%d/latin10k.txt", port) + url = fmt.Sprintf("http://%s:%d/latin10k.txt", ip, port) resp, err = http.Get(url) if err != nil { return fmt.Errorf("Error reaching http server: %v", err) @@ -95,13 +95,13 @@ func runHTTPRequest(port int) error { return nil } -func testHTTPServer(t *testing.T, port int) { +func testHTTPServer(t *testing.T, ip string, port int) { const requests = 10 ch := make(chan error, requests) for i := 0; i < requests; i++ { go func() { start := time.Now() - err := runHTTPRequest(port) + err := runHTTPRequest(ip, port) log.Printf("Response time %v: %v", time.Since(start).String(), err) ch <- err }() @@ -110,7 +110,7 @@ func testHTTPServer(t *testing.T, port int) { for i := 0; i < requests; i++ { err := <-ch if err != nil { - t.Errorf("testHTTPServer(%d) failed: %v", port, err) + t.Errorf("testHTTPServer(%s, %d) failed: %v", ip, port, err) } } } @@ -121,27 +121,28 @@ func TestHttpd(t *testing.T) { defer d.CleanUp(ctx) // Start the container. + port := 80 opts := dockerutil.RunOpts{ Image: "basic/httpd", - Ports: []int{80}, + Ports: []int{port}, } d.CopyFiles(&opts, "/usr/local/apache2/htdocs", "test/image/latin10k.txt") if err := d.Spawn(ctx, opts); err != nil { t.Fatalf("docker run failed: %v", err) } - // Find where port 80 is mapped to. - port, err := d.FindPort(ctx, 80) + // Find container IP address. + ip, err := d.FindIP(ctx, false) if err != nil { - t.Fatalf("FindPort(80) failed: %v", err) + t.Fatalf("docker.FindIP failed: %v", err) } // Wait until it's up and running. - if err := testutil.WaitForHTTP(port, defaultWait); err != nil { + if err := testutil.WaitForHTTP(ip.String(), port, defaultWait); err != nil { t.Errorf("WaitForHTTP() timeout: %v", err) } - testHTTPServer(t, port) + testHTTPServer(t, ip.String(), port) } func TestNginx(t *testing.T) { @@ -150,27 +151,28 @@ func TestNginx(t *testing.T) { defer d.CleanUp(ctx) // Start the container. + port := 80 opts := dockerutil.RunOpts{ Image: "basic/nginx", - Ports: []int{80}, + Ports: []int{port}, } d.CopyFiles(&opts, "/usr/share/nginx/html", "test/image/latin10k.txt") if err := d.Spawn(ctx, opts); err != nil { t.Fatalf("docker run failed: %v", err) } - // Find where port 80 is mapped to. - port, err := d.FindPort(ctx, 80) + // Find container IP address. + ip, err := d.FindIP(ctx, false) if err != nil { - t.Fatalf("FindPort(80) failed: %v", err) + t.Fatalf("docker.FindIP failed: %v", err) } // Wait until it's up and running. - if err := testutil.WaitForHTTP(port, defaultWait); err != nil { + if err := testutil.WaitForHTTP(ip.String(), port, defaultWait); err != nil { t.Errorf("WaitForHTTP() timeout: %v", err) } - testHTTPServer(t, port) + testHTTPServer(t, ip.String(), port) } func TestMysql(t *testing.T) { @@ -218,26 +220,27 @@ func TestTomcat(t *testing.T) { defer d.CleanUp(ctx) // Start the server. + port := 8080 if err := d.Spawn(ctx, dockerutil.RunOpts{ Image: "basic/tomcat", - Ports: []int{8080}, + Ports: []int{port}, }); err != nil { t.Fatalf("docker run failed: %v", err) } - // Find where port 8080 is mapped to. - port, err := d.FindPort(ctx, 8080) + // Find container IP address. + ip, err := d.FindIP(ctx, false) if err != nil { - t.Fatalf("FindPort(8080) failed: %v", err) + t.Fatalf("docker.FindIP failed: %v", err) } // Wait until it's up and running. - if err := testutil.WaitForHTTP(port, defaultWait); err != nil { + if err := testutil.WaitForHTTP(ip.String(), port, defaultWait); err != nil { t.Fatalf("WaitForHTTP() timeout: %v", err) } // Ensure that content is being served. - url := fmt.Sprintf("http://localhost:%d", port) + url := fmt.Sprintf("http://%s:%d", ip.String(), port) resp, err := http.Get(url) if err != nil { t.Errorf("Error reaching http server: %v", err) @@ -253,28 +256,29 @@ func TestRuby(t *testing.T) { defer d.CleanUp(ctx) // Execute the ruby workload. + port := 8080 opts := dockerutil.RunOpts{ Image: "basic/ruby", - Ports: []int{8080}, + Ports: []int{port}, } d.CopyFiles(&opts, "/src", "test/image/ruby.rb", "test/image/ruby.sh") if err := d.Spawn(ctx, opts, "/src/ruby.sh"); err != nil { t.Fatalf("docker run failed: %v", err) } - // Find where port 8080 is mapped to. - port, err := d.FindPort(ctx, 8080) + // Find container IP address. + ip, err := d.FindIP(ctx, false) if err != nil { - t.Fatalf("FindPort(8080) failed: %v", err) + t.Fatalf("docker.FindIP failed: %v", err) } // Wait until it's up and running, 'gem install' can take some time. - if err := testutil.WaitForHTTP(port, time.Minute); err != nil { + if err := testutil.WaitForHTTP(ip.String(), port, time.Minute); err != nil { t.Fatalf("WaitForHTTP() timeout: %v", err) } // Ensure that content is being served. - url := fmt.Sprintf("http://localhost:%d", port) + url := fmt.Sprintf("http://%s:%d", ip.String(), port) resp, err := http.Get(url) if err != nil { t.Errorf("error reaching http server: %v", err) diff --git a/test/iptables/iptables_test.go b/test/iptables/iptables_test.go index e2beb30d5..398f70ecd 100644 --- a/test/iptables/iptables_test.go +++ b/test/iptables/iptables_test.go @@ -48,6 +48,13 @@ func singleTest(t *testing.T, test TestCase) { } } +// TODO(gvisor.dev/issue/3549): IPv6 NAT support. +func ipv4Test(t *testing.T, test TestCase) { + t.Run("IPv4", func(t *testing.T) { + iptablesTest(t, test, false) + }) +} + func iptablesTest(t *testing.T, test TestCase, ipv6 bool) { if _, ok := Tests[test.Name()]; !ok { t.Fatalf("no test found with name %q. Has it been registered?", test.Name()) @@ -72,11 +79,6 @@ func iptablesTest(t *testing.T, test TestCase, ipv6 bool) { d.CleanUp(context.Background()) }() - // TODO(gvisor.dev/issue/170): Skipping IPv6 gVisor tests. - if ipv6 && dockerutil.Runtime() != "runc" { - t.Skip("gVisor ip6tables not yet implemented") - } - // Create and start the container. opts := dockerutil.RunOpts{ Image: "iptables", @@ -314,75 +316,75 @@ func TestInputInvertDestination(t *testing.T) { singleTest(t, FilterInputInvertDestination{}) } -func TestOutputDestination(t *testing.T) { +func TestFilterOutputDestination(t *testing.T) { singleTest(t, FilterOutputDestination{}) } -func TestOutputInvertDestination(t *testing.T) { +func TestFilterOutputInvertDestination(t *testing.T) { singleTest(t, FilterOutputInvertDestination{}) } func TestNATPreRedirectUDPPort(t *testing.T) { - singleTest(t, NATPreRedirectUDPPort{}) + ipv4Test(t, NATPreRedirectUDPPort{}) } func TestNATPreRedirectTCPPort(t *testing.T) { - singleTest(t, NATPreRedirectTCPPort{}) + ipv4Test(t, NATPreRedirectTCPPort{}) } func TestNATPreRedirectTCPOutgoing(t *testing.T) { - singleTest(t, NATPreRedirectTCPOutgoing{}) + ipv4Test(t, NATPreRedirectTCPOutgoing{}) } func TestNATOutRedirectTCPIncoming(t *testing.T) { - singleTest(t, NATOutRedirectTCPIncoming{}) + ipv4Test(t, NATOutRedirectTCPIncoming{}) } func TestNATOutRedirectUDPPort(t *testing.T) { - singleTest(t, NATOutRedirectUDPPort{}) + ipv4Test(t, NATOutRedirectUDPPort{}) } func TestNATOutRedirectTCPPort(t *testing.T) { - singleTest(t, NATOutRedirectTCPPort{}) + ipv4Test(t, NATOutRedirectTCPPort{}) } func TestNATDropUDP(t *testing.T) { - singleTest(t, NATDropUDP{}) + ipv4Test(t, NATDropUDP{}) } func TestNATAcceptAll(t *testing.T) { - singleTest(t, NATAcceptAll{}) + ipv4Test(t, NATAcceptAll{}) } func TestNATOutRedirectIP(t *testing.T) { - singleTest(t, NATOutRedirectIP{}) + ipv4Test(t, NATOutRedirectIP{}) } func TestNATOutDontRedirectIP(t *testing.T) { - singleTest(t, NATOutDontRedirectIP{}) + ipv4Test(t, NATOutDontRedirectIP{}) } func TestNATOutRedirectInvert(t *testing.T) { - singleTest(t, NATOutRedirectInvert{}) + ipv4Test(t, NATOutRedirectInvert{}) } func TestNATPreRedirectIP(t *testing.T) { - singleTest(t, NATPreRedirectIP{}) + ipv4Test(t, NATPreRedirectIP{}) } func TestNATPreDontRedirectIP(t *testing.T) { - singleTest(t, NATPreDontRedirectIP{}) + ipv4Test(t, NATPreDontRedirectIP{}) } func TestNATPreRedirectInvert(t *testing.T) { - singleTest(t, NATPreRedirectInvert{}) + ipv4Test(t, NATPreRedirectInvert{}) } func TestNATRedirectRequiresProtocol(t *testing.T) { - singleTest(t, NATRedirectRequiresProtocol{}) + ipv4Test(t, NATRedirectRequiresProtocol{}) } func TestNATLoopbackSkipsPrerouting(t *testing.T) { - singleTest(t, NATLoopbackSkipsPrerouting{}) + ipv4Test(t, NATLoopbackSkipsPrerouting{}) } func TestInputSource(t *testing.T) { @@ -419,9 +421,9 @@ func TestFilterAddrs(t *testing.T) { } func TestNATPreOriginalDst(t *testing.T) { - singleTest(t, NATPreOriginalDst{}) + ipv4Test(t, NATPreOriginalDst{}) } func TestNATOutOriginalDst(t *testing.T) { - singleTest(t, NATOutOriginalDst{}) + ipv4Test(t, NATOutOriginalDst{}) } diff --git a/test/packetimpact/dut/BUILD b/test/packetimpact/dut/BUILD index 3ce63c2c6..ccf1c735f 100644 --- a/test/packetimpact/dut/BUILD +++ b/test/packetimpact/dut/BUILD @@ -16,3 +16,13 @@ cc_binary( "//test/packetimpact/proto:posix_server_cc_proto", ], ) + +cc_binary( + name = "posix_server_dynamic", + srcs = ["posix_server.cc"], + deps = [ + grpcpp, + "//test/packetimpact/proto:posix_server_cc_grpc_proto", + "//test/packetimpact/proto:posix_server_cc_proto", + ], +) diff --git a/test/packetimpact/dut/posix_server.cc b/test/packetimpact/dut/posix_server.cc index 29d4cc6fe..4de8540f6 100644 --- a/test/packetimpact/dut/posix_server.cc +++ b/test/packetimpact/dut/posix_server.cc @@ -21,6 +21,7 @@ #include <string.h> #include <sys/socket.h> #include <sys/types.h> +#include <time.h> #include <unistd.h> #include <iostream> @@ -28,6 +29,7 @@ #include "include/grpcpp/security/server_credentials.h" #include "include/grpcpp/server_builder.h" +#include "include/grpcpp/server_context.h" #include "test/packetimpact/proto/posix_server.grpc.pb.h" #include "test/packetimpact/proto/posix_server.pb.h" @@ -108,18 +110,20 @@ } class PosixImpl final : public posix_server::Posix::Service { - ::grpc::Status Accept(grpc_impl::ServerContext *context, + ::grpc::Status Accept(grpc::ServerContext *context, const ::posix_server::AcceptRequest *request, ::posix_server::AcceptResponse *response) override { sockaddr_storage addr; socklen_t addrlen = sizeof(addr); response->set_fd(accept(request->sockfd(), reinterpret_cast<sockaddr *>(&addr), &addrlen)); - response->set_errno_(errno); + if (response->fd() < 0) { + response->set_errno_(errno); + } return sockaddr_to_proto(addr, addrlen, response->mutable_addr()); } - ::grpc::Status Bind(grpc_impl::ServerContext *context, + ::grpc::Status Bind(grpc::ServerContext *context, const ::posix_server::BindRequest *request, ::posix_server::BindResponse *response) override { if (!request->has_addr()) { @@ -136,19 +140,23 @@ class PosixImpl final : public posix_server::Posix::Service { response->set_ret( bind(request->sockfd(), reinterpret_cast<sockaddr *>(&addr), addr_len)); - response->set_errno_(errno); + if (response->ret() < 0) { + response->set_errno_(errno); + } return ::grpc::Status::OK; } - ::grpc::Status Close(grpc_impl::ServerContext *context, + ::grpc::Status Close(grpc::ServerContext *context, const ::posix_server::CloseRequest *request, ::posix_server::CloseResponse *response) override { response->set_ret(close(request->fd())); - response->set_errno_(errno); + if (response->ret() < 0) { + response->set_errno_(errno); + } return ::grpc::Status::OK; } - ::grpc::Status Connect(grpc_impl::ServerContext *context, + ::grpc::Status Connect(grpc::ServerContext *context, const ::posix_server::ConnectRequest *request, ::posix_server::ConnectResponse *response) override { if (!request->has_addr()) { @@ -164,32 +172,38 @@ class PosixImpl final : public posix_server::Posix::Service { response->set_ret(connect(request->sockfd(), reinterpret_cast<sockaddr *>(&addr), addr_len)); - response->set_errno_(errno); + if (response->ret() < 0) { + response->set_errno_(errno); + } return ::grpc::Status::OK; } - ::grpc::Status Fcntl(grpc_impl::ServerContext *context, + ::grpc::Status Fcntl(grpc::ServerContext *context, const ::posix_server::FcntlRequest *request, ::posix_server::FcntlResponse *response) override { response->set_ret(::fcntl(request->fd(), request->cmd(), request->arg())); - response->set_errno_(errno); + if (response->ret() < 0) { + response->set_errno_(errno); + } return ::grpc::Status::OK; } ::grpc::Status GetSockName( - grpc_impl::ServerContext *context, + grpc::ServerContext *context, const ::posix_server::GetSockNameRequest *request, ::posix_server::GetSockNameResponse *response) override { sockaddr_storage addr; socklen_t addrlen = sizeof(addr); response->set_ret(getsockname( request->sockfd(), reinterpret_cast<sockaddr *>(&addr), &addrlen)); - response->set_errno_(errno); + if (response->ret() < 0) { + response->set_errno_(errno); + } return sockaddr_to_proto(addr, addrlen, response->mutable_addr()); } ::grpc::Status GetSockOpt( - grpc_impl::ServerContext *context, + grpc::ServerContext *context, const ::posix_server::GetSockOptRequest *request, ::posix_server::GetSockOptResponse *response) override { switch (request->type()) { @@ -226,15 +240,19 @@ class PosixImpl final : public posix_server::Posix::Service { return ::grpc::Status(grpc::StatusCode::INVALID_ARGUMENT, "Unknown SockOpt Type"); } - response->set_errno_(errno); + if (response->ret() < 0) { + response->set_errno_(errno); + } return ::grpc::Status::OK; } - ::grpc::Status Listen(grpc_impl::ServerContext *context, + ::grpc::Status Listen(grpc::ServerContext *context, const ::posix_server::ListenRequest *request, ::posix_server::ListenResponse *response) override { response->set_ret(listen(request->sockfd(), request->backlog())); - response->set_errno_(errno); + if (response->ret() < 0) { + response->set_errno_(errno); + } return ::grpc::Status::OK; } @@ -243,7 +261,9 @@ class PosixImpl final : public posix_server::Posix::Service { ::posix_server::SendResponse *response) override { response->set_ret(::send(request->sockfd(), request->buf().data(), request->buf().size(), request->flags())); - response->set_errno_(errno); + if (response->ret() < 0) { + response->set_errno_(errno); + } return ::grpc::Status::OK; } @@ -264,12 +284,14 @@ class PosixImpl final : public posix_server::Posix::Service { response->set_ret(::sendto(request->sockfd(), request->buf().data(), request->buf().size(), request->flags(), reinterpret_cast<sockaddr *>(&addr), addr_len)); - response->set_errno_(errno); + if (response->ret() < 0) { + response->set_errno_(errno); + } return ::grpc::Status::OK; } ::grpc::Status SetSockOpt( - grpc_impl::ServerContext *context, + grpc::ServerContext *context, const ::posix_server::SetSockOptRequest *request, ::posix_server::SetSockOptResponse *response) override { switch (request->optval().val_case()) { @@ -286,9 +308,9 @@ class PosixImpl final : public posix_server::Posix::Service { break; } case ::posix_server::SockOptVal::kTimeval: { - timeval tv = {.tv_sec = static_cast<__time_t>( + timeval tv = {.tv_sec = static_cast<time_t>( request->optval().timeval().seconds()), - .tv_usec = static_cast<__suseconds_t>( + .tv_usec = static_cast<suseconds_t>( request->optval().timeval().microseconds())}; response->set_ret(setsockopt(request->sockfd(), request->level(), request->optname(), &tv, sizeof(tv))); @@ -298,16 +320,29 @@ class PosixImpl final : public posix_server::Posix::Service { return ::grpc::Status(grpc::StatusCode::INVALID_ARGUMENT, "Unknown SockOpt Type"); } - response->set_errno_(errno); + if (response->ret() < 0) { + response->set_errno_(errno); + } return ::grpc::Status::OK; } - ::grpc::Status Socket(grpc_impl::ServerContext *context, + ::grpc::Status Socket(grpc::ServerContext *context, const ::posix_server::SocketRequest *request, ::posix_server::SocketResponse *response) override { response->set_fd( socket(request->domain(), request->type(), request->protocol())); - response->set_errno_(errno); + if (response->fd() < 0) { + response->set_errno_(errno); + } + return ::grpc::Status::OK; + } + + ::grpc::Status Shutdown(grpc::ServerContext *context, + const ::posix_server::ShutdownRequest *request, + ::posix_server::ShutdownResponse *response) override { + if (shutdown(request->fd(), request->how()) < 0) { + response->set_errno_(errno); + } return ::grpc::Status::OK; } @@ -320,7 +355,9 @@ class PosixImpl final : public posix_server::Posix::Service { if (response->ret() >= 0) { response->set_buf(buf.data(), response->ret()); } - response->set_errno_(errno); + if (response->ret() < 0) { + response->set_errno_(errno); + } return ::grpc::Status::OK; } }; diff --git a/test/packetimpact/proto/posix_server.proto b/test/packetimpact/proto/posix_server.proto index ccd20b10d..f32ed54ef 100644 --- a/test/packetimpact/proto/posix_server.proto +++ b/test/packetimpact/proto/posix_server.proto @@ -188,6 +188,15 @@ message SocketResponse { int32 errno_ = 2; // "errno" may fail to compile in c++. } +message ShutdownRequest { + int32 fd = 1; + int32 how = 2; +} + +message ShutdownResponse { + int32 errno_ = 1; // "errno" may fail to compile in c++. +} + message RecvRequest { int32 sockfd = 1; int32 len = 2; @@ -225,6 +234,8 @@ service Posix { rpc SetSockOpt(SetSockOptRequest) returns (SetSockOptResponse); // Call socket() on the DUT. rpc Socket(SocketRequest) returns (SocketResponse); + // Call shutdown() on the DUT. + rpc Shutdown(ShutdownRequest) returns (ShutdownResponse); // Call recv() on the DUT. rpc Recv(RecvRequest) returns (RecvResponse); } diff --git a/test/packetimpact/runner/BUILD b/test/packetimpact/runner/BUILD index ff2be9b30..605dd4972 100644 --- a/test/packetimpact/runner/BUILD +++ b/test/packetimpact/runner/BUILD @@ -1,4 +1,4 @@ -load("//tools:defs.bzl", "bzl_library", "go_test") +load("//tools:defs.bzl", "bzl_library", "go_library", "go_test") package( default_visibility = ["//test/packetimpact:__subpackages__"], @@ -7,21 +7,31 @@ package( go_test( name = "packetimpact_test", - srcs = ["packetimpact_test.go"], + srcs = [ + "packetimpact_test.go", + ], tags = [ # Not intended to be run directly. "local", "manual", ], - deps = [ - "//pkg/test/dockerutil", - "//test/packetimpact/netdevs", - "@com_github_docker_docker//api/types/mount:go_default_library", - ], + deps = [":runner"], ) bzl_library( name = "defs_bzl", srcs = ["defs.bzl"], - visibility = ["//visibility:private"], + visibility = ["//test/packetimpact:__subpackages__"], +) + +go_library( + name = "runner", + testonly = True, + srcs = ["dut.go"], + visibility = ["//test/packetimpact:__subpackages__"], + deps = [ + "//pkg/test/dockerutil", + "//test/packetimpact/netdevs", + "@com_github_docker_docker//api/types/mount:go_default_library", + ], ) diff --git a/test/packetimpact/runner/defs.bzl b/test/packetimpact/runner/defs.bzl index 93a36c6c2..f56d3c42e 100644 --- a/test/packetimpact/runner/defs.bzl +++ b/test/packetimpact/runner/defs.bzl @@ -23,8 +23,9 @@ def _packetimpact_test_impl(ctx): transitive_files = [] if hasattr(ctx.attr._test_runner, "data_runfiles"): transitive_files.append(ctx.attr._test_runner.data_runfiles.files) + files = [test_runner] + ctx.files.testbench_binary + ctx.files._posix_server runfiles = ctx.runfiles( - files = [test_runner] + ctx.files.testbench_binary + ctx.files._posix_server_binary, + files = files, transitive_files = depset(transitive = transitive_files), collect_default = True, collect_data = True, @@ -38,7 +39,7 @@ _packetimpact_test = rule( cfg = "target", default = ":packetimpact_test", ), - "_posix_server_binary": attr.label( + "_posix_server": attr.label( cfg = "target", default = "//test/packetimpact/dut:posix_server", ), @@ -125,6 +126,7 @@ def packetimpact_go_test(name, size = "small", pure = True, expect_native_failur name = testbench_binary, size = size, pure = pure, + nogo = False, # FIXME(gvisor.dev/issue/3374): Not working with all build systems. tags = [ "local", "manual", diff --git a/test/packetimpact/runner/dut.go b/test/packetimpact/runner/dut.go new file mode 100644 index 000000000..96a0fb6c8 --- /dev/null +++ b/test/packetimpact/runner/dut.go @@ -0,0 +1,442 @@ +// Copyright 2020 The gVisor Authors. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +// Package runner starts docker containers and networking for a packetimpact test. +package runner + +import ( + "context" + "flag" + "fmt" + "io/ioutil" + "log" + "math/rand" + "net" + "os" + "os/exec" + "path" + "path/filepath" + "strings" + "testing" + "time" + + "github.com/docker/docker/api/types/mount" + "gvisor.dev/gvisor/pkg/test/dockerutil" + "gvisor.dev/gvisor/test/packetimpact/netdevs" +) + +// stringList implements flag.Value. +type stringList []string + +// String implements flag.Value.String. +func (l *stringList) String() string { + return strings.Join(*l, ",") +} + +// Set implements flag.Value.Set. +func (l *stringList) Set(value string) error { + *l = append(*l, value) + return nil +} + +var ( + native = false + testbenchBinary = "" + tshark = false + extraTestArgs = stringList{} + expectFailure = false + + // DutAddr is the IP addres for DUT. + DutAddr = net.IPv4(0, 0, 0, 10) + testbenchAddr = net.IPv4(0, 0, 0, 20) +) + +// RegisterFlags defines flags and associates them with the package-level +// exported variables above. It should be called by tests in their init +// functions. +func RegisterFlags(fs *flag.FlagSet) { + fs.BoolVar(&native, "native", false, "whether the test should be run natively") + fs.StringVar(&testbenchBinary, "testbench_binary", "", "path to the testbench binary") + fs.BoolVar(&tshark, "tshark", false, "use more verbose tshark in logs instead of tcpdump") + flag.Var(&extraTestArgs, "extra_test_arg", "extra arguments to pass to the testbench") + flag.BoolVar(&expectFailure, "expect_failure", false, "expect that the test will fail when run") +} + +// CtrlPort is the port that posix_server listens on. +const CtrlPort = "40000" + +// logger implements testutil.Logger. +// +// Labels logs based on their source and formats multi-line logs. +type logger string + +// Name implements testutil.Logger.Name. +func (l logger) Name() string { + return string(l) +} + +// Logf implements testutil.Logger.Logf. +func (l logger) Logf(format string, args ...interface{}) { + lines := strings.Split(fmt.Sprintf(format, args...), "\n") + log.Printf("%s: %s", l, lines[0]) + for _, line := range lines[1:] { + log.Printf("%*s %s", len(l), "", line) + } +} + +// TestWithDUT runs a packetimpact test with the given information. +func TestWithDUT(ctx context.Context, t *testing.T, mkDevice func(*dockerutil.Container) DUT, containerAddr net.IP) { + if testbenchBinary == "" { + t.Fatal("--testbench_binary is missing") + } + dockerutil.EnsureSupportedDockerVersion() + + // Create the networks needed for the test. One control network is needed for + // the gRPC control packets and one test network on which to transmit the test + // packets. + ctrlNet := dockerutil.NewNetwork(ctx, logger("ctrlNet")) + testNet := dockerutil.NewNetwork(ctx, logger("testNet")) + for _, dn := range []*dockerutil.Network{ctrlNet, testNet} { + for { + if err := createDockerNetwork(ctx, dn); err != nil { + t.Log("creating docker network:", err) + const wait = 100 * time.Millisecond + t.Logf("sleeping %s and will try creating docker network again", wait) + // This can fail if another docker network claimed the same IP so we'll + // just try again. + time.Sleep(wait) + continue + } + break + } + dn := dn + t.Cleanup(func() { + if err := dn.Cleanup(ctx); err != nil { + t.Errorf("unable to cleanup container %s: %s", dn.Name, err) + } + }) + // Sanity check. + if inspect, err := dn.Inspect(ctx); err != nil { + t.Fatalf("failed to inspect network %s: %v", dn.Name, err) + } else if inspect.Name != dn.Name { + t.Fatalf("name mismatch for network want: %s got: %s", dn.Name, inspect.Name) + } + } + + tmpDir, err := ioutil.TempDir("", "container-output") + if err != nil { + t.Fatal("creating temp dir:", err) + } + t.Cleanup(func() { + if err := exec.Command("/bin/cp", "-r", tmpDir, os.Getenv("TEST_UNDECLARED_OUTPUTS_DIR")).Run(); err != nil { + t.Errorf("unable to copy container output files: %s", err) + } + if err := os.RemoveAll(tmpDir); err != nil { + t.Errorf("failed to remove tmpDir %s: %s", tmpDir, err) + } + }) + + const testOutputDir = "/tmp/testoutput" + + // Create the Docker container for the DUT. + var dut *dockerutil.Container + if native { + dut = dockerutil.MakeNativeContainer(ctx, logger("dut")) + } else { + dut = dockerutil.MakeContainer(ctx, logger("dut")) + } + t.Cleanup(func() { + dut.CleanUp(ctx) + }) + + runOpts := dockerutil.RunOpts{ + Image: "packetimpact", + CapAdd: []string{"NET_ADMIN"}, + Mounts: []mount.Mount{{ + Type: mount.TypeBind, + Source: tmpDir, + Target: testOutputDir, + ReadOnly: false, + }}, + } + + device := mkDevice(dut) + remoteIPv6, remoteMAC, dutDeviceID, dutTestNetDev := device.Prepare(ctx, t, runOpts, ctrlNet, testNet, containerAddr) + + // Create the Docker container for the testbench. + testbench := dockerutil.MakeNativeContainer(ctx, logger("testbench")) + + tbb := path.Base(testbenchBinary) + containerTestbenchBinary := filepath.Join("/packetimpact", tbb) + testbench.CopyFiles(&runOpts, "/packetimpact", filepath.Join("test/packetimpact/tests", tbb)) + + // snifferNetDev is a network device on the test orchestrator that we will + // run sniffer (tcpdump or tshark) on and inject traffic to, not to be + // confused with the device on the DUT. + const snifferNetDev = "eth2" + // Run tcpdump in the test bench unbuffered, without DNS resolution, just on + // the interface with the test packets. + snifferArgs := []string{ + "tcpdump", + "-S", "-vvv", "-U", "-n", + "-i", snifferNetDev, + "-w", testOutputDir + "/dump.pcap", + } + snifferRegex := "tcpdump: listening.*\n" + if tshark { + // Run tshark in the test bench unbuffered, without DNS resolution, just on + // the interface with the test packets. + snifferArgs = []string{ + "tshark", "-V", "-l", "-n", "-i", snifferNetDev, + "-o", "tcp.check_checksum:TRUE", + "-o", "udp.check_checksum:TRUE", + } + snifferRegex = "Capturing on.*\n" + } + + if err := StartContainer( + ctx, + runOpts, + testbench, + testbenchAddr, + []*dockerutil.Network{ctrlNet, testNet}, + snifferArgs..., + ); err != nil { + t.Fatalf("failed to start docker container for testbench sniffer: %s", err) + } + // Kill so that it will flush output. + t.Cleanup(func() { + time.Sleep(1 * time.Second) + testbench.Exec(ctx, dockerutil.ExecOpts{}, "killall", snifferArgs[0]) + }) + + if _, err := testbench.WaitForOutput(ctx, snifferRegex, 60*time.Second); err != nil { + t.Fatalf("sniffer on %s never listened: %s", dut.Name, err) + } + + // When the Linux kernel receives a SYN-ACK for a SYN it didn't send, it + // will respond with an RST. In most packetimpact tests, the SYN is sent + // by the raw socket and the kernel knows nothing about the connection, this + // behavior will break lots of TCP related packetimpact tests. To prevent + // this, we can install the following iptables rules. The raw socket that + // packetimpact tests use will still be able to see everything. + for _, bin := range []string{"iptables", "ip6tables"} { + if logs, err := testbench.Exec(ctx, dockerutil.ExecOpts{}, bin, "-A", "INPUT", "-i", snifferNetDev, "-p", "tcp", "-j", "DROP"); err != nil { + t.Fatalf("unable to Exec %s on container %s: %s, logs from testbench:\n%s", bin, testbench.Name, err, logs) + } + } + + // FIXME(b/156449515): Some piece of the system has a race. The old + // bash script version had a sleep, so we have one too. The race should + // be fixed and this sleep removed. + time.Sleep(time.Second) + + // Start a packetimpact test on the test bench. The packetimpact test sends + // and receives packets and also sends POSIX socket commands to the + // posix_server to be executed on the DUT. + testArgs := []string{containerTestbenchBinary} + testArgs = append(testArgs, extraTestArgs...) + testArgs = append(testArgs, + "--posix_server_ip", AddressInSubnet(DutAddr, *ctrlNet.Subnet).String(), + "--posix_server_port", CtrlPort, + "--remote_ipv4", AddressInSubnet(DutAddr, *testNet.Subnet).String(), + "--local_ipv4", AddressInSubnet(testbenchAddr, *testNet.Subnet).String(), + "--remote_ipv6", remoteIPv6.String(), + "--remote_mac", remoteMAC.String(), + "--remote_interface_id", fmt.Sprintf("%d", dutDeviceID), + "--local_device", snifferNetDev, + "--remote_device", dutTestNetDev, + fmt.Sprintf("--native=%t", native), + ) + testbenchLogs, err := testbench.Exec(ctx, dockerutil.ExecOpts{}, testArgs...) + if (err != nil) != expectFailure { + var dutLogs string + if logs, err := device.Logs(ctx); err != nil { + dutLogs = fmt.Sprintf("failed to fetch DUT logs: %s", err) + } else { + dutLogs = logs + } + + t.Errorf(`test error: %v, expect failure: %t + +%s + +====== Begin of Testbench Logs ====== + +%s + +====== End of Testbench Logs ======`, + err, expectFailure, dutLogs, testbenchLogs) + } +} + +// DUT describes how to setup/teardown the dut for packetimpact tests. +type DUT interface { + // Prepare prepares the dut, starts posix_server and returns the IPv6, MAC + // address, the interface ID, and the interface name for the testNet on DUT. + Prepare(ctx context.Context, t *testing.T, runOpts dockerutil.RunOpts, ctrlNet, testNet *dockerutil.Network, containerAddr net.IP) (net.IP, net.HardwareAddr, uint32, string) + // Logs retrieves the logs from the dut. + Logs(ctx context.Context) (string, error) +} + +// DockerDUT describes a docker based DUT. +type DockerDUT struct { + c *dockerutil.Container +} + +// NewDockerDUT creates a docker based DUT. +func NewDockerDUT(c *dockerutil.Container) DUT { + return &DockerDUT{ + c: c, + } +} + +// Prepare implements DUT.Prepare. +func (dut *DockerDUT) Prepare(ctx context.Context, t *testing.T, runOpts dockerutil.RunOpts, ctrlNet, testNet *dockerutil.Network, containerAddr net.IP) (net.IP, net.HardwareAddr, uint32, string) { + const containerPosixServerBinary = "/packetimpact/posix_server" + dut.c.CopyFiles(&runOpts, "/packetimpact", "test/packetimpact/dut/posix_server") + + if err := StartContainer( + ctx, + runOpts, + dut.c, + containerAddr, + []*dockerutil.Network{ctrlNet, testNet}, + containerPosixServerBinary, + "--ip=0.0.0.0", + "--port="+CtrlPort, + ); err != nil { + t.Fatalf("failed to start docker container for DUT: %s", err) + } + + if _, err := dut.c.WaitForOutput(ctx, "Server listening.*\n", 60*time.Second); err != nil { + t.Fatalf("%s on container %s never listened: %s", containerPosixServerBinary, dut.c.Name, err) + } + + dutTestDevice, dutDeviceInfo, err := deviceByIP(ctx, dut.c, AddressInSubnet(containerAddr, *testNet.Subnet)) + if err != nil { + t.Fatal(err) + } + + remoteMAC := dutDeviceInfo.MAC + remoteIPv6 := dutDeviceInfo.IPv6Addr + // Netstack as DUT doesn't assign IPv6 addresses automatically so do it if + // needed. + if remoteIPv6 == nil { + if _, err := dut.c.Exec(ctx, dockerutil.ExecOpts{}, "ip", "addr", "add", netdevs.MACToIP(remoteMAC).String(), "scope", "link", "dev", dutTestDevice); err != nil { + t.Fatalf("unable to ip addr add on container %s: %s", dut.c.Name, err) + } + // Now try again, to make sure that it worked. + _, dutDeviceInfo, err = deviceByIP(ctx, dut.c, AddressInSubnet(containerAddr, *testNet.Subnet)) + if err != nil { + t.Fatal(err) + } + remoteIPv6 = dutDeviceInfo.IPv6Addr + if remoteIPv6 == nil { + t.Fatalf("unable to set IPv6 address on container %s", dut.c.Name) + } + } + const testNetDev = "eth2" + + return remoteIPv6, dutDeviceInfo.MAC, dutDeviceInfo.ID, testNetDev +} + +// Logs implements DUT.Logs. +func (dut *DockerDUT) Logs(ctx context.Context) (string, error) { + logs, err := dut.c.Logs(ctx) + if err != nil { + return "", err + } + return fmt.Sprintf(`====== Begin of DUT Logs ====== + +%s + +====== End of DUT Logs ======`, logs), nil +} + +// AddNetworks connects docker network with the container and assigns the specific IP. +func AddNetworks(ctx context.Context, d *dockerutil.Container, addr net.IP, networks []*dockerutil.Network) error { + for _, dn := range networks { + ip := AddressInSubnet(addr, *dn.Subnet) + // Connect to the network with the specified IP address. + if err := dn.Connect(ctx, d, ip.String(), ""); err != nil { + return fmt.Errorf("unable to connect container %s to network %s: %w", d.Name, dn.Name, err) + } + } + return nil +} + +// AddressInSubnet combines the subnet provided with the address and returns a +// new address. The return address bits come from the subnet where the mask is 1 +// and from the ip address where the mask is 0. +func AddressInSubnet(addr net.IP, subnet net.IPNet) net.IP { + var octets []byte + for i := 0; i < 4; i++ { + octets = append(octets, (subnet.IP.To4()[i]&subnet.Mask[i])+(addr.To4()[i]&(^subnet.Mask[i]))) + } + return net.IP(octets) +} + +// deviceByIP finds a deviceInfo and device name from an IP address. +func deviceByIP(ctx context.Context, d *dockerutil.Container, ip net.IP) (string, netdevs.DeviceInfo, error) { + out, err := d.Exec(ctx, dockerutil.ExecOpts{}, "ip", "addr", "show") + if err != nil { + return "", netdevs.DeviceInfo{}, fmt.Errorf("listing devices on %s container: %w\n%s", d.Name, err, out) + } + devs, err := netdevs.ParseDevices(out) + if err != nil { + return "", netdevs.DeviceInfo{}, fmt.Errorf("parsing devices from %s container: %w\n%s", d.Name, err, out) + } + testDevice, deviceInfo, err := netdevs.FindDeviceByIP(ip, devs) + if err != nil { + return "", netdevs.DeviceInfo{}, fmt.Errorf("can't find deviceInfo for container %s: %w", d.Name, err) + } + return testDevice, deviceInfo, nil +} + +// createDockerNetwork makes a randomly-named network that will start with the +// namePrefix. The network will be a random /24 subnet. +func createDockerNetwork(ctx context.Context, n *dockerutil.Network) error { + randSource := rand.NewSource(time.Now().UnixNano()) + r1 := rand.New(randSource) + // Class C, 192.0.0.0 to 223.255.255.255, transitionally has mask 24. + ip := net.IPv4(byte(r1.Intn(224-192)+192), byte(r1.Intn(256)), byte(r1.Intn(256)), 0) + n.Subnet = &net.IPNet{ + IP: ip, + Mask: ip.DefaultMask(), + } + return n.Create(ctx) +} + +// StartContainer will create a container instance from runOpts, connect it +// with the specified docker networks and start executing the specified cmd. +func StartContainer(ctx context.Context, runOpts dockerutil.RunOpts, c *dockerutil.Container, containerAddr net.IP, ns []*dockerutil.Network, cmd ...string) error { + conf, hostconf, netconf := c.ConfigsFrom(runOpts, cmd...) + _ = netconf + hostconf.AutoRemove = true + hostconf.Sysctls = map[string]string{"net.ipv6.conf.all.disable_ipv6": "0"} + + if err := c.CreateFrom(ctx, conf, hostconf, nil); err != nil { + return fmt.Errorf("unable to create container %s: %w", c.Name, err) + } + + if err := AddNetworks(ctx, c, containerAddr, ns); err != nil { + return fmt.Errorf("unable to connect the container with the networks: %w", err) + } + + if err := c.Start(ctx); err != nil { + return fmt.Errorf("unable to start container %s: %w", c.Name, err) + } + return nil +} diff --git a/test/packetimpact/runner/packetimpact_test.go b/test/packetimpact/runner/packetimpact_test.go index e8c183977..c598bfc29 100644 --- a/test/packetimpact/runner/packetimpact_test.go +++ b/test/packetimpact/runner/packetimpact_test.go @@ -18,366 +18,15 @@ package packetimpact_test import ( "context" "flag" - "fmt" - "io/ioutil" - "log" - "math/rand" - "net" - "os" - "os/exec" - "path" - "strings" "testing" - "time" - "github.com/docker/docker/api/types/mount" - "gvisor.dev/gvisor/pkg/test/dockerutil" - "gvisor.dev/gvisor/test/packetimpact/netdevs" + "gvisor.dev/gvisor/test/packetimpact/runner" ) -// stringList implements flag.Value. -type stringList []string - -// String implements flag.Value.String. -func (l *stringList) String() string { - return strings.Join(*l, ",") -} - -// Set implements flag.Value.Set. -func (l *stringList) Set(value string) error { - *l = append(*l, value) - return nil -} - -var ( - native = flag.Bool("native", false, "whether the test should be run natively") - testbenchBinary = flag.String("testbench_binary", "", "path to the testbench binary") - tshark = flag.Bool("tshark", false, "use more verbose tshark in logs instead of tcpdump") - extraTestArgs = stringList{} - expectFailure = flag.Bool("expect_failure", false, "expect that the test will fail when run") - - dutAddr = net.IPv4(0, 0, 0, 10) - testbenchAddr = net.IPv4(0, 0, 0, 20) -) - -const ctrlPort = "40000" - -// logger implements testutil.Logger. -// -// Labels logs based on their source and formats multi-line logs. -type logger string - -// Name implements testutil.Logger.Name. -func (l logger) Name() string { - return string(l) -} - -// Logf implements testutil.Logger.Logf. -func (l logger) Logf(format string, args ...interface{}) { - lines := strings.Split(fmt.Sprintf(format, args...), "\n") - log.Printf("%s: %s", l, lines[0]) - for _, line := range lines[1:] { - log.Printf("%*s %s", len(l), "", line) - } +func init() { + runner.RegisterFlags(flag.CommandLine) } func TestOne(t *testing.T) { - flag.Var(&extraTestArgs, "extra_test_arg", "extra arguments to pass to the testbench") - flag.Parse() - if *testbenchBinary == "" { - t.Fatal("--testbench_binary is missing") - } - dockerutil.EnsureSupportedDockerVersion() - ctx := context.Background() - - // Create the networks needed for the test. One control network is needed for - // the gRPC control packets and one test network on which to transmit the test - // packets. - ctrlNet := dockerutil.NewNetwork(ctx, logger("ctrlNet")) - testNet := dockerutil.NewNetwork(ctx, logger("testNet")) - for _, dn := range []*dockerutil.Network{ctrlNet, testNet} { - for { - if err := createDockerNetwork(ctx, dn); err != nil { - t.Log("creating docker network:", err) - const wait = 100 * time.Millisecond - t.Logf("sleeping %s and will try creating docker network again", wait) - // This can fail if another docker network claimed the same IP so we'll - // just try again. - time.Sleep(wait) - continue - } - break - } - defer func(dn *dockerutil.Network) { - if err := dn.Cleanup(ctx); err != nil { - t.Errorf("unable to cleanup container %s: %s", dn.Name, err) - } - }(dn) - // Sanity check. - inspect, err := dn.Inspect(ctx) - if err != nil { - t.Fatalf("failed to inspect network %s: %v", dn.Name, err) - } else if inspect.Name != dn.Name { - t.Fatalf("name mismatch for network want: %s got: %s", dn.Name, inspect.Name) - } - - } - - tmpDir, err := ioutil.TempDir("", "container-output") - if err != nil { - t.Fatal("creating temp dir:", err) - } - defer os.RemoveAll(tmpDir) - - const testOutputDir = "/tmp/testoutput" - - // Create the Docker container for the DUT. - var dut *dockerutil.Container - if *native { - dut = dockerutil.MakeNativeContainer(ctx, logger("dut")) - } else { - dut = dockerutil.MakeContainer(ctx, logger("dut")) - } - - runOpts := dockerutil.RunOpts{ - Image: "packetimpact", - CapAdd: []string{"NET_ADMIN"}, - Mounts: []mount.Mount{mount.Mount{ - Type: mount.TypeBind, - Source: tmpDir, - Target: testOutputDir, - ReadOnly: false, - }}, - } - - const containerPosixServerBinary = "/packetimpact/posix_server" - dut.CopyFiles(&runOpts, "/packetimpact", "/test/packetimpact/dut/posix_server") - - conf, hostconf, _ := dut.ConfigsFrom(runOpts, containerPosixServerBinary, "--ip=0.0.0.0", "--port="+ctrlPort) - hostconf.AutoRemove = true - hostconf.Sysctls = map[string]string{"net.ipv6.conf.all.disable_ipv6": "0"} - - if err := dut.CreateFrom(ctx, conf, hostconf, nil); err != nil { - t.Fatalf("unable to create container %s: %v", dut.Name, err) - } - - defer dut.CleanUp(ctx) - - // Add ctrlNet as eth1 and testNet as eth2. - const testNetDev = "eth2" - if err := addNetworks(ctx, dut, dutAddr, []*dockerutil.Network{ctrlNet, testNet}); err != nil { - t.Fatal(err) - } - - if err := dut.Start(ctx); err != nil { - t.Fatalf("unable to start container %s: %s", dut.Name, err) - } - - if _, err := dut.WaitForOutput(ctx, "Server listening.*\n", 60*time.Second); err != nil { - t.Fatalf("%s on container %s never listened: %s", containerPosixServerBinary, dut.Name, err) - } - - dutTestDevice, dutDeviceInfo, err := deviceByIP(ctx, dut, addressInSubnet(dutAddr, *testNet.Subnet)) - if err != nil { - t.Fatal(err) - } - - remoteMAC := dutDeviceInfo.MAC - remoteIPv6 := dutDeviceInfo.IPv6Addr - // Netstack as DUT doesn't assign IPv6 addresses automatically so do it if - // needed. - if remoteIPv6 == nil { - if _, err := dut.Exec(ctx, dockerutil.ExecOpts{}, "ip", "addr", "add", netdevs.MACToIP(remoteMAC).String(), "scope", "link", "dev", dutTestDevice); err != nil { - t.Fatalf("unable to ip addr add on container %s: %s", dut.Name, err) - } - // Now try again, to make sure that it worked. - _, dutDeviceInfo, err = deviceByIP(ctx, dut, addressInSubnet(dutAddr, *testNet.Subnet)) - if err != nil { - t.Fatal(err) - } - remoteIPv6 = dutDeviceInfo.IPv6Addr - if remoteIPv6 == nil { - t.Fatal("unable to set IPv6 address on container", dut.Name) - } - } - - // Create the Docker container for the testbench. - testbench := dockerutil.MakeNativeContainer(ctx, logger("testbench")) - - tbb := path.Base(*testbenchBinary) - containerTestbenchBinary := "/packetimpact/" + tbb - runOpts = dockerutil.RunOpts{ - Image: "packetimpact", - CapAdd: []string{"NET_ADMIN"}, - Mounts: []mount.Mount{mount.Mount{ - Type: mount.TypeBind, - Source: tmpDir, - Target: testOutputDir, - ReadOnly: false, - }}, - } - testbench.CopyFiles(&runOpts, "/packetimpact", "/test/packetimpact/tests/"+tbb) - - // Run tcpdump in the test bench unbuffered, without DNS resolution, just on - // the interface with the test packets. - snifferArgs := []string{ - "tcpdump", - "-S", "-vvv", "-U", "-n", - "-i", testNetDev, - "-w", testOutputDir + "/dump.pcap", - } - snifferRegex := "tcpdump: listening.*\n" - if *tshark { - // Run tshark in the test bench unbuffered, without DNS resolution, just on - // the interface with the test packets. - snifferArgs = []string{ - "tshark", "-V", "-l", "-n", "-i", testNetDev, - "-o", "tcp.check_checksum:TRUE", - "-o", "udp.check_checksum:TRUE", - } - snifferRegex = "Capturing on.*\n" - } - - defer func() { - if err := exec.Command("/bin/cp", "-r", tmpDir, os.Getenv("TEST_UNDECLARED_OUTPUTS_DIR")).Run(); err != nil { - t.Error("unable to copy container output files:", err) - } - }() - - conf, hostconf, _ = testbench.ConfigsFrom(runOpts, snifferArgs...) - hostconf.AutoRemove = true - hostconf.Sysctls = map[string]string{"net.ipv6.conf.all.disable_ipv6": "0"} - - if err := testbench.CreateFrom(ctx, conf, hostconf, nil); err != nil { - t.Fatalf("unable to create container %s: %s", testbench.Name, err) - } - defer testbench.CleanUp(ctx) - - // Add ctrlNet as eth1 and testNet as eth2. - if err := addNetworks(ctx, testbench, testbenchAddr, []*dockerutil.Network{ctrlNet, testNet}); err != nil { - t.Fatal(err) - } - - if err := testbench.Start(ctx); err != nil { - t.Fatalf("unable to start container %s: %s", testbench.Name, err) - } - - // Kill so that it will flush output. - defer func() { - time.Sleep(1 * time.Second) - testbench.Exec(ctx, dockerutil.ExecOpts{}, "killall", snifferArgs[0]) - }() - - if _, err := testbench.WaitForOutput(ctx, snifferRegex, 60*time.Second); err != nil { - t.Fatalf("sniffer on %s never listened: %s", dut.Name, err) - } - - // Because the Linux kernel receives the SYN-ACK but didn't send the SYN it - // will issue an RST. To prevent this IPtables can be used to filter out all - // incoming packets. The raw socket that packetimpact tests use will still see - // everything. - for _, bin := range []string{"iptables", "ip6tables"} { - if logs, err := testbench.Exec(ctx, dockerutil.ExecOpts{}, bin, "-A", "INPUT", "-i", testNetDev, "-p", "tcp", "-j", "DROP"); err != nil { - t.Fatalf("unable to Exec %s on container %s: %s, logs from testbench:\n%s", bin, testbench.Name, err, logs) - } - } - - // FIXME(b/156449515): Some piece of the system has a race. The old - // bash script version had a sleep, so we have one too. The race should - // be fixed and this sleep removed. - time.Sleep(time.Second) - - // Start a packetimpact test on the test bench. The packetimpact test sends - // and receives packets and also sends POSIX socket commands to the - // posix_server to be executed on the DUT. - testArgs := []string{containerTestbenchBinary} - testArgs = append(testArgs, extraTestArgs...) - testArgs = append(testArgs, - "--posix_server_ip", addressInSubnet(dutAddr, *ctrlNet.Subnet).String(), - "--posix_server_port", ctrlPort, - "--remote_ipv4", addressInSubnet(dutAddr, *testNet.Subnet).String(), - "--local_ipv4", addressInSubnet(testbenchAddr, *testNet.Subnet).String(), - "--remote_ipv6", remoteIPv6.String(), - "--remote_mac", remoteMAC.String(), - "--remote_interface_id", fmt.Sprintf("%d", dutDeviceInfo.ID), - "--device", testNetDev, - fmt.Sprintf("--native=%t", *native), - ) - testbenchLogs, err := testbench.Exec(ctx, dockerutil.ExecOpts{}, testArgs...) - if (err != nil) != *expectFailure { - var dutLogs string - if logs, err := dut.Logs(ctx); err != nil { - dutLogs = fmt.Sprintf("failed to fetch DUT logs: %s", err) - } else { - dutLogs = logs - } - - t.Errorf(`test error: %v, expect failure: %t - -====== Begin of DUT Logs ====== - -%s - -====== End of DUT Logs ====== - -====== Begin of Testbench Logs ====== - -%s - -====== End of Testbench Logs ======`, - err, *expectFailure, dutLogs, testbenchLogs) - } -} - -func addNetworks(ctx context.Context, d *dockerutil.Container, addr net.IP, networks []*dockerutil.Network) error { - for _, dn := range networks { - ip := addressInSubnet(addr, *dn.Subnet) - // Connect to the network with the specified IP address. - if err := dn.Connect(ctx, d, ip.String(), ""); err != nil { - return fmt.Errorf("unable to connect container %s to network %s: %w", d.Name, dn.Name, err) - } - } - return nil -} - -// addressInSubnet combines the subnet provided with the address and returns a -// new address. The return address bits come from the subnet where the mask is 1 -// and from the ip address where the mask is 0. -func addressInSubnet(addr net.IP, subnet net.IPNet) net.IP { - var octets []byte - for i := 0; i < 4; i++ { - octets = append(octets, (subnet.IP.To4()[i]&subnet.Mask[i])+(addr.To4()[i]&(^subnet.Mask[i]))) - } - return net.IP(octets) -} - -// createDockerNetwork makes a randomly-named network that will start with the -// namePrefix. The network will be a random /24 subnet. -func createDockerNetwork(ctx context.Context, n *dockerutil.Network) error { - randSource := rand.NewSource(time.Now().UnixNano()) - r1 := rand.New(randSource) - // Class C, 192.0.0.0 to 223.255.255.255, transitionally has mask 24. - ip := net.IPv4(byte(r1.Intn(224-192)+192), byte(r1.Intn(256)), byte(r1.Intn(256)), 0) - n.Subnet = &net.IPNet{ - IP: ip, - Mask: ip.DefaultMask(), - } - return n.Create(ctx) -} - -// deviceByIP finds a deviceInfo and device name from an IP address. -func deviceByIP(ctx context.Context, d *dockerutil.Container, ip net.IP) (string, netdevs.DeviceInfo, error) { - out, err := d.Exec(ctx, dockerutil.ExecOpts{}, "ip", "addr", "show") - if err != nil { - return "", netdevs.DeviceInfo{}, fmt.Errorf("listing devices on %s container: %w", d.Name, err) - } - devs, err := netdevs.ParseDevices(out) - if err != nil { - return "", netdevs.DeviceInfo{}, fmt.Errorf("parsing devices from %s container: %w", d.Name, err) - } - testDevice, deviceInfo, err := netdevs.FindDeviceByIP(ip, devs) - if err != nil { - return "", netdevs.DeviceInfo{}, fmt.Errorf("can't find deviceInfo for container %s: %w", d.Name, err) - } - return testDevice, deviceInfo, nil + runner.TestWithDUT(context.Background(), t, runner.NewDockerDUT, runner.DutAddr) } diff --git a/test/packetimpact/testbench/connections.go b/test/packetimpact/testbench/connections.go index 3af5f83fd..a90046f69 100644 --- a/test/packetimpact/testbench/connections.go +++ b/test/packetimpact/testbench/connections.go @@ -615,7 +615,7 @@ func (conn *Connection) ExpectFrame(t *testing.T, layers Layers, timeout time.Du if errs == nil { return nil, fmt.Errorf("got no frames matching %v during %s", layers, timeout) } - return nil, fmt.Errorf("got no frames matching %v during %s: got %w", layers, timeout, errs) + return nil, fmt.Errorf("got frames %w want %v during %s", errs, layers, timeout) } if conn.match(layers, gotLayers) { for i, s := range conn.layerStates { diff --git a/test/packetimpact/testbench/dut.go b/test/packetimpact/testbench/dut.go index 73c532e75..6165ab293 100644 --- a/test/packetimpact/testbench/dut.go +++ b/test/packetimpact/testbench/dut.go @@ -16,11 +16,13 @@ package testbench import ( "context" + "encoding/binary" "flag" "net" "strconv" "syscall" "testing" + "time" pb "gvisor.dev/gvisor/test/packetimpact/proto/posix_server_go_proto" @@ -700,3 +702,43 @@ func (dut *DUT) RecvWithErrno(ctx context.Context, t *testing.T, sockfd, len, fl } return resp.GetRet(), resp.GetBuf(), syscall.Errno(resp.GetErrno_()) } + +// SetSockLingerOption sets SO_LINGER socket option on the DUT. +func (dut *DUT) SetSockLingerOption(t *testing.T, sockfd int32, timeout time.Duration, enable bool) { + var linger unix.Linger + if enable { + linger.Onoff = 1 + } + linger.Linger = int32(timeout / time.Second) + + buf := make([]byte, 8) + binary.LittleEndian.PutUint32(buf, uint32(linger.Onoff)) + binary.LittleEndian.PutUint32(buf[4:], uint32(linger.Linger)) + dut.SetSockOpt(t, sockfd, unix.SOL_SOCKET, unix.SO_LINGER, buf) +} + +// Shutdown calls shutdown on the DUT and causes a fatal test failure if it doesn't +// succeed. If more control over the timeout or error handling is needed, use +// ShutdownWithErrno. +func (dut *DUT) Shutdown(t *testing.T, fd, how int32) error { + t.Helper() + + ctx, cancel := context.WithTimeout(context.Background(), RPCTimeout) + defer cancel() + return dut.ShutdownWithErrno(ctx, t, fd, how) +} + +// ShutdownWithErrno calls shutdown on the DUT. +func (dut *DUT) ShutdownWithErrno(ctx context.Context, t *testing.T, fd, how int32) error { + t.Helper() + + req := pb.ShutdownRequest{ + Fd: fd, + How: how, + } + resp, err := dut.posixServer.Shutdown(ctx, &req) + if err != nil { + t.Fatalf("failed to call Shutdown: %s", err) + } + return syscall.Errno(resp.GetErrno_()) +} diff --git a/test/packetimpact/testbench/rawsockets.go b/test/packetimpact/testbench/rawsockets.go index 57e822725..193bb2dc8 100644 --- a/test/packetimpact/testbench/rawsockets.go +++ b/test/packetimpact/testbench/rawsockets.go @@ -139,7 +139,7 @@ type Injector struct { func NewInjector(t *testing.T) (Injector, error) { t.Helper() - ifInfo, err := net.InterfaceByName(Device) + ifInfo, err := net.InterfaceByName(LocalDevice) if err != nil { return Injector{}, err } diff --git a/test/packetimpact/testbench/testbench.go b/test/packetimpact/testbench/testbench.go index e3629e1f3..0073a1361 100644 --- a/test/packetimpact/testbench/testbench.go +++ b/test/packetimpact/testbench/testbench.go @@ -29,8 +29,11 @@ import ( var ( // Native indicates that the test is being run natively. Native = false - // Device is the local device on the test network. - Device = "" + // LocalDevice is the device that testbench uses to inject traffic. + LocalDevice = "" + // RemoteDevice is the device name on the DUT, individual tests can + // use the name to construct tests. + RemoteDevice = "" // LocalIPv4 is the local IPv4 address on the test network. LocalIPv4 = "" @@ -80,7 +83,8 @@ func RegisterFlags(fs *flag.FlagSet) { fs.StringVar(&RemoteIPv4, "remote_ipv4", RemoteIPv4, "remote IPv4 address for test packets") fs.StringVar(&RemoteIPv6, "remote_ipv6", RemoteIPv6, "remote IPv6 address for test packets") fs.StringVar(&RemoteMAC, "remote_mac", RemoteMAC, "remote mac address for test packets") - fs.StringVar(&Device, "device", Device, "local device for test packets") + fs.StringVar(&LocalDevice, "local_device", LocalDevice, "local device to inject traffic") + fs.StringVar(&RemoteDevice, "remote_device", RemoteDevice, "remote device on the DUT") fs.BoolVar(&Native, "native", Native, "whether the test is running natively") fs.Uint64Var(&RemoteInterfaceID, "remote_interface_id", RemoteInterfaceID, "remote interface ID for test packets") } diff --git a/test/packetimpact/tests/BUILD b/test/packetimpact/tests/BUILD index 74658fea0..fbfea61e1 100644 --- a/test/packetimpact/tests/BUILD +++ b/test/packetimpact/tests/BUILD @@ -166,8 +166,8 @@ packetimpact_go_test( ) packetimpact_go_test( - name = "tcp_close_wait_ack", - srcs = ["tcp_close_wait_ack_test.go"], + name = "tcp_unacc_seq_ack", + srcs = ["tcp_unacc_seq_ack_test.go"], deps = [ "//pkg/tcpip/header", "//pkg/tcpip/seqnum", @@ -260,6 +260,28 @@ packetimpact_go_test( ) packetimpact_go_test( + name = "tcp_timewait_reset", + srcs = ["tcp_timewait_reset_test.go"], + # TODO(b/168523247): Fix netstack then remove the line below. + expect_netstack_failure = True, + deps = [ + "//pkg/tcpip/header", + "//test/packetimpact/testbench", + "@org_golang_x_sys//unix:go_default_library", + ], +) + +packetimpact_go_test( + name = "tcp_queue_send_in_syn_sent", + srcs = ["tcp_queue_send_in_syn_sent_test.go"], + deps = [ + "//pkg/tcpip/header", + "//test/packetimpact/testbench", + "@org_golang_x_sys//unix:go_default_library", + ], +) + +packetimpact_go_test( name = "icmpv6_param_problem", srcs = ["icmpv6_param_problem_test.go"], # TODO(b/153485026): Fix netstack then remove the line below. @@ -308,3 +330,13 @@ packetimpact_go_test( "@org_golang_x_sys//unix:go_default_library", ], ) + +packetimpact_go_test( + name = "tcp_linger", + srcs = ["tcp_linger_test.go"], + deps = [ + "//pkg/tcpip/header", + "//test/packetimpact/testbench", + "@org_golang_x_sys//unix:go_default_library", + ], +) diff --git a/test/packetimpact/tests/tcp_close_wait_ack_test.go b/test/packetimpact/tests/tcp_close_wait_ack_test.go deleted file mode 100644 index e6a96f214..000000000 --- a/test/packetimpact/tests/tcp_close_wait_ack_test.go +++ /dev/null @@ -1,109 +0,0 @@ -// Copyright 2020 The gVisor Authors. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -package tcp_close_wait_ack_test - -import ( - "flag" - "fmt" - "testing" - "time" - - "golang.org/x/sys/unix" - "gvisor.dev/gvisor/pkg/tcpip/header" - "gvisor.dev/gvisor/pkg/tcpip/seqnum" - "gvisor.dev/gvisor/test/packetimpact/testbench" -) - -func init() { - testbench.RegisterFlags(flag.CommandLine) -} - -func TestCloseWaitAck(t *testing.T) { - for _, tt := range []struct { - description string - makeTestingTCP func(t *testing.T, conn *testbench.TCPIPv4, seqNumOffset, windowSize seqnum.Size) testbench.TCP - seqNumOffset seqnum.Size - expectAck bool - }{ - {"OTW", generateOTWSeqSegment, 0, false}, - {"OTW", generateOTWSeqSegment, 1, true}, - {"OTW", generateOTWSeqSegment, 2, true}, - {"ACK", generateUnaccACKSegment, 0, false}, - {"ACK", generateUnaccACKSegment, 1, true}, - {"ACK", generateUnaccACKSegment, 2, true}, - } { - t.Run(fmt.Sprintf("%s%d", tt.description, tt.seqNumOffset), func(t *testing.T) { - dut := testbench.NewDUT(t) - defer dut.TearDown() - listenFd, remotePort := dut.CreateListener(t, unix.SOCK_STREAM, unix.IPPROTO_TCP, 1) - defer dut.Close(t, listenFd) - conn := testbench.NewTCPIPv4(t, testbench.TCP{DstPort: &remotePort}, testbench.TCP{SrcPort: &remotePort}) - defer conn.Close(t) - - conn.Connect(t) - acceptFd, _ := dut.Accept(t, listenFd) - - // Send a FIN to DUT to intiate the active close - conn.Send(t, testbench.TCP{Flags: testbench.Uint8(header.TCPFlagAck | header.TCPFlagFin)}) - gotTCP, err := conn.Expect(t, testbench.TCP{Flags: testbench.Uint8(header.TCPFlagAck)}, time.Second) - if err != nil { - t.Fatalf("expected an ACK for our fin and DUT should enter CLOSE_WAIT: %s", err) - } - windowSize := seqnum.Size(*gotTCP.WindowSize) - - // Send a segment with OTW Seq / unacc ACK and expect an ACK back - conn.Send(t, tt.makeTestingTCP(t, &conn, tt.seqNumOffset, windowSize), &testbench.Payload{Bytes: []byte("Sample Data")}) - gotAck, err := conn.Expect(t, testbench.TCP{Flags: testbench.Uint8(header.TCPFlagAck)}, time.Second) - if tt.expectAck && err != nil { - t.Fatalf("expected an ack but got none: %s", err) - } - if !tt.expectAck && gotAck != nil { - t.Fatalf("expected no ack but got one: %s", gotAck) - } - - // Now let's verify DUT is indeed in CLOSE_WAIT - dut.Close(t, acceptFd) - if _, err := conn.Expect(t, testbench.TCP{Flags: testbench.Uint8(header.TCPFlagAck | header.TCPFlagFin)}, time.Second); err != nil { - t.Fatalf("expected DUT to send a FIN: %s", err) - } - // Ack the FIN from DUT - conn.Send(t, testbench.TCP{Flags: testbench.Uint8(header.TCPFlagAck)}) - // Send some extra data to DUT - conn.Send(t, testbench.TCP{Flags: testbench.Uint8(header.TCPFlagAck)}, &testbench.Payload{Bytes: []byte("Sample Data")}) - if _, err := conn.Expect(t, testbench.TCP{Flags: testbench.Uint8(header.TCPFlagRst)}, time.Second); err != nil { - t.Fatalf("expected DUT to send an RST: %s", err) - } - }) - } -} - -// generateOTWSeqSegment generates an segment with -// seqnum = RCV.NXT + RCV.WND + seqNumOffset, the generated segment is only -// acceptable when seqNumOffset is 0, otherwise an ACK is expected from the -// receiver. -func generateOTWSeqSegment(t *testing.T, conn *testbench.TCPIPv4, seqNumOffset seqnum.Size, windowSize seqnum.Size) testbench.TCP { - lastAcceptable := conn.LocalSeqNum(t).Add(windowSize) - otwSeq := uint32(lastAcceptable.Add(seqNumOffset)) - return testbench.TCP{SeqNum: testbench.Uint32(otwSeq), Flags: testbench.Uint8(header.TCPFlagAck)} -} - -// generateUnaccACKSegment generates an segment with -// acknum = SND.NXT + seqNumOffset, the generated segment is only acceptable -// when seqNumOffset is 0, otherwise an ACK is expected from the receiver. -func generateUnaccACKSegment(t *testing.T, conn *testbench.TCPIPv4, seqNumOffset seqnum.Size, windowSize seqnum.Size) testbench.TCP { - lastAcceptable := conn.RemoteSeqNum(t) - unaccAck := uint32(lastAcceptable.Add(seqNumOffset)) - return testbench.TCP{AckNum: testbench.Uint32(unaccAck), Flags: testbench.Uint8(header.TCPFlagAck)} -} diff --git a/test/packetimpact/tests/tcp_linger_test.go b/test/packetimpact/tests/tcp_linger_test.go new file mode 100644 index 000000000..913e49e06 --- /dev/null +++ b/test/packetimpact/tests/tcp_linger_test.go @@ -0,0 +1,253 @@ +// Copyright 2020 The gVisor Authors. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package tcp_linger_test + +import ( + "context" + "flag" + "syscall" + "testing" + "time" + + "golang.org/x/sys/unix" + "gvisor.dev/gvisor/pkg/tcpip/header" + "gvisor.dev/gvisor/test/packetimpact/testbench" +) + +func init() { + testbench.RegisterFlags(flag.CommandLine) +} + +func createSocket(t *testing.T, dut testbench.DUT) (int32, int32, testbench.TCPIPv4) { + listenFD, remotePort := dut.CreateListener(t, unix.SOCK_STREAM, unix.IPPROTO_TCP, 1) + conn := testbench.NewTCPIPv4(t, testbench.TCP{DstPort: &remotePort}, testbench.TCP{SrcPort: &remotePort}) + conn.Connect(t) + acceptFD, _ := dut.Accept(t, listenFD) + return acceptFD, listenFD, conn +} + +func closeAll(t *testing.T, dut testbench.DUT, listenFD int32, conn testbench.TCPIPv4) { + conn.Close(t) + dut.Close(t, listenFD) + dut.TearDown() +} + +// lingerDuration is the timeout value used with SO_LINGER socket option. +const lingerDuration = 3 * time.Second + +// TestTCPLingerZeroTimeout tests when SO_LINGER is set with zero timeout. DUT +// should send RST-ACK when socket is closed. +func TestTCPLingerZeroTimeout(t *testing.T) { + // Create a socket, listen, TCP connect, and accept. + dut := testbench.NewDUT(t) + acceptFD, listenFD, conn := createSocket(t, dut) + defer closeAll(t, dut, listenFD, conn) + + dut.SetSockLingerOption(t, acceptFD, 0, true) + dut.Close(t, acceptFD) + + // If the linger timeout is set to zero, the DUT should send a RST. + if _, err := conn.Expect(t, testbench.TCP{Flags: testbench.Uint8(header.TCPFlagRst | header.TCPFlagAck)}, time.Second); err != nil { + t.Errorf("expected RST-ACK packet within a second but got none: %s", err) + } + conn.Send(t, testbench.TCP{Flags: testbench.Uint8(header.TCPFlagAck)}) +} + +// TestTCPLingerOff tests when SO_LINGER is not set. DUT should send FIN-ACK +// when socket is closed. +func TestTCPLingerOff(t *testing.T) { + // Create a socket, listen, TCP connect, and accept. + dut := testbench.NewDUT(t) + acceptFD, listenFD, conn := createSocket(t, dut) + defer closeAll(t, dut, listenFD, conn) + + dut.Close(t, acceptFD) + + // If SO_LINGER is not set, DUT should send a FIN-ACK. + if _, err := conn.Expect(t, testbench.TCP{Flags: testbench.Uint8(header.TCPFlagFin | header.TCPFlagAck)}, time.Second); err != nil { + t.Errorf("expected FIN-ACK packet within a second but got none: %s", err) + } + conn.Send(t, testbench.TCP{Flags: testbench.Uint8(header.TCPFlagAck)}) +} + +// TestTCPLingerNonZeroTimeout tests when SO_LINGER is set with non-zero timeout. +// DUT should close the socket after timeout. +func TestTCPLingerNonZeroTimeout(t *testing.T) { + for _, tt := range []struct { + description string + lingerOn bool + }{ + {"WithNonZeroLinger", true}, + {"WithoutLinger", false}, + } { + t.Run(tt.description, func(t *testing.T) { + // Create a socket, listen, TCP connect, and accept. + dut := testbench.NewDUT(t) + acceptFD, listenFD, conn := createSocket(t, dut) + defer closeAll(t, dut, listenFD, conn) + + dut.SetSockLingerOption(t, acceptFD, lingerDuration, tt.lingerOn) + + // Increase timeout as Close will take longer time to + // return when SO_LINGER is set with non-zero timeout. + timeout := lingerDuration + 1*time.Second + ctx, cancel := context.WithTimeout(context.Background(), timeout) + defer cancel() + start := time.Now() + dut.CloseWithErrno(ctx, t, acceptFD) + end := time.Now() + diff := end.Sub(start) + + if tt.lingerOn && diff < lingerDuration { + t.Errorf("expected close to return after %v seconds, but returned sooner", lingerDuration) + } else if !tt.lingerOn && diff > 1*time.Second { + t.Errorf("expected close to return within a second, but returned later") + } + + if _, err := conn.Expect(t, testbench.TCP{Flags: testbench.Uint8(header.TCPFlagFin | header.TCPFlagAck)}, time.Second); err != nil { + t.Errorf("expected FIN-ACK packet within a second but got none: %s", err) + } + conn.Send(t, testbench.TCP{Flags: testbench.Uint8(header.TCPFlagAck)}) + }) + } +} + +// TestTCPLingerSendNonZeroTimeout tests when SO_LINGER is set with non-zero +// timeout and send a packet. DUT should close the socket after timeout. +func TestTCPLingerSendNonZeroTimeout(t *testing.T) { + for _, tt := range []struct { + description string + lingerOn bool + }{ + {"WithSendNonZeroLinger", true}, + {"WithoutLinger", false}, + } { + t.Run(tt.description, func(t *testing.T) { + // Create a socket, listen, TCP connect, and accept. + dut := testbench.NewDUT(t) + acceptFD, listenFD, conn := createSocket(t, dut) + defer closeAll(t, dut, listenFD, conn) + + dut.SetSockLingerOption(t, acceptFD, lingerDuration, tt.lingerOn) + + // Send data. + sampleData := []byte("Sample Data") + dut.Send(t, acceptFD, sampleData, 0) + + // Increase timeout as Close will take longer time to + // return when SO_LINGER is set with non-zero timeout. + timeout := lingerDuration + 1*time.Second + ctx, cancel := context.WithTimeout(context.Background(), timeout) + defer cancel() + start := time.Now() + dut.CloseWithErrno(ctx, t, acceptFD) + end := time.Now() + diff := end.Sub(start) + + if tt.lingerOn && diff < lingerDuration { + t.Errorf("expected close to return after %v seconds, but returned sooner", lingerDuration) + } else if !tt.lingerOn && diff > 1*time.Second { + t.Errorf("expected close to return within a second, but returned later") + } + + samplePayload := &testbench.Payload{Bytes: sampleData} + if _, err := conn.ExpectData(t, &testbench.TCP{}, samplePayload, time.Second); err != nil { + t.Fatalf("expected a packet with payload %v: %s", samplePayload, err) + } + + if _, err := conn.Expect(t, testbench.TCP{Flags: testbench.Uint8(header.TCPFlagFin | header.TCPFlagAck)}, time.Second); err != nil { + t.Errorf("expected FIN-ACK packet within a second but got none: %s", err) + } + conn.Send(t, testbench.TCP{Flags: testbench.Uint8(header.TCPFlagAck)}) + }) + } +} + +// TestTCPLingerShutdownZeroTimeout tests SO_LINGER with shutdown() and zero +// timeout. DUT should send RST-ACK when socket is closed. +func TestTCPLingerShutdownZeroTimeout(t *testing.T) { + // Create a socket, listen, TCP connect, and accept. + dut := testbench.NewDUT(t) + acceptFD, listenFD, conn := createSocket(t, dut) + defer closeAll(t, dut, listenFD, conn) + + dut.SetSockLingerOption(t, acceptFD, 0, true) + dut.Shutdown(t, acceptFD, syscall.SHUT_RDWR) + dut.Close(t, acceptFD) + + // Shutdown will send FIN-ACK with read/write option. + if _, err := conn.Expect(t, testbench.TCP{Flags: testbench.Uint8(header.TCPFlagFin | header.TCPFlagAck)}, time.Second); err != nil { + t.Errorf("expected FIN-ACK packet within a second but got none: %s", err) + } + + // If the linger timeout is set to zero, the DUT should send a RST. + if _, err := conn.Expect(t, testbench.TCP{Flags: testbench.Uint8(header.TCPFlagRst | header.TCPFlagAck)}, time.Second); err != nil { + t.Errorf("expected RST-ACK packet within a second but got none: %s", err) + } + conn.Send(t, testbench.TCP{Flags: testbench.Uint8(header.TCPFlagAck)}) +} + +// TestTCPLingerShutdownSendNonZeroTimeout tests SO_LINGER with shutdown() and +// non-zero timeout. DUT should close the socket after timeout. +func TestTCPLingerShutdownSendNonZeroTimeout(t *testing.T) { + for _, tt := range []struct { + description string + lingerOn bool + }{ + {"shutdownRDWR", true}, + {"shutdownRDWR", false}, + } { + t.Run(tt.description, func(t *testing.T) { + // Create a socket, listen, TCP connect, and accept. + dut := testbench.NewDUT(t) + acceptFD, listenFD, conn := createSocket(t, dut) + defer closeAll(t, dut, listenFD, conn) + + dut.SetSockLingerOption(t, acceptFD, lingerDuration, tt.lingerOn) + + // Send data. + sampleData := []byte("Sample Data") + dut.Send(t, acceptFD, sampleData, 0) + + dut.Shutdown(t, acceptFD, syscall.SHUT_RDWR) + + // Increase timeout as Close will take longer time to + // return when SO_LINGER is set with non-zero timeout. + timeout := lingerDuration + 1*time.Second + ctx, cancel := context.WithTimeout(context.Background(), timeout) + defer cancel() + start := time.Now() + dut.CloseWithErrno(ctx, t, acceptFD) + end := time.Now() + diff := end.Sub(start) + + if tt.lingerOn && diff < lingerDuration { + t.Errorf("expected close to return after %v seconds, but returned sooner", lingerDuration) + } else if !tt.lingerOn && diff > 1*time.Second { + t.Errorf("expected close to return within a second, but returned later") + } + + samplePayload := &testbench.Payload{Bytes: sampleData} + if _, err := conn.ExpectData(t, &testbench.TCP{}, samplePayload, time.Second); err != nil { + t.Fatalf("expected a packet with payload %v: %s", samplePayload, err) + } + + if _, err := conn.Expect(t, testbench.TCP{Flags: testbench.Uint8(header.TCPFlagFin | header.TCPFlagAck)}, time.Second); err != nil { + t.Errorf("expected FIN-ACK packet within a second but got none: %s", err) + } + conn.Send(t, testbench.TCP{Flags: testbench.Uint8(header.TCPFlagAck)}) + }) + } +} diff --git a/test/packetimpact/tests/tcp_queue_send_in_syn_sent_test.go b/test/packetimpact/tests/tcp_queue_send_in_syn_sent_test.go new file mode 100644 index 000000000..0ec8fd748 --- /dev/null +++ b/test/packetimpact/tests/tcp_queue_send_in_syn_sent_test.go @@ -0,0 +1,133 @@ +// Copyright 2020 The gVisor Authors. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package tcp_queue_send_in_syn_sent_test + +import ( + "context" + "errors" + "flag" + "net" + "sync" + "syscall" + "testing" + "time" + + "golang.org/x/sys/unix" + "gvisor.dev/gvisor/pkg/tcpip/header" + "gvisor.dev/gvisor/test/packetimpact/testbench" +) + +func init() { + testbench.RegisterFlags(flag.CommandLine) +} + +// TestQueueSendInSynSent tests send behavior when the TCP state +// is SYN-SENT. +// It tests for 2 variants when in SYN_SENT state and: +// (1) DUT blocks on send and complete handshake +// (2) DUT blocks on send and receive a TCP RST. +func TestQueueSendInSynSent(t *testing.T) { + for _, tt := range []struct { + description string + reset bool + }{ + {description: "Complete handshake", reset: false}, + {description: "Send RST", reset: true}, + } { + t.Run(tt.description, func(t *testing.T) { + dut := testbench.NewDUT(t) + defer dut.TearDown() + + socket, remotePort := dut.CreateBoundSocket(t, unix.SOCK_STREAM, unix.IPPROTO_TCP, net.ParseIP(testbench.RemoteIPv4)) + conn := testbench.NewTCPIPv4(t, testbench.TCP{DstPort: &remotePort}, testbench.TCP{SrcPort: &remotePort}) + defer conn.Close(t) + + sampleData := []byte("Sample Data") + samplePayload := &testbench.Payload{Bytes: sampleData} + dut.SetNonBlocking(t, socket, true) + if _, err := dut.ConnectWithErrno(context.Background(), t, socket, conn.LocalAddr(t)); !errors.Is(err, syscall.EINPROGRESS) { + t.Fatalf("failed to bring DUT to SYN-SENT, got: %s, want EINPROGRESS", err) + } + if _, err := conn.Expect(t, testbench.TCP{Flags: testbench.Uint8(header.TCPFlagSyn)}, time.Second); err != nil { + t.Fatalf("expected a SYN from DUT, but got none: %s", err) + } + if _, err := dut.SendWithErrno(context.Background(), t, socket, sampleData, 0); err != syscall.Errno(unix.EWOULDBLOCK) { + t.Fatalf("expected error %s, got %s", syscall.Errno(unix.EWOULDBLOCK), err) + } + + // Test blocking write. + dut.SetNonBlocking(t, socket, false) + + var wg sync.WaitGroup + defer wg.Wait() + wg.Add(1) + var block sync.WaitGroup + block.Add(1) + go func() { + defer wg.Done() + ctx, cancel := context.WithTimeout(context.Background(), time.Second*3) + defer cancel() + + block.Done() + // Issue SEND call in SYN-SENT, this should be queued for + // process until the connection is established. + n, err := dut.SendWithErrno(ctx, t, socket, sampleData, 0) + if tt.reset { + if err != syscall.Errno(unix.ECONNREFUSED) { + t.Errorf("expected error %s, got %s", syscall.Errno(unix.ECONNREFUSED), err) + } + if n != -1 { + t.Errorf("expected return value %d, got %d", -1, n) + } + return + } + if n != int32(len(sampleData)) { + t.Errorf("failed to send on DUT: %s", err) + } + }() + + // Wait for the goroutine to be scheduled and before it + // blocks on endpoint send. + block.Wait() + // The following sleep is used to prevent the connection + // from being established before we are blocked on send. + time.Sleep(100 * time.Millisecond) + + if tt.reset { + conn.Send(t, testbench.TCP{Flags: testbench.Uint8(header.TCPFlagRst | header.TCPFlagAck)}) + return + } + + // Bring the connection to Established. + conn.Send(t, testbench.TCP{Flags: testbench.Uint8(header.TCPFlagSyn | header.TCPFlagAck)}) + + // Expect the data from the DUT's enqueued send request. + // + // On Linux, this can be piggybacked with the ACK completing the + // handshake. On gVisor, getting such a piggyback is a bit more + // complicated because the actual data enqueuing occurs in the + // callers of endpoint Write. + if _, err := conn.ExpectData(t, &testbench.TCP{Flags: testbench.Uint8(header.TCPFlagPsh | header.TCPFlagAck)}, samplePayload, time.Second); err != nil { + t.Fatalf("expected payload was not received: %s", err) + } + + // Send sample payload and expect an ACK to ensure connection is still ESTABLISHED. + conn.Send(t, testbench.TCP{Flags: testbench.Uint8(header.TCPFlagPsh | header.TCPFlagAck)}, &testbench.Payload{Bytes: sampleData}) + if _, err := conn.Expect(t, testbench.TCP{Flags: testbench.Uint8(header.TCPFlagAck)}, time.Second); err != nil { + t.Fatalf("expected an ACK from DUT, but got none: %s", err) + } + }) + } +} diff --git a/test/packetimpact/tests/tcp_timewait_reset_test.go b/test/packetimpact/tests/tcp_timewait_reset_test.go new file mode 100644 index 000000000..2f76a6531 --- /dev/null +++ b/test/packetimpact/tests/tcp_timewait_reset_test.go @@ -0,0 +1,68 @@ +// Copyright 2020 The gVisor Authors. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package tcp_timewait_reset_test + +import ( + "flag" + "testing" + "time" + + "golang.org/x/sys/unix" + "gvisor.dev/gvisor/pkg/tcpip/header" + "gvisor.dev/gvisor/test/packetimpact/testbench" +) + +func init() { + testbench.RegisterFlags(flag.CommandLine) +} + +// TestTimeWaitReset tests handling of RST when in TIME_WAIT state. +func TestTimeWaitReset(t *testing.T) { + dut := testbench.NewDUT(t) + defer dut.TearDown() + listenFD, remotePort := dut.CreateListener(t, unix.SOCK_STREAM, unix.IPPROTO_TCP, 1 /*backlog*/) + defer dut.Close(t, listenFD) + conn := testbench.NewTCPIPv4(t, testbench.TCP{DstPort: &remotePort}, testbench.TCP{SrcPort: &remotePort}) + defer conn.Close(t) + + conn.Connect(t) + acceptFD, _ := dut.Accept(t, listenFD) + + // Trigger active close. + dut.Close(t, acceptFD) + + _, err := conn.Expect(t, testbench.TCP{Flags: testbench.Uint8(header.TCPFlagFin | header.TCPFlagAck)}, time.Second) + if err != nil { + t.Fatalf("expected a FIN: %s", err) + } + conn.Send(t, testbench.TCP{Flags: testbench.Uint8(header.TCPFlagAck)}) + // Send a FIN, DUT should transition to TIME_WAIT from FIN_WAIT2. + conn.Send(t, testbench.TCP{Flags: testbench.Uint8(header.TCPFlagFin | header.TCPFlagAck)}) + if _, err := conn.Expect(t, testbench.TCP{Flags: testbench.Uint8(header.TCPFlagAck)}, time.Second); err != nil { + t.Fatalf("expected an ACK for our FIN: %s", err) + } + + // Send a RST, the DUT should transition to CLOSED from TIME_WAIT. + // This is the default Linux behavior, it can be changed to ignore RSTs via + // sysctl net.ipv4.tcp_rfc1337. + conn.Send(t, testbench.TCP{Flags: testbench.Uint8(header.TCPFlagRst)}) + + conn.Send(t, testbench.TCP{Flags: testbench.Uint8(header.TCPFlagAck)}) + // The DUT should reply with RST to our ACK as the state should have + // transitioned to CLOSED. + if _, err := conn.Expect(t, testbench.TCP{Flags: testbench.Uint8(header.TCPFlagRst)}, time.Second); err != nil { + t.Fatalf("expected a RST: %s", err) + } +} diff --git a/test/packetimpact/tests/tcp_unacc_seq_ack_test.go b/test/packetimpact/tests/tcp_unacc_seq_ack_test.go new file mode 100644 index 000000000..d078bbf15 --- /dev/null +++ b/test/packetimpact/tests/tcp_unacc_seq_ack_test.go @@ -0,0 +1,234 @@ +// Copyright 2020 The gVisor Authors. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package tcp_unacc_seq_ack_test + +import ( + "flag" + "fmt" + "syscall" + "testing" + "time" + + "golang.org/x/sys/unix" + "gvisor.dev/gvisor/pkg/tcpip/header" + "gvisor.dev/gvisor/pkg/tcpip/seqnum" + "gvisor.dev/gvisor/test/packetimpact/testbench" +) + +func init() { + testbench.RegisterFlags(flag.CommandLine) +} + +func TestEstablishedUnaccSeqAck(t *testing.T) { + for _, tt := range []struct { + description string + makeTestingTCP func(t *testing.T, conn *testbench.TCPIPv4, seqNumOffset, windowSize seqnum.Size) testbench.TCP + seqNumOffset seqnum.Size + expectAck bool + restoreSeq bool + }{ + {description: "OTWSeq", makeTestingTCP: generateOTWSeqSegment, seqNumOffset: 0, expectAck: true, restoreSeq: true}, + {description: "OTWSeq", makeTestingTCP: generateOTWSeqSegment, seqNumOffset: 1, expectAck: true, restoreSeq: true}, + {description: "OTWSeq", makeTestingTCP: generateOTWSeqSegment, seqNumOffset: 2, expectAck: true, restoreSeq: true}, + {description: "UnaccAck", makeTestingTCP: generateUnaccACKSegment, seqNumOffset: 0, expectAck: true, restoreSeq: false}, + {description: "UnaccAck", makeTestingTCP: generateUnaccACKSegment, seqNumOffset: 1, expectAck: false, restoreSeq: true}, + {description: "UnaccAck", makeTestingTCP: generateUnaccACKSegment, seqNumOffset: 2, expectAck: false, restoreSeq: true}, + } { + t.Run(fmt.Sprintf("%s:offset=%d", tt.description, tt.seqNumOffset), func(t *testing.T) { + dut := testbench.NewDUT(t) + defer dut.TearDown() + listenFD, remotePort := dut.CreateListener(t, unix.SOCK_STREAM, unix.IPPROTO_TCP, 1 /*backlog*/) + defer dut.Close(t, listenFD) + conn := testbench.NewTCPIPv4(t, testbench.TCP{DstPort: &remotePort}, testbench.TCP{SrcPort: &remotePort}) + defer conn.Close(t) + + conn.Connect(t) + dut.Accept(t, listenFD) + + sampleData := []byte("Sample Data") + samplePayload := &testbench.Payload{Bytes: sampleData} + + conn.Send(t, testbench.TCP{Flags: testbench.Uint8(header.TCPFlagAck | header.TCPFlagPsh)}, samplePayload) + gotTCP, err := conn.Expect(t, testbench.TCP{Flags: testbench.Uint8(header.TCPFlagAck)}, time.Second) + if err != nil { + t.Fatalf("expected ack %s", err) + } + windowSize := seqnum.Size(*gotTCP.WindowSize) + + origSeq := *conn.LocalSeqNum(t) + // Send a segment with OTW Seq / unacc ACK. + conn.Send(t, tt.makeTestingTCP(t, &conn, tt.seqNumOffset, windowSize), samplePayload) + if tt.restoreSeq { + // Restore the local sequence number to ensure that the incoming + // ACK matches the TCP layer state. + *conn.LocalSeqNum(t) = origSeq + } + gotAck, err := conn.Expect(t, testbench.TCP{Flags: testbench.Uint8(header.TCPFlagAck)}, time.Second) + if tt.expectAck && err != nil { + t.Fatalf("expected an ack but got none: %s", err) + } + if err == nil && !tt.expectAck && gotAck != nil { + t.Fatalf("expected no ack but got one: %s", gotAck) + } + }) + } +} + +func TestPassiveCloseUnaccSeqAck(t *testing.T) { + for _, tt := range []struct { + description string + makeTestingTCP func(t *testing.T, conn *testbench.TCPIPv4, seqNumOffset, windowSize seqnum.Size) testbench.TCP + seqNumOffset seqnum.Size + expectAck bool + }{ + {description: "OTWSeq", makeTestingTCP: generateOTWSeqSegment, seqNumOffset: 0, expectAck: false}, + {description: "OTWSeq", makeTestingTCP: generateOTWSeqSegment, seqNumOffset: 1, expectAck: true}, + {description: "OTWSeq", makeTestingTCP: generateOTWSeqSegment, seqNumOffset: 2, expectAck: true}, + {description: "UnaccAck", makeTestingTCP: generateUnaccACKSegment, seqNumOffset: 0, expectAck: false}, + {description: "UnaccAck", makeTestingTCP: generateUnaccACKSegment, seqNumOffset: 1, expectAck: true}, + {description: "UnaccAck", makeTestingTCP: generateUnaccACKSegment, seqNumOffset: 2, expectAck: true}, + } { + t.Run(fmt.Sprintf("%s:offset=%d", tt.description, tt.seqNumOffset), func(t *testing.T) { + dut := testbench.NewDUT(t) + defer dut.TearDown() + listenFD, remotePort := dut.CreateListener(t, unix.SOCK_STREAM, unix.IPPROTO_TCP, 1 /*backlog*/) + defer dut.Close(t, listenFD) + conn := testbench.NewTCPIPv4(t, testbench.TCP{DstPort: &remotePort}, testbench.TCP{SrcPort: &remotePort}) + defer conn.Close(t) + + conn.Connect(t) + acceptFD, _ := dut.Accept(t, listenFD) + + // Send a FIN to DUT to intiate the passive close. + conn.Send(t, testbench.TCP{Flags: testbench.Uint8(header.TCPFlagAck | header.TCPFlagFin)}) + gotTCP, err := conn.Expect(t, testbench.TCP{Flags: testbench.Uint8(header.TCPFlagAck)}, time.Second) + if err != nil { + t.Fatalf("expected an ACK for our fin and DUT should enter CLOSE_WAIT: %s", err) + } + windowSize := seqnum.Size(*gotTCP.WindowSize) + + sampleData := []byte("Sample Data") + samplePayload := &testbench.Payload{Bytes: sampleData} + + // Send a segment with OTW Seq / unacc ACK. + conn.Send(t, tt.makeTestingTCP(t, &conn, tt.seqNumOffset, windowSize), samplePayload) + gotAck, err := conn.Expect(t, testbench.TCP{Flags: testbench.Uint8(header.TCPFlagAck)}, time.Second) + if tt.expectAck && err != nil { + t.Errorf("expected an ack but got none: %s", err) + } + if err == nil && !tt.expectAck && gotAck != nil { + t.Errorf("expected no ack but got one: %s", gotAck) + } + + // Now let's verify DUT is indeed in CLOSE_WAIT + dut.Close(t, acceptFD) + if _, err := conn.Expect(t, testbench.TCP{Flags: testbench.Uint8(header.TCPFlagAck | header.TCPFlagFin)}, time.Second); err != nil { + t.Fatalf("expected DUT to send a FIN: %s", err) + } + // Ack the FIN from DUT + conn.Send(t, testbench.TCP{Flags: testbench.Uint8(header.TCPFlagAck)}) + // Send some extra data to DUT + conn.Send(t, testbench.TCP{Flags: testbench.Uint8(header.TCPFlagAck)}, samplePayload) + if _, err := conn.Expect(t, testbench.TCP{Flags: testbench.Uint8(header.TCPFlagRst)}, time.Second); err != nil { + t.Fatalf("expected DUT to send an RST: %s", err) + } + }) + } +} + +func TestActiveCloseUnaccpSeqAck(t *testing.T) { + for _, tt := range []struct { + description string + makeTestingTCP func(t *testing.T, conn *testbench.TCPIPv4, seqNumOffset, windowSize seqnum.Size) testbench.TCP + seqNumOffset seqnum.Size + restoreSeq bool + }{ + {description: "OTWSeq", makeTestingTCP: generateOTWSeqSegment, seqNumOffset: 0, restoreSeq: true}, + {description: "OTWSeq", makeTestingTCP: generateOTWSeqSegment, seqNumOffset: 1, restoreSeq: true}, + {description: "OTWSeq", makeTestingTCP: generateOTWSeqSegment, seqNumOffset: 2, restoreSeq: true}, + {description: "UnaccAck", makeTestingTCP: generateUnaccACKSegment, seqNumOffset: 0, restoreSeq: false}, + {description: "UnaccAck", makeTestingTCP: generateUnaccACKSegment, seqNumOffset: 1, restoreSeq: true}, + {description: "UnaccAck", makeTestingTCP: generateUnaccACKSegment, seqNumOffset: 2, restoreSeq: true}, + } { + t.Run(fmt.Sprintf("%s:offset=%d", tt.description, tt.seqNumOffset), func(t *testing.T) { + dut := testbench.NewDUT(t) + defer dut.TearDown() + listenFD, remotePort := dut.CreateListener(t, unix.SOCK_STREAM, unix.IPPROTO_TCP, 1 /*backlog*/) + defer dut.Close(t, listenFD) + conn := testbench.NewTCPIPv4(t, testbench.TCP{DstPort: &remotePort}, testbench.TCP{SrcPort: &remotePort}) + defer conn.Close(t) + + conn.Connect(t) + acceptFD, _ := dut.Accept(t, listenFD) + + // Trigger active close. + dut.Shutdown(t, acceptFD, syscall.SHUT_WR) + + // Get to FIN_WAIT2 + gotTCP, err := conn.Expect(t, testbench.TCP{Flags: testbench.Uint8(header.TCPFlagFin | header.TCPFlagAck)}, time.Second) + if err != nil { + t.Fatalf("expected a FIN: %s", err) + } + conn.Send(t, testbench.TCP{Flags: testbench.Uint8(header.TCPFlagAck)}) + + sendUnaccSeqAck := func(state string) { + t.Helper() + sampleData := []byte("Sample Data") + samplePayload := &testbench.Payload{Bytes: sampleData} + + origSeq := *conn.LocalSeqNum(t) + // Send a segment with OTW Seq / unacc ACK. + conn.Send(t, tt.makeTestingTCP(t, &conn, tt.seqNumOffset, seqnum.Size(*gotTCP.WindowSize)), samplePayload) + if tt.restoreSeq { + // Restore the local sequence number to ensure that the + // incoming ACK matches the TCP layer state. + *conn.LocalSeqNum(t) = origSeq + } + if _, err := conn.Expect(t, testbench.TCP{Flags: testbench.Uint8(header.TCPFlagAck)}, time.Second); err != nil { + t.Errorf("expected an ack in %s state, but got none: %s", state, err) + } + } + + sendUnaccSeqAck("FIN_WAIT2") + + // Send a FIN to DUT to get to TIME_WAIT + conn.Send(t, testbench.TCP{Flags: testbench.Uint8(header.TCPFlagFin | header.TCPFlagAck)}) + if _, err := conn.Expect(t, testbench.TCP{Flags: testbench.Uint8(header.TCPFlagAck)}, time.Second); err != nil { + t.Fatalf("expected an ACK for our fin and DUT should enter TIME_WAIT: %s", err) + } + + sendUnaccSeqAck("TIME_WAIT") + }) + } +} + +// generateOTWSeqSegment generates an segment with +// seqnum = RCV.NXT + RCV.WND + seqNumOffset, the generated segment is only +// acceptable when seqNumOffset is 0, otherwise an ACK is expected from the +// receiver. +func generateOTWSeqSegment(t *testing.T, conn *testbench.TCPIPv4, seqNumOffset seqnum.Size, windowSize seqnum.Size) testbench.TCP { + lastAcceptable := conn.LocalSeqNum(t).Add(windowSize) + otwSeq := uint32(lastAcceptable.Add(seqNumOffset)) + return testbench.TCP{SeqNum: testbench.Uint32(otwSeq), Flags: testbench.Uint8(header.TCPFlagAck)} +} + +// generateUnaccACKSegment generates an segment with +// acknum = SND.NXT + seqNumOffset, the generated segment is only acceptable +// when seqNumOffset is 0, otherwise an ACK is expected from the receiver. +func generateUnaccACKSegment(t *testing.T, conn *testbench.TCPIPv4, seqNumOffset seqnum.Size, windowSize seqnum.Size) testbench.TCP { + lastAcceptable := conn.RemoteSeqNum(t) + unaccAck := uint32(lastAcceptable.Add(seqNumOffset)) + return testbench.TCP{AckNum: testbench.Uint32(unaccAck), Flags: testbench.Uint8(header.TCPFlagAck)} +} diff --git a/test/packetimpact/tests/udp_discard_mcast_source_addr_test.go b/test/packetimpact/tests/udp_discard_mcast_source_addr_test.go index d30177e64..3d2791a6e 100644 --- a/test/packetimpact/tests/udp_discard_mcast_source_addr_test.go +++ b/test/packetimpact/tests/udp_discard_mcast_source_addr_test.go @@ -53,6 +53,7 @@ func TestDiscardsUDPPacketsWithMcastSourceAddressV4(t *testing.T) { t, testbench.IPv4{SrcAddr: testbench.Address(tcpip.Address(mcastAddr.To4()))}, testbench.UDP{}, + &testbench.Payload{Bytes: []byte("test payload")}, ) ret, payload, errno := dut.RecvWithErrno(context.Background(), t, remoteFD, 100, 0) @@ -76,14 +77,15 @@ func TestDiscardsUDPPacketsWithMcastSourceAddressV6(t *testing.T) { net.IPv6interfacelocalallnodes, net.IPv6linklocalallnodes, net.IPv6linklocalallrouters, - net.ParseIP("fe01::42"), - net.ParseIP("fe02::4242"), + net.ParseIP("ff01::42"), + net.ParseIP("ff02::4242"), } { t.Run(fmt.Sprintf("srcaddr=%s", mcastAddr), func(t *testing.T) { conn.SendIPv6( t, testbench.IPv6{SrcAddr: testbench.Address(tcpip.Address(mcastAddr.To16()))}, testbench.UDP{}, + &testbench.Payload{Bytes: []byte("test payload")}, ) ret, payload, errno := dut.RecvWithErrno(context.Background(), t, remoteFD, 100, 0) if errno != syscall.EAGAIN || errno != syscall.EWOULDBLOCK { diff --git a/test/perf/BUILD b/test/perf/BUILD index 471d8c2ab..b763be50e 100644 --- a/test/perf/BUILD +++ b/test/perf/BUILD @@ -3,33 +3,40 @@ load("//test/runner:defs.bzl", "syscall_test") package(licenses = ["notice"]) syscall_test( + debug = False, test = "//test/perf/linux:clock_getres_benchmark", ) syscall_test( + debug = False, test = "//test/perf/linux:clock_gettime_benchmark", ) syscall_test( + debug = False, test = "//test/perf/linux:death_benchmark", ) syscall_test( + debug = False, test = "//test/perf/linux:epoll_benchmark", ) syscall_test( size = "large", + debug = False, test = "//test/perf/linux:fork_benchmark", ) syscall_test( size = "large", + debug = False, test = "//test/perf/linux:futex_benchmark", ) syscall_test( size = "enormous", + debug = False, shard_count = 10, tags = ["nogotsan"], test = "//test/perf/linux:getdents_benchmark", @@ -37,81 +44,96 @@ syscall_test( syscall_test( size = "large", + debug = False, test = "//test/perf/linux:getpid_benchmark", ) syscall_test( size = "enormous", + debug = False, tags = ["nogotsan"], test = "//test/perf/linux:gettid_benchmark", ) syscall_test( size = "large", + debug = False, test = "//test/perf/linux:mapping_benchmark", ) syscall_test( size = "large", add_overlay = True, + debug = False, test = "//test/perf/linux:open_benchmark", ) syscall_test( + debug = False, test = "//test/perf/linux:pipe_benchmark", ) syscall_test( size = "large", add_overlay = True, + debug = False, test = "//test/perf/linux:randread_benchmark", ) syscall_test( size = "large", add_overlay = True, + debug = False, test = "//test/perf/linux:read_benchmark", ) syscall_test( size = "large", + debug = False, test = "//test/perf/linux:sched_yield_benchmark", ) syscall_test( size = "large", + debug = False, test = "//test/perf/linux:send_recv_benchmark", ) syscall_test( size = "large", add_overlay = True, + debug = False, test = "//test/perf/linux:seqwrite_benchmark", ) syscall_test( size = "enormous", + debug = False, test = "//test/perf/linux:signal_benchmark", ) syscall_test( + debug = False, test = "//test/perf/linux:sleep_benchmark", ) syscall_test( size = "large", add_overlay = True, + debug = False, test = "//test/perf/linux:stat_benchmark", ) syscall_test( size = "enormous", add_overlay = True, + debug = False, test = "//test/perf/linux:unlink_benchmark", ) syscall_test( size = "large", add_overlay = True, + debug = False, test = "//test/perf/linux:write_benchmark", ) diff --git a/test/perf/linux/BUILD b/test/perf/linux/BUILD index b4e907826..dd1d2438c 100644 --- a/test/perf/linux/BUILD +++ b/test/perf/linux/BUILD @@ -354,3 +354,19 @@ cc_binary( "//test/util:test_util", ], ) + +cc_binary( + name = "open_read_close_benchmark", + testonly = 1, + srcs = [ + "open_read_close_benchmark.cc", + ], + deps = [ + gbenchmark, + gtest, + "//test/util:fs_util", + "//test/util:logging", + "//test/util:temp_path", + "//test/util:test_main", + ], +) diff --git a/test/perf/linux/getdents_benchmark.cc b/test/perf/linux/getdents_benchmark.cc index d8e81fa8c..9030eb356 100644 --- a/test/perf/linux/getdents_benchmark.cc +++ b/test/perf/linux/getdents_benchmark.cc @@ -105,7 +105,7 @@ void BM_GetdentsSameFD(benchmark::State& state) { state.SetItemsProcessed(state.iterations()); } -BENCHMARK(BM_GetdentsSameFD)->Range(1, 1 << 16)->UseRealTime(); +BENCHMARK(BM_GetdentsSameFD)->Range(1, 1 << 12)->UseRealTime(); // Creates a directory containing `files` files, and reads all the directory // entries from the directory using a new FD each time. diff --git a/test/perf/linux/open_read_close_benchmark.cc b/test/perf/linux/open_read_close_benchmark.cc new file mode 100644 index 000000000..8b023a3d8 --- /dev/null +++ b/test/perf/linux/open_read_close_benchmark.cc @@ -0,0 +1,61 @@ +// Copyright 2020 The gVisor Authors. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include <fcntl.h> +#include <stdlib.h> +#include <unistd.h> + +#include <memory> +#include <string> +#include <vector> + +#include "gtest/gtest.h" +#include "benchmark/benchmark.h" +#include "test/util/fs_util.h" +#include "test/util/logging.h" +#include "test/util/temp_path.h" + +namespace gvisor { +namespace testing { + +namespace { + +void BM_OpenReadClose(benchmark::State& state) { + const int size = state.range(0); + std::vector<TempPath> cache; + for (int i = 0; i < size; i++) { + auto path = ASSERT_NO_ERRNO_AND_VALUE(TempPath::CreateFileWith( + GetAbsoluteTestTmpdir(), "some content", 0644)); + cache.emplace_back(std::move(path)); + } + + char buf[1]; + unsigned int seed = 1; + for (auto _ : state) { + const int chosen = rand_r(&seed) % size; + int fd = open(cache[chosen].path().c_str(), O_RDONLY); + TEST_CHECK(fd != -1); + TEST_CHECK(read(fd, buf, 1) == 1); + close(fd); + } +} + +// Gofer dentry cache is 1000 by default. Go over it to force files to be closed +// for real. +BENCHMARK(BM_OpenReadClose)->Range(1000, 16384)->UseRealTime(); + +} // namespace + +} // namespace testing +} // namespace gvisor diff --git a/test/root/crictl_test.go b/test/root/crictl_test.go index df91fa0fe..11ac5cb52 100644 --- a/test/root/crictl_test.go +++ b/test/root/crictl_test.go @@ -418,7 +418,7 @@ func setup(t *testing.T, version string) (*criutil.Crictl, func(), error) { // care about the docker runtime name. config = v2Template default: - t.Fatalf("unknown version: %d", version) + t.Fatalf("unknown version: %s", version) } t.Logf("Using config: %s", config) diff --git a/test/runner/defs.bzl b/test/runner/defs.bzl index 2d64934b0..032ebd04e 100644 --- a/test/runner/defs.bzl +++ b/test/runner/defs.bzl @@ -62,7 +62,8 @@ def _syscall_test( overlay = False, add_uds_tree = False, vfs2 = False, - fuse = False): + fuse = False, + debug = True): # Prepend "runsc" to non-native platform names. full_platform = platform if platform == "native" else "runsc_" + platform @@ -111,6 +112,8 @@ def _syscall_test( "--add-uds-tree=" + str(add_uds_tree), "--vfs2=" + str(vfs2), "--fuse=" + str(fuse), + "--strace=" + str(debug), + "--debug=" + str(debug), ] # Call the rule above. @@ -134,6 +137,7 @@ def syscall_test( add_hostinet = False, vfs2 = True, fuse = False, + debug = True, tags = None): """syscall_test is a macro that will create targets for all platforms. diff --git a/test/runner/runner.go b/test/runner/runner.go index 5ac91310d..22d535f8d 100644 --- a/test/runner/runner.go +++ b/test/runner/runner.go @@ -106,11 +106,14 @@ func runTestCaseNative(testBin string, tc gtest.TestCase, t *testing.T) { cmd.Env = env cmd.Stdout = os.Stdout cmd.Stderr = os.Stderr + cmd.SysProcAttr = &syscall.SysProcAttr{} + + if specutils.HasCapabilities(capability.CAP_SYS_ADMIN) { + cmd.SysProcAttr.Cloneflags |= syscall.CLONE_NEWUTS + } if specutils.HasCapabilities(capability.CAP_NET_ADMIN) { - cmd.SysProcAttr = &syscall.SysProcAttr{ - Cloneflags: syscall.CLONE_NEWNET, - } + cmd.SysProcAttr.Cloneflags |= syscall.CLONE_NEWNET } if err := cmd.Run(); err != nil { diff --git a/test/runtimes/BUILD b/test/runtimes/BUILD index 066338ee3..22b526f59 100644 --- a/test/runtimes/BUILD +++ b/test/runtimes/BUILD @@ -5,7 +5,7 @@ package(licenses = ["notice"]) runtime_test( name = "go1.12", - exclude_file = "exclude_go1.12.csv", + exclude_file = "exclude/go1.12.csv", lang = "go", shard_count = 8, ) @@ -13,28 +13,28 @@ runtime_test( runtime_test( name = "java11", batch = 100, - exclude_file = "exclude_java11.csv", + exclude_file = "exclude/java11.csv", lang = "java", shard_count = 16, ) runtime_test( name = "nodejs12.4.0", - exclude_file = "exclude_nodejs12.4.0.csv", + exclude_file = "exclude/nodejs12.4.0.csv", lang = "nodejs", shard_count = 8, ) runtime_test( name = "php7.3.6", - exclude_file = "exclude_php7.3.6.csv", + exclude_file = "exclude/php7.3.6.csv", lang = "php", shard_count = 8, ) runtime_test( name = "python3.7.3", - exclude_file = "exclude_python3.7.3.csv", + exclude_file = "exclude/python3.7.3.csv", lang = "python", shard_count = 8, ) diff --git a/test/runtimes/README.md b/test/runtimes/README.md new file mode 100644 index 000000000..9dda1a728 --- /dev/null +++ b/test/runtimes/README.md @@ -0,0 +1,62 @@ +# gVisor Runtime Tests + +App Engine uses gvisor to sandbox application containers. The runtime tests aim +to test `runsc` compatibility with these +[standard runtimes](https://cloud.google.com/appengine/docs/standard/runtimes). +The test itself runs the language-defined tests inside the sandboxed standard +runtime container. + +Note: [Ruby runtime](https://cloud.google.com/appengine/docs/standard/ruby) is +currently in beta mode and so we do not run tests for it yet. + +### Testing Locally + +To run runtime tests individually from a given runtime, use the following table. + +Language | Version | Download Image | Run Test(s) +-------- | ------- | ------------------------------------------- | ----------- +Go | 1.12 | `make -C images load-runtimes_go1.12` | If the test name ends with `.go`, it is an on-disk test: <br> `docker run --runtime=runsc -it gvisor.dev/images/runtimes/go1.12 ( cd /usr/local/go/test ; go run run.go -v -- <TEST_NAME>... )` <br> Otherwise it is a tool test: <br> `docker run --runtime=runsc -it gvisor.dev/images/runtimes/go1.12 go tool dist test -v -no-rebuild ^TEST1$\|^TEST2$...` +Java | 11 | `make -C images load-runtimes_java11` | `docker run --runtime=runsc -it gvisor.dev/images/runtimes/java11 jtreg -agentvm -dir:/root/test/jdk -noreport -timeoutFactor:20 -verbose:summary <TEST_NAME>...` +NodeJS | 12.4.0 | `make -C images load-runtimes_nodejs12.4.0` | `docker run --runtime=runsc -it gvisor.dev/images/runtimes/nodejs12.4.0 python tools/test.py --timeout=180 <TEST_NAME>...` +Php | 7.3.6 | `make -C images load-runtimes_php7.3.6` | `docker run --runtime=runsc -it gvisor.dev/images/runtimes/php7.3.6 make test "TESTS=<TEST_NAME>..."` +Python | 3.7.3 | `make -C images load-runtimes_python3.7.3` | `docker run --runtime=runsc -it gvisor.dev/images/runtimes/python3.7.3 ./python -m test <TEST_NAME>...` + +To run an entire runtime test locally, use the following table. + +Note: java runtime test take 1+ hours with 16 cores. + +Language | Version | Running the test suite +-------- | ------- | ---------------------------------------- +Go | 1.12 | `make go1.12-runtime-tests{_vfs2}` +Java | 11 | `make java11-runtime-tests{_vfs2}` +NodeJS | 12.4.0 | `make nodejs12.4.0-runtime-tests{_vfs2}` +Php | 7.3.6 | `make php7.3.6-runtime-tests{_vfs2}` +Python | 3.7.3 | `make python3.7.3-runtime-tests{_vfs2}` + +#### Clean Up + +Sometimes when runtime tests fail or when the testing container itself crashes +unexpectedly, the containers are not removed or sometimes do not even exit. This +can cause some docker commands like `docker system prune` to hang forever. + +Here are some helpful commands (should be executed in order): + +```bash +docker ps -a # Lists all docker processes; useful when investigating hanging containers. +docker kill $(docker ps -a -q) # Kills all running containers. +docker rm $(docker ps -a -q) # Removes all exited containers. +docker system prune # Remove unused data. +``` + +### Testing Infrastructure + +There are 3 components to this tests infrastructure: + +- [`runner`](runner) - This is the test entrypoint. This is the binary is + invoked by `bazel test`. The runner spawns the target runtime container + using `runsc` and then copies over the `proctor` binary into the container. +- [`proctor`](proctor) - This binary acts as our agent inside the container + which communicates with the runner and actually executes tests. +- [`exclude`](exclude) - Holds a CSV file for each language runtime containing + the full path of tests that should be excluded from running along with a + reason for exclusion. diff --git a/test/runtimes/exclude_go1.12.csv b/test/runtimes/exclude/go1.12.csv index 81e02cf64..81e02cf64 100644 --- a/test/runtimes/exclude_go1.12.csv +++ b/test/runtimes/exclude/go1.12.csv diff --git a/test/runtimes/exclude_java11.csv b/test/runtimes/exclude/java11.csv index 997a29cad..997a29cad 100644 --- a/test/runtimes/exclude_java11.csv +++ b/test/runtimes/exclude/java11.csv diff --git a/test/runtimes/exclude_nodejs12.4.0.csv b/test/runtimes/exclude/nodejs12.4.0.csv index 1d8e65fd0..749fb9482 100644 --- a/test/runtimes/exclude_nodejs12.4.0.csv +++ b/test/runtimes/exclude/nodejs12.4.0.csv @@ -1,4 +1,5 @@ test name,bug id,comment +async-hooks/test-statwatcher.js,https://github.com/nodejs/node/issues/21425,Check for fix inclusion in nodejs releases after 2020-03-29 benchmark/test-benchmark-fs.js,, benchmark/test-benchmark-napi.js,, doctool/test-make-doc.js,b/68848110,Expected to fail. @@ -11,8 +12,9 @@ parallel/test-dgram-socket-buffer-size.js,b/68847921, parallel/test-dns-channel-timeout.js,b/161893056, parallel/test-fs-access.js,, parallel/test-fs-watchfile.js,,Flaky - File already exists error -parallel/test-fs-write-stream.js,,Flaky -parallel/test-fs-write-stream-throw-type-error.js,b/110226209, +parallel/test-fs-write-stream.js,b/166819807,Flaky +parallel/test-fs-write-stream-double-close,b/166819807,Flaky +parallel/test-fs-write-stream-throw-type-error.js,b/166819807,Flaky parallel/test-http-writable-true-after-close.js,,Flaky - Mismatched <anonymous> function calls. Expected exactly 1 actual 2 parallel/test-os.js,b/63997097, parallel/test-net-server-listen-options.js,,Flaky - EADDRINUSE @@ -49,6 +51,7 @@ pseudo-tty/test-tty-wrap.js,b/162801321, pummel/test-heapdump-http2.js,,Flaky pummel/test-net-pingpong.js,, pummel/test-vm-memleak.js,b/162799436, +pummel/test-watch-file.js,,Flaky - Timeout sequential/test-child-process-pass-fd.js,b/63926391,Flaky sequential/test-https-connect-localport.js,,Flaky - EADDRINUSE sequential/test-net-bytes-per-incoming-chunk-overhead.js,,flaky - timeout diff --git a/test/runtimes/exclude_php7.3.6.csv b/test/runtimes/exclude/php7.3.6.csv index 2ce979dc8..a73f3bcfb 100644 --- a/test/runtimes/exclude_php7.3.6.csv +++ b/test/runtimes/exclude/php7.3.6.csv @@ -13,6 +13,13 @@ ext/session/tests/session_set_save_handler_class_018.phpt,, ext/session/tests/session_set_save_handler_iface_003.phpt,, ext/session/tests/session_set_save_handler_sid_001.phpt,, ext/session/tests/session_set_save_handler_variation4.phpt,, +ext/standard/tests/file/disk.phpt,https://bugs.php.net/bug.php?id=80018, +ext/standard/tests/file/disk_free_space_basic.phpt,https://bugs.php.net/bug.php?id=80018, +ext/standard/tests/file/disk_free_space_error.phpt,https://bugs.php.net/bug.php?id=80018, +ext/standard/tests/file/disk_free_space_variation.phpt,https://bugs.php.net/bug.php?id=80018, +ext/standard/tests/file/disk_total_space_basic.phpt,https://bugs.php.net/bug.php?id=80018, +ext/standard/tests/file/disk_total_space_error.phpt,https://bugs.php.net/bug.php?id=80018, +ext/standard/tests/file/disk_total_space_variation.phpt,https://bugs.php.net/bug.php?id=80018, ext/standard/tests/file/fopen_variation19.phpt,b/162894964, ext/standard/tests/file/lstat_stat_variation14.phpt,,Flaky ext/standard/tests/file/php_fd_wrapper_01.phpt,, @@ -25,6 +32,7 @@ ext/standard/tests/file/symlink_link_linkinfo_is_link_variation4.phpt,b/16289534 ext/standard/tests/file/symlink_link_linkinfo_is_link_variation8.phpt,b/162896223, ext/standard/tests/general_functions/escapeshellarg_bug71270.phpt,, ext/standard/tests/general_functions/escapeshellcmd_bug71270.phpt,, +ext/standard/tests/streams/proc_open_bug60120.phpt,,Flaky until php-src 3852a35fdbcb ext/standard/tests/streams/proc_open_bug69900.phpt,,Flaky ext/standard/tests/streams/stream_socket_sendto.phpt,, ext/standard/tests/strings/007.phpt,, diff --git a/test/runtimes/exclude_python3.7.3.csv b/test/runtimes/exclude/python3.7.3.csv index 8760f8951..8760f8951 100644 --- a/test/runtimes/exclude_python3.7.3.csv +++ b/test/runtimes/exclude/python3.7.3.csv diff --git a/test/runtimes/proctor/BUILD b/test/runtimes/proctor/BUILD index f76e2ddc0..d1935cbe8 100644 --- a/test/runtimes/proctor/BUILD +++ b/test/runtimes/proctor/BUILD @@ -21,6 +21,7 @@ go_test( size = "small", srcs = ["proctor_test.go"], library = ":proctor", + nogo = False, # FIXME(gvisor.dev/issue/3374): Not working with all build systems. pure = True, deps = [ "//pkg/test/testutil", diff --git a/test/syscalls/BUILD b/test/syscalls/BUILD index 0eadc6b08..f949bc0e3 100644 --- a/test/syscalls/BUILD +++ b/test/syscalls/BUILD @@ -64,6 +64,8 @@ syscall_test( syscall_test( size = "large", + # Produce too many logs in the debug mode. + debug = False, shard_count = 50, # Takes too long for TSAN. Since this is kind of a stress test that doesn't # involve much concurrency, TSAN's usefulness here is limited anyway. @@ -251,12 +253,20 @@ syscall_test( ) syscall_test( + test = "//test/syscalls/linux:ip6tables_test", +) + +syscall_test( size = "large", shard_count = 5, test = "//test/syscalls/linux:itimer_test", ) syscall_test( + test = "//test/syscalls/linux:kcov_test", +) + +syscall_test( test = "//test/syscalls/linux:kill_test", ) @@ -662,6 +672,21 @@ syscall_test( ) syscall_test( + size = "medium", + # Takes too long under gotsan to run. + tags = ["nogotsan"], + test = "//test/syscalls/linux:socket_ipv4_udp_unbound_loopback_nogotsan_test", +) + +syscall_test( + test = "//test/syscalls/linux:socket_ipv4_udp_unbound_loopback_netlink_test", +) + +syscall_test( + test = "//test/syscalls/linux:socket_ipv6_udp_unbound_loopback_netlink_test", +) + +syscall_test( test = "//test/syscalls/linux:socket_ip_unbound_test", ) @@ -778,6 +803,7 @@ syscall_test( syscall_test( add_overlay = True, test = "//test/syscalls/linux:statfs_test", + use_tmpfs = True, # Test specifically relies on TEST_TMPDIR to be tmpfs. ) syscall_test( diff --git a/test/syscalls/linux/BUILD b/test/syscalls/linux/BUILD index 66a31cd28..451feb8f5 100644 --- a/test/syscalls/linux/BUILD +++ b/test/syscalls/linux/BUILD @@ -22,6 +22,7 @@ exports_files( "socket_ipv4_tcp_unbound_external_networking_test.cc", "socket_ipv4_udp_unbound_external_networking_test.cc", "socket_ipv4_udp_unbound_loopback.cc", + "socket_ipv4_udp_unbound_loopback_nogotsan.cc", "tcp_socket.cc", "udp_bind.cc", "udp_socket.cc", @@ -1030,6 +1031,24 @@ cc_binary( ) cc_binary( + name = "ip6tables_test", + testonly = 1, + srcs = [ + "ip6tables.cc", + ], + linkstatic = 1, + deps = [ + ":iptables_types", + ":socket_test_util", + "//test/util:capability_util", + "//test/util:file_descriptor", + gtest, + "//test/util:test_main", + "//test/util:test_util", + ], +) + +cc_binary( name = "itimer_test", testonly = 1, srcs = ["itimer.cc"], @@ -1050,6 +1069,20 @@ cc_binary( ) cc_binary( + name = "kcov_test", + testonly = 1, + srcs = ["kcov.cc"], + linkstatic = 1, + deps = [ + "//test/util:capability_util", + "//test/util:file_descriptor", + gtest, + "//test/util:test_main", + "//test/util:test_util", + ], +) + +cc_binary( name = "kill_test", testonly = 1, srcs = ["kill.cc"], @@ -1640,6 +1673,7 @@ cc_binary( gtest, "//test/util:memory_util", "//test/util:posix_error", + "//test/util:proc_util", "//test/util:temp_path", "//test/util:test_util", "//test/util:thread_util", @@ -1862,6 +1896,7 @@ cc_binary( srcs = ["readahead.cc"], linkstatic = 1, deps = [ + ":socket_test_util", "//test/util:file_descriptor", gtest, "//test/util:temp_path", @@ -1953,6 +1988,7 @@ cc_binary( gtest, "//test/util:logging", "//test/util:multiprocess_util", + "//test/util:posix_error", "//test/util:test_main", "//test/util:test_util", ], @@ -2383,6 +2419,43 @@ cc_library( ) cc_library( + name = "socket_ipv4_udp_unbound_netlink_test_cases", + testonly = 1, + srcs = [ + "socket_ipv4_udp_unbound_netlink.cc", + ], + hdrs = [ + "socket_ipv4_udp_unbound_netlink.h", + ], + deps = [ + ":socket_netlink_route_util", + ":socket_test_util", + "//test/util:capability_util", + "//test/util:cleanup", + gtest, + ], + alwayslink = 1, +) + +cc_library( + name = "socket_ipv6_udp_unbound_netlink_test_cases", + testonly = 1, + srcs = [ + "socket_ipv6_udp_unbound_netlink.cc", + ], + hdrs = [ + "socket_ipv6_udp_unbound_netlink.h", + ], + deps = [ + ":socket_netlink_route_util", + ":socket_test_util", + "//test/util:capability_util", + gtest, + ], + alwayslink = 1, +) + +cc_library( name = "socket_ipv4_udp_unbound_external_networking_test_cases", testonly = 1, srcs = [ @@ -2719,6 +2792,55 @@ cc_binary( ) cc_binary( + name = "socket_ipv4_udp_unbound_loopback_nogotsan_test", + testonly = 1, + srcs = [ + "socket_ipv4_udp_unbound_loopback_nogotsan.cc", + ], + linkstatic = 1, + deps = [ + ":ip_socket_test_util", + ":socket_test_util", + gtest, + "//test/util:test_main", + "//test/util:test_util", + "@com_google_absl//absl/memory", + ], +) + +cc_binary( + name = "socket_ipv4_udp_unbound_loopback_netlink_test", + testonly = 1, + srcs = [ + "socket_ipv4_udp_unbound_loopback_netlink.cc", + ], + linkstatic = 1, + deps = [ + ":ip_socket_test_util", + ":socket_ipv4_udp_unbound_netlink_test_cases", + ":socket_test_util", + "//test/util:test_main", + "//test/util:test_util", + ], +) + +cc_binary( + name = "socket_ipv6_udp_unbound_loopback_netlink_test", + testonly = 1, + srcs = [ + "socket_ipv6_udp_unbound_loopback_netlink.cc", + ], + linkstatic = 1, + deps = [ + ":ip_socket_test_util", + ":socket_ipv6_udp_unbound_netlink_test_cases", + ":socket_test_util", + "//test/util:test_main", + "//test/util:test_util", + ], +) + +cc_binary( name = "socket_ip_unbound_test", testonly = 1, srcs = [ @@ -3547,15 +3669,12 @@ cc_binary( ], ) -cc_library( - name = "udp_socket_test_cases", +cc_binary( + name = "udp_socket_test", testonly = 1, - srcs = [ - "udp_socket_errqueue_test_case.cc", - "udp_socket_test_cases.cc", - ], - hdrs = ["udp_socket_test_cases.h"], + srcs = ["udp_socket.cc"], defines = select_system(), + linkstatic = 1, deps = [ ":ip_socket_test_util", ":socket_test_util", @@ -3570,17 +3689,6 @@ cc_library( "//test/util:test_util", "//test/util:thread_util", ], - alwayslink = 1, -) - -cc_binary( - name = "udp_socket_test", - testonly = 1, - srcs = ["udp_socket.cc"], - linkstatic = 1, - deps = [ - ":udp_socket_test_cases", - ], ) cc_binary( diff --git a/test/syscalls/linux/exec_binary.cc b/test/syscalls/linux/exec_binary.cc index 18d2f22c1..3797fd4c8 100644 --- a/test/syscalls/linux/exec_binary.cc +++ b/test/syscalls/linux/exec_binary.cc @@ -1042,6 +1042,13 @@ class ElfInterpreterStaticTest // Statically linked ELF with a statically linked ELF interpreter. TEST_P(ElfInterpreterStaticTest, Test) { + // TODO(gvisor.dev/issue/3721): Test has been observed to segfault on 5.X + // kernels. + if (!IsRunningOnGvisor()) { + auto version = ASSERT_NO_ERRNO_AND_VALUE(GetKernelVersion()); + SKIP_IF(version.major > 4); + } + const std::vector<char> segment_suffix = std::get<0>(GetParam()); const int expected_errno = std::get<1>(GetParam()); diff --git a/test/syscalls/linux/fallocate.cc b/test/syscalls/linux/fallocate.cc index cabc2b751..edd23e063 100644 --- a/test/syscalls/linux/fallocate.cc +++ b/test/syscalls/linux/fallocate.cc @@ -179,6 +179,12 @@ TEST_F(AllocateTest, FallocateOtherFDs) { auto sock0 = FileDescriptor(socks[0]); auto sock1 = FileDescriptor(socks[1]); EXPECT_THAT(fallocate(sock0.get(), 0, 0, 10), SyscallFailsWithErrno(ENODEV)); + + int pipefds[2]; + ASSERT_THAT(pipe(pipefds), SyscallSucceeds()); + EXPECT_THAT(fallocate(pipefds[1], 0, 0, 10), SyscallFailsWithErrno(ESPIPE)); + close(pipefds[0]); + close(pipefds[1]); } } // namespace diff --git a/test/syscalls/linux/flock.cc b/test/syscalls/linux/flock.cc index 638a93979..549141cbb 100644 --- a/test/syscalls/linux/flock.cc +++ b/test/syscalls/linux/flock.cc @@ -185,7 +185,7 @@ TEST_F(FlockTest, TestMultipleHolderSharedExclusive) { ASSERT_THAT(flock(test_file_fd_.get(), LOCK_UN), SyscallSucceedsWithValue(0)); } -TEST_F(FlockTest, TestSharedLockFailExclusiveHolder) { +TEST_F(FlockTest, TestSharedLockFailExclusiveHolderNonblocking) { // This test will verify that a shared lock is denied while // someone holds an exclusive lock. ASSERT_THAT(flock(test_file_fd_.get(), LOCK_EX | LOCK_NB), @@ -203,7 +203,33 @@ TEST_F(FlockTest, TestSharedLockFailExclusiveHolder) { ASSERT_THAT(flock(test_file_fd_.get(), LOCK_UN), SyscallSucceedsWithValue(0)); } -TEST_F(FlockTest, TestExclusiveLockFailExclusiveHolder) { +void trivial_handler(int signum) {} + +TEST_F(FlockTest, TestSharedLockFailExclusiveHolderBlocking_NoRandomSave) { + const DisableSave ds; // Timing-related. + + // This test will verify that a shared lock is denied while + // someone holds an exclusive lock. + ASSERT_THAT(flock(test_file_fd_.get(), LOCK_EX | LOCK_NB), + SyscallSucceedsWithValue(0)); + + const FileDescriptor fd = + ASSERT_NO_ERRNO_AND_VALUE(Open(test_file_name_, O_RDWR)); + + // Register a signal handler for SIGALRM and set an alarm that will go off + // while blocking in the subsequent flock() call. This will interrupt flock() + // and cause it to return EINTR. + struct sigaction act = {}; + act.sa_handler = trivial_handler; + ASSERT_THAT(sigaction(SIGALRM, &act, NULL), SyscallSucceeds()); + ASSERT_THAT(ualarm(10000, 0), SyscallSucceeds()); + ASSERT_THAT(flock(fd.get(), LOCK_SH), SyscallFailsWithErrno(EINTR)); + + // Unlock + ASSERT_THAT(flock(test_file_fd_.get(), LOCK_UN), SyscallSucceedsWithValue(0)); +} + +TEST_F(FlockTest, TestExclusiveLockFailExclusiveHolderNonblocking) { // This test will verify that an exclusive lock is denied while // someone already holds an exclsuive lock. ASSERT_THAT(flock(test_file_fd_.get(), LOCK_EX | LOCK_NB), @@ -221,6 +247,30 @@ TEST_F(FlockTest, TestExclusiveLockFailExclusiveHolder) { ASSERT_THAT(flock(test_file_fd_.get(), LOCK_UN), SyscallSucceedsWithValue(0)); } +TEST_F(FlockTest, TestExclusiveLockFailExclusiveHolderBlocking_NoRandomSave) { + const DisableSave ds; // Timing-related. + + // This test will verify that an exclusive lock is denied while + // someone already holds an exclsuive lock. + ASSERT_THAT(flock(test_file_fd_.get(), LOCK_EX | LOCK_NB), + SyscallSucceedsWithValue(0)); + + const FileDescriptor fd = + ASSERT_NO_ERRNO_AND_VALUE(Open(test_file_name_, O_RDWR)); + + // Register a signal handler for SIGALRM and set an alarm that will go off + // while blocking in the subsequent flock() call. This will interrupt flock() + // and cause it to return EINTR. + struct sigaction act = {}; + act.sa_handler = trivial_handler; + ASSERT_THAT(sigaction(SIGALRM, &act, NULL), SyscallSucceeds()); + ASSERT_THAT(ualarm(10000, 0), SyscallSucceeds()); + ASSERT_THAT(flock(fd.get(), LOCK_EX), SyscallFailsWithErrno(EINTR)); + + // Unlock + ASSERT_THAT(flock(test_file_fd_.get(), LOCK_UN), SyscallSucceedsWithValue(0)); +} + TEST_F(FlockTest, TestMultipleHolderSharedExclusiveUpgrade) { // This test will verify that we cannot obtain an exclusive lock while // a shared lock is held by another descriptor, then verify that an upgrade diff --git a/test/syscalls/linux/inotify.cc b/test/syscalls/linux/inotify.cc index 5cb325a9e..a5c421118 100644 --- a/test/syscalls/linux/inotify.cc +++ b/test/syscalls/linux/inotify.cc @@ -1371,9 +1371,10 @@ TEST(Inotify, HardlinksReuseSameWatch) { // that now and drain the resulting events. file1_fd.reset(); events = ASSERT_NO_ERRNO_AND_VALUE(DrainEvents(fd.get())); - ASSERT_THAT(events, - Are({Event(IN_CLOSE_WRITE, root_wd, Basename(file1.path())), - Event(IN_CLOSE_WRITE, file1_wd)})); + ASSERT_THAT( + events, + AreUnordered({Event(IN_CLOSE_WRITE, root_wd, Basename(file1.path())), + Event(IN_CLOSE_WRITE, file1_wd)})); // Try removing the link and let's see what events show up. Note that after // this, we still have a link to the file so the watch shouldn't be @@ -1381,8 +1382,9 @@ TEST(Inotify, HardlinksReuseSameWatch) { const std::string link1_path = link1.reset(); events = ASSERT_NO_ERRNO_AND_VALUE(DrainEvents(fd.get())); - ASSERT_THAT(events, Are({Event(IN_ATTRIB, link1_wd), - Event(IN_DELETE, root_wd, Basename(link1_path))})); + ASSERT_THAT(events, + AreUnordered({Event(IN_ATTRIB, link1_wd), + Event(IN_DELETE, root_wd, Basename(link1_path))})); // Now remove the other link. Since this is the last link to the file, the // watch should be automatically removed. @@ -1934,14 +1936,22 @@ TEST(Inotify, IncludeUnlinkedFile_NoRandomSave) { ASSERT_THAT(write(fd.get(), &val, sizeof(val)), SyscallSucceeds()); std::vector<Event> events = ASSERT_NO_ERRNO_AND_VALUE(DrainEvents(inotify_fd.get())); - EXPECT_THAT(events, Are({ - Event(IN_ATTRIB, file_wd), - Event(IN_DELETE, dir_wd, Basename(file.path())), - Event(IN_ACCESS, dir_wd, Basename(file.path())), - Event(IN_ACCESS, file_wd), - Event(IN_MODIFY, dir_wd, Basename(file.path())), - Event(IN_MODIFY, file_wd), - })); + EXPECT_THAT(events, AnyOf(Are({ + Event(IN_ATTRIB, file_wd), + Event(IN_DELETE, dir_wd, Basename(file.path())), + Event(IN_ACCESS, dir_wd, Basename(file.path())), + Event(IN_ACCESS, file_wd), + Event(IN_MODIFY, dir_wd, Basename(file.path())), + Event(IN_MODIFY, file_wd), + }), + Are({ + Event(IN_DELETE, dir_wd, Basename(file.path())), + Event(IN_ATTRIB, file_wd), + Event(IN_ACCESS, dir_wd, Basename(file.path())), + Event(IN_ACCESS, file_wd), + Event(IN_MODIFY, dir_wd, Basename(file.path())), + Event(IN_MODIFY, file_wd), + }))); fd.reset(); events = ASSERT_NO_ERRNO_AND_VALUE(DrainEvents(inotify_fd.get())); @@ -1984,7 +1994,7 @@ TEST(Inotify, ExcludeUnlink_NoRandomSave) { ASSERT_THAT(read(fd.get(), &val, sizeof(val)), SyscallSucceeds()); std::vector<Event> events = ASSERT_NO_ERRNO_AND_VALUE(DrainEvents(inotify_fd.get())); - EXPECT_THAT(events, Are({ + EXPECT_THAT(events, AreUnordered({ Event(IN_ATTRIB, file_wd), Event(IN_DELETE, dir_wd, Basename(file.path())), })); @@ -2127,12 +2137,18 @@ TEST(Inotify, ExcludeUnlinkInodeEvents_NoRandomSave) { ASSERT_THAT(ftruncate(fd.get(), 12345), SyscallSucceeds()); std::vector<Event> events = ASSERT_NO_ERRNO_AND_VALUE(DrainEvents(inotify_fd.get())); - EXPECT_THAT(events, Are({ - Event(IN_ATTRIB, file_wd), - Event(IN_DELETE, dir_wd, Basename(file.path())), - Event(IN_MODIFY, dir_wd, Basename(file.path())), - Event(IN_MODIFY, file_wd), - })); + EXPECT_THAT(events, AnyOf(Are({ + Event(IN_ATTRIB, file_wd), + Event(IN_DELETE, dir_wd, Basename(file.path())), + Event(IN_MODIFY, dir_wd, Basename(file.path())), + Event(IN_MODIFY, file_wd), + }), + Are({ + Event(IN_DELETE, dir_wd, Basename(file.path())), + Event(IN_ATTRIB, file_wd), + Event(IN_MODIFY, dir_wd, Basename(file.path())), + Event(IN_MODIFY, file_wd), + }))); const struct timeval times[2] = {{1, 0}, {2, 0}}; ASSERT_THAT(futimes(fd.get(), times), SyscallSucceeds()); diff --git a/test/syscalls/linux/ip6tables.cc b/test/syscalls/linux/ip6tables.cc new file mode 100644 index 000000000..97297ee2b --- /dev/null +++ b/test/syscalls/linux/ip6tables.cc @@ -0,0 +1,207 @@ +// Copyright 2020 The gVisor Authors. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include <linux/capability.h> +#include <sys/socket.h> + +#include "gtest/gtest.h" +#include "test/syscalls/linux/iptables.h" +#include "test/syscalls/linux/socket_test_util.h" +#include "test/util/capability_util.h" +#include "test/util/file_descriptor.h" +#include "test/util/test_util.h" + +namespace gvisor { +namespace testing { + +namespace { + +constexpr char kNatTablename[] = "nat"; +constexpr char kErrorTarget[] = "ERROR"; +constexpr size_t kEmptyStandardEntrySize = + sizeof(struct ip6t_entry) + sizeof(struct xt_standard_target); +constexpr size_t kEmptyErrorEntrySize = + sizeof(struct ip6t_entry) + sizeof(struct xt_error_target); + +TEST(IP6TablesBasic, FailSockoptNonRaw) { + // Even if the user has CAP_NET_RAW, they shouldn't be able to use the + // ip6tables sockopts with a non-raw socket. + SKIP_IF(!ASSERT_NO_ERRNO_AND_VALUE(HaveCapability(CAP_NET_RAW))); + + int sock; + ASSERT_THAT(sock = socket(AF_INET6, SOCK_DGRAM, 0), SyscallSucceeds()); + + struct ipt_getinfo info = {}; + snprintf(info.name, XT_TABLE_MAXNAMELEN, "%s", kNatTablename); + socklen_t info_size = sizeof(info); + EXPECT_THAT(getsockopt(sock, SOL_IPV6, IP6T_SO_GET_INFO, &info, &info_size), + SyscallFailsWithErrno(ENOPROTOOPT)); + + EXPECT_THAT(close(sock), SyscallSucceeds()); +} + +TEST(IP6TablesBasic, GetInfoErrorPrecedence) { + SKIP_IF(!ASSERT_NO_ERRNO_AND_VALUE(HaveCapability(CAP_NET_RAW))); + + int sock; + ASSERT_THAT(sock = socket(AF_INET6, SOCK_DGRAM, 0), SyscallSucceeds()); + + // When using the wrong type of socket and a too-short optlen, we should get + // EINVAL. + struct ipt_getinfo info = {}; + snprintf(info.name, XT_TABLE_MAXNAMELEN, "%s", kNatTablename); + socklen_t info_size = sizeof(info) - 1; + EXPECT_THAT(getsockopt(sock, SOL_IPV6, IP6T_SO_GET_INFO, &info, &info_size), + SyscallFailsWithErrno(EINVAL)); +} + +TEST(IP6TablesBasic, GetEntriesErrorPrecedence) { + SKIP_IF(!ASSERT_NO_ERRNO_AND_VALUE(HaveCapability(CAP_NET_RAW))); + + int sock; + ASSERT_THAT(sock = socket(AF_INET6, SOCK_DGRAM, 0), SyscallSucceeds()); + + // When using the wrong type of socket and a too-short optlen, we should get + // EINVAL. + struct ip6t_get_entries entries = {}; + socklen_t entries_size = sizeof(struct ip6t_get_entries) - 1; + snprintf(entries.name, XT_TABLE_MAXNAMELEN, "%s", kNatTablename); + EXPECT_THAT( + getsockopt(sock, SOL_IPV6, IP6T_SO_GET_ENTRIES, &entries, &entries_size), + SyscallFailsWithErrno(EINVAL)); +} + +// This tests the initial state of a machine with empty ip6tables via +// getsockopt(IP6T_SO_GET_INFO). We don't have a guarantee that the iptables are +// empty when running in native, but we can test that gVisor has the same +// initial state that a newly-booted Linux machine would have. +TEST(IP6TablesTest, InitialInfo) { + SKIP_IF(!ASSERT_NO_ERRNO_AND_VALUE(HaveCapability(CAP_NET_RAW))); + + FileDescriptor sock = + ASSERT_NO_ERRNO_AND_VALUE(Socket(AF_INET6, SOCK_RAW, IPPROTO_RAW)); + + // Get info via sockopt. + struct ipt_getinfo info = {}; + snprintf(info.name, XT_TABLE_MAXNAMELEN, "%s", kNatTablename); + socklen_t info_size = sizeof(info); + ASSERT_THAT( + getsockopt(sock.get(), SOL_IPV6, IP6T_SO_GET_INFO, &info, &info_size), + SyscallSucceeds()); + + // The nat table supports PREROUTING, and OUTPUT. + unsigned int valid_hooks = + (1 << NF_IP6_PRE_ROUTING) | (1 << NF_IP6_LOCAL_OUT) | + (1 << NF_IP6_POST_ROUTING) | (1 << NF_IP6_LOCAL_IN); + EXPECT_EQ(info.valid_hooks, valid_hooks); + + // Each chain consists of an empty entry with a standard target.. + EXPECT_EQ(info.hook_entry[NF_IP6_PRE_ROUTING], 0); + EXPECT_EQ(info.hook_entry[NF_IP6_LOCAL_IN], kEmptyStandardEntrySize); + EXPECT_EQ(info.hook_entry[NF_IP6_LOCAL_OUT], kEmptyStandardEntrySize * 2); + EXPECT_EQ(info.hook_entry[NF_IP6_POST_ROUTING], kEmptyStandardEntrySize * 3); + + // The underflow points are the same as the entry points. + EXPECT_EQ(info.underflow[NF_IP6_PRE_ROUTING], 0); + EXPECT_EQ(info.underflow[NF_IP6_LOCAL_IN], kEmptyStandardEntrySize); + EXPECT_EQ(info.underflow[NF_IP6_LOCAL_OUT], kEmptyStandardEntrySize * 2); + EXPECT_EQ(info.underflow[NF_IP6_POST_ROUTING], kEmptyStandardEntrySize * 3); + + // One entry for each chain, plus an error entry at the end. + EXPECT_EQ(info.num_entries, 5); + + EXPECT_EQ(info.size, 4 * kEmptyStandardEntrySize + kEmptyErrorEntrySize); + EXPECT_EQ(strcmp(info.name, kNatTablename), 0); +} + +// This tests the initial state of a machine with empty ip6tables via +// getsockopt(IP6T_SO_GET_ENTRIES). We don't have a guarantee that the iptables +// are empty when running in native, but we can test that gVisor has the same +// initial state that a newly-booted Linux machine would have. +TEST(IP6TablesTest, InitialEntries) { + SKIP_IF(!ASSERT_NO_ERRNO_AND_VALUE(HaveCapability(CAP_NET_RAW))); + + FileDescriptor sock = + ASSERT_NO_ERRNO_AND_VALUE(Socket(AF_INET6, SOCK_RAW, IPPROTO_RAW)); + + // Get info via sockopt. + struct ipt_getinfo info = {}; + snprintf(info.name, XT_TABLE_MAXNAMELEN, "%s", kNatTablename); + socklen_t info_size = sizeof(info); + ASSERT_THAT( + getsockopt(sock.get(), SOL_IPV6, IP6T_SO_GET_INFO, &info, &info_size), + SyscallSucceeds()); + + // Use info to get entries. + socklen_t entries_size = sizeof(struct ip6t_get_entries) + info.size; + struct ip6t_get_entries* entries = + static_cast<struct ip6t_get_entries*>(malloc(entries_size)); + snprintf(entries->name, XT_TABLE_MAXNAMELEN, "%s", kNatTablename); + entries->size = info.size; + ASSERT_THAT(getsockopt(sock.get(), SOL_IPV6, IP6T_SO_GET_ENTRIES, entries, + &entries_size), + SyscallSucceeds()); + + // Verify the name and size. + ASSERT_EQ(info.size, entries->size); + ASSERT_EQ(strcmp(entries->name, kNatTablename), 0); + + // Verify that the entrytable is 4 entries with accept targets and no matches + // followed by a single error target. + size_t entry_offset = 0; + while (entry_offset < entries->size) { + struct ip6t_entry* entry = reinterpret_cast<struct ip6t_entry*>( + reinterpret_cast<char*>(entries->entrytable) + entry_offset); + + // ipv6 should be zeroed. + struct ip6t_ip6 zeroed = {}; + ASSERT_EQ(memcmp(static_cast<void*>(&zeroed), + static_cast<void*>(&entry->ipv6), sizeof(zeroed)), + 0); + + // target_offset should be zero. + EXPECT_EQ(entry->target_offset, sizeof(ip6t_entry)); + + if (entry_offset < kEmptyStandardEntrySize * 4) { + // The first 4 entries are standard targets + struct xt_standard_target* target = + reinterpret_cast<struct xt_standard_target*>(entry->elems); + EXPECT_EQ(entry->next_offset, kEmptyStandardEntrySize); + EXPECT_EQ(target->target.u.user.target_size, sizeof(*target)); + EXPECT_EQ(strcmp(target->target.u.user.name, ""), 0); + EXPECT_EQ(target->target.u.user.revision, 0); + // This is what's returned for an accept verdict. I don't know why. + EXPECT_EQ(target->verdict, -NF_ACCEPT - 1); + } else { + // The last entry is an error target + struct xt_error_target* target = + reinterpret_cast<struct xt_error_target*>(entry->elems); + EXPECT_EQ(entry->next_offset, kEmptyErrorEntrySize); + EXPECT_EQ(target->target.u.user.target_size, sizeof(*target)); + EXPECT_EQ(strcmp(target->target.u.user.name, kErrorTarget), 0); + EXPECT_EQ(target->target.u.user.revision, 0); + EXPECT_EQ(strcmp(target->errorname, kErrorTarget), 0); + } + + entry_offset += entry->next_offset; + break; + } + + free(entries); +} + +} // namespace + +} // namespace testing +} // namespace gvisor diff --git a/test/syscalls/linux/iptables.cc b/test/syscalls/linux/iptables.cc index b8e4ece64..83b6a164a 100644 --- a/test/syscalls/linux/iptables.cc +++ b/test/syscalls/linux/iptables.cc @@ -67,12 +67,56 @@ TEST(IPTablesBasic, FailSockoptNonRaw) { struct ipt_getinfo info = {}; snprintf(info.name, XT_TABLE_MAXNAMELEN, "%s", kNatTablename); socklen_t info_size = sizeof(info); - EXPECT_THAT(getsockopt(sock, IPPROTO_IP, SO_GET_INFO, &info, &info_size), + EXPECT_THAT(getsockopt(sock, SOL_IP, IPT_SO_GET_INFO, &info, &info_size), SyscallFailsWithErrno(ENOPROTOOPT)); ASSERT_THAT(close(sock), SyscallSucceeds()); } +TEST(IPTablesBasic, GetInfoErrorPrecedence) { + SKIP_IF(!ASSERT_NO_ERRNO_AND_VALUE(HaveCapability(CAP_NET_RAW))); + + int sock; + ASSERT_THAT(sock = socket(AF_INET, SOCK_DGRAM, 0), SyscallSucceeds()); + + // When using the wrong type of socket and a too-short optlen, we should get + // EINVAL. + struct ipt_getinfo info = {}; + snprintf(info.name, XT_TABLE_MAXNAMELEN, "%s", kNatTablename); + socklen_t info_size = sizeof(info) - 1; + ASSERT_THAT(getsockopt(sock, SOL_IP, IPT_SO_GET_INFO, &info, &info_size), + SyscallFailsWithErrno(EINVAL)); +} + +TEST(IPTablesBasic, GetEntriesErrorPrecedence) { + SKIP_IF(!ASSERT_NO_ERRNO_AND_VALUE(HaveCapability(CAP_NET_RAW))); + + int sock; + ASSERT_THAT(sock = socket(AF_INET, SOCK_DGRAM, 0), SyscallSucceeds()); + + // When using the wrong type of socket and a too-short optlen, we should get + // EINVAL. + struct ipt_get_entries entries = {}; + socklen_t entries_size = sizeof(struct ipt_get_entries) - 1; + snprintf(entries.name, XT_TABLE_MAXNAMELEN, "%s", kNatTablename); + ASSERT_THAT( + getsockopt(sock, SOL_IP, IPT_SO_GET_ENTRIES, &entries, &entries_size), + SyscallFailsWithErrno(EINVAL)); +} + +TEST(IPTablesBasic, OriginalDstErrors) { + SKIP_IF(!ASSERT_NO_ERRNO_AND_VALUE(HaveCapability(CAP_NET_RAW))); + + int sock; + ASSERT_THAT(sock = socket(AF_INET, SOCK_STREAM, 0), SyscallSucceeds()); + + // Sockets not affected by NAT should fail to find an original destination. + struct sockaddr_in addr = {}; + socklen_t addr_len = sizeof(addr); + EXPECT_THAT(getsockopt(sock, SOL_IP, SO_ORIGINAL_DST, &addr, &addr_len), + SyscallFailsWithErrno(ENOTCONN)); +} + // Fixture for iptables tests. class IPTablesTest : public ::testing::Test { protected: @@ -112,7 +156,7 @@ TEST_F(IPTablesTest, InitialState) { struct ipt_getinfo info = {}; snprintf(info.name, XT_TABLE_MAXNAMELEN, "%s", kNatTablename); socklen_t info_size = sizeof(info); - ASSERT_THAT(getsockopt(s_, IPPROTO_IP, SO_GET_INFO, &info, &info_size), + ASSERT_THAT(getsockopt(s_, SOL_IP, IPT_SO_GET_INFO, &info, &info_size), SyscallSucceeds()); // The nat table supports PREROUTING, and OUTPUT. @@ -148,7 +192,7 @@ TEST_F(IPTablesTest, InitialState) { snprintf(entries->name, XT_TABLE_MAXNAMELEN, "%s", kNatTablename); entries->size = info.size; ASSERT_THAT( - getsockopt(s_, IPPROTO_IP, SO_GET_ENTRIES, entries, &entries_size), + getsockopt(s_, SOL_IP, IPT_SO_GET_ENTRIES, entries, &entries_size), SyscallSucceeds()); // Verify the name and size. diff --git a/test/syscalls/linux/iptables.h b/test/syscalls/linux/iptables.h index 0719c60a4..d0fc10fea 100644 --- a/test/syscalls/linux/iptables.h +++ b/test/syscalls/linux/iptables.h @@ -27,27 +27,32 @@ #include <linux/netfilter/x_tables.h> #include <linux/netfilter_ipv4.h> +#include <linux/netfilter_ipv6.h> #include <net/if.h> #include <netinet/ip.h> #include <stdint.h> +// +// IPv4 ABI. +// + #define ipt_standard_target xt_standard_target #define ipt_entry_target xt_entry_target #define ipt_error_target xt_error_target enum SockOpts { // For setsockopt. - BASE_CTL = 64, - SO_SET_REPLACE = BASE_CTL, - SO_SET_ADD_COUNTERS, - SO_SET_MAX = SO_SET_ADD_COUNTERS, + IPT_BASE_CTL = 64, + IPT_SO_SET_REPLACE = IPT_BASE_CTL, + IPT_SO_SET_ADD_COUNTERS = IPT_BASE_CTL + 1, + IPT_SO_SET_MAX = IPT_SO_SET_ADD_COUNTERS, // For getsockopt. - SO_GET_INFO = BASE_CTL, - SO_GET_ENTRIES, - SO_GET_REVISION_MATCH, - SO_GET_REVISION_TARGET, - SO_GET_MAX = SO_GET_REVISION_TARGET + IPT_SO_GET_INFO = IPT_BASE_CTL, + IPT_SO_GET_ENTRIES = IPT_BASE_CTL + 1, + IPT_SO_GET_REVISION_MATCH = IPT_BASE_CTL + 2, + IPT_SO_GET_REVISION_TARGET = IPT_BASE_CTL + 3, + IPT_SO_GET_MAX = IPT_SO_GET_REVISION_TARGET }; // ipt_ip specifies basic matching criteria that can be applied by examining @@ -115,7 +120,7 @@ struct ipt_entry { unsigned char elems[0]; }; -// Passed to getsockopt(SO_GET_INFO). +// Passed to getsockopt(IPT_SO_GET_INFO). struct ipt_getinfo { // The name of the table. The user only fills this in, the rest is filled in // when returning from getsockopt. Currently "nat" and "mangle" are supported. @@ -127,7 +132,7 @@ struct ipt_getinfo { unsigned int valid_hooks; // The offset into the entry table for each valid hook. The entry table is - // returned by getsockopt(SO_GET_ENTRIES). + // returned by getsockopt(IPT_SO_GET_ENTRIES). unsigned int hook_entry[NF_IP_NUMHOOKS]; // For each valid hook, the underflow is the offset into the entry table to @@ -142,14 +147,14 @@ struct ipt_getinfo { unsigned int underflow[NF_IP_NUMHOOKS]; // The number of entries in the entry table returned by - // getsockopt(SO_GET_ENTRIES). + // getsockopt(IPT_SO_GET_ENTRIES). unsigned int num_entries; - // The size of the entry table returned by getsockopt(SO_GET_ENTRIES). + // The size of the entry table returned by getsockopt(IPT_SO_GET_ENTRIES). unsigned int size; }; -// Passed to getsockopt(SO_GET_ENTRIES). +// Passed to getsockopt(IPT_SO_GET_ENTRIES). struct ipt_get_entries { // The name of the table. The user fills this in. Currently "nat" and "mangle" // are supported. @@ -195,4 +200,103 @@ struct ipt_replace { struct ipt_entry entries[0]; }; +// +// IPv6 ABI. +// + +enum SockOpts6 { + // For setsockopt. + IP6T_BASE_CTL = 64, + IP6T_SO_SET_REPLACE = IP6T_BASE_CTL, + IP6T_SO_SET_ADD_COUNTERS = IP6T_BASE_CTL + 1, + IP6T_SO_SET_MAX = IP6T_SO_SET_ADD_COUNTERS, + + // For getsockopt. + IP6T_SO_GET_INFO = IP6T_BASE_CTL, + IP6T_SO_GET_ENTRIES = IP6T_BASE_CTL + 1, + IP6T_SO_GET_REVISION_MATCH = IP6T_BASE_CTL + 4, + IP6T_SO_GET_REVISION_TARGET = IP6T_BASE_CTL + 5, + IP6T_SO_GET_MAX = IP6T_SO_GET_REVISION_TARGET +}; + +// ip6t_ip6 specifies basic matching criteria that can be applied by examining +// only the IP header of a packet. +struct ip6t_ip6 { + // Source IP address. + struct in6_addr src; + + // Destination IP address. + struct in6_addr dst; + + // Source IP address mask. + struct in6_addr smsk; + + // Destination IP address mask. + struct in6_addr dmsk; + + // Input interface. + char iniface[IFNAMSIZ]; + + // Output interface. + char outiface[IFNAMSIZ]; + + // Input interface mask. + unsigned char iniface_mask[IFNAMSIZ]; + + // Output interface mask. + unsigned char outiface_mask[IFNAMSIZ]; + + // Transport protocol. + uint16_t proto; + + // TOS. + uint8_t tos; + + // Flags. + uint8_t flags; + + // Inverse flags. + uint8_t invflags; +}; + +// ip6t_entry is an ip6tables rule. +struct ip6t_entry { + // Basic matching information used to match a packet's IP header. + struct ip6t_ip6 ipv6; + + // A caching field that isn't used by userspace. + unsigned int nfcache; + + // The number of bytes between the start of this entry and the rule's target. + uint16_t target_offset; + + // The total size of this rule, from the beginning of the entry to the end of + // the target. + uint16_t next_offset; + + // A return pointer not used by userspace. + unsigned int comefrom; + + // Counters for packets and bytes, which we don't yet implement. + struct xt_counters counters; + + // The data for all this rules matches followed by the target. This runs + // beyond the value of sizeof(struct ip6t_entry). + unsigned char elems[0]; +}; + +// Passed to getsockopt(IP6T_SO_GET_ENTRIES). +struct ip6t_get_entries { + // The name of the table. + char name[XT_TABLE_MAXNAMELEN]; + + // The size of the entry table in bytes. The user fills this in with the value + // from struct ipt_getinfo.size. + unsigned int size; + + // The entries for the given table. This will run past the size defined by + // sizeof(struct ip6t_get_entries). + struct ip6t_entry entrytable[0]; +}; + #endif // GVISOR_TEST_SYSCALLS_IPTABLES_TYPES_H_ diff --git a/test/syscalls/linux/kcov.cc b/test/syscalls/linux/kcov.cc new file mode 100644 index 000000000..f3c30444e --- /dev/null +++ b/test/syscalls/linux/kcov.cc @@ -0,0 +1,70 @@ +// Copyright 2018 The gVisor Authors. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include <sys/errno.h> +#include <sys/ioctl.h> +#include <sys/mman.h> + +#include "gtest/gtest.h" +#include "test/util/capability_util.h" +#include "test/util/file_descriptor.h" +#include "test/util/test_util.h" + +namespace gvisor { +namespace testing { + +namespace { + +// For this test to work properly, it must be run with coverage enabled. On +// native Linux, this involves compiling the kernel with kcov enabled. For +// gVisor, we need to enable the Go coverage tool, e.g. +// bazel test --collect_coverage_data --instrumentation_filter=//pkg/... <test>. +TEST(KcovTest, Kcov) { + SKIP_IF(!ASSERT_NO_ERRNO_AND_VALUE(HaveCapability((CAP_DAC_OVERRIDE)))); + + constexpr int kSize = 4096; + constexpr int KCOV_INIT_TRACE = 0x80086301; + constexpr int KCOV_ENABLE = 0x6364; + + int fd; + ASSERT_THAT(fd = open("/sys/kernel/debug/kcov", O_RDWR), + AnyOf(SyscallSucceeds(), SyscallFailsWithErrno(ENOENT))); + + // Kcov not enabled. + SKIP_IF(errno == ENOENT); + + ASSERT_THAT(ioctl(fd, KCOV_INIT_TRACE, kSize), SyscallSucceeds()); + uint64_t* area = (uint64_t*)mmap(nullptr, kSize * sizeof(uint64_t), + PROT_READ | PROT_WRITE, MAP_SHARED, fd, 0); + + ASSERT_TRUE(area != MAP_FAILED); + ASSERT_THAT(ioctl(fd, KCOV_ENABLE, 0), SyscallSucceeds()); + + for (int i = 0; i < 10; i++) { + // Make some syscalls to generate coverage data. + ASSERT_THAT(ioctl(fd, KCOV_ENABLE, 0), SyscallFailsWithErrno(EINVAL)); + } + + uint64_t num_pcs = *(uint64_t*)(area); + EXPECT_GT(num_pcs, 0); + for (uint64_t i = 1; i <= num_pcs; i++) { + // Verify that PCs are in the standard kernel range. + EXPECT_GT(area[i], 0xffffffff7fffffffL); + } +} + +} // namespace + +} // namespace testing +} // namespace gvisor diff --git a/test/syscalls/linux/memfd.cc b/test/syscalls/linux/memfd.cc index f8b7f7938..4a450742b 100644 --- a/test/syscalls/linux/memfd.cc +++ b/test/syscalls/linux/memfd.cc @@ -14,12 +14,10 @@ #include <errno.h> #include <fcntl.h> -#include <linux/magic.h> #include <linux/memfd.h> #include <linux/unistd.h> #include <string.h> #include <sys/mman.h> -#include <sys/statfs.h> #include <sys/syscall.h> #include <vector> @@ -53,6 +51,7 @@ namespace { #define F_SEAL_GROW 0x0004 #define F_SEAL_WRITE 0x0008 +using ::gvisor::testing::IsTmpfs; using ::testing::StartsWith; const std::string kMemfdName = "some-memfd"; @@ -444,20 +443,6 @@ TEST(MemfdTest, SealsAreInodeLevelProperties) { EXPECT_THAT(ftruncate(memfd3.get(), kPageSize), SyscallFailsWithErrno(EPERM)); } -PosixErrorOr<bool> IsTmpfs(const std::string& path) { - struct statfs stat; - if (statfs(path.c_str(), &stat)) { - if (errno == ENOENT) { - // Nothing at path, don't raise this as an error. Instead, just report no - // tmpfs at path. - return false; - } - return PosixError(errno, - absl::StrFormat("statfs(\"%s\", %#p)", path, &stat)); - } - return stat.f_type == TMPFS_MAGIC; -} - // Tmpfs files also support seals, but are created with F_SEAL_SEAL. TEST(MemfdTest, TmpfsFilesHaveSealSeal) { SKIP_IF(!ASSERT_NO_ERRNO_AND_VALUE(IsTmpfs("/tmp"))); diff --git a/test/syscalls/linux/mkdir.cc b/test/syscalls/linux/mkdir.cc index 4036a9275..27758203d 100644 --- a/test/syscalls/linux/mkdir.cc +++ b/test/syscalls/linux/mkdir.cc @@ -82,6 +82,13 @@ TEST_F(MkdirTest, FailsOnDirWithoutWritePerms) { SyscallFailsWithErrno(EACCES)); } +TEST_F(MkdirTest, MkdirAtEmptyPath) { + ASSERT_THAT(mkdir(dirname_.c_str(), 0777), SyscallSucceeds()); + auto fd = + ASSERT_NO_ERRNO_AND_VALUE(Open(dirname_, O_RDONLY | O_DIRECTORY, 0666)); + EXPECT_THAT(mkdirat(fd.get(), "", 0777), SyscallFailsWithErrno(ENOENT)); +} + } // namespace } // namespace testing diff --git a/test/syscalls/linux/mknod.cc b/test/syscalls/linux/mknod.cc index 05dfb375a..89e4564e8 100644 --- a/test/syscalls/linux/mknod.cc +++ b/test/syscalls/linux/mknod.cc @@ -14,6 +14,7 @@ #include <errno.h> #include <fcntl.h> +#include <sys/socket.h> #include <sys/stat.h> #include <sys/types.h> #include <sys/un.h> @@ -103,6 +104,24 @@ TEST(MknodTest, UnimplementedTypesReturnError) { ASSERT_THAT(mknod(path.c_str(), S_IFBLK, 0), SyscallFailsWithErrno(EPERM)); } +TEST(MknodTest, Socket) { + ASSERT_THAT(chdir(GetAbsoluteTestTmpdir().c_str()), SyscallSucceeds()); + + SKIP_IF(IsRunningOnGvisor() && IsRunningWithVFS1()); + + ASSERT_THAT(mknod("./file0", S_IFSOCK | S_IRUSR | S_IWUSR, 0), + SyscallSucceeds()); + + int sk; + ASSERT_THAT(sk = socket(AF_UNIX, SOCK_SEQPACKET, 0), SyscallSucceeds()); + FileDescriptor fd(sk); + + struct sockaddr_un addr = {.sun_family = AF_UNIX}; + absl::SNPrintF(addr.sun_path, sizeof(addr.sun_path), "./file0"); + ASSERT_THAT(connect(sk, (struct sockaddr *)&addr, sizeof(addr)), + SyscallFailsWithErrno(ECONNREFUSED)); +} + TEST(MknodTest, Fifo) { const std::string fifo = NewTempAbsPath(); ASSERT_THAT(mknod(fifo.c_str(), S_IFIFO | S_IRUSR | S_IWUSR, 0), @@ -184,6 +203,14 @@ TEST(MknodTest, FifoTruncNoOp) { EXPECT_THAT(ftruncate(wfd.get(), 0), SyscallFailsWithErrno(EINVAL)); } +TEST(MknodTest, MknodAtEmptyPath) { + auto dir = ASSERT_NO_ERRNO_AND_VALUE(TempPath::CreateDir()); + auto fd = + ASSERT_NO_ERRNO_AND_VALUE(Open(dir.path(), O_RDONLY | O_DIRECTORY, 0666)); + EXPECT_THAT(mknodat(fd.get(), "", S_IFREG | 0777, 0), + SyscallFailsWithErrno(ENOENT)); +} + } // namespace } // namespace testing diff --git a/test/syscalls/linux/mmap.cc b/test/syscalls/linux/mmap.cc index 6d3227ab6..e52c9cbcb 100644 --- a/test/syscalls/linux/mmap.cc +++ b/test/syscalls/linux/mmap.cc @@ -43,6 +43,8 @@ #include "test/util/temp_path.h" #include "test/util/test_util.h" +using ::testing::AnyOf; +using ::testing::Eq; using ::testing::Gt; namespace gvisor { @@ -296,7 +298,8 @@ TEST_F(MMapTest, MapDevZeroSegfaultAfterUnmap) { }; EXPECT_THAT(InForkedProcess(rest), - IsPosixErrorOkAndHolds(W_EXITCODE(0, SIGSEGV))); + IsPosixErrorOkAndHolds(AnyOf(Eq(W_EXITCODE(0, SIGSEGV)), + Eq(W_EXITCODE(0, 128 + SIGSEGV))))); } TEST_F(MMapTest, MapDevZeroUnaligned) { diff --git a/test/syscalls/linux/mount.cc b/test/syscalls/linux/mount.cc index 46b6f38db..3aab25b23 100644 --- a/test/syscalls/linux/mount.cc +++ b/test/syscalls/linux/mount.cc @@ -147,8 +147,15 @@ TEST(MountTest, UmountDetach) { // Unmount the tmpfs. mount.Release()(); - const struct stat after2 = ASSERT_NO_ERRNO_AND_VALUE(Stat(dir.path())); - EXPECT_EQ(before.st_ino, after2.st_ino); + // Only check for inode number equality if the directory is not in overlayfs. + // If xino option is not enabled and if all overlayfs layers do not belong to + // the same filesystem then "the value of st_ino for directory objects may not + // be persistent and could change even while the overlay filesystem is + // mounted." -- Documentation/filesystems/overlayfs.txt + if (!ASSERT_NO_ERRNO_AND_VALUE(IsOverlayfs(dir.path()))) { + const struct stat after2 = ASSERT_NO_ERRNO_AND_VALUE(Stat(dir.path())); + EXPECT_EQ(before.st_ino, after2.st_ino); + } // Can still read file after unmounting. std::vector<char> buf(sizeof(kContents)); @@ -213,8 +220,15 @@ TEST(MountTest, MountTmpfs) { } // Now that dir is unmounted again, we should have the old inode back. - const struct stat after = ASSERT_NO_ERRNO_AND_VALUE(Stat(dir.path())); - EXPECT_EQ(before.st_ino, after.st_ino); + // Only check for inode number equality if the directory is not in overlayfs. + // If xino option is not enabled and if all overlayfs layers do not belong to + // the same filesystem then "the value of st_ino for directory objects may not + // be persistent and could change even while the overlay filesystem is + // mounted." -- Documentation/filesystems/overlayfs.txt + if (!ASSERT_NO_ERRNO_AND_VALUE(IsOverlayfs(dir.path()))) { + const struct stat after = ASSERT_NO_ERRNO_AND_VALUE(Stat(dir.path())); + EXPECT_EQ(before.st_ino, after.st_ino); + } } TEST(MountTest, MountTmpfsMagicValIgnored) { diff --git a/test/syscalls/linux/open_create.cc b/test/syscalls/linux/open_create.cc index 51eacf3f2..78c36f98f 100644 --- a/test/syscalls/linux/open_create.cc +++ b/test/syscalls/linux/open_create.cc @@ -88,21 +88,21 @@ TEST(CreateTest, CreateExclusively) { SyscallFailsWithErrno(EEXIST)); } -TEST(CreateTeast, CreatWithOTrunc) { +TEST(CreateTest, CreatWithOTrunc) { std::string dirpath = JoinPath(GetAbsoluteTestTmpdir(), "truncd"); ASSERT_THAT(mkdir(dirpath.c_str(), 0777), SyscallSucceeds()); ASSERT_THAT(open(dirpath.c_str(), O_CREAT | O_TRUNC, 0666), SyscallFailsWithErrno(EISDIR)); } -TEST(CreateTeast, CreatDirWithOTruncAndReadOnly) { +TEST(CreateTest, CreatDirWithOTruncAndReadOnly) { std::string dirpath = JoinPath(GetAbsoluteTestTmpdir(), "truncd"); ASSERT_THAT(mkdir(dirpath.c_str(), 0777), SyscallSucceeds()); ASSERT_THAT(open(dirpath.c_str(), O_CREAT | O_TRUNC | O_RDONLY, 0666), SyscallFailsWithErrno(EISDIR)); } -TEST(CreateTeast, CreatFileWithOTruncAndReadOnly) { +TEST(CreateTest, CreatFileWithOTruncAndReadOnly) { std::string dirpath = JoinPath(GetAbsoluteTestTmpdir(), "truncfile"); int dirfd; ASSERT_THAT(dirfd = open(dirpath.c_str(), O_RDWR | O_CREAT, 0666), @@ -149,6 +149,116 @@ TEST(CreateTest, OpenCreateROThenRW) { EXPECT_THAT(WriteFd(fd2.get(), &c, 1), SyscallSucceedsWithValue(1)); } +TEST(CreateTest, ChmodReadToWriteBetweenOpens_NoRandomSave) { + // Make sure we don't have CAP_DAC_OVERRIDE, since that allows the user to + // override file read/write permissions. CAP_DAC_READ_SEARCH needs to be + // cleared for the same reason. + ASSERT_NO_ERRNO(SetCapability(CAP_DAC_OVERRIDE, false)); + ASSERT_NO_ERRNO(SetCapability(CAP_DAC_READ_SEARCH, false)); + + const TempPath file = + ASSERT_NO_ERRNO_AND_VALUE(TempPath::CreateFileMode(0400)); + + const FileDescriptor rfd = + ASSERT_NO_ERRNO_AND_VALUE(Open(file.path(), O_RDONLY)); + + // Cannot restore after making permissions more restrictive. + const DisableSave ds; + ASSERT_THAT(fchmod(rfd.get(), 0200), SyscallSucceeds()); + + EXPECT_THAT(open(file.path().c_str(), O_RDONLY), + SyscallFailsWithErrno(EACCES)); + + const FileDescriptor wfd = + ASSERT_NO_ERRNO_AND_VALUE(Open(file.path(), O_WRONLY)); + + char c = 'x'; + EXPECT_THAT(write(wfd.get(), &c, 1), SyscallSucceedsWithValue(1)); + c = 0; + EXPECT_THAT(read(rfd.get(), &c, 1), SyscallSucceedsWithValue(1)); + EXPECT_EQ(c, 'x'); +} + +TEST(CreateTest, ChmodWriteToReadBetweenOpens_NoRandomSave) { + // Make sure we don't have CAP_DAC_OVERRIDE, since that allows the user to + // override file read/write permissions. + ASSERT_NO_ERRNO(SetCapability(CAP_DAC_OVERRIDE, false)); + + const TempPath file = + ASSERT_NO_ERRNO_AND_VALUE(TempPath::CreateFileMode(0200)); + + const FileDescriptor wfd = + ASSERT_NO_ERRNO_AND_VALUE(Open(file.path(), O_WRONLY)); + + // Cannot restore after making permissions more restrictive. + const DisableSave ds; + ASSERT_THAT(fchmod(wfd.get(), 0400), SyscallSucceeds()); + + EXPECT_THAT(open(file.path().c_str(), O_WRONLY), + SyscallFailsWithErrno(EACCES)); + + const FileDescriptor rfd = + ASSERT_NO_ERRNO_AND_VALUE(Open(file.path(), O_RDONLY)); + + char c = 'x'; + EXPECT_THAT(write(wfd.get(), &c, 1), SyscallSucceedsWithValue(1)); + c = 0; + EXPECT_THAT(read(rfd.get(), &c, 1), SyscallSucceedsWithValue(1)); + EXPECT_EQ(c, 'x'); +} + +TEST(CreateTest, CreateWithReadFlagNotAllowedByMode_NoRandomSave) { + // The only time we can open a file with flags forbidden by its permissions + // is when we are creating the file. We cannot re-open with the same flags, + // so we cannot restore an fd obtained from such an operation. + const DisableSave ds; + + // Make sure we don't have CAP_DAC_OVERRIDE, since that allows the user to + // override file read/write permissions. CAP_DAC_READ_SEARCH needs to be + // cleared for the same reason. + ASSERT_NO_ERRNO(SetCapability(CAP_DAC_OVERRIDE, false)); + ASSERT_NO_ERRNO(SetCapability(CAP_DAC_READ_SEARCH, false)); + + // Create and open a file with read flag but without read permissions. + const std::string path = NewTempAbsPath(); + const FileDescriptor rfd = + ASSERT_NO_ERRNO_AND_VALUE(Open(path, O_CREAT | O_RDONLY, 0222)); + + EXPECT_THAT(open(path.c_str(), O_RDONLY), SyscallFailsWithErrno(EACCES)); + const FileDescriptor wfd = ASSERT_NO_ERRNO_AND_VALUE(Open(path, O_WRONLY)); + + char c = 'x'; + EXPECT_THAT(write(wfd.get(), &c, 1), SyscallSucceedsWithValue(1)); + c = 0; + EXPECT_THAT(read(rfd.get(), &c, 1), SyscallSucceedsWithValue(1)); + EXPECT_EQ(c, 'x'); +} + +TEST(CreateTest, CreateWithWriteFlagNotAllowedByMode_NoRandomSave) { + // The only time we can open a file with flags forbidden by its permissions + // is when we are creating the file. We cannot re-open with the same flags, + // so we cannot restore an fd obtained from such an operation. + const DisableSave ds; + + // Make sure we don't have CAP_DAC_OVERRIDE, since that allows the user to + // override file read/write permissions. + ASSERT_NO_ERRNO(SetCapability(CAP_DAC_OVERRIDE, false)); + + // Create and open a file with write flag but without write permissions. + const std::string path = NewTempAbsPath(); + const FileDescriptor wfd = + ASSERT_NO_ERRNO_AND_VALUE(Open(path, O_CREAT | O_WRONLY, 0444)); + + EXPECT_THAT(open(path.c_str(), O_WRONLY), SyscallFailsWithErrno(EACCES)); + const FileDescriptor rfd = ASSERT_NO_ERRNO_AND_VALUE(Open(path, O_RDONLY)); + + char c = 'x'; + EXPECT_THAT(write(wfd.get(), &c, 1), SyscallSucceedsWithValue(1)); + c = 0; + EXPECT_THAT(read(rfd.get(), &c, 1), SyscallSucceedsWithValue(1)); + EXPECT_EQ(c, 'x'); +} + } // namespace } // namespace testing diff --git a/test/syscalls/linux/packet_socket_raw.cc b/test/syscalls/linux/packet_socket_raw.cc index a11a03415..b558e3a01 100644 --- a/test/syscalls/linux/packet_socket_raw.cc +++ b/test/syscalls/linux/packet_socket_raw.cc @@ -14,9 +14,7 @@ #include <arpa/inet.h> #include <linux/capability.h> -#ifndef __fuchsia__ #include <linux/filter.h> -#endif // __fuchsia__ #include <linux/if_arp.h> #include <linux/if_packet.h> #include <net/ethernet.h> @@ -618,8 +616,6 @@ TEST_P(RawPacketTest, GetSocketErrorBind) { } } -#ifndef __fuchsia__ - TEST_P(RawPacketTest, SetSocketDetachFilterNoInstalledFilter) { // TODO(gvisor.dev/2746): Support SO_ATTACH_FILTER/SO_DETACH_FILTER. // @@ -647,7 +643,26 @@ TEST_P(RawPacketTest, GetSocketDetachFilter) { SyscallFailsWithErrno(ENOPROTOOPT)); } -#endif // __fuchsia__ +TEST_P(RawPacketTest, SetAndGetSocketLinger) { + SKIP_IF(!ASSERT_NO_ERRNO_AND_VALUE(HaveCapability(CAP_NET_RAW))); + + int level = SOL_SOCKET; + int type = SO_LINGER; + + struct linger sl; + sl.l_onoff = 1; + sl.l_linger = 5; + ASSERT_THAT(setsockopt(s_, level, type, &sl, sizeof(sl)), + SyscallSucceedsWithValue(0)); + + struct linger got_linger = {}; + socklen_t length = sizeof(sl); + ASSERT_THAT(getsockopt(s_, level, type, &got_linger, &length), + SyscallSucceedsWithValue(0)); + + ASSERT_EQ(length, sizeof(got_linger)); + EXPECT_EQ(0, memcmp(&sl, &got_linger, length)); +} INSTANTIATE_TEST_SUITE_P(AllInetTests, RawPacketTest, ::testing::Values(ETH_P_IP, ETH_P_ALL)); diff --git a/test/syscalls/linux/pipe.cc b/test/syscalls/linux/pipe.cc index 34291850d..c097c9187 100644 --- a/test/syscalls/linux/pipe.cc +++ b/test/syscalls/linux/pipe.cc @@ -13,7 +13,9 @@ // limitations under the License. #include <fcntl.h> /* Obtain O_* constant definitions */ +#include <linux/magic.h> #include <sys/ioctl.h> +#include <sys/statfs.h> #include <sys/uio.h> #include <unistd.h> @@ -198,6 +200,16 @@ TEST_P(PipeTest, NonBlocking) { SyscallFailsWithErrno(EWOULDBLOCK)); } +TEST(PipeTest, StatFS) { + int fds[2]; + ASSERT_THAT(pipe(fds), SyscallSucceeds()); + struct statfs st; + EXPECT_THAT(fstatfs(fds[0], &st), SyscallSucceeds()); + EXPECT_EQ(st.f_type, PIPEFS_MAGIC); + EXPECT_EQ(st.f_bsize, getpagesize()); + EXPECT_EQ(st.f_namelen, NAME_MAX); +} + TEST(Pipe2Test, CloExec) { int fds[2]; ASSERT_THAT(pipe2(fds, O_CLOEXEC), SyscallSucceeds()); diff --git a/test/syscalls/linux/prctl.cc b/test/syscalls/linux/prctl.cc index 04c5161f5..f675dc430 100644 --- a/test/syscalls/linux/prctl.cc +++ b/test/syscalls/linux/prctl.cc @@ -153,7 +153,7 @@ TEST(PrctlTest, PDeathSig) { // Enable tracing, then raise SIGSTOP and expect our parent to suppress // it. TEST_CHECK(ptrace(PTRACE_TRACEME, 0, 0, 0) >= 0); - raise(SIGSTOP); + TEST_CHECK(raise(SIGSTOP) == 0); // Sleep until killed by our parent death signal. sleep(3) is // async-signal-safe, absl::SleepFor isn't. while (true) { diff --git a/test/syscalls/linux/proc.cc b/test/syscalls/linux/proc.cc index d6b875dbf..c1488b06b 100644 --- a/test/syscalls/linux/proc.cc +++ b/test/syscalls/linux/proc.cc @@ -16,6 +16,7 @@ #include <errno.h> #include <fcntl.h> #include <limits.h> +#include <linux/magic.h> #include <sched.h> #include <signal.h> #include <stddef.h> @@ -26,6 +27,7 @@ #include <sys/mman.h> #include <sys/prctl.h> #include <sys/stat.h> +#include <sys/statfs.h> #include <sys/utsname.h> #include <syscall.h> #include <unistd.h> @@ -61,6 +63,7 @@ #include "test/util/fs_util.h" #include "test/util/memory_util.h" #include "test/util/posix_error.h" +#include "test/util/proc_util.h" #include "test/util/temp_path.h" #include "test/util/test_util.h" #include "test/util/thread_util.h" @@ -670,6 +673,23 @@ TEST(ProcSelfMaps, Mprotect) { 3 * kPageSize, PROT_READ))); } +TEST(ProcSelfMaps, SharedAnon) { + const Mapping m = ASSERT_NO_ERRNO_AND_VALUE( + MmapAnon(kPageSize, PROT_READ, MAP_SHARED | MAP_ANONYMOUS)); + + const auto proc_self_maps = + ASSERT_NO_ERRNO_AND_VALUE(GetContents("/proc/self/maps")); + for (const auto& line : absl::StrSplit(proc_self_maps, '\n')) { + const auto entry = ASSERT_NO_ERRNO_AND_VALUE(ParseProcMapsLine(line)); + if (entry.start <= m.addr() && m.addr() < entry.end) { + // cf. proc(5), "/proc/[pid]/map_files/" + EXPECT_EQ(entry.filename, "/dev/zero (deleted)"); + return; + } + } + FAIL() << "no maps entry containing mapping at " << m.ptr(); +} + TEST(ProcSelfFd, OpenFd) { int pipe_fds[2]; ASSERT_THAT(pipe2(pipe_fds, O_CLOEXEC), SyscallSucceeds()); @@ -692,6 +712,30 @@ TEST(ProcSelfFd, OpenFd) { ASSERT_THAT(close(pipe_fds[1]), SyscallSucceeds()); } +static void CheckFdDirGetdentsDuplicates(const std::string& path) { + const FileDescriptor fd = + ASSERT_NO_ERRNO_AND_VALUE(Open(path.c_str(), O_RDONLY | O_DIRECTORY)); + // Open a FD whose value is supposed to be much larger than + // the number of FDs opened by current process. + auto newfd = fcntl(fd.get(), F_DUPFD, 1024); + EXPECT_GE(newfd, 1024); + auto fd_closer = Cleanup([newfd]() { close(newfd); }); + auto fd_files = ASSERT_NO_ERRNO_AND_VALUE(ListDir(path.c_str(), false)); + std::unordered_set<std::string> fd_files_dedup(fd_files.begin(), + fd_files.end()); + EXPECT_EQ(fd_files.size(), fd_files_dedup.size()); +} + +// This is a regression test for gvisor.dev/issues/3894 +TEST(ProcSelfFd, GetdentsDuplicates) { + CheckFdDirGetdentsDuplicates("/proc/self/fd"); +} + +// This is a regression test for gvisor.dev/issues/3894 +TEST(ProcSelfFdInfo, GetdentsDuplicates) { + CheckFdDirGetdentsDuplicates("/proc/self/fdinfo"); +} + TEST(ProcSelfFdInfo, CorrectFds) { // Make sure there is at least one open file. auto f = ASSERT_NO_ERRNO_AND_VALUE(TempPath::CreateFile()); @@ -771,17 +815,12 @@ TEST(ProcCpuinfo, DeniesWriteNonRoot) { constexpr int kNobody = 65534; EXPECT_THAT(syscall(SYS_setuid, kNobody), SyscallSucceeds()); EXPECT_THAT(open("/proc/cpuinfo", O_WRONLY), SyscallFailsWithErrno(EACCES)); - // TODO(gvisor.dev/issue/1193): Properly support setting size attributes in - // kernfs. - if (!IsRunningOnGvisor() || IsRunningWithVFS1()) { - EXPECT_THAT(truncate("/proc/cpuinfo", 123), - SyscallFailsWithErrno(EACCES)); - } + EXPECT_THAT(truncate("/proc/cpuinfo", 123), SyscallFailsWithErrno(EACCES)); }); } // With root privileges, it is possible to open /proc/cpuinfo with write mode, -// but all write operations will return EIO. +// but all write operations should fail. TEST(ProcCpuinfo, DeniesWriteRoot) { // VFS1 does not behave differently for root/non-root. SKIP_IF(IsRunningWithVFS1()); @@ -790,16 +829,10 @@ TEST(ProcCpuinfo, DeniesWriteRoot) { int fd; EXPECT_THAT(fd = open("/proc/cpuinfo", O_WRONLY), SyscallSucceeds()); if (fd > 0) { - EXPECT_THAT(write(fd, "x", 1), SyscallFailsWithErrno(EIO)); - EXPECT_THAT(pwrite(fd, "x", 1, 123), SyscallFailsWithErrno(EIO)); - } - // TODO(gvisor.dev/issue/1193): Properly support setting size attributes in - // kernfs. - if (!IsRunningOnGvisor() || IsRunningWithVFS1()) { - if (fd > 0) { - EXPECT_THAT(ftruncate(fd, 123), SyscallFailsWithErrno(EIO)); - } - EXPECT_THAT(truncate("/proc/cpuinfo", 123), SyscallFailsWithErrno(EIO)); + // Truncate is not tested--it may succeed on some kernels without doing + // anything. + EXPECT_THAT(write(fd, "x", 1), SyscallFails()); + EXPECT_THAT(pwrite(fd, "x", 1, 123), SyscallFails()); } } @@ -2159,6 +2192,18 @@ TEST(Proc, PidTidIOAccounting) { noop.Join(); } +TEST(Proc, Statfs) { + struct statfs st; + EXPECT_THAT(statfs("/proc", &st), SyscallSucceeds()); + if (IsRunningWithVFS1()) { + EXPECT_EQ(st.f_type, ANON_INODE_FS_MAGIC); + } else { + EXPECT_EQ(st.f_type, PROC_SUPER_MAGIC); + } + EXPECT_EQ(st.f_bsize, getpagesize()); + EXPECT_EQ(st.f_namelen, NAME_MAX); +} + } // namespace } // namespace testing } // namespace gvisor diff --git a/test/syscalls/linux/proc_net.cc b/test/syscalls/linux/proc_net.cc index b9a5a99bd..5f9378160 100644 --- a/test/syscalls/linux/proc_net.cc +++ b/test/syscalls/linux/proc_net.cc @@ -39,6 +39,7 @@ namespace testing { namespace { constexpr const char kProcNet[] = "/proc/net"; +constexpr const char kIpForward[] = "/proc/sys/net/ipv4/ip_forward"; TEST(ProcNetSymlinkTarget, FileMode) { struct stat s; @@ -517,36 +518,36 @@ TEST(ProcSysNetIpv4Recovery, CanReadAndWrite) { } TEST(ProcSysNetIpv4IpForward, Exists) { - auto fd = - ASSERT_NO_ERRNO_AND_VALUE(Open("/proc/sys/net/ipv4/ip_forward", O_RDWR)); + auto fd = ASSERT_NO_ERRNO_AND_VALUE(Open(kIpForward, O_RDONLY)); } TEST(ProcSysNetIpv4IpForward, DefaultValueEqZero) { - auto const fd = - ASSERT_NO_ERRNO_AND_VALUE(Open("/proc/sys/net/ipv4/ip_forward", O_RDWR)); + auto const fd = ASSERT_NO_ERRNO_AND_VALUE(Open(kIpForward, O_RDONLY)); char buf = 101; EXPECT_THAT(PreadFd(fd.get(), &buf, sizeof(buf), 0), SyscallSucceedsWithValue(sizeof(buf))); - EXPECT_TRUE(buf == '0') << "unexpected ip_forward: " << buf; + EXPECT_EQ(buf, '0') << "unexpected ip_forward: " << buf; } TEST(ProcSysNetIpv4IpForward, CanReadAndWrite) { - auto const fd = - ASSERT_NO_ERRNO_AND_VALUE(Open("/proc/sys/net/ipv4/ip_forward", O_RDWR)); + SKIP_IF(!ASSERT_NO_ERRNO_AND_VALUE(HaveCapability((CAP_DAC_OVERRIDE)))); - char buf = 101; + auto const fd = ASSERT_NO_ERRNO_AND_VALUE(Open(kIpForward, O_RDWR)); + + char buf; EXPECT_THAT(PreadFd(fd.get(), &buf, sizeof(buf), 0), SyscallSucceedsWithValue(sizeof(buf))); - EXPECT_TRUE(buf == '0') << "unexpected ip_forward: " << buf; + EXPECT_TRUE(buf == '0' || buf == '1') << "unexpected ip_forward: " << buf; - constexpr char to_write = '1'; + // constexpr char to_write = '1'; + char to_write = (buf == '1') ? '0' : '1'; EXPECT_THAT(PwriteFd(fd.get(), &to_write, sizeof(to_write), 0), SyscallSucceedsWithValue(sizeof(to_write))); - buf = 101; + buf = 0; EXPECT_THAT(PreadFd(fd.get(), &buf, sizeof(buf), 0), SyscallSucceedsWithValue(sizeof(buf))); EXPECT_EQ(buf, to_write); diff --git a/test/syscalls/linux/pty.cc b/test/syscalls/linux/pty.cc index f9392b9e0..0b174e2be 100644 --- a/test/syscalls/linux/pty.cc +++ b/test/syscalls/linux/pty.cc @@ -51,6 +51,7 @@ using ::testing::AnyOf; using ::testing::Contains; using ::testing::Eq; using ::testing::Not; +using SubprocessCallback = std::function<void()>; // Tests Unix98 pseudoterminals. // @@ -389,15 +390,15 @@ TEST(PtyTrunc, Truncate) { // (f)truncate should. FileDescriptor master = ASSERT_NO_ERRNO_AND_VALUE(Open(kMasterPath, O_RDWR | O_TRUNC)); - int n = ASSERT_NO_ERRNO_AND_VALUE(SlaveID(master)); + int n = ASSERT_NO_ERRNO_AND_VALUE(ReplicaID(master)); std::string spath = absl::StrCat("/dev/pts/", n); - FileDescriptor slave = + FileDescriptor replica = ASSERT_NO_ERRNO_AND_VALUE(Open(spath, O_RDWR | O_NONBLOCK | O_TRUNC)); EXPECT_THAT(truncate(kMasterPath, 0), SyscallFailsWithErrno(EINVAL)); EXPECT_THAT(truncate(spath.c_str(), 0), SyscallFailsWithErrno(EINVAL)); EXPECT_THAT(ftruncate(master.get(), 0), SyscallFailsWithErrno(EINVAL)); - EXPECT_THAT(ftruncate(slave.get(), 0), SyscallFailsWithErrno(EINVAL)); + EXPECT_THAT(ftruncate(replica.get(), 0), SyscallFailsWithErrno(EINVAL)); } TEST(BasicPtyTest, StatUnopenedMaster) { @@ -453,16 +454,16 @@ void ExpectReadable(const FileDescriptor& fd, int expected, char* buf) { EXPECT_EQ(expected, n); } -TEST(BasicPtyTest, OpenMasterSlave) { +TEST(BasicPtyTest, OpenMasterReplica) { FileDescriptor master = ASSERT_NO_ERRNO_AND_VALUE(Open("/dev/ptmx", O_RDWR)); - FileDescriptor slave = ASSERT_NO_ERRNO_AND_VALUE(OpenSlave(master)); + FileDescriptor replica = ASSERT_NO_ERRNO_AND_VALUE(OpenReplica(master)); } -// The slave entry in /dev/pts/ disappears when the master is closed, even if -// the slave is still open. -TEST(BasicPtyTest, SlaveEntryGoneAfterMasterClose) { +// The replica entry in /dev/pts/ disappears when the master is closed, even if +// the replica is still open. +TEST(BasicPtyTest, ReplicaEntryGoneAfterMasterClose) { FileDescriptor master = ASSERT_NO_ERRNO_AND_VALUE(Open("/dev/ptmx", O_RDWR)); - FileDescriptor slave = ASSERT_NO_ERRNO_AND_VALUE(OpenSlave(master)); + FileDescriptor replica = ASSERT_NO_ERRNO_AND_VALUE(OpenReplica(master)); // Get pty index. int index = -1; @@ -482,12 +483,12 @@ TEST(BasicPtyTest, Getdents) { FileDescriptor master1 = ASSERT_NO_ERRNO_AND_VALUE(Open("/dev/ptmx", O_RDWR)); int index1 = -1; ASSERT_THAT(ioctl(master1.get(), TIOCGPTN, &index1), SyscallSucceeds()); - FileDescriptor slave1 = ASSERT_NO_ERRNO_AND_VALUE(OpenSlave(master1)); + FileDescriptor replica1 = ASSERT_NO_ERRNO_AND_VALUE(OpenReplica(master1)); FileDescriptor master2 = ASSERT_NO_ERRNO_AND_VALUE(Open("/dev/ptmx", O_RDWR)); int index2 = -1; ASSERT_THAT(ioctl(master2.get(), TIOCGPTN, &index2), SyscallSucceeds()); - FileDescriptor slave2 = ASSERT_NO_ERRNO_AND_VALUE(OpenSlave(master2)); + FileDescriptor replica2 = ASSERT_NO_ERRNO_AND_VALUE(OpenReplica(master2)); // The directory contains ptmx, index1, and index2. (Plus any additional PTYs // unrelated to this test.) @@ -519,59 +520,60 @@ class PtyTest : public ::testing::Test { protected: void SetUp() override { master_ = ASSERT_NO_ERRNO_AND_VALUE(Open("/dev/ptmx", O_RDWR | O_NONBLOCK)); - slave_ = ASSERT_NO_ERRNO_AND_VALUE(OpenSlave(master_)); + replica_ = ASSERT_NO_ERRNO_AND_VALUE(OpenReplica(master_)); } void DisableCanonical() { struct kernel_termios t = {}; - EXPECT_THAT(ioctl(slave_.get(), TCGETS, &t), SyscallSucceeds()); + EXPECT_THAT(ioctl(replica_.get(), TCGETS, &t), SyscallSucceeds()); t.c_lflag &= ~ICANON; - EXPECT_THAT(ioctl(slave_.get(), TCSETS, &t), SyscallSucceeds()); + EXPECT_THAT(ioctl(replica_.get(), TCSETS, &t), SyscallSucceeds()); } void EnableCanonical() { struct kernel_termios t = {}; - EXPECT_THAT(ioctl(slave_.get(), TCGETS, &t), SyscallSucceeds()); + EXPECT_THAT(ioctl(replica_.get(), TCGETS, &t), SyscallSucceeds()); t.c_lflag |= ICANON; - EXPECT_THAT(ioctl(slave_.get(), TCSETS, &t), SyscallSucceeds()); + EXPECT_THAT(ioctl(replica_.get(), TCSETS, &t), SyscallSucceeds()); } - // Master and slave ends of the PTY. Non-blocking. + // Master and replica ends of the PTY. Non-blocking. FileDescriptor master_; - FileDescriptor slave_; + FileDescriptor replica_; }; -// Master to slave sanity test. -TEST_F(PtyTest, WriteMasterToSlave) { - // N.B. by default, the slave reads nothing until the master writes a newline. +// Master to replica sanity test. +TEST_F(PtyTest, WriteMasterToReplica) { + // N.B. by default, the replica reads nothing until the master writes a + // newline. constexpr char kBuf[] = "hello\n"; EXPECT_THAT(WriteFd(master_.get(), kBuf, sizeof(kBuf) - 1), SyscallSucceedsWithValue(sizeof(kBuf) - 1)); - // Linux moves data from the master to the slave via async work scheduled via - // tty_flip_buffer_push. Since it is asynchronous, the data may not be + // Linux moves data from the master to the replica via async work scheduled + // via tty_flip_buffer_push. Since it is asynchronous, the data may not be // available for reading immediately. Instead we must poll and assert that it // becomes available "soon". char buf[sizeof(kBuf)] = {}; - ExpectReadable(slave_, sizeof(buf) - 1, buf); + ExpectReadable(replica_, sizeof(buf) - 1, buf); EXPECT_EQ(memcmp(buf, kBuf, sizeof(kBuf)), 0); } -// Slave to master sanity test. -TEST_F(PtyTest, WriteSlaveToMaster) { - // N.B. by default, the master reads nothing until the slave writes a newline, - // and the master gets a carriage return. +// Replica to master sanity test. +TEST_F(PtyTest, WriteReplicaToMaster) { + // N.B. by default, the master reads nothing until the replica writes a + // newline, and the master gets a carriage return. constexpr char kInput[] = "hello\n"; constexpr char kExpected[] = "hello\r\n"; - EXPECT_THAT(WriteFd(slave_.get(), kInput, sizeof(kInput) - 1), + EXPECT_THAT(WriteFd(replica_.get(), kInput, sizeof(kInput) - 1), SyscallSucceedsWithValue(sizeof(kInput) - 1)); - // Linux moves data from the master to the slave via async work scheduled via - // tty_flip_buffer_push. Since it is asynchronous, the data may not be + // Linux moves data from the master to the replica via async work scheduled + // via tty_flip_buffer_push. Since it is asynchronous, the data may not be // available for reading immediately. Instead we must poll and assert that it // becomes available "soon". @@ -587,32 +589,33 @@ TEST_F(PtyTest, WriteInvalidUTF8) { SyscallSucceedsWithValue(sizeof(c))); } -// Both the master and slave report the standard default termios settings. +// Both the master and replica report the standard default termios settings. // -// Note that TCGETS on the master actually redirects to the slave (see comment +// Note that TCGETS on the master actually redirects to the replica (see comment // on MasterTermiosUnchangable). TEST_F(PtyTest, DefaultTermios) { struct kernel_termios t = {}; - EXPECT_THAT(ioctl(slave_.get(), TCGETS, &t), SyscallSucceeds()); + EXPECT_THAT(ioctl(replica_.get(), TCGETS, &t), SyscallSucceeds()); EXPECT_EQ(t, DefaultTermios()); EXPECT_THAT(ioctl(master_.get(), TCGETS, &t), SyscallSucceeds()); EXPECT_EQ(t, DefaultTermios()); } -// Changing termios from the master actually affects the slave. +// Changing termios from the master actually affects the replica. // -// TCSETS on the master actually redirects to the slave (see comment on +// TCSETS on the master actually redirects to the replica (see comment on // MasterTermiosUnchangable). -TEST_F(PtyTest, TermiosAffectsSlave) { +TEST_F(PtyTest, TermiosAffectsReplica) { struct kernel_termios master_termios = {}; EXPECT_THAT(ioctl(master_.get(), TCGETS, &master_termios), SyscallSucceeds()); master_termios.c_lflag ^= ICANON; EXPECT_THAT(ioctl(master_.get(), TCSETS, &master_termios), SyscallSucceeds()); - struct kernel_termios slave_termios = {}; - EXPECT_THAT(ioctl(slave_.get(), TCGETS, &slave_termios), SyscallSucceeds()); - EXPECT_EQ(master_termios, slave_termios); + struct kernel_termios replica_termios = {}; + EXPECT_THAT(ioctl(replica_.get(), TCGETS, &replica_termios), + SyscallSucceeds()); + EXPECT_EQ(master_termios, replica_termios); } // The master end of the pty has termios: @@ -627,7 +630,7 @@ TEST_F(PtyTest, TermiosAffectsSlave) { // // (From drivers/tty/pty.c:unix98_pty_init) // -// All termios control ioctls on the master actually redirect to the slave +// All termios control ioctls on the master actually redirect to the replica // (drivers/tty/tty_ioctl.c:tty_mode_ioctl), making it impossible to change the // master termios. // @@ -640,7 +643,7 @@ TEST_F(PtyTest, MasterTermiosUnchangable) { EXPECT_THAT(ioctl(master_.get(), TCSETS, &master_termios), SyscallSucceeds()); char c = '\r'; - ASSERT_THAT(WriteFd(slave_.get(), &c, 1), SyscallSucceedsWithValue(1)); + ASSERT_THAT(WriteFd(replica_.get(), &c, 1), SyscallSucceedsWithValue(1)); ExpectReadable(master_, 1, &c); EXPECT_EQ(c, '\r'); // ICRNL had no effect! @@ -653,15 +656,15 @@ TEST_F(PtyTest, TermiosICRNL) { struct kernel_termios t = DefaultTermios(); t.c_iflag |= ICRNL; t.c_lflag &= ~ICANON; // for byte-by-byte reading. - ASSERT_THAT(ioctl(slave_.get(), TCSETS, &t), SyscallSucceeds()); + ASSERT_THAT(ioctl(replica_.get(), TCSETS, &t), SyscallSucceeds()); char c = '\r'; ASSERT_THAT(WriteFd(master_.get(), &c, 1), SyscallSucceedsWithValue(1)); - ExpectReadable(slave_, 1, &c); + ExpectReadable(replica_, 1, &c); EXPECT_EQ(c, '\n'); - ExpectFinished(slave_); + ExpectFinished(replica_); } // ONLCR rewrites output \n to \r\n. @@ -669,42 +672,42 @@ TEST_F(PtyTest, TermiosONLCR) { struct kernel_termios t = DefaultTermios(); t.c_oflag |= ONLCR; t.c_lflag &= ~ICANON; // for byte-by-byte reading. - ASSERT_THAT(ioctl(slave_.get(), TCSETS, &t), SyscallSucceeds()); + ASSERT_THAT(ioctl(replica_.get(), TCSETS, &t), SyscallSucceeds()); char c = '\n'; - ASSERT_THAT(WriteFd(slave_.get(), &c, 1), SyscallSucceedsWithValue(1)); + ASSERT_THAT(WriteFd(replica_.get(), &c, 1), SyscallSucceedsWithValue(1)); // Extra byte for NUL for EXPECT_STREQ. char buf[3] = {}; ExpectReadable(master_, 2, buf); EXPECT_STREQ(buf, "\r\n"); - ExpectFinished(slave_); + ExpectFinished(replica_); } TEST_F(PtyTest, TermiosIGNCR) { struct kernel_termios t = DefaultTermios(); t.c_iflag |= IGNCR; t.c_lflag &= ~ICANON; // for byte-by-byte reading. - ASSERT_THAT(ioctl(slave_.get(), TCSETS, &t), SyscallSucceeds()); + ASSERT_THAT(ioctl(replica_.get(), TCSETS, &t), SyscallSucceeds()); char c = '\r'; ASSERT_THAT(WriteFd(master_.get(), &c, 1), SyscallSucceedsWithValue(1)); // Nothing to read. - ASSERT_THAT(PollAndReadFd(slave_.get(), &c, 1, kTimeout), + ASSERT_THAT(PollAndReadFd(replica_.get(), &c, 1, kTimeout), PosixErrorIs(ETIMEDOUT, ::testing::StrEq("Poll timed out"))); } -// Test that we can successfully poll for readable data from the slave. -TEST_F(PtyTest, TermiosPollSlave) { +// Test that we can successfully poll for readable data from the replica. +TEST_F(PtyTest, TermiosPollReplica) { struct kernel_termios t = DefaultTermios(); t.c_iflag |= IGNCR; t.c_lflag &= ~ICANON; // for byte-by-byte reading. - ASSERT_THAT(ioctl(slave_.get(), TCSETS, &t), SyscallSucceeds()); + ASSERT_THAT(ioctl(replica_.get(), TCSETS, &t), SyscallSucceeds()); absl::Notification notify; - int sfd = slave_.get(); + int sfd = replica_.get(); ScopedThread th([sfd, ¬ify]() { notify.Notify(); @@ -753,33 +756,33 @@ TEST_F(PtyTest, TermiosPollMaster) { absl::SleepFor(absl::Seconds(1)); char s[] = "foo\n"; - ASSERT_THAT(WriteFd(slave_.get(), s, strlen(s) + 1), SyscallSucceeds()); + ASSERT_THAT(WriteFd(replica_.get(), s, strlen(s) + 1), SyscallSucceeds()); } TEST_F(PtyTest, TermiosINLCR) { struct kernel_termios t = DefaultTermios(); t.c_iflag |= INLCR; t.c_lflag &= ~ICANON; // for byte-by-byte reading. - ASSERT_THAT(ioctl(slave_.get(), TCSETS, &t), SyscallSucceeds()); + ASSERT_THAT(ioctl(replica_.get(), TCSETS, &t), SyscallSucceeds()); char c = '\n'; ASSERT_THAT(WriteFd(master_.get(), &c, 1), SyscallSucceedsWithValue(1)); - ExpectReadable(slave_, 1, &c); + ExpectReadable(replica_, 1, &c); EXPECT_EQ(c, '\r'); - ExpectFinished(slave_); + ExpectFinished(replica_); } TEST_F(PtyTest, TermiosONOCR) { struct kernel_termios t = DefaultTermios(); t.c_oflag |= ONOCR; t.c_lflag &= ~ICANON; // for byte-by-byte reading. - ASSERT_THAT(ioctl(slave_.get(), TCSETS, &t), SyscallSucceeds()); + ASSERT_THAT(ioctl(replica_.get(), TCSETS, &t), SyscallSucceeds()); // The terminal is at column 0, so there should be no CR to read. char c = '\r'; - ASSERT_THAT(WriteFd(slave_.get(), &c, 1), SyscallSucceedsWithValue(1)); + ASSERT_THAT(WriteFd(replica_.get(), &c, 1), SyscallSucceedsWithValue(1)); // Nothing to read. ASSERT_THAT(PollAndReadFd(master_.get(), &c, 1, kTimeout), @@ -789,7 +792,7 @@ TEST_F(PtyTest, TermiosONOCR) { // out of the other end. constexpr char kInput[] = "foo\r"; constexpr int kInputSize = sizeof(kInput) - 1; - ASSERT_THAT(WriteFd(slave_.get(), kInput, kInputSize), + ASSERT_THAT(WriteFd(replica_.get(), kInput, kInputSize), SyscallSucceedsWithValue(kInputSize)); char buf[kInputSize] = {}; @@ -800,7 +803,7 @@ TEST_F(PtyTest, TermiosONOCR) { ExpectFinished(master_); // Terminal should be at column 0 again, so no CR can be read. - ASSERT_THAT(WriteFd(slave_.get(), &c, 1), SyscallSucceedsWithValue(1)); + ASSERT_THAT(WriteFd(replica_.get(), &c, 1), SyscallSucceedsWithValue(1)); // Nothing to read. ASSERT_THAT(PollAndReadFd(master_.get(), &c, 1, kTimeout), @@ -811,11 +814,11 @@ TEST_F(PtyTest, TermiosOCRNL) { struct kernel_termios t = DefaultTermios(); t.c_oflag |= OCRNL; t.c_lflag &= ~ICANON; // for byte-by-byte reading. - ASSERT_THAT(ioctl(slave_.get(), TCSETS, &t), SyscallSucceeds()); + ASSERT_THAT(ioctl(replica_.get(), TCSETS, &t), SyscallSucceeds()); // The terminal is at column 0, so there should be no CR to read. char c = '\r'; - ASSERT_THAT(WriteFd(slave_.get(), &c, 1), SyscallSucceedsWithValue(1)); + ASSERT_THAT(WriteFd(replica_.get(), &c, 1), SyscallSucceedsWithValue(1)); ExpectReadable(master_, 1, &c); EXPECT_EQ(c, '\n'); @@ -831,24 +834,24 @@ TEST_F(PtyTest, VEOLTermination) { ASSERT_THAT(WriteFd(master_.get(), kInput, sizeof(kInput)), SyscallSucceedsWithValue(sizeof(kInput))); char buf[sizeof(kInput)] = {}; - ASSERT_THAT(PollAndReadFd(slave_.get(), buf, sizeof(kInput), kTimeout), + ASSERT_THAT(PollAndReadFd(replica_.get(), buf, sizeof(kInput), kTimeout), PosixErrorIs(ETIMEDOUT, ::testing::StrEq("Poll timed out"))); // Set the EOL character to '=' and write it. constexpr char delim = '='; struct kernel_termios t = DefaultTermios(); t.c_cc[VEOL] = delim; - ASSERT_THAT(ioctl(slave_.get(), TCSETS, &t), SyscallSucceeds()); + ASSERT_THAT(ioctl(replica_.get(), TCSETS, &t), SyscallSucceeds()); ASSERT_THAT(WriteFd(master_.get(), &delim, 1), SyscallSucceedsWithValue(1)); // Now we can read, as sending EOL caused the line to become available. - ExpectReadable(slave_, sizeof(kInput), buf); + ExpectReadable(replica_, sizeof(kInput), buf); EXPECT_EQ(memcmp(buf, kInput, sizeof(kInput)), 0); - ExpectReadable(slave_, 1, buf); + ExpectReadable(replica_, 1, buf); EXPECT_EQ(buf[0], '='); - ExpectFinished(slave_); + ExpectFinished(replica_); } // Tests that we can write more than the 4096 character limit, then a @@ -864,9 +867,9 @@ TEST_F(PtyTest, CanonBigWrite) { // We can read the line. char buf[kMaxLineSize] = {}; - ExpectReadable(slave_, kMaxLineSize, buf); + ExpectReadable(replica_, kMaxLineSize, buf); - ExpectFinished(slave_); + ExpectFinished(replica_); } // Tests that data written in canonical mode can be read immediately once @@ -880,15 +883,15 @@ TEST_F(PtyTest, SwitchCanonToNoncanon) { // Nothing available yet. char buf[sizeof(kInput)] = {}; - ASSERT_THAT(PollAndReadFd(slave_.get(), buf, sizeof(kInput), kTimeout), + ASSERT_THAT(PollAndReadFd(replica_.get(), buf, sizeof(kInput), kTimeout), PosixErrorIs(ETIMEDOUT, ::testing::StrEq("Poll timed out"))); DisableCanonical(); - ExpectReadable(slave_, sizeof(kInput), buf); + ExpectReadable(replica_, sizeof(kInput), buf); EXPECT_STREQ(buf, kInput); - ExpectFinished(slave_); + ExpectFinished(replica_); } TEST_F(PtyTest, SwitchCanonToNonCanonNewline) { @@ -901,10 +904,10 @@ TEST_F(PtyTest, SwitchCanonToNonCanonNewline) { // We can read the line. char buf[sizeof(kInput)] = {}; - ExpectReadable(slave_, sizeof(kInput), buf); + ExpectReadable(replica_, sizeof(kInput), buf); EXPECT_STREQ(buf, kInput); - ExpectFinished(slave_); + ExpectFinished(replica_); } TEST_F(PtyTest, SwitchNoncanonToCanonNewlineBig) { @@ -917,7 +920,7 @@ TEST_F(PtyTest, SwitchNoncanonToCanonNewlineBig) { ASSERT_THAT(WriteFd(master_.get(), input, kWriteLen), SyscallSucceedsWithValue(kWriteLen)); // Wait for the input queue to fill. - ASSERT_NO_ERRNO(WaitUntilReceived(slave_.get(), kMaxLineSize - 1)); + ASSERT_NO_ERRNO(WaitUntilReceived(replica_.get(), kMaxLineSize - 1)); constexpr char delim = '\n'; ASSERT_THAT(WriteFd(master_.get(), &delim, 1), SyscallSucceedsWithValue(1)); @@ -925,12 +928,12 @@ TEST_F(PtyTest, SwitchNoncanonToCanonNewlineBig) { // We can read the line. char buf[kMaxLineSize] = {}; - ExpectReadable(slave_, kMaxLineSize - 1, buf); + ExpectReadable(replica_, kMaxLineSize - 1, buf); // We can also read the remaining characters. - ExpectReadable(slave_, 6, buf); + ExpectReadable(replica_, 6, buf); - ExpectFinished(slave_); + ExpectFinished(replica_); } TEST_F(PtyTest, SwitchNoncanonToCanonNoNewline) { @@ -942,15 +945,15 @@ TEST_F(PtyTest, SwitchNoncanonToCanonNoNewline) { ASSERT_THAT(WriteFd(master_.get(), kInput, sizeof(kInput) - 1), SyscallSucceedsWithValue(sizeof(kInput) - 1)); - ASSERT_NO_ERRNO(WaitUntilReceived(slave_.get(), sizeof(kInput) - 1)); + ASSERT_NO_ERRNO(WaitUntilReceived(replica_.get(), sizeof(kInput) - 1)); EnableCanonical(); // We can read the line. char buf[sizeof(kInput)] = {}; - ExpectReadable(slave_, sizeof(kInput) - 1, buf); + ExpectReadable(replica_, sizeof(kInput) - 1, buf); EXPECT_STREQ(buf, kInput); - ExpectFinished(slave_); + ExpectFinished(replica_); } TEST_F(PtyTest, SwitchNoncanonToCanonNoNewlineBig) { @@ -964,14 +967,14 @@ TEST_F(PtyTest, SwitchNoncanonToCanonNoNewlineBig) { ASSERT_THAT(WriteFd(master_.get(), input, kWriteLen), SyscallSucceedsWithValue(kWriteLen)); - ASSERT_NO_ERRNO(WaitUntilReceived(slave_.get(), kMaxLineSize - 1)); + ASSERT_NO_ERRNO(WaitUntilReceived(replica_.get(), kMaxLineSize - 1)); EnableCanonical(); // We can read the line. char buf[kMaxLineSize] = {}; - ExpectReadable(slave_, kMaxLineSize - 1, buf); + ExpectReadable(replica_, kMaxLineSize - 1, buf); - ExpectFinished(slave_); + ExpectFinished(replica_); } // Tests that we can write over the 4095 noncanonical limit, then read out @@ -990,17 +993,17 @@ TEST_F(PtyTest, NoncanonBigWrite) { } // We should be able to read out everything. Sleep a bit so that Linux has a - // chance to move data from the master to the slave. - ASSERT_NO_ERRNO(WaitUntilReceived(slave_.get(), kMaxLineSize - 1)); + // chance to move data from the master to the replica. + ASSERT_NO_ERRNO(WaitUntilReceived(replica_.get(), kMaxLineSize - 1)); for (int i = 0; i < kInputSize; i++) { // This makes too many syscalls for save/restore. const DisableSave ds; char c; - ExpectReadable(slave_, 1, &c); + ExpectReadable(replica_, 1, &c); ASSERT_EQ(c, kInput); } - ExpectFinished(slave_); + ExpectFinished(replica_); } // ICANON doesn't make input available until a line delimiter is typed. @@ -1015,18 +1018,18 @@ TEST_F(PtyTest, TermiosICANONNewline) { char buf[5] = {}; // Nothing available yet. - ASSERT_THAT(PollAndReadFd(slave_.get(), buf, sizeof(input), kTimeout), + ASSERT_THAT(PollAndReadFd(replica_.get(), buf, sizeof(input), kTimeout), PosixErrorIs(ETIMEDOUT, ::testing::StrEq("Poll timed out"))); char delim = '\n'; ASSERT_THAT(WriteFd(master_.get(), &delim, 1), SyscallSucceedsWithValue(1)); // Now it is available. - ASSERT_NO_ERRNO(WaitUntilReceived(slave_.get(), sizeof(input) + 1)); - ExpectReadable(slave_, sizeof(input) + 1, buf); + ASSERT_NO_ERRNO(WaitUntilReceived(replica_.get(), sizeof(input) + 1)); + ExpectReadable(replica_, sizeof(input) + 1, buf); EXPECT_STREQ(buf, "abc\n"); - ExpectFinished(slave_); + ExpectFinished(replica_); } // ICANON doesn't make input available until a line delimiter is typed. @@ -1041,16 +1044,16 @@ TEST_F(PtyTest, TermiosICANONEOF) { char buf[4] = {}; // Nothing available yet. - ASSERT_THAT(PollAndReadFd(slave_.get(), buf, sizeof(input), kTimeout), + ASSERT_THAT(PollAndReadFd(replica_.get(), buf, sizeof(input), kTimeout), PosixErrorIs(ETIMEDOUT, ::testing::StrEq("Poll timed out"))); char delim = ControlCharacter('D'); ASSERT_THAT(WriteFd(master_.get(), &delim, 1), SyscallSucceedsWithValue(1)); // Now it is available. Note that ^D is not included. - ExpectReadable(slave_, sizeof(input), buf); + ExpectReadable(replica_, sizeof(input), buf); EXPECT_STREQ(buf, "abc"); - ExpectFinished(slave_); + ExpectFinished(replica_); } // ICANON limits us to 4096 bytes including a terminating character. Anything @@ -1076,12 +1079,12 @@ TEST_F(PtyTest, CanonDiscard) { // There should be multiple truncated lines available to read. for (int i = 0; i < kIter; i++) { char buf[kInputSize] = {}; - ExpectReadable(slave_, kMaxLineSize, buf); + ExpectReadable(replica_, kMaxLineSize, buf); EXPECT_EQ(buf[kMaxLineSize - 1], delim); EXPECT_EQ(buf[kMaxLineSize - 2], kInput); } - ExpectFinished(slave_); + ExpectFinished(replica_); } TEST_F(PtyTest, CanonMultiline) { @@ -1096,15 +1099,15 @@ TEST_F(PtyTest, CanonMultiline) { // Get the first line. char line1[8] = {}; - ExpectReadable(slave_, sizeof(kInput1) - 1, line1); + ExpectReadable(replica_, sizeof(kInput1) - 1, line1); EXPECT_STREQ(line1, kInput1); // Get the second line. char line2[8] = {}; - ExpectReadable(slave_, sizeof(kInput2) - 1, line2); + ExpectReadable(replica_, sizeof(kInput2) - 1, line2); EXPECT_STREQ(line2, kInput2); - ExpectFinished(slave_); + ExpectFinished(replica_); } TEST_F(PtyTest, SwitchNoncanonToCanonMultiline) { @@ -1121,15 +1124,15 @@ TEST_F(PtyTest, SwitchNoncanonToCanonMultiline) { SyscallSucceedsWithValue(sizeof(kInput2) - 1)); ASSERT_NO_ERRNO( - WaitUntilReceived(slave_.get(), sizeof(kInput1) + sizeof(kInput2) - 2)); + WaitUntilReceived(replica_.get(), sizeof(kInput1) + sizeof(kInput2) - 2)); EnableCanonical(); // Get all together as one line. char line[9] = {}; - ExpectReadable(slave_, 8, line); + ExpectReadable(replica_, 8, line); EXPECT_STREQ(line, kExpected); - ExpectFinished(slave_); + ExpectFinished(replica_); } TEST_F(PtyTest, SwitchTwiceMultiline) { @@ -1146,15 +1149,15 @@ TEST_F(PtyTest, SwitchTwiceMultiline) { // All written characters have to make it into the input queue before // canonical mode is re-enabled. If the final '!' character hasn't been // enqueued before canonical mode is re-enabled, it won't be readable. - ASSERT_NO_ERRNO(WaitUntilReceived(slave_.get(), kExpected.size())); + ASSERT_NO_ERRNO(WaitUntilReceived(replica_.get(), kExpected.size())); EnableCanonical(); // Get all together as one line. char line[10] = {}; - ExpectReadable(slave_, 9, line); + ExpectReadable(replica_, 9, line); EXPECT_STREQ(line, kExpected.c_str()); - ExpectFinished(slave_); + ExpectFinished(replica_); } TEST_F(PtyTest, QueueSize) { @@ -1162,7 +1165,7 @@ TEST_F(PtyTest, QueueSize) { constexpr char kInput1[] = "GO\n"; ASSERT_THAT(WriteFd(master_.get(), kInput1, sizeof(kInput1) - 1), SyscallSucceedsWithValue(sizeof(kInput1) - 1)); - ASSERT_NO_ERRNO(WaitUntilReceived(slave_.get(), sizeof(kInput1) - 1)); + ASSERT_NO_ERRNO(WaitUntilReceived(replica_.get(), sizeof(kInput1) - 1)); // Ensure that writing more (beyond what is readable) does not impact the // readable size. @@ -1171,7 +1174,7 @@ TEST_F(PtyTest, QueueSize) { ASSERT_THAT(WriteFd(master_.get(), input, kMaxLineSize), SyscallSucceedsWithValue(kMaxLineSize)); int inputBufSize = ASSERT_NO_ERRNO_AND_VALUE( - WaitUntilReceived(slave_.get(), sizeof(kInput1) - 1)); + WaitUntilReceived(replica_.get(), sizeof(kInput1) - 1)); EXPECT_EQ(inputBufSize, sizeof(kInput1) - 1); } @@ -1196,9 +1199,9 @@ TEST_F(PtyTest, PartialBadBuffer) { EXPECT_THAT(WriteFd(master_.get(), kBuf, size), SyscallSucceedsWithValue(size)); - // Read from the slave into bad_buffer. - ASSERT_NO_ERRNO(WaitUntilReceived(slave_.get(), size)); - EXPECT_THAT(ReadFd(slave_.get(), bad_buffer, size), + // Read from the replica into bad_buffer. + ASSERT_NO_ERRNO(WaitUntilReceived(replica_.get(), size)); + EXPECT_THAT(ReadFd(replica_.get(), bad_buffer, size), SyscallFailsWithErrno(EFAULT)); EXPECT_THAT(munmap(addr, 2 * kPageSize), SyscallSucceeds()) << addr; @@ -1218,16 +1221,16 @@ TEST_F(PtyTest, SimpleEcho) { TEST_F(PtyTest, GetWindowSize) { struct winsize ws; - ASSERT_THAT(ioctl(slave_.get(), TIOCGWINSZ, &ws), SyscallSucceeds()); + ASSERT_THAT(ioctl(replica_.get(), TIOCGWINSZ, &ws), SyscallSucceeds()); EXPECT_EQ(ws.ws_row, 0); EXPECT_EQ(ws.ws_col, 0); } -TEST_F(PtyTest, SetSlaveWindowSize) { +TEST_F(PtyTest, SetReplicaWindowSize) { constexpr uint16_t kRows = 343; constexpr uint16_t kCols = 2401; struct winsize ws = {.ws_row = kRows, .ws_col = kCols}; - ASSERT_THAT(ioctl(slave_.get(), TIOCSWINSZ, &ws), SyscallSucceeds()); + ASSERT_THAT(ioctl(replica_.get(), TIOCSWINSZ, &ws), SyscallSucceeds()); struct winsize retrieved_ws = {}; ASSERT_THAT(ioctl(master_.get(), TIOCGWINSZ, &retrieved_ws), @@ -1243,7 +1246,7 @@ TEST_F(PtyTest, SetMasterWindowSize) { ASSERT_THAT(ioctl(master_.get(), TIOCSWINSZ, &ws), SyscallSucceeds()); struct winsize retrieved_ws = {}; - ASSERT_THAT(ioctl(slave_.get(), TIOCGWINSZ, &retrieved_ws), + ASSERT_THAT(ioctl(replica_.get(), TIOCGWINSZ, &retrieved_ws), SyscallSucceeds()); EXPECT_EQ(retrieved_ws.ws_row, kRows); EXPECT_EQ(retrieved_ws.ws_col, kCols); @@ -1253,7 +1256,7 @@ class JobControlTest : public ::testing::Test { protected: void SetUp() override { master_ = ASSERT_NO_ERRNO_AND_VALUE(Open("/dev/ptmx", O_RDWR | O_NONBLOCK)); - slave_ = ASSERT_NO_ERRNO_AND_VALUE(OpenSlave(master_)); + replica_ = ASSERT_NO_ERRNO_AND_VALUE(OpenReplica(master_)); // Make this a session leader, which also drops the controlling terminal. // In the gVisor test environment, this test will be run as the session @@ -1263,61 +1266,82 @@ class JobControlTest : public ::testing::Test { } } - // Master and slave ends of the PTY. Non-blocking. + PosixError RunInChild(SubprocessCallback childFunc) { + pid_t child = fork(); + if (!child) { + childFunc(); + _exit(0); + } + int wstatus; + if (waitpid(child, &wstatus, 0) != child) { + return PosixError( + errno, absl::StrCat("child failed with wait status: ", wstatus)); + } + return PosixError(wstatus, "process returned"); + } + + // Master and replica ends of the PTY. Non-blocking. FileDescriptor master_; - FileDescriptor slave_; + FileDescriptor replica_; }; TEST_F(JobControlTest, SetTTYMaster) { - ASSERT_THAT(ioctl(master_.get(), TIOCSCTTY, 0), SyscallSucceeds()); + auto res = RunInChild([=]() { + TEST_PCHECK(setsid() >= 0); + TEST_PCHECK(!ioctl(master_.get(), TIOCSCTTY, 0)); + }); + ASSERT_NO_ERRNO(res); } TEST_F(JobControlTest, SetTTY) { - ASSERT_THAT(ioctl(slave_.get(), TIOCSCTTY, 0), SyscallSucceeds()); + auto res = RunInChild([=]() { + TEST_PCHECK(setsid() >= 0); + TEST_PCHECK(ioctl(!replica_.get(), TIOCSCTTY, 0)); + }); + ASSERT_NO_ERRNO(res); } TEST_F(JobControlTest, SetTTYNonLeader) { // Fork a process that won't be the session leader. - pid_t child = fork(); - if (!child) { - // We shouldn't be able to set the terminal. - TEST_PCHECK(ioctl(slave_.get(), TIOCSCTTY, 0)); - _exit(0); - } - - int wstatus; - ASSERT_THAT(waitpid(child, &wstatus, 0), SyscallSucceedsWithValue(child)); - ASSERT_EQ(wstatus, 0); + auto res = + RunInChild([=]() { TEST_PCHECK(ioctl(replica_.get(), TIOCSCTTY, 0)); }); + ASSERT_NO_ERRNO(res); } TEST_F(JobControlTest, SetTTYBadArg) { - // Despite the man page saying arg should be 0 here, Linux doesn't actually - // check. - ASSERT_THAT(ioctl(slave_.get(), TIOCSCTTY, 1), SyscallSucceeds()); + SKIP_IF(!ASSERT_NO_ERRNO_AND_VALUE(HaveCapability(CAP_SYS_ADMIN))); + auto res = RunInChild([=]() { + TEST_PCHECK(setsid() >= 0); + TEST_PCHECK(!ioctl(replica_.get(), TIOCSCTTY, 1)); + }); + ASSERT_NO_ERRNO(res); } TEST_F(JobControlTest, SetTTYDifferentSession) { SKIP_IF(ASSERT_NO_ERRNO_AND_VALUE(HaveCapability(CAP_SYS_ADMIN))); - ASSERT_THAT(ioctl(slave_.get(), TIOCSCTTY, 0), SyscallSucceeds()); - - // Fork, join a new session, and try to steal the parent's controlling - // terminal, which should fail. - pid_t child = fork(); - if (!child) { + auto res = RunInChild([=]() { TEST_PCHECK(setsid() >= 0); - // We shouldn't be able to steal the terminal. - TEST_PCHECK(ioctl(slave_.get(), TIOCSCTTY, 1)); - _exit(0); - } + TEST_PCHECK(!ioctl(replica_.get(), TIOCSCTTY, 1)); - int wstatus; - ASSERT_THAT(waitpid(child, &wstatus, 0), SyscallSucceedsWithValue(child)); - ASSERT_EQ(wstatus, 0); + // Fork, join a new session, and try to steal the parent's controlling + // terminal, which should fail. + pid_t grandchild = fork(); + if (!grandchild) { + TEST_PCHECK(setsid() >= 0); + // We shouldn't be able to steal the terminal. + TEST_PCHECK(ioctl(replica_.get(), TIOCSCTTY, 1)); + _exit(0); + } + + int gcwstatus; + TEST_PCHECK(waitpid(grandchild, &gcwstatus, 0) == grandchild); + TEST_PCHECK(gcwstatus == 0); + }); } TEST_F(JobControlTest, ReleaseTTY) { - ASSERT_THAT(ioctl(slave_.get(), TIOCSCTTY, 0), SyscallSucceeds()); + ASSERT_THAT(ioctl(replica_.get(), TIOCSCTTY, 0), SyscallSucceeds()); // Make sure we're ignoring SIGHUP, which will be sent to this process once we // disconnect they TTY. @@ -1327,48 +1351,60 @@ TEST_F(JobControlTest, ReleaseTTY) { sigemptyset(&sa.sa_mask); struct sigaction old_sa; EXPECT_THAT(sigaction(SIGHUP, &sa, &old_sa), SyscallSucceeds()); - EXPECT_THAT(ioctl(slave_.get(), TIOCNOTTY), SyscallSucceeds()); + EXPECT_THAT(ioctl(replica_.get(), TIOCNOTTY), SyscallSucceeds()); EXPECT_THAT(sigaction(SIGHUP, &old_sa, NULL), SyscallSucceeds()); } TEST_F(JobControlTest, ReleaseUnsetTTY) { - ASSERT_THAT(ioctl(slave_.get(), TIOCNOTTY), SyscallFailsWithErrno(ENOTTY)); + ASSERT_THAT(ioctl(replica_.get(), TIOCNOTTY), SyscallFailsWithErrno(ENOTTY)); } TEST_F(JobControlTest, ReleaseWrongTTY) { - ASSERT_THAT(ioctl(slave_.get(), TIOCSCTTY, 0), SyscallSucceeds()); - - ASSERT_THAT(ioctl(master_.get(), TIOCNOTTY), SyscallFailsWithErrno(ENOTTY)); + auto res = RunInChild([=]() { + TEST_PCHECK(setsid() >= 0); + TEST_PCHECK(!ioctl(replica_.get(), TIOCSCTTY, 0)); + TEST_PCHECK(ioctl(master_.get(), TIOCNOTTY) < 0 && errno == ENOTTY); + }); + ASSERT_NO_ERRNO(res); } TEST_F(JobControlTest, ReleaseTTYNonLeader) { - ASSERT_THAT(ioctl(slave_.get(), TIOCSCTTY, 0), SyscallSucceeds()); + auto ret = RunInChild([=]() { + TEST_PCHECK(setsid() >= 0); + TEST_PCHECK(!ioctl(replica_.get(), TIOCSCTTY, 0)); - pid_t child = fork(); - if (!child) { - TEST_PCHECK(!ioctl(slave_.get(), TIOCNOTTY)); - _exit(0); - } + pid_t grandchild = fork(); + if (!grandchild) { + TEST_PCHECK(!ioctl(replica_.get(), TIOCNOTTY)); + _exit(0); + } - int wstatus; - ASSERT_THAT(waitpid(child, &wstatus, 0), SyscallSucceedsWithValue(child)); - ASSERT_EQ(wstatus, 0); + int wstatus; + TEST_PCHECK(waitpid(grandchild, &wstatus, 0) == grandchild); + TEST_PCHECK(wstatus == 0); + }); + ASSERT_NO_ERRNO(ret); } TEST_F(JobControlTest, ReleaseTTYDifferentSession) { - ASSERT_THAT(ioctl(slave_.get(), TIOCSCTTY, 0), SyscallSucceeds()); - - pid_t child = fork(); - if (!child) { - // Join a new session, then try to disconnect. + auto ret = RunInChild([=]() { TEST_PCHECK(setsid() >= 0); - TEST_PCHECK(ioctl(slave_.get(), TIOCNOTTY)); - _exit(0); - } - int wstatus; - ASSERT_THAT(waitpid(child, &wstatus, 0), SyscallSucceedsWithValue(child)); - ASSERT_EQ(wstatus, 0); + TEST_PCHECK(!ioctl(replica_.get(), TIOCSCTTY, 0)); + + pid_t grandchild = fork(); + if (!grandchild) { + // Join a new session, then try to disconnect. + TEST_PCHECK(setsid() >= 0); + TEST_PCHECK(ioctl(replica_.get(), TIOCNOTTY)); + _exit(0); + } + + int wstatus; + TEST_PCHECK(waitpid(grandchild, &wstatus, 0) == grandchild); + TEST_PCHECK(wstatus == 0); + }); + ASSERT_NO_ERRNO(ret); } // Used by the child process spawned in ReleaseTTYSignals to track received @@ -1387,7 +1423,7 @@ void sig_handler(int signum) { received |= signum; } // - Checks that thread 1 got both signals // - Checks that thread 2 didn't get any signals. TEST_F(JobControlTest, ReleaseTTYSignals) { - ASSERT_THAT(ioctl(slave_.get(), TIOCSCTTY, 0), SyscallSucceeds()); + ASSERT_THAT(ioctl(replica_.get(), TIOCSCTTY, 0), SyscallSucceeds()); received = 0; struct sigaction sa = {}; @@ -1439,7 +1475,7 @@ TEST_F(JobControlTest, ReleaseTTYSignals) { // Release the controlling terminal, sending SIGHUP and SIGCONT to all other // processes in this process group. - EXPECT_THAT(ioctl(slave_.get(), TIOCNOTTY), SyscallSucceeds()); + EXPECT_THAT(ioctl(replica_.get(), TIOCNOTTY), SyscallSucceeds()); EXPECT_THAT(sigaction(SIGHUP, &old_sa, NULL), SyscallSucceeds()); @@ -1456,20 +1492,21 @@ TEST_F(JobControlTest, ReleaseTTYSignals) { } TEST_F(JobControlTest, GetForegroundProcessGroup) { - ASSERT_THAT(ioctl(slave_.get(), TIOCSCTTY, 0), SyscallSucceeds()); - pid_t foreground_pgid; - pid_t pid; - ASSERT_THAT(ioctl(slave_.get(), TIOCGPGRP, &foreground_pgid), - SyscallSucceeds()); - ASSERT_THAT(pid = getpid(), SyscallSucceeds()); - - ASSERT_EQ(foreground_pgid, pid); + auto res = RunInChild([=]() { + pid_t pid, foreground_pgid; + TEST_PCHECK(setsid() >= 0); + TEST_PCHECK(!ioctl(replica_.get(), TIOCSCTTY, 1)); + TEST_PCHECK(!ioctl(replica_.get(), TIOCGPGRP, &foreground_pgid)); + TEST_PCHECK((pid = getpid()) >= 0); + TEST_PCHECK(pid == foreground_pgid); + }); + ASSERT_NO_ERRNO(res); } TEST_F(JobControlTest, GetForegroundProcessGroupNonControlling) { // At this point there's no controlling terminal, so TIOCGPGRP should fail. pid_t foreground_pgid; - ASSERT_THAT(ioctl(slave_.get(), TIOCGPGRP, &foreground_pgid), + ASSERT_THAT(ioctl(replica_.get(), TIOCGPGRP, &foreground_pgid), SyscallFailsWithErrno(ENOTTY)); } @@ -1479,113 +1516,125 @@ TEST_F(JobControlTest, GetForegroundProcessGroupNonControlling) { // - sets that child as the foreground process group // - kills its child and sets itself as the foreground process group. TEST_F(JobControlTest, SetForegroundProcessGroup) { - ASSERT_THAT(ioctl(slave_.get(), TIOCSCTTY, 0), SyscallSucceeds()); - - // Ignore SIGTTOU so that we don't stop ourself when calling tcsetpgrp. - struct sigaction sa = {}; - sa.sa_handler = SIG_IGN; - sa.sa_flags = 0; - sigemptyset(&sa.sa_mask); - sigaction(SIGTTOU, &sa, NULL); - - // Set ourself as the foreground process group. - ASSERT_THAT(tcsetpgrp(slave_.get(), getpgid(0)), SyscallSucceeds()); - - // Create a new process that just waits to be signaled. - pid_t child = fork(); - if (!child) { - TEST_PCHECK(!pause()); - // We should never reach this. - _exit(1); - } - - // Make the child its own process group, then make it the controlling process - // group of the terminal. - ASSERT_THAT(setpgid(child, child), SyscallSucceeds()); - ASSERT_THAT(tcsetpgrp(slave_.get(), child), SyscallSucceeds()); + auto res = RunInChild([=]() { + TEST_PCHECK(!ioctl(replica_.get(), TIOCSCTTY, 0)); + + // Ignore SIGTTOU so that we don't stop ourself when calling tcsetpgrp. + struct sigaction sa = {}; + sa.sa_handler = SIG_IGN; + sa.sa_flags = 0; + sigemptyset(&sa.sa_mask); + sigaction(SIGTTOU, &sa, NULL); + + // Set ourself as the foreground process group. + TEST_PCHECK(!tcsetpgrp(replica_.get(), getpgid(0))); + + // Create a new process that just waits to be signaled. + pid_t grandchild = fork(); + if (!grandchild) { + TEST_PCHECK(!pause()); + // We should never reach this. + _exit(1); + } - // Sanity check - we're still the controlling session. - ASSERT_EQ(getsid(0), getsid(child)); + // Make the child its own process group, then make it the controlling + // process group of the terminal. + TEST_PCHECK(!setpgid(grandchild, grandchild)); + TEST_PCHECK(!tcsetpgrp(replica_.get(), grandchild)); - // Signal the child, wait for it to exit, then retake the terminal. - ASSERT_THAT(kill(child, SIGTERM), SyscallSucceeds()); - int wstatus; - ASSERT_THAT(waitpid(child, &wstatus, 0), SyscallSucceedsWithValue(child)); - ASSERT_TRUE(WIFSIGNALED(wstatus)); - ASSERT_EQ(WTERMSIG(wstatus), SIGTERM); + // Sanity check - we're still the controlling session. + TEST_PCHECK(getsid(0) == getsid(grandchild)); - // Set ourself as the foreground process. - pid_t pgid; - ASSERT_THAT(pgid = getpgid(0), SyscallSucceeds()); - ASSERT_THAT(tcsetpgrp(slave_.get(), pgid), SyscallSucceeds()); + // Signal the child, wait for it to exit, then retake the terminal. + TEST_PCHECK(!kill(grandchild, SIGTERM)); + int wstatus; + TEST_PCHECK(waitpid(grandchild, &wstatus, 0) == grandchild); + TEST_PCHECK(WIFSIGNALED(wstatus)); + TEST_PCHECK(WTERMSIG(wstatus) == SIGTERM); + + // Set ourself as the foreground process. + pid_t pgid; + TEST_PCHECK(pgid = getpgid(0) == 0); + TEST_PCHECK(!tcsetpgrp(replica_.get(), pgid)); + }); } TEST_F(JobControlTest, SetForegroundProcessGroupWrongTTY) { pid_t pid = getpid(); - ASSERT_THAT(ioctl(slave_.get(), TIOCSPGRP, &pid), + ASSERT_THAT(ioctl(replica_.get(), TIOCSPGRP, &pid), SyscallFailsWithErrno(ENOTTY)); } TEST_F(JobControlTest, SetForegroundProcessGroupNegPgid) { - ASSERT_THAT(ioctl(slave_.get(), TIOCSCTTY, 0), SyscallSucceeds()); + auto ret = RunInChild([=]() { + TEST_PCHECK(setsid() >= 0); + TEST_PCHECK(!ioctl(replica_.get(), TIOCSCTTY, 0)); - pid_t pid = -1; - ASSERT_THAT(ioctl(slave_.get(), TIOCSPGRP, &pid), - SyscallFailsWithErrno(EINVAL)); + pid_t pid = -1; + TEST_PCHECK(ioctl(replica_.get(), TIOCSPGRP, &pid) && errno == EINVAL); + }); + ASSERT_NO_ERRNO(ret); } TEST_F(JobControlTest, SetForegroundProcessGroupEmptyProcessGroup) { - ASSERT_THAT(ioctl(slave_.get(), TIOCSCTTY, 0), SyscallSucceeds()); - - // Create a new process, put it in a new process group, make that group the - // foreground process group, then have the process wait. - pid_t child = fork(); - if (!child) { - TEST_PCHECK(!setpgid(0, 0)); - _exit(0); - } + auto ret = RunInChild([=]() { + TEST_PCHECK(!ioctl(replica_.get(), TIOCSCTTY, 0)); + + // Create a new process, put it in a new process group, make that group the + // foreground process group, then have the process wait. + pid_t grandchild = fork(); + if (!grandchild) { + TEST_PCHECK(!setpgid(0, 0)); + _exit(0); + } - // Wait for the child to exit. - int wstatus; - EXPECT_THAT(waitpid(child, &wstatus, 0), SyscallSucceedsWithValue(child)); - // The child's process group doesn't exist anymore - this should fail. - ASSERT_THAT(ioctl(slave_.get(), TIOCSPGRP, &child), - SyscallFailsWithErrno(ESRCH)); + // Wait for the child to exit. + int wstatus; + TEST_PCHECK(waitpid(grandchild, &wstatus, 0) == grandchild); + // The child's process group doesn't exist anymore - this should fail. + TEST_PCHECK(ioctl(replica_.get(), TIOCSPGRP, &grandchild) != 0 && + errno == ESRCH); + }); } TEST_F(JobControlTest, SetForegroundProcessGroupDifferentSession) { - ASSERT_THAT(ioctl(slave_.get(), TIOCSCTTY, 0), SyscallSucceeds()); + auto ret = RunInChild([=]() { + TEST_PCHECK(setsid() >= 0); + TEST_PCHECK(!ioctl(replica_.get(), TIOCSCTTY, 0)); - int sync_setsid[2]; - int sync_exit[2]; - ASSERT_THAT(pipe(sync_setsid), SyscallSucceeds()); - ASSERT_THAT(pipe(sync_exit), SyscallSucceeds()); + int sync_setsid[2]; + int sync_exit[2]; + TEST_PCHECK(pipe(sync_setsid) >= 0); + TEST_PCHECK(pipe(sync_exit) >= 0); - // Create a new process and put it in a new session. - pid_t child = fork(); - if (!child) { - TEST_PCHECK(setsid() >= 0); - // Tell the parent we're in a new session. - char c = 'c'; - TEST_PCHECK(WriteFd(sync_setsid[1], &c, 1) == 1); - TEST_PCHECK(ReadFd(sync_exit[0], &c, 1) == 1); - _exit(0); - } + // Create a new process and put it in a new session. + pid_t grandchild = fork(); + if (!grandchild) { + TEST_PCHECK(setsid() >= 0); + // Tell the parent we're in a new session. + char c = 'c'; + TEST_PCHECK(WriteFd(sync_setsid[1], &c, 1) == 1); + TEST_PCHECK(ReadFd(sync_exit[0], &c, 1) == 1); + _exit(0); + } - // Wait for the child to tell us it's in a new session. - char c = 'c'; - ASSERT_THAT(ReadFd(sync_setsid[0], &c, 1), SyscallSucceedsWithValue(1)); + // Wait for the child to tell us it's in a new session. + char c = 'c'; + TEST_PCHECK(ReadFd(sync_setsid[0], &c, 1) == 1); - // Child is in a new session, so we can't make it the foregroup process group. - EXPECT_THAT(ioctl(slave_.get(), TIOCSPGRP, &child), - SyscallFailsWithErrno(EPERM)); + // Child is in a new session, so we can't make it the foregroup process + // group. + TEST_PCHECK(ioctl(replica_.get(), TIOCSPGRP, &grandchild) && + errno == EPERM); - EXPECT_THAT(WriteFd(sync_exit[1], &c, 1), SyscallSucceedsWithValue(1)); + TEST_PCHECK(WriteFd(sync_exit[1], &c, 1) == 1); - int wstatus; - EXPECT_THAT(waitpid(child, &wstatus, 0), SyscallSucceedsWithValue(child)); - EXPECT_TRUE(WIFEXITED(wstatus)); - EXPECT_EQ(WEXITSTATUS(wstatus), 0); + int wstatus; + TEST_PCHECK(waitpid(grandchild, &wstatus, 0) == grandchild); + TEST_PCHECK(WIFEXITED(wstatus)); + TEST_PCHECK(!WEXITSTATUS(wstatus)); + }); + ASSERT_NO_ERRNO(ret); } // Verify that we don't hang when creating a new session from an orphaned diff --git a/test/syscalls/linux/pty_root.cc b/test/syscalls/linux/pty_root.cc index 1d7dbefdb..4ac648729 100644 --- a/test/syscalls/linux/pty_root.cc +++ b/test/syscalls/linux/pty_root.cc @@ -50,10 +50,10 @@ TEST(JobControlRootTest, StealTTY) { FileDescriptor master = ASSERT_NO_ERRNO_AND_VALUE(Open("/dev/ptmx", O_RDWR | O_NONBLOCK)); - FileDescriptor slave = ASSERT_NO_ERRNO_AND_VALUE(OpenSlave(master)); + FileDescriptor replica = ASSERT_NO_ERRNO_AND_VALUE(OpenReplica(master)); - // Make slave the controlling terminal. - ASSERT_THAT(ioctl(slave.get(), TIOCSCTTY, 0), SyscallSucceeds()); + // Make replica the controlling terminal. + ASSERT_THAT(ioctl(replica.get(), TIOCSCTTY, 0), SyscallSucceeds()); // Fork, join a new session, and try to steal the parent's controlling // terminal, which should succeed when we have CAP_SYS_ADMIN and pass an arg @@ -62,9 +62,9 @@ TEST(JobControlRootTest, StealTTY) { if (!child) { ASSERT_THAT(setsid(), SyscallSucceeds()); // We shouldn't be able to steal the terminal with the wrong arg value. - TEST_PCHECK(ioctl(slave.get(), TIOCSCTTY, 0)); + TEST_PCHECK(ioctl(replica.get(), TIOCSCTTY, 0)); // We should be able to steal it if we are true root. - TEST_PCHECK(true_root == !ioctl(slave.get(), TIOCSCTTY, 1)); + TEST_PCHECK(true_root == !ioctl(replica.get(), TIOCSCTTY, 1)); _exit(0); } diff --git a/test/syscalls/linux/raw_socket.cc b/test/syscalls/linux/raw_socket.cc index 8d6e5c913..54709371c 100644 --- a/test/syscalls/linux/raw_socket.cc +++ b/test/syscalls/linux/raw_socket.cc @@ -13,9 +13,7 @@ // limitations under the License. #include <linux/capability.h> -#ifndef __fuchsia__ #include <linux/filter.h> -#endif // __fuchsia__ #include <netinet/in.h> #include <netinet/ip.h> #include <netinet/ip6.h> @@ -815,8 +813,6 @@ void RawSocketTest::ReceiveBufFrom(int sock, char* recv_buf, ASSERT_NO_FATAL_FAILURE(RecvNoCmsg(sock, recv_buf, recv_buf_len)); } -#ifndef __fuchsia__ - TEST_P(RawSocketTest, SetSocketDetachFilterNoInstalledFilter) { // TODO(gvisor.dev/2746): Support SO_ATTACH_FILTER/SO_DETACH_FILTER. if (IsRunningOnGvisor()) { @@ -838,8 +834,6 @@ TEST_P(RawSocketTest, GetSocketDetachFilter) { SyscallFailsWithErrno(ENOPROTOOPT)); } -#endif // __fuchsia__ - // AF_INET6+SOCK_RAW+IPPROTO_RAW sockets can be created, but not written to. TEST(RawSocketTest, IPv6ProtoRaw) { SKIP_IF(!ASSERT_NO_ERRNO_AND_VALUE(HaveCapability(CAP_NET_RAW))); diff --git a/test/syscalls/linux/raw_socket_icmp.cc b/test/syscalls/linux/raw_socket_icmp.cc index 3de898df7..1b9dbc584 100644 --- a/test/syscalls/linux/raw_socket_icmp.cc +++ b/test/syscalls/linux/raw_socket_icmp.cc @@ -416,6 +416,28 @@ TEST_F(RawSocketICMPTest, BindConnectSendAndReceive) { ASSERT_NO_FATAL_FAILURE(ExpectICMPSuccess(icmp)); } +// Set and get SO_LINGER. +TEST_F(RawSocketICMPTest, SetAndGetSocketLinger) { + SKIP_IF(!ASSERT_NO_ERRNO_AND_VALUE(HaveCapability(CAP_NET_RAW))); + + int level = SOL_SOCKET; + int type = SO_LINGER; + + struct linger sl; + sl.l_onoff = 1; + sl.l_linger = 5; + ASSERT_THAT(setsockopt(s_, level, type, &sl, sizeof(sl)), + SyscallSucceedsWithValue(0)); + + struct linger got_linger = {}; + socklen_t length = sizeof(sl); + ASSERT_THAT(getsockopt(s_, level, type, &got_linger, &length), + SyscallSucceedsWithValue(0)); + + ASSERT_EQ(length, sizeof(got_linger)); + EXPECT_EQ(0, memcmp(&sl, &got_linger, length)); +} + void RawSocketICMPTest::ExpectICMPSuccess(const struct icmphdr& icmp) { // We're going to receive both the echo request and reply, but the order is // indeterminate. diff --git a/test/syscalls/linux/readahead.cc b/test/syscalls/linux/readahead.cc index 09703b5c1..71073bb3c 100644 --- a/test/syscalls/linux/readahead.cc +++ b/test/syscalls/linux/readahead.cc @@ -16,6 +16,7 @@ #include <fcntl.h> #include "gtest/gtest.h" +#include "test/syscalls/linux/socket_test_util.h" #include "test/util/file_descriptor.h" #include "test/util/temp_path.h" #include "test/util/test_util.h" @@ -29,7 +30,15 @@ TEST(ReadaheadTest, InvalidFD) { EXPECT_THAT(readahead(-1, 1, 1), SyscallFailsWithErrno(EBADF)); } +TEST(ReadaheadTest, UnsupportedFile) { + FileDescriptor sock = + ASSERT_NO_ERRNO_AND_VALUE(Socket(AF_UNIX, SOCK_STREAM, 0)); + ASSERT_THAT(readahead(sock.get(), 1, 1), SyscallFailsWithErrno(EINVAL)); +} + TEST(ReadaheadTest, InvalidOffset) { + // This test is not valid for some Linux Kernels. + SKIP_IF(!IsRunningOnGvisor()); const TempPath in_file = ASSERT_NO_ERRNO_AND_VALUE(TempPath::CreateFile()); const FileDescriptor fd = ASSERT_NO_ERRNO_AND_VALUE(Open(in_file.path(), O_RDWR)); @@ -79,6 +88,8 @@ TEST(ReadaheadTest, WriteOnly) { } TEST(ReadaheadTest, InvalidSize) { + // This test is not valid on some Linux kernels. + SKIP_IF(!IsRunningOnGvisor()); const TempPath in_file = ASSERT_NO_ERRNO_AND_VALUE(TempPath::CreateFile()); const FileDescriptor fd = ASSERT_NO_ERRNO_AND_VALUE(Open(in_file.path(), O_RDWR)); diff --git a/test/syscalls/linux/rename.cc b/test/syscalls/linux/rename.cc index 833c0dc4f..5458f54ad 100644 --- a/test/syscalls/linux/rename.cc +++ b/test/syscalls/linux/rename.cc @@ -170,6 +170,9 @@ TEST(RenameTest, FileOverwritesFile) { } TEST(RenameTest, DirectoryOverwritesDirectoryLinkCount) { + // Directory link counts are synthetic on overlay filesystems. + SKIP_IF(ASSERT_NO_ERRNO_AND_VALUE(IsOverlayfs(GetAbsoluteTestTmpdir()))); + auto parent1 = ASSERT_NO_ERRNO_AND_VALUE(TempPath::CreateDir()); EXPECT_THAT(Links(parent1.path()), IsPosixErrorOkAndHolds(2)); diff --git a/test/syscalls/linux/rseq.cc b/test/syscalls/linux/rseq.cc index 4bfb1ff56..94f9154a0 100644 --- a/test/syscalls/linux/rseq.cc +++ b/test/syscalls/linux/rseq.cc @@ -24,6 +24,7 @@ #include "test/syscalls/linux/rseq/uapi.h" #include "test/util/logging.h" #include "test/util/multiprocess_util.h" +#include "test/util/posix_error.h" #include "test/util/test_util.h" namespace gvisor { @@ -31,6 +32,9 @@ namespace testing { namespace { +using ::testing::AnyOf; +using ::testing::Eq; + // Syscall test for rseq (restartable sequences). // // We must be very careful about how these tests are written. Each thread may @@ -98,7 +102,7 @@ void RunChildTest(std::string test_case, int want_status) { int status = 0; ASSERT_THAT(RetryEINTR(waitpid)(child_pid, &status, 0), SyscallSucceeds()); - ASSERT_EQ(status, want_status); + ASSERT_THAT(status, AnyOf(Eq(want_status), Eq(128 + want_status))); } // Test that rseq must be aligned. diff --git a/test/syscalls/linux/rseq/rseq.cc b/test/syscalls/linux/rseq/rseq.cc index f036db26d..6f5d38bba 100644 --- a/test/syscalls/linux/rseq/rseq.cc +++ b/test/syscalls/linux/rseq/rseq.cc @@ -74,84 +74,95 @@ int TestUnaligned() { // Sanity test that registration works. int TestRegister() { struct rseq r = {}; - if (int ret = sys_rseq(&r, sizeof(r), 0, 0); sys_errno(ret) != 0) { + int ret = sys_rseq(&r, sizeof(r), 0, 0); + if (sys_errno(ret) != 0) { return 1; } return 0; -}; +} // Registration can't be done twice. int TestDoubleRegister() { struct rseq r = {}; - if (int ret = sys_rseq(&r, sizeof(r), 0, 0); sys_errno(ret) != 0) { + int ret = sys_rseq(&r, sizeof(r), 0, 0); + if (sys_errno(ret) != 0) { return 1; } - if (int ret = sys_rseq(&r, sizeof(r), 0, 0); sys_errno(ret) != EBUSY) { + ret = sys_rseq(&r, sizeof(r), 0, 0); + if (sys_errno(ret) != EBUSY) { return 1; } return 0; -}; +} // Registration can be done again after unregister. int TestRegisterUnregister() { struct rseq r = {}; - if (int ret = sys_rseq(&r, sizeof(r), 0, 0); sys_errno(ret) != 0) { + + int ret = sys_rseq(&r, sizeof(r), 0, 0); + if (sys_errno(ret) != 0) { return 1; } - if (int ret = sys_rseq(&r, sizeof(r), kRseqFlagUnregister, 0); - sys_errno(ret) != 0) { + ret = sys_rseq(&r, sizeof(r), kRseqFlagUnregister, 0); + if (sys_errno(ret) != 0) { return 1; } - if (int ret = sys_rseq(&r, sizeof(r), 0, 0); sys_errno(ret) != 0) { + ret = sys_rseq(&r, sizeof(r), 0, 0); + if (sys_errno(ret) != 0) { return 1; } return 0; -}; +} // The pointer to rseq must match on register/unregister. int TestUnregisterDifferentPtr() { struct rseq r = {}; - if (int ret = sys_rseq(&r, sizeof(r), 0, 0); sys_errno(ret) != 0) { + + int ret = sys_rseq(&r, sizeof(r), 0, 0); + if (sys_errno(ret) != 0) { return 1; } struct rseq r2 = {}; - if (int ret = sys_rseq(&r2, sizeof(r2), kRseqFlagUnregister, 0); - sys_errno(ret) != EINVAL) { + + ret = sys_rseq(&r2, sizeof(r2), kRseqFlagUnregister, 0); + if (sys_errno(ret) != EINVAL) { return 1; } return 0; -}; +} // The signature must match on register/unregister. int TestUnregisterDifferentSignature() { constexpr int kSignature = 0; struct rseq r = {}; - if (int ret = sys_rseq(&r, sizeof(r), 0, kSignature); sys_errno(ret) != 0) { + int ret = sys_rseq(&r, sizeof(r), 0, kSignature); + if (sys_errno(ret) != 0) { return 1; } - if (int ret = sys_rseq(&r, sizeof(r), kRseqFlagUnregister, kSignature + 1); - sys_errno(ret) != EPERM) { + ret = sys_rseq(&r, sizeof(r), kRseqFlagUnregister, kSignature + 1); + if (sys_errno(ret) != EPERM) { return 1; } return 0; -}; +} // The CPU ID is initialized. int TestCPU() { struct rseq r = {}; r.cpu_id = kRseqCPUIDUninitialized; - if (int ret = sys_rseq(&r, sizeof(r), 0, 0); sys_errno(ret) != 0) { + int ret = sys_rseq(&r, sizeof(r), 0, 0); + if (sys_errno(ret) != 0) { return 1; } @@ -163,13 +174,13 @@ int TestCPU() { } return 0; -}; +} // Critical section is eventually aborted. int TestAbort() { struct rseq r = {}; - if (int ret = sys_rseq(&r, sizeof(r), 0, kRseqSignature); - sys_errno(ret) != 0) { + int ret = sys_rseq(&r, sizeof(r), 0, kRseqSignature); + if (sys_errno(ret) != 0) { return 1; } @@ -185,13 +196,13 @@ int TestAbort() { rseq_loop(&r, &cs); return 0; -}; +} // Abort may be before the critical section. int TestAbortBefore() { struct rseq r = {}; - if (int ret = sys_rseq(&r, sizeof(r), 0, kRseqSignature); - sys_errno(ret) != 0) { + int ret = sys_rseq(&r, sizeof(r), 0, kRseqSignature); + if (sys_errno(ret) != 0) { return 1; } @@ -207,13 +218,13 @@ int TestAbortBefore() { rseq_loop(&r, &cs); return 0; -}; +} // Signature must match. int TestAbortSignature() { struct rseq r = {}; - if (int ret = sys_rseq(&r, sizeof(r), 0, kRseqSignature + 1); - sys_errno(ret) != 0) { + int ret = sys_rseq(&r, sizeof(r), 0, kRseqSignature + 1); + if (sys_errno(ret) != 0) { return 1; } @@ -229,13 +240,13 @@ int TestAbortSignature() { rseq_loop(&r, &cs); return 1; -}; +} // Abort must not be in the critical section. int TestAbortPreCommit() { struct rseq r = {}; - if (int ret = sys_rseq(&r, sizeof(r), 0, kRseqSignature + 1); - sys_errno(ret) != 0) { + int ret = sys_rseq(&r, sizeof(r), 0, kRseqSignature + 1); + if (sys_errno(ret) != 0) { return 1; } @@ -251,13 +262,13 @@ int TestAbortPreCommit() { rseq_loop(&r, &cs); return 1; -}; +} // rseq.rseq_cs is cleared on abort. int TestAbortClearsCS() { struct rseq r = {}; - if (int ret = sys_rseq(&r, sizeof(r), 0, kRseqSignature); - sys_errno(ret) != 0) { + int ret = sys_rseq(&r, sizeof(r), 0, kRseqSignature); + if (sys_errno(ret) != 0) { return 1; } @@ -277,13 +288,13 @@ int TestAbortClearsCS() { } return 0; -}; +} // rseq.rseq_cs is cleared on abort outside of critical section. int TestInvalidAbortClearsCS() { struct rseq r = {}; - if (int ret = sys_rseq(&r, sizeof(r), 0, kRseqSignature); - sys_errno(ret) != 0) { + int ret = sys_rseq(&r, sizeof(r), 0, kRseqSignature); + if (sys_errno(ret) != 0) { return 1; } @@ -306,7 +317,7 @@ int TestInvalidAbortClearsCS() { } return 0; -}; +} // Exit codes: // 0 - Pass diff --git a/test/syscalls/linux/rseq/test.h b/test/syscalls/linux/rseq/test.h index 3b7bb74b1..ff0dd6e48 100644 --- a/test/syscalls/linux/rseq/test.h +++ b/test/syscalls/linux/rseq/test.h @@ -20,22 +20,20 @@ namespace testing { // Test cases supported by rseq binary. -inline constexpr char kRseqTestUnaligned[] = "unaligned"; -inline constexpr char kRseqTestRegister[] = "register"; -inline constexpr char kRseqTestDoubleRegister[] = "double-register"; -inline constexpr char kRseqTestRegisterUnregister[] = "register-unregister"; -inline constexpr char kRseqTestUnregisterDifferentPtr[] = - "unregister-different-ptr"; -inline constexpr char kRseqTestUnregisterDifferentSignature[] = +constexpr char kRseqTestUnaligned[] = "unaligned"; +constexpr char kRseqTestRegister[] = "register"; +constexpr char kRseqTestDoubleRegister[] = "double-register"; +constexpr char kRseqTestRegisterUnregister[] = "register-unregister"; +constexpr char kRseqTestUnregisterDifferentPtr[] = "unregister-different-ptr"; +constexpr char kRseqTestUnregisterDifferentSignature[] = "unregister-different-signature"; -inline constexpr char kRseqTestCPU[] = "cpu"; -inline constexpr char kRseqTestAbort[] = "abort"; -inline constexpr char kRseqTestAbortBefore[] = "abort-before"; -inline constexpr char kRseqTestAbortSignature[] = "abort-signature"; -inline constexpr char kRseqTestAbortPreCommit[] = "abort-precommit"; -inline constexpr char kRseqTestAbortClearsCS[] = "abort-clears-cs"; -inline constexpr char kRseqTestInvalidAbortClearsCS[] = - "invalid-abort-clears-cs"; +constexpr char kRseqTestCPU[] = "cpu"; +constexpr char kRseqTestAbort[] = "abort"; +constexpr char kRseqTestAbortBefore[] = "abort-before"; +constexpr char kRseqTestAbortSignature[] = "abort-signature"; +constexpr char kRseqTestAbortPreCommit[] = "abort-precommit"; +constexpr char kRseqTestAbortClearsCS[] = "abort-clears-cs"; +constexpr char kRseqTestInvalidAbortClearsCS[] = "invalid-abort-clears-cs"; } // namespace testing } // namespace gvisor diff --git a/test/syscalls/linux/sendfile.cc b/test/syscalls/linux/sendfile.cc index 64123e904..a8bfb01f1 100644 --- a/test/syscalls/linux/sendfile.cc +++ b/test/syscalls/linux/sendfile.cc @@ -198,7 +198,39 @@ TEST(SendFileTest, SendAndUpdateFileOffset) { EXPECT_EQ(absl::string_view(kData, kHalfDataSize), absl::string_view(actual, bytes_sent)); - // Verify that the input file offset has been updated + // Verify that the input file offset has been updated. + ASSERT_THAT(read(inf.get(), &actual, kDataSize - bytes_sent), + SyscallSucceedsWithValue(kHalfDataSize)); + EXPECT_EQ( + absl::string_view(kData + kDataSize - bytes_sent, kDataSize - bytes_sent), + absl::string_view(actual, kHalfDataSize)); +} + +TEST(SendFileTest, SendToDevZeroAndUpdateFileOffset) { + // Create temp files. + // Test input string length must be > 2 AND even. + constexpr char kData[] = "The slings and arrows of outrageous fortune,"; + constexpr int kDataSize = sizeof(kData) - 1; + constexpr int kHalfDataSize = kDataSize / 2; + const TempPath in_file = ASSERT_NO_ERRNO_AND_VALUE(TempPath::CreateFileWith( + GetAbsoluteTestTmpdir(), kData, TempPath::kDefaultFileMode)); + + // Open the input file as read only. + const FileDescriptor inf = + ASSERT_NO_ERRNO_AND_VALUE(Open(in_file.path(), O_RDONLY)); + + // Open /dev/zero as write only. + const FileDescriptor outf = + ASSERT_NO_ERRNO_AND_VALUE(Open("/dev/zero", O_WRONLY)); + + // Send data and verify that sendfile returns the correct value. + int bytes_sent; + EXPECT_THAT( + bytes_sent = sendfile(outf.get(), inf.get(), nullptr, kHalfDataSize), + SyscallSucceedsWithValue(kHalfDataSize)); + + char actual[kHalfDataSize]; + // Verify that the input file offset has been updated. ASSERT_THAT(read(inf.get(), &actual, kDataSize - bytes_sent), SyscallSucceedsWithValue(kHalfDataSize)); EXPECT_EQ( @@ -250,7 +282,7 @@ TEST(SendFileTest, SendAndUpdateFileOffsetFromNonzeroStartingPoint) { EXPECT_EQ(absl::string_view(kData + kQuarterDataSize, kHalfDataSize), absl::string_view(actual, bytes_sent)); - // Verify that the input file offset has been updated + // Verify that the input file offset has been updated. ASSERT_THAT(read(inf.get(), &actual, kQuarterDataSize), SyscallSucceedsWithValue(kQuarterDataSize)); @@ -501,6 +533,22 @@ TEST(SendFileTest, SendPipeWouldBlock) { SyscallFailsWithErrno(EWOULDBLOCK)); } +TEST(SendFileTest, SendPipeEOF) { + // Create and open an empty input file. + const TempPath in_file = ASSERT_NO_ERRNO_AND_VALUE(TempPath::CreateFile()); + const FileDescriptor inf = + ASSERT_NO_ERRNO_AND_VALUE(Open(in_file.path(), O_RDONLY)); + + // Setup the output named pipe. + int fds[2]; + ASSERT_THAT(pipe2(fds, O_NONBLOCK), SyscallSucceeds()); + const FileDescriptor rfd(fds[0]); + const FileDescriptor wfd(fds[1]); + + EXPECT_THAT(sendfile(wfd.get(), inf.get(), nullptr, 123), + SyscallSucceedsWithValue(0)); +} + TEST(SendFileTest, SendPipeBlocks) { // Create temp file. constexpr char kData[] = diff --git a/test/syscalls/linux/shm.cc b/test/syscalls/linux/shm.cc index c7fdbb924..d6e8b3e59 100644 --- a/test/syscalls/linux/shm.cc +++ b/test/syscalls/linux/shm.cc @@ -29,6 +29,8 @@ namespace testing { namespace { using ::testing::_; +using ::testing::AnyOf; +using ::testing::Eq; const uint64_t kAllocSize = kPageSize * 128ULL; @@ -394,7 +396,8 @@ TEST(ShmDeathTest, SegmentNotAccessibleAfterDetach) { }; EXPECT_THAT(InForkedProcess(rest), - IsPosixErrorOkAndHolds(W_EXITCODE(0, SIGSEGV))); + IsPosixErrorOkAndHolds(AnyOf(Eq(W_EXITCODE(0, SIGSEGV)), + Eq(W_EXITCODE(0, 128 + SIGSEGV))))); } TEST(ShmTest, RequestingSegmentSmallerThanSHMMINFails) { diff --git a/test/syscalls/linux/socket.cc b/test/syscalls/linux/socket.cc index c20cd3fcc..e680d3dd7 100644 --- a/test/syscalls/linux/socket.cc +++ b/test/syscalls/linux/socket.cc @@ -14,6 +14,7 @@ #include <sys/socket.h> #include <sys/stat.h> +#include <sys/statfs.h> #include <sys/types.h> #include <unistd.h> @@ -26,6 +27,9 @@ namespace gvisor { namespace testing { +// From linux/magic.h, but we can't depend on linux headers here. +#define SOCKFS_MAGIC 0x534F434B + TEST(SocketTest, UnixSocketPairProtocol) { int socks[2]; ASSERT_THAT(socketpair(AF_UNIX, SOCK_STREAM, PF_UNIX, socks), @@ -94,6 +98,19 @@ TEST(SocketTest, UnixSocketStat) { } } +TEST(SocketTest, UnixSocketStatFS) { + SKIP_IF(IsRunningWithVFS1()); + + FileDescriptor bound = + ASSERT_NO_ERRNO_AND_VALUE(Socket(AF_UNIX, SOCK_STREAM, PF_UNIX)); + + struct statfs st; + EXPECT_THAT(fstatfs(bound.get(), &st), SyscallSucceeds()); + EXPECT_EQ(st.f_type, SOCKFS_MAGIC); + EXPECT_EQ(st.f_bsize, getpagesize()); + EXPECT_EQ(st.f_namelen, NAME_MAX); +} + using SocketOpenTest = ::testing::TestWithParam<int>; // UDS cannot be opened. diff --git a/test/syscalls/linux/socket_generic.cc b/test/syscalls/linux/socket_generic.cc index a6182f0ac..5d39e6fbd 100644 --- a/test/syscalls/linux/socket_generic.cc +++ b/test/syscalls/linux/socket_generic.cc @@ -463,7 +463,7 @@ TEST_P(AllSocketPairTest, SetGetSendTimeout) { auto sockets = ASSERT_NO_ERRNO_AND_VALUE(NewSocketPair()); // tv_usec should be a multiple of 4000 to work on most systems. - timeval tv = {.tv_sec = 89, .tv_usec = 42000}; + timeval tv = {.tv_sec = 89, .tv_usec = 44000}; EXPECT_THAT( setsockopt(sockets->first_fd(), SOL_SOCKET, SO_SNDTIMEO, &tv, sizeof(tv)), SyscallSucceeds()); diff --git a/test/syscalls/linux/socket_generic_stress.cc b/test/syscalls/linux/socket_generic_stress.cc index 19239e9e9..6cd67123d 100644 --- a/test/syscalls/linux/socket_generic_stress.cc +++ b/test/syscalls/linux/socket_generic_stress.cc @@ -30,6 +30,9 @@ namespace testing { using ConnectStressTest = SocketPairTest; TEST_P(ConnectStressTest, Reset65kTimes) { + // TODO(b/165912341): These are too slow on KVM platform with nested virt. + SKIP_IF(GvisorPlatform() == Platform::kKVM); + for (int i = 0; i < 1 << 16; ++i) { auto sockets = ASSERT_NO_ERRNO_AND_VALUE(NewSocketPair()); @@ -68,6 +71,9 @@ INSTANTIATE_TEST_SUITE_P( using PersistentListenerConnectStressTest = SocketPairTest; TEST_P(PersistentListenerConnectStressTest, 65kTimesShutdownCloseFirst) { + // TODO(b/165912341): These are too slow on KVM platform with nested virt. + SKIP_IF(GvisorPlatform() == Platform::kKVM); + for (int i = 0; i < 1 << 16; ++i) { auto sockets = ASSERT_NO_ERRNO_AND_VALUE(NewSocketPair()); ASSERT_THAT(shutdown(sockets->first_fd(), SHUT_RDWR), SyscallSucceeds()); @@ -87,6 +93,9 @@ TEST_P(PersistentListenerConnectStressTest, 65kTimesShutdownCloseFirst) { } TEST_P(PersistentListenerConnectStressTest, 65kTimesShutdownCloseSecond) { + // TODO(b/165912341): These are too slow on KVM platform with nested virt. + SKIP_IF(GvisorPlatform() == Platform::kKVM); + for (int i = 0; i < 1 << 16; ++i) { auto sockets = ASSERT_NO_ERRNO_AND_VALUE(NewSocketPair()); ASSERT_THAT(shutdown(sockets->second_fd(), SHUT_RDWR), SyscallSucceeds()); @@ -106,6 +115,9 @@ TEST_P(PersistentListenerConnectStressTest, 65kTimesShutdownCloseSecond) { } TEST_P(PersistentListenerConnectStressTest, 65kTimesClose) { + // TODO(b/165912341): These are too slow on KVM platform with nested virt. + SKIP_IF(GvisorPlatform() == Platform::kKVM); + for (int i = 0; i < 1 << 16; ++i) { auto sockets = ASSERT_NO_ERRNO_AND_VALUE(NewSocketPair()); } diff --git a/test/syscalls/linux/socket_inet_loopback.cc b/test/syscalls/linux/socket_inet_loopback.cc index c3b42682f..11fcec443 100644 --- a/test/syscalls/linux/socket_inet_loopback.cc +++ b/test/syscalls/linux/socket_inet_loopback.cc @@ -97,11 +97,9 @@ TEST(BadSocketPairArgs, ValidateErrForBadCallsToSocketPair) { ASSERT_THAT(socketpair(AF_INET6, 0, 0, fd), SyscallFailsWithErrno(ESOCKTNOSUPPORT)); - // Invalid AF will return ENOAFSUPPORT. - ASSERT_THAT(socketpair(AF_MAX, 0, 0, fd), - SyscallFailsWithErrno(EAFNOSUPPORT)); - ASSERT_THAT(socketpair(8675309, 0, 0, fd), - SyscallFailsWithErrno(EAFNOSUPPORT)); + // Invalid AF will fail. + ASSERT_THAT(socketpair(AF_MAX, 0, 0, fd), SyscallFails()); + ASSERT_THAT(socketpair(8675309, 0, 0, fd), SyscallFails()); } enum class Operation { @@ -116,7 +114,8 @@ std::string OperationToString(Operation operation) { return "Bind"; case Operation::Connect: return "Connect"; - case Operation::SendTo: + // Operation::SendTo is the default. + default: return "SendTo"; } } @@ -861,36 +860,38 @@ TEST_P(SocketInetLoopbackTest, TCPResetAfterClose) { SyscallSucceedsWithValue(0)); } -// This test is disabled under random save as the the restore run -// results in the stack.Seed() being different which can cause -// sequence number of final connect to be one that is considered -// old and can cause the test to be flaky. -TEST_P(SocketInetLoopbackTest, TCPPassiveCloseNoTimeWaitTest_NoRandomSave) { - auto const& param = GetParam(); - TestAddress const& listener = param.listener; - TestAddress const& connector = param.connector; - +// setupTimeWaitClose sets up a socket endpoint in TIME_WAIT state. +// Callers can choose to perform active close on either ends of the connection +// and also specify if they want to enabled SO_REUSEADDR. +void setupTimeWaitClose(const TestAddress* listener, + const TestAddress* connector, bool reuse, + bool accept_close, sockaddr_storage* listen_addr, + sockaddr_storage* conn_bound_addr) { // Create the listening socket. - const FileDescriptor listen_fd = ASSERT_NO_ERRNO_AND_VALUE( - Socket(listener.family(), SOCK_STREAM, IPPROTO_TCP)); - sockaddr_storage listen_addr = listener.addr; - ASSERT_THAT(bind(listen_fd.get(), reinterpret_cast<sockaddr*>(&listen_addr), - listener.addr_len), + FileDescriptor listen_fd = ASSERT_NO_ERRNO_AND_VALUE( + Socket(listener->family(), SOCK_STREAM, IPPROTO_TCP)); + if (reuse) { + ASSERT_THAT(setsockopt(listen_fd.get(), SOL_SOCKET, SO_REUSEADDR, + &kSockOptOn, sizeof(kSockOptOn)), + SyscallSucceeds()); + } + ASSERT_THAT(bind(listen_fd.get(), reinterpret_cast<sockaddr*>(listen_addr), + listener->addr_len), SyscallSucceeds()); ASSERT_THAT(listen(listen_fd.get(), SOMAXCONN), SyscallSucceeds()); // Get the port bound by the listening socket. - socklen_t addrlen = listener.addr_len; + socklen_t addrlen = listener->addr_len; ASSERT_THAT(getsockname(listen_fd.get(), - reinterpret_cast<sockaddr*>(&listen_addr), &addrlen), + reinterpret_cast<sockaddr*>(listen_addr), &addrlen), SyscallSucceeds()); uint16_t const port = - ASSERT_NO_ERRNO_AND_VALUE(AddrPort(listener.family(), listen_addr)); + ASSERT_NO_ERRNO_AND_VALUE(AddrPort(listener->family(), *listen_addr)); // Connect to the listening socket. FileDescriptor conn_fd = ASSERT_NO_ERRNO_AND_VALUE( - Socket(connector.family(), SOCK_STREAM, IPPROTO_TCP)); + Socket(connector->family(), SOCK_STREAM, IPPROTO_TCP)); // We disable saves after this point as a S/R causes the netstack seed // to be regenerated which changes what ports/ISN is picked for a given @@ -901,11 +902,12 @@ TEST_P(SocketInetLoopbackTest, TCPPassiveCloseNoTimeWaitTest_NoRandomSave) { // // TODO(gvisor.dev/issue/940): S/R portSeed/portHint DisableSave ds; - sockaddr_storage conn_addr = connector.addr; - ASSERT_NO_ERRNO(SetAddrPort(connector.family(), &conn_addr, port)); + + sockaddr_storage conn_addr = connector->addr; + ASSERT_NO_ERRNO(SetAddrPort(connector->family(), &conn_addr, port)); ASSERT_THAT(RetryEINTR(connect)(conn_fd.get(), reinterpret_cast<sockaddr*>(&conn_addr), - connector.addr_len), + connector->addr_len), SyscallSucceeds()); // Accept the connection. @@ -913,50 +915,150 @@ TEST_P(SocketInetLoopbackTest, TCPPassiveCloseNoTimeWaitTest_NoRandomSave) { ASSERT_NO_ERRNO_AND_VALUE(Accept(listen_fd.get(), nullptr, nullptr)); // Get the address/port bound by the connecting socket. - sockaddr_storage conn_bound_addr; - socklen_t conn_addrlen = connector.addr_len; + socklen_t conn_addrlen = connector->addr_len; ASSERT_THAT( - getsockname(conn_fd.get(), reinterpret_cast<sockaddr*>(&conn_bound_addr), + getsockname(conn_fd.get(), reinterpret_cast<sockaddr*>(conn_bound_addr), &conn_addrlen), SyscallSucceeds()); - // shutdown the accept FD to trigger TIME_WAIT on the accepted socket which - // should cause the conn_fd to follow CLOSE_WAIT->LAST_ACK->CLOSED instead of - // TIME_WAIT. - ASSERT_THAT(shutdown(accepted.get(), SHUT_RDWR), SyscallSucceeds()); + FileDescriptor active_closefd, passive_closefd; + if (accept_close) { + active_closefd = std::move(accepted); + passive_closefd = std::move(conn_fd); + } else { + active_closefd = std::move(conn_fd); + passive_closefd = std::move(accepted); + } + + // shutdown to trigger TIME_WAIT. + ASSERT_THAT(shutdown(active_closefd.get(), SHUT_RDWR), SyscallSucceeds()); { const int kTimeout = 10000; struct pollfd pfd = { - .fd = conn_fd.get(), + .fd = passive_closefd.get(), .events = POLLIN, }; ASSERT_THAT(poll(&pfd, 1, kTimeout), SyscallSucceedsWithValue(1)); ASSERT_EQ(pfd.revents, POLLIN); } + ScopedThread t([&]() { + constexpr int kTimeout = 10000; + constexpr int16_t want_events = POLLHUP; + struct pollfd pfd = { + .fd = active_closefd.get(), + .events = want_events, + }; + ASSERT_THAT(poll(&pfd, 1, kTimeout), SyscallSucceedsWithValue(1)); + }); - conn_fd.reset(); - // This sleep is required to give conn_fd time to transition to TIME-WAIT. + passive_closefd.reset(); + t.Join(); + active_closefd.reset(); + // This sleep is needed to reduce flake to ensure that the passive-close + // ensures the state transitions to CLOSE from LAST_ACK. absl::SleepFor(absl::Seconds(1)); +} - // At this point conn_fd should be the one that moved to CLOSE_WAIT and - // eventually to CLOSED. +// These tests are disabled under random save as the the restore run +// results in the stack.Seed() being different which can cause +// sequence number of final connect to be one that is considered +// old and can cause the test to be flaky. +// +// Test re-binding of client and server bound addresses when the older +// connection is in TIME_WAIT. +TEST_P(SocketInetLoopbackTest, TCPPassiveCloseNoTimeWaitTest_NoRandomSave) { + auto const& param = GetParam(); + sockaddr_storage listen_addr, conn_bound_addr; + listen_addr = param.listener.addr; + setupTimeWaitClose(¶m.listener, ¶m.connector, false /*reuse*/, + true /*accept_close*/, &listen_addr, &conn_bound_addr); - // Now bind and connect a new socket and verify that we can immediately - // rebind the address bound by the conn_fd as it never entered TIME_WAIT. - const FileDescriptor conn_fd2 = ASSERT_NO_ERRNO_AND_VALUE( - Socket(connector.family(), SOCK_STREAM, IPPROTO_TCP)); + // Now bind a new socket and verify that we can immediately rebind the address + // bound by the conn_fd as it never entered TIME_WAIT. + const FileDescriptor conn_fd = ASSERT_NO_ERRNO_AND_VALUE( + Socket(param.connector.family(), SOCK_STREAM, IPPROTO_TCP)); + ASSERT_THAT(bind(conn_fd.get(), reinterpret_cast<sockaddr*>(&conn_bound_addr), + param.connector.addr_len), + SyscallSucceeds()); - ASSERT_THAT(bind(conn_fd2.get(), - reinterpret_cast<sockaddr*>(&conn_bound_addr), conn_addrlen), + FileDescriptor listen_fd = ASSERT_NO_ERRNO_AND_VALUE( + Socket(param.listener.family(), SOCK_STREAM, IPPROTO_TCP)); + ASSERT_THAT(bind(listen_fd.get(), reinterpret_cast<sockaddr*>(&listen_addr), + param.listener.addr_len), + SyscallFailsWithErrno(EADDRINUSE)); +} + +TEST_P(SocketInetLoopbackTest, + TCPPassiveCloseNoTimeWaitReuseTest_NoRandomSave) { + auto const& param = GetParam(); + sockaddr_storage listen_addr, conn_bound_addr; + listen_addr = param.listener.addr; + setupTimeWaitClose(¶m.listener, ¶m.connector, true /*reuse*/, + true /*accept_close*/, &listen_addr, &conn_bound_addr); + + FileDescriptor listen_fd = ASSERT_NO_ERRNO_AND_VALUE( + Socket(param.listener.family(), SOCK_STREAM, IPPROTO_TCP)); + ASSERT_THAT(setsockopt(listen_fd.get(), SOL_SOCKET, SO_REUSEADDR, &kSockOptOn, + sizeof(kSockOptOn)), SyscallSucceeds()); - ASSERT_THAT(RetryEINTR(connect)(conn_fd2.get(), + ASSERT_THAT(bind(listen_fd.get(), reinterpret_cast<sockaddr*>(&listen_addr), + param.listener.addr_len), + SyscallSucceeds()); + ASSERT_THAT(listen(listen_fd.get(), SOMAXCONN), SyscallSucceeds()); + + // Now bind and connect new socket and verify that we can immediately rebind + // the address bound by the conn_fd as it never entered TIME_WAIT. + const FileDescriptor conn_fd = ASSERT_NO_ERRNO_AND_VALUE( + Socket(param.connector.family(), SOCK_STREAM, IPPROTO_TCP)); + ASSERT_THAT(setsockopt(conn_fd.get(), SOL_SOCKET, SO_REUSEADDR, &kSockOptOn, + sizeof(kSockOptOn)), + SyscallSucceeds()); + ASSERT_THAT(bind(conn_fd.get(), reinterpret_cast<sockaddr*>(&conn_bound_addr), + param.connector.addr_len), + SyscallSucceeds()); + + uint16_t const port = + ASSERT_NO_ERRNO_AND_VALUE(AddrPort(param.listener.family(), listen_addr)); + sockaddr_storage conn_addr = param.connector.addr; + ASSERT_NO_ERRNO(SetAddrPort(param.connector.family(), &conn_addr, port)); + ASSERT_THAT(RetryEINTR(connect)(conn_fd.get(), reinterpret_cast<sockaddr*>(&conn_addr), - conn_addrlen), + param.connector.addr_len), SyscallSucceeds()); } TEST_P(SocketInetLoopbackTest, TCPActiveCloseTimeWaitTest_NoRandomSave) { auto const& param = GetParam(); + sockaddr_storage listen_addr, conn_bound_addr; + listen_addr = param.listener.addr; + setupTimeWaitClose(¶m.listener, ¶m.connector, false /*reuse*/, + false /*accept_close*/, &listen_addr, &conn_bound_addr); + FileDescriptor conn_fd = ASSERT_NO_ERRNO_AND_VALUE( + Socket(param.connector.family(), SOCK_STREAM, IPPROTO_TCP)); + + ASSERT_THAT(bind(conn_fd.get(), reinterpret_cast<sockaddr*>(&conn_bound_addr), + param.connector.addr_len), + SyscallFailsWithErrno(EADDRINUSE)); +} + +TEST_P(SocketInetLoopbackTest, TCPActiveCloseTimeWaitReuseTest_NoRandomSave) { + auto const& param = GetParam(); + sockaddr_storage listen_addr, conn_bound_addr; + listen_addr = param.listener.addr; + setupTimeWaitClose(¶m.listener, ¶m.connector, true /*reuse*/, + false /*accept_close*/, &listen_addr, &conn_bound_addr); + FileDescriptor conn_fd = ASSERT_NO_ERRNO_AND_VALUE( + Socket(param.connector.family(), SOCK_STREAM, IPPROTO_TCP)); + ASSERT_THAT(setsockopt(conn_fd.get(), SOL_SOCKET, SO_REUSEADDR, &kSockOptOn, + sizeof(kSockOptOn)), + SyscallSucceeds()); + ASSERT_THAT(bind(conn_fd.get(), reinterpret_cast<sockaddr*>(&conn_bound_addr), + param.connector.addr_len), + SyscallFailsWithErrno(EADDRINUSE)); +} + +TEST_P(SocketInetLoopbackTest, AcceptedInheritsTCPUserTimeout) { + auto const& param = GetParam(); TestAddress const& listener = param.listener; TestAddress const& connector = param.connector; @@ -975,23 +1077,19 @@ TEST_P(SocketInetLoopbackTest, TCPActiveCloseTimeWaitTest_NoRandomSave) { reinterpret_cast<sockaddr*>(&listen_addr), &addrlen), SyscallSucceeds()); - uint16_t const port = + const uint16_t port = ASSERT_NO_ERRNO_AND_VALUE(AddrPort(listener.family(), listen_addr)); + // Set the userTimeout on the listening socket. + constexpr int kUserTimeout = 10; + ASSERT_THAT(setsockopt(listen_fd.get(), IPPROTO_TCP, TCP_USER_TIMEOUT, + &kUserTimeout, sizeof(kUserTimeout)), + SyscallSucceeds()); + // Connect to the listening socket. FileDescriptor conn_fd = ASSERT_NO_ERRNO_AND_VALUE( Socket(connector.family(), SOCK_STREAM, IPPROTO_TCP)); - // We disable saves after this point as a S/R causes the netstack seed - // to be regenerated which changes what ports/ISN is picked for a given - // tuple (src ip,src port, dst ip, dst port). This can cause the final - // SYN to use a sequence number that looks like one from the current - // connection in TIME_WAIT and will not be accepted causing the test - // to timeout. - // - // TODO(gvisor.dev/issue/940): S/R portSeed/portHint - DisableSave ds; - sockaddr_storage conn_addr = connector.addr; ASSERT_NO_ERRNO(SetAddrPort(connector.family(), &conn_addr, port)); ASSERT_THAT(RetryEINTR(connect)(conn_fd.get(), @@ -1002,51 +1100,18 @@ TEST_P(SocketInetLoopbackTest, TCPActiveCloseTimeWaitTest_NoRandomSave) { // Accept the connection. auto accepted = ASSERT_NO_ERRNO_AND_VALUE(Accept(listen_fd.get(), nullptr, nullptr)); - - // Get the address/port bound by the connecting socket. - sockaddr_storage conn_bound_addr; - socklen_t conn_addrlen = connector.addr_len; + // Verify that the accepted socket inherited the user timeout set on + // listening socket. + int get = -1; + socklen_t get_len = sizeof(get); ASSERT_THAT( - getsockname(conn_fd.get(), reinterpret_cast<sockaddr*>(&conn_bound_addr), - &conn_addrlen), + getsockopt(accepted.get(), IPPROTO_TCP, TCP_USER_TIMEOUT, &get, &get_len), SyscallSucceeds()); - - // shutdown the conn FD to trigger TIME_WAIT on the connect socket. - ASSERT_THAT(shutdown(conn_fd.get(), SHUT_RDWR), SyscallSucceeds()); - { - const int kTimeout = 10000; - struct pollfd pfd = { - .fd = accepted.get(), - .events = POLLIN, - }; - ASSERT_THAT(poll(&pfd, 1, kTimeout), SyscallSucceedsWithValue(1)); - ASSERT_EQ(pfd.revents, POLLIN); - } - ScopedThread t([&]() { - constexpr int kTimeout = 10000; - constexpr int16_t want_events = POLLHUP; - struct pollfd pfd = { - .fd = conn_fd.get(), - .events = want_events, - }; - ASSERT_THAT(poll(&pfd, 1, kTimeout), SyscallSucceedsWithValue(1)); - }); - - accepted.reset(); - t.Join(); - conn_fd.reset(); - - // Now bind and connect a new socket and verify that we can't immediately - // rebind the address bound by the conn_fd as it is in TIME_WAIT. - conn_fd = ASSERT_NO_ERRNO_AND_VALUE( - Socket(connector.family(), SOCK_STREAM, IPPROTO_TCP)); - - ASSERT_THAT(bind(conn_fd.get(), reinterpret_cast<sockaddr*>(&conn_bound_addr), - conn_addrlen), - SyscallFailsWithErrno(EADDRINUSE)); + EXPECT_EQ(get_len, sizeof(get)); + EXPECT_EQ(get, kUserTimeout); } -TEST_P(SocketInetLoopbackTest, AcceptedInheritsTCPUserTimeout) { +TEST_P(SocketInetLoopbackTest, TCPAcceptAfterReset) { auto const& param = GetParam(); TestAddress const& listener = param.listener; TestAddress const& connector = param.connector; @@ -1061,20 +1126,17 @@ TEST_P(SocketInetLoopbackTest, AcceptedInheritsTCPUserTimeout) { ASSERT_THAT(listen(listen_fd.get(), SOMAXCONN), SyscallSucceeds()); // Get the port bound by the listening socket. - socklen_t addrlen = listener.addr_len; - ASSERT_THAT(getsockname(listen_fd.get(), - reinterpret_cast<sockaddr*>(&listen_addr), &addrlen), - SyscallSucceeds()); + { + socklen_t addrlen = listener.addr_len; + ASSERT_THAT( + getsockname(listen_fd.get(), reinterpret_cast<sockaddr*>(&listen_addr), + &addrlen), + SyscallSucceeds()); + } const uint16_t port = ASSERT_NO_ERRNO_AND_VALUE(AddrPort(listener.family(), listen_addr)); - // Set the userTimeout on the listening socket. - constexpr int kUserTimeout = 10; - ASSERT_THAT(setsockopt(listen_fd.get(), IPPROTO_TCP, TCP_USER_TIMEOUT, - &kUserTimeout, sizeof(kUserTimeout)), - SyscallSucceeds()); - // Connect to the listening socket. FileDescriptor conn_fd = ASSERT_NO_ERRNO_AND_VALUE( Socket(connector.family(), SOCK_STREAM, IPPROTO_TCP)); @@ -1086,18 +1148,47 @@ TEST_P(SocketInetLoopbackTest, AcceptedInheritsTCPUserTimeout) { connector.addr_len), SyscallSucceeds()); - // Accept the connection. - auto accepted = - ASSERT_NO_ERRNO_AND_VALUE(Accept(listen_fd.get(), nullptr, nullptr)); - // Verify that the accepted socket inherited the user timeout set on - // listening socket. - int get = -1; - socklen_t get_len = sizeof(get); + // Trigger a RST by turning linger off and closing the socket. + struct linger opt = { + .l_onoff = 1, + .l_linger = 0, + }; ASSERT_THAT( - getsockopt(accepted.get(), IPPROTO_TCP, TCP_USER_TIMEOUT, &get, &get_len), + setsockopt(conn_fd.get(), SOL_SOCKET, SO_LINGER, &opt, sizeof(opt)), SyscallSucceeds()); - EXPECT_EQ(get_len, sizeof(get)); - EXPECT_EQ(get, kUserTimeout); + ASSERT_THAT(close(conn_fd.release()), SyscallSucceeds()); + + if (IsRunningOnGvisor()) { + // Gvisor packet procssing is asynchronous and can take a bit of time in + // some cases so we give it a bit of time to process the RST packet before + // calling accept. + // + // There is nothing to poll() on so we have no choice but to use a sleep + // here. + absl::SleepFor(absl::Milliseconds(100)); + } + + sockaddr_storage accept_addr; + socklen_t addrlen = sizeof(accept_addr); + + auto accept_fd = ASSERT_NO_ERRNO_AND_VALUE(Accept( + listen_fd.get(), reinterpret_cast<sockaddr*>(&accept_addr), &addrlen)); + ASSERT_EQ(addrlen, listener.addr_len); + + // TODO(gvisor.dev/issue/3812): Remove after SO_ERROR is fixed. + if (IsRunningOnGvisor()) { + char buf[10]; + ASSERT_THAT(ReadFd(accept_fd.get(), buf, sizeof(buf)), + SyscallFailsWithErrno(ECONNRESET)); + } else { + int err; + socklen_t optlen = sizeof(err); + ASSERT_THAT( + getsockopt(accept_fd.get(), SOL_SOCKET, SO_ERROR, &err, &optlen), + SyscallSucceeds()); + ASSERT_EQ(err, ECONNRESET); + ASSERT_EQ(optlen, sizeof(err)); + } } // TODO(gvisor.dev/issue/1688): Partially completed passive endpoints are not @@ -2573,6 +2664,44 @@ TEST_P(SocketMultiProtocolInetLoopbackTest, V4EphemeralPortReservedReuseAddr) { SyscallSucceeds()); } +TEST_P(SocketMultiProtocolInetLoopbackTest, + MultipleBindsAllowedNoListeningReuseAddr) { + const auto& param = GetParam(); + // UDP sockets are allowed to bind/listen on the port w/ SO_REUSEADDR, for TCP + // this is only permitted if there is no other listening socket. + SKIP_IF(param.type != SOCK_STREAM); + // Bind the v4 loopback on a v4 socket. + const TestAddress& test_addr = V4Loopback(); + sockaddr_storage bound_addr = test_addr.addr; + FileDescriptor bound_fd = + ASSERT_NO_ERRNO_AND_VALUE(Socket(test_addr.family(), param.type, 0)); + + ASSERT_THAT(setsockopt(bound_fd.get(), SOL_SOCKET, SO_REUSEADDR, &kSockOptOn, + sizeof(kSockOptOn)), + SyscallSucceeds()); + ASSERT_THAT(bind(bound_fd.get(), reinterpret_cast<sockaddr*>(&bound_addr), + test_addr.addr_len), + SyscallSucceeds()); + // Get the port that we bound. + socklen_t bound_addr_len = test_addr.addr_len; + ASSERT_THAT( + getsockname(bound_fd.get(), reinterpret_cast<sockaddr*>(&bound_addr), + &bound_addr_len), + SyscallSucceeds()); + + // Now create a socket and bind it to the same port, this should + // succeed since there is no listening socket for the same port. + FileDescriptor second_fd = + ASSERT_NO_ERRNO_AND_VALUE(Socket(test_addr.family(), param.type, 0)); + + ASSERT_THAT(setsockopt(second_fd.get(), SOL_SOCKET, SO_REUSEADDR, &kSockOptOn, + sizeof(kSockOptOn)), + SyscallSucceeds()); + ASSERT_THAT(bind(second_fd.get(), reinterpret_cast<sockaddr*>(&bound_addr), + test_addr.addr_len), + SyscallSucceeds()); +} + TEST_P(SocketMultiProtocolInetLoopbackTest, PortReuseTwoSockets) { auto const& param = GetParam(); TestAddress const& test_addr = V4Loopback(); diff --git a/test/syscalls/linux/socket_inet_loopback_nogotsan.cc b/test/syscalls/linux/socket_inet_loopback_nogotsan.cc index 791e2bd51..1a0b53394 100644 --- a/test/syscalls/linux/socket_inet_loopback_nogotsan.cc +++ b/test/syscalls/linux/socket_inet_loopback_nogotsan.cc @@ -168,6 +168,71 @@ INSTANTIATE_TEST_SUITE_P( TestParam{V6Loopback(), V6Loopback()}), DescribeTestParam); +struct ProtocolTestParam { + std::string description; + int type; +}; + +std::string DescribeProtocolTestParam( + ::testing::TestParamInfo<ProtocolTestParam> const& info) { + return info.param.description; +} + +using SocketMultiProtocolInetLoopbackTest = + ::testing::TestWithParam<ProtocolTestParam>; + +TEST_P(SocketMultiProtocolInetLoopbackTest, + BindAvoidsListeningPortsReuseAddr_NoRandomSave) { + const auto& param = GetParam(); + // UDP sockets are allowed to bind/listen on the port w/ SO_REUSEADDR, for TCP + // this is only permitted if there is no other listening socket. + SKIP_IF(param.type != SOCK_STREAM); + + DisableSave ds; // Too many syscalls. + + // A map of port to file descriptor binding the port. + std::map<uint16_t, FileDescriptor> listen_sockets; + + // Exhaust all ephemeral ports. + while (true) { + // Bind the v4 loopback on a v4 socket. + TestAddress const& test_addr = V4Loopback(); + sockaddr_storage bound_addr = test_addr.addr; + FileDescriptor bound_fd = + ASSERT_NO_ERRNO_AND_VALUE(Socket(test_addr.family(), param.type, 0)); + + ASSERT_THAT(setsockopt(bound_fd.get(), SOL_SOCKET, SO_REUSEADDR, + &kSockOptOn, sizeof(kSockOptOn)), + SyscallSucceeds()); + + int ret = bind(bound_fd.get(), reinterpret_cast<sockaddr*>(&bound_addr), + test_addr.addr_len); + if (ret != 0) { + ASSERT_EQ(errno, EADDRINUSE); + break; + } + // Get the port that we bound. + socklen_t bound_addr_len = test_addr.addr_len; + ASSERT_THAT( + getsockname(bound_fd.get(), reinterpret_cast<sockaddr*>(&bound_addr), + &bound_addr_len), + SyscallSucceeds()); + uint16_t port = reinterpret_cast<sockaddr_in*>(&bound_addr)->sin_port; + + // Newly bound port should not already be in use by a listening socket. + ASSERT_EQ(listen_sockets.find(port), listen_sockets.end()); + auto fd = bound_fd.get(); + listen_sockets.insert(std::make_pair(port, std::move(bound_fd))); + ASSERT_THAT(listen(fd, SOMAXCONN), SyscallSucceeds()); + } +} + +INSTANTIATE_TEST_SUITE_P( + AllFamilies, SocketMultiProtocolInetLoopbackTest, + ::testing::Values(ProtocolTestParam{"TCP", SOCK_STREAM}, + ProtocolTestParam{"UDP", SOCK_DGRAM}), + DescribeProtocolTestParam); + } // namespace } // namespace testing diff --git a/test/syscalls/linux/socket_ip_tcp_generic.cc b/test/syscalls/linux/socket_ip_tcp_generic.cc index 53c076787..f4b69c46c 100644 --- a/test/syscalls/linux/socket_ip_tcp_generic.cc +++ b/test/syscalls/linux/socket_ip_tcp_generic.cc @@ -819,18 +819,37 @@ TEST_P(TCPSocketPairTest, TCPLingerTimeoutDefault) { EXPECT_EQ(get, kDefaultTCPLingerTimeout); } -TEST_P(TCPSocketPairTest, SetTCPLingerTimeoutZeroOrLess) { +TEST_P(TCPSocketPairTest, SetTCPLingerTimeoutLessThanZero) { auto sockets = ASSERT_NO_ERRNO_AND_VALUE(NewSocketPair()); - constexpr int kZero = 0; - EXPECT_THAT(setsockopt(sockets->first_fd(), IPPROTO_TCP, TCP_LINGER2, &kZero, - sizeof(kZero)), - SyscallSucceedsWithValue(0)); - constexpr int kNegative = -1234; EXPECT_THAT(setsockopt(sockets->first_fd(), IPPROTO_TCP, TCP_LINGER2, &kNegative, sizeof(kNegative)), SyscallSucceedsWithValue(0)); + int get = INT_MAX; + socklen_t get_len = sizeof(get); + EXPECT_THAT( + getsockopt(sockets->first_fd(), IPPROTO_TCP, TCP_LINGER2, &get, &get_len), + SyscallSucceedsWithValue(0)); + EXPECT_EQ(get_len, sizeof(get)); + EXPECT_EQ(get, -1); +} + +TEST_P(TCPSocketPairTest, SetTCPLingerTimeoutZero) { + auto sockets = ASSERT_NO_ERRNO_AND_VALUE(NewSocketPair()); + + constexpr int kZero = 0; + EXPECT_THAT(setsockopt(sockets->first_fd(), IPPROTO_TCP, TCP_LINGER2, &kZero, + sizeof(kZero)), + SyscallSucceedsWithValue(0)); + int get = -1; + socklen_t get_len = sizeof(get); + EXPECT_THAT( + getsockopt(sockets->first_fd(), IPPROTO_TCP, TCP_LINGER2, &get, &get_len), + SyscallSucceedsWithValue(0)); + EXPECT_EQ(get_len, sizeof(get)); + EXPECT_THAT(get, + AnyOf(Eq(kMaxTCPLingerTimeout), Eq(kOldMaxTCPLingerTimeout))); } TEST_P(TCPSocketPairTest, SetTCPLingerTimeoutAboveMax) { @@ -1061,5 +1080,124 @@ TEST_P(TCPSocketPairTest, TCPResetDuringClose_NoRandomSave) { } } +// Test setsockopt and getsockopt for a socket with SO_LINGER option. +TEST_P(TCPSocketPairTest, SetAndGetLingerOption) { + auto sockets = ASSERT_NO_ERRNO_AND_VALUE(NewSocketPair()); + + // Check getsockopt before SO_LINGER option is set. + struct linger got_linger = {-1, -1}; + socklen_t got_len = sizeof(got_linger); + + ASSERT_THAT(getsockopt(sockets->first_fd(), SOL_SOCKET, SO_LINGER, + &got_linger, &got_len), + SyscallSucceeds()); + ASSERT_THAT(got_len, sizeof(got_linger)); + struct linger want_linger = {}; + EXPECT_EQ(0, memcmp(&want_linger, &got_linger, got_len)); + + // Set and get SO_LINGER with negative values. + struct linger sl; + sl.l_onoff = 1; + sl.l_linger = -3; + ASSERT_THAT( + setsockopt(sockets->first_fd(), SOL_SOCKET, SO_LINGER, &sl, sizeof(sl)), + SyscallSucceeds()); + ASSERT_THAT(getsockopt(sockets->first_fd(), SOL_SOCKET, SO_LINGER, + &got_linger, &got_len), + SyscallSucceeds()); + ASSERT_EQ(got_len, sizeof(got_linger)); + EXPECT_EQ(sl.l_onoff, got_linger.l_onoff); + // Linux returns a different value as it uses HZ to convert the seconds to + // jiffies which overflows for negative values. We want to be compatible with + // linux for getsockopt return value. + if (IsRunningOnGvisor()) { + EXPECT_EQ(sl.l_linger, got_linger.l_linger); + } + + // Set and get SO_LINGER option with positive values. + sl.l_onoff = 1; + sl.l_linger = 5; + ASSERT_THAT( + setsockopt(sockets->first_fd(), SOL_SOCKET, SO_LINGER, &sl, sizeof(sl)), + SyscallSucceeds()); + ASSERT_THAT(getsockopt(sockets->first_fd(), SOL_SOCKET, SO_LINGER, + &got_linger, &got_len), + SyscallSucceeds()); + ASSERT_EQ(got_len, sizeof(got_linger)); + EXPECT_EQ(0, memcmp(&sl, &got_linger, got_len)); +} + +// Test socket to disable SO_LINGER option. +TEST_P(TCPSocketPairTest, SetOffLingerOption) { + auto sockets = ASSERT_NO_ERRNO_AND_VALUE(NewSocketPair()); + + // Set the SO_LINGER option. + struct linger sl; + sl.l_onoff = 1; + sl.l_linger = 5; + ASSERT_THAT( + setsockopt(sockets->first_fd(), SOL_SOCKET, SO_LINGER, &sl, sizeof(sl)), + SyscallSucceeds()); + + // Check getsockopt after SO_LINGER option is set. + struct linger got_linger = {-1, -1}; + socklen_t got_len = sizeof(got_linger); + ASSERT_THAT(getsockopt(sockets->first_fd(), SOL_SOCKET, SO_LINGER, + &got_linger, &got_len), + SyscallSucceeds()); + ASSERT_EQ(got_len, sizeof(got_linger)); + EXPECT_EQ(0, memcmp(&sl, &got_linger, got_len)); + + sl.l_onoff = 0; + sl.l_linger = 5; + ASSERT_THAT( + setsockopt(sockets->first_fd(), SOL_SOCKET, SO_LINGER, &sl, sizeof(sl)), + SyscallSucceeds()); + + // Check getsockopt after SO_LINGER option is set to zero. + ASSERT_THAT(getsockopt(sockets->first_fd(), SOL_SOCKET, SO_LINGER, + &got_linger, &got_len), + SyscallSucceeds()); + ASSERT_EQ(got_len, sizeof(got_linger)); + EXPECT_EQ(0, memcmp(&sl, &got_linger, got_len)); +} + +// Test close on dup'd socket with SO_LINGER option set. +TEST_P(TCPSocketPairTest, CloseWithLingerOption) { + auto sockets = ASSERT_NO_ERRNO_AND_VALUE(NewSocketPair()); + + // Set the SO_LINGER option. + struct linger sl; + sl.l_onoff = 1; + sl.l_linger = 5; + ASSERT_THAT( + setsockopt(sockets->first_fd(), SOL_SOCKET, SO_LINGER, &sl, sizeof(sl)), + SyscallSucceeds()); + + // Check getsockopt after SO_LINGER option is set. + struct linger got_linger = {-1, -1}; + socklen_t got_len = sizeof(got_linger); + ASSERT_THAT(getsockopt(sockets->first_fd(), SOL_SOCKET, SO_LINGER, + &got_linger, &got_len), + SyscallSucceeds()); + ASSERT_EQ(got_len, sizeof(got_linger)); + EXPECT_EQ(0, memcmp(&sl, &got_linger, got_len)); + + FileDescriptor dupFd = FileDescriptor(dup(sockets->first_fd())); + ASSERT_THAT(close(sockets->release_first_fd()), SyscallSucceeds()); + char buf[10] = {}; + // Write on dupFd should succeed as socket will not be closed until + // all references are removed. + ASSERT_THAT(RetryEINTR(write)(dupFd.get(), buf, sizeof(buf)), + SyscallSucceedsWithValue(sizeof(buf))); + ASSERT_THAT(RetryEINTR(write)(sockets->first_fd(), buf, sizeof(buf)), + SyscallFailsWithErrno(EBADF)); + + // Close the socket. + dupFd.reset(); + // Write on dupFd should fail as all references for socket are removed. + ASSERT_THAT(RetryEINTR(write)(dupFd.get(), buf, sizeof(buf)), + SyscallFailsWithErrno(EBADF)); +} } // namespace testing } // namespace gvisor diff --git a/test/syscalls/linux/socket_ip_udp_generic.cc b/test/syscalls/linux/socket_ip_udp_generic.cc index edb86aded..3f2c0fdf2 100644 --- a/test/syscalls/linux/socket_ip_udp_generic.cc +++ b/test/syscalls/linux/socket_ip_udp_generic.cc @@ -435,8 +435,10 @@ TEST_P(UDPSocketPairTest, TOSRecvMismatch) { // Test that an IPv4 socket does not support the IPv6 TClass option. TEST_P(UDPSocketPairTest, TClassRecvMismatch) { - // This should only test AF_INET sockets for the mismatch behavior. - SKIP_IF(GetParam().domain != AF_INET); + // This should only test AF_INET6 sockets for the mismatch behavior. + SKIP_IF(GetParam().domain != AF_INET6); + // IPV6_RECVTCLASS is only valid for SOCK_DGRAM and SOCK_RAW. + SKIP_IF(GetParam().type != SOCK_DGRAM | GetParam().type != SOCK_RAW); auto sockets = ASSERT_NO_ERRNO_AND_VALUE(NewSocketPair()); @@ -448,5 +450,27 @@ TEST_P(UDPSocketPairTest, TClassRecvMismatch) { SyscallFailsWithErrno(EOPNOTSUPP)); } +// Test the SO_LINGER option can be set/get on udp socket. +TEST_P(UDPSocketPairTest, SetAndGetSocketLinger) { + auto sockets = ASSERT_NO_ERRNO_AND_VALUE(NewSocketPair()); + int level = SOL_SOCKET; + int type = SO_LINGER; + + struct linger sl; + sl.l_onoff = 1; + sl.l_linger = 5; + ASSERT_THAT(setsockopt(sockets->first_fd(), level, type, &sl, sizeof(sl)), + SyscallSucceedsWithValue(0)); + + struct linger got_linger = {}; + socklen_t length = sizeof(sl); + ASSERT_THAT( + getsockopt(sockets->first_fd(), level, type, &got_linger, &length), + SyscallSucceedsWithValue(0)); + + ASSERT_EQ(length, sizeof(got_linger)); + EXPECT_EQ(0, memcmp(&sl, &got_linger, length)); +} + } // namespace testing } // namespace gvisor diff --git a/test/syscalls/linux/socket_ip_unbound.cc b/test/syscalls/linux/socket_ip_unbound.cc index 1c7b0cf90..8f7ccc868 100644 --- a/test/syscalls/linux/socket_ip_unbound.cc +++ b/test/syscalls/linux/socket_ip_unbound.cc @@ -217,6 +217,8 @@ TEST_P(IPUnboundSocketTest, InvalidLargeTOS) { } TEST_P(IPUnboundSocketTest, CheckSkipECN) { + // Test is inconsistant on different kernels. + SKIP_IF(!IsRunningOnGvisor()); auto socket = ASSERT_NO_ERRNO_AND_VALUE(NewSocket()); int set = 0xFF; socklen_t set_sz = sizeof(set); diff --git a/test/syscalls/linux/socket_ipv4_udp_unbound.cc b/test/syscalls/linux/socket_ipv4_udp_unbound.cc index bc005e2bb..02ea05e22 100644 --- a/test/syscalls/linux/socket_ipv4_udp_unbound.cc +++ b/test/syscalls/linux/socket_ipv4_udp_unbound.cc @@ -2121,39 +2121,6 @@ TEST_P(IPv4UDPUnboundSocketTest, ReuseAddrReusePortDistribution) { SyscallSucceedsWithValue(kMessageSize)); } -// Check that connect returns EADDRNOTAVAIL when out of local ephemeral ports. -// We disable S/R because this test creates a large number of sockets. -TEST_P(IPv4UDPUnboundSocketTest, UDPConnectPortExhaustion_NoRandomSave) { - auto receiver1 = ASSERT_NO_ERRNO_AND_VALUE(NewSocket()); - constexpr int kClients = 65536; - // Bind the first socket to the loopback and take note of the selected port. - auto addr = V4Loopback(); - ASSERT_THAT(bind(receiver1->get(), reinterpret_cast<sockaddr*>(&addr.addr), - addr.addr_len), - SyscallSucceeds()); - socklen_t addr_len = addr.addr_len; - ASSERT_THAT(getsockname(receiver1->get(), - reinterpret_cast<sockaddr*>(&addr.addr), &addr_len), - SyscallSucceeds()); - EXPECT_EQ(addr_len, addr.addr_len); - - // Disable cooperative S/R as we are making too many syscalls. - DisableSave ds; - std::vector<std::unique_ptr<FileDescriptor>> sockets; - for (int i = 0; i < kClients; i++) { - auto s = ASSERT_NO_ERRNO_AND_VALUE(NewSocket()); - - int ret = connect(s->get(), reinterpret_cast<sockaddr*>(&addr.addr), - addr.addr_len); - if (ret == 0) { - sockets.push_back(std::move(s)); - continue; - } - ASSERT_THAT(ret, SyscallFailsWithErrno(EAGAIN)); - break; - } -} - // Test that socket will receive packet info control message. TEST_P(IPv4UDPUnboundSocketTest, SetAndReceiveIPPKTINFO) { // TODO(gvisor.dev/issue/1202): ioctl() is not supported by hostinet. diff --git a/test/syscalls/linux/socket_ipv4_udp_unbound_loopback_netlink.cc b/test/syscalls/linux/socket_ipv4_udp_unbound_loopback_netlink.cc new file mode 100644 index 000000000..8052bf404 --- /dev/null +++ b/test/syscalls/linux/socket_ipv4_udp_unbound_loopback_netlink.cc @@ -0,0 +1,32 @@ +// Copyright 2020 The gVisor Authors. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include <vector> + +#include "test/syscalls/linux/ip_socket_test_util.h" +#include "test/syscalls/linux/socket_ipv4_udp_unbound_netlink.h" +#include "test/syscalls/linux/socket_test_util.h" +#include "test/util/test_util.h" + +namespace gvisor { +namespace testing { + +INSTANTIATE_TEST_SUITE_P( + IPv4UDPSockets, IPv4UDPUnboundSocketNetlinkTest, + ::testing::ValuesIn(ApplyVec<SocketKind>(IPv4UDPUnboundSocket, + AllBitwiseCombinations(List<int>{ + 0, SOCK_NONBLOCK})))); + +} // namespace testing +} // namespace gvisor diff --git a/test/syscalls/linux/socket_ipv4_udp_unbound_loopback_nogotsan.cc b/test/syscalls/linux/socket_ipv4_udp_unbound_loopback_nogotsan.cc new file mode 100644 index 000000000..bcbd2feac --- /dev/null +++ b/test/syscalls/linux/socket_ipv4_udp_unbound_loopback_nogotsan.cc @@ -0,0 +1,94 @@ +// Copyright 2020 The gVisor Authors. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include <sys/socket.h> +#include <sys/types.h> + +#include "gtest/gtest.h" +#include "absl/memory/memory.h" +#include "test/syscalls/linux/ip_socket_test_util.h" +#include "test/syscalls/linux/socket_test_util.h" +#include "test/util/test_util.h" + +namespace gvisor { +namespace testing { + +// Test fixture for tests that apply to IPv4 UDP sockets. +using IPv4UDPUnboundSocketNogotsanTest = SimpleSocketTest; + +// Check that connect returns EAGAIN when out of local ephemeral ports. +// We disable S/R because this test creates a large number of sockets. +TEST_P(IPv4UDPUnboundSocketNogotsanTest, + UDPConnectPortExhaustion_NoRandomSave) { + auto receiver1 = ASSERT_NO_ERRNO_AND_VALUE(NewSocket()); + constexpr int kClients = 65536; + // Bind the first socket to the loopback and take note of the selected port. + auto addr = V4Loopback(); + ASSERT_THAT(bind(receiver1->get(), reinterpret_cast<sockaddr*>(&addr.addr), + addr.addr_len), + SyscallSucceeds()); + socklen_t addr_len = addr.addr_len; + ASSERT_THAT(getsockname(receiver1->get(), + reinterpret_cast<sockaddr*>(&addr.addr), &addr_len), + SyscallSucceeds()); + EXPECT_EQ(addr_len, addr.addr_len); + + // Disable cooperative S/R as we are making too many syscalls. + DisableSave ds; + std::vector<std::unique_ptr<FileDescriptor>> sockets; + for (int i = 0; i < kClients; i++) { + auto s = ASSERT_NO_ERRNO_AND_VALUE(NewSocket()); + + int ret = connect(s->get(), reinterpret_cast<sockaddr*>(&addr.addr), + addr.addr_len); + if (ret == 0) { + sockets.push_back(std::move(s)); + continue; + } + ASSERT_THAT(ret, SyscallFailsWithErrno(EAGAIN)); + break; + } +} + +// Check that bind returns EADDRINUSE when out of local ephemeral ports. +// We disable S/R because this test creates a large number of sockets. +TEST_P(IPv4UDPUnboundSocketNogotsanTest, UDPBindPortExhaustion_NoRandomSave) { + auto receiver1 = ASSERT_NO_ERRNO_AND_VALUE(NewSocket()); + constexpr int kClients = 65536; + auto addr = V4Loopback(); + // Disable cooperative S/R as we are making too many syscalls. + DisableSave ds; + std::vector<std::unique_ptr<FileDescriptor>> sockets; + for (int i = 0; i < kClients; i++) { + auto s = ASSERT_NO_ERRNO_AND_VALUE(NewSocket()); + + int ret = + bind(s->get(), reinterpret_cast<sockaddr*>(&addr.addr), addr.addr_len); + if (ret == 0) { + sockets.push_back(std::move(s)); + continue; + } + ASSERT_THAT(ret, SyscallFailsWithErrno(EADDRINUSE)); + break; + } +} + +INSTANTIATE_TEST_SUITE_P( + IPv4UDPSockets, IPv4UDPUnboundSocketNogotsanTest, + ::testing::ValuesIn(ApplyVec<SocketKind>(IPv4UDPUnboundSocket, + AllBitwiseCombinations(List<int>{ + 0, SOCK_NONBLOCK})))); + +} // namespace testing +} // namespace gvisor diff --git a/test/syscalls/linux/socket_ipv4_udp_unbound_netlink.cc b/test/syscalls/linux/socket_ipv4_udp_unbound_netlink.cc new file mode 100644 index 000000000..49a0f06d9 --- /dev/null +++ b/test/syscalls/linux/socket_ipv4_udp_unbound_netlink.cc @@ -0,0 +1,225 @@ +// Copyright 2020 The gVisor Authors. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "test/syscalls/linux/socket_ipv4_udp_unbound_netlink.h" + +#include <arpa/inet.h> +#include <poll.h> + +#include "gtest/gtest.h" +#include "test/syscalls/linux/socket_netlink_route_util.h" +#include "test/util/capability_util.h" +#include "test/util/cleanup.h" + +namespace gvisor { +namespace testing { + +constexpr size_t kSendBufSize = 200; + +// Checks that the loopback interface considers itself bound to all IPs in an +// associated subnet. +TEST_P(IPv4UDPUnboundSocketNetlinkTest, JoinSubnet) { + SKIP_IF(!ASSERT_NO_ERRNO_AND_VALUE(HaveCapability(CAP_NET_ADMIN))); + + // Add an IP address to the loopback interface. + Link loopback_link = ASSERT_NO_ERRNO_AND_VALUE(LoopbackLink()); + struct in_addr addr; + ASSERT_EQ(1, inet_pton(AF_INET, "192.0.2.1", &addr)); + ASSERT_NO_ERRNO(LinkAddLocalAddr(loopback_link.index, AF_INET, + /*prefixlen=*/24, &addr, sizeof(addr))); + Cleanup defer_addr_removal = Cleanup( + [loopback_link = std::move(loopback_link), addr = std::move(addr)] { + if (IsRunningOnGvisor()) { + // TODO(gvisor.dev/issue/3921): Remove this once deleting addresses + // via netlink is supported. + EXPECT_THAT(LinkDelLocalAddr(loopback_link.index, AF_INET, + /*prefixlen=*/24, &addr, sizeof(addr)), + PosixErrorIs(EOPNOTSUPP, ::testing::_)); + } else { + EXPECT_NO_ERRNO(LinkDelLocalAddr(loopback_link.index, AF_INET, + /*prefixlen=*/24, &addr, + sizeof(addr))); + } + }); + + auto snd_sock = ASSERT_NO_ERRNO_AND_VALUE(NewSocket()); + auto rcv_sock = ASSERT_NO_ERRNO_AND_VALUE(NewSocket()); + + // Send from an unassigned address but an address that is in the subnet + // associated with the loopback interface. + TestAddress sender_addr("V4NotAssignd1"); + sender_addr.addr.ss_family = AF_INET; + sender_addr.addr_len = sizeof(sockaddr_in); + ASSERT_EQ(1, inet_pton(AF_INET, "192.0.2.2", + &(reinterpret_cast<sockaddr_in*>(&sender_addr.addr) + ->sin_addr.s_addr))); + ASSERT_THAT( + bind(snd_sock->get(), reinterpret_cast<sockaddr*>(&sender_addr.addr), + sender_addr.addr_len), + SyscallSucceeds()); + + // Send the packet to an unassigned address but an address that is in the + // subnet associated with the loopback interface. + TestAddress receiver_addr("V4NotAssigned2"); + receiver_addr.addr.ss_family = AF_INET; + receiver_addr.addr_len = sizeof(sockaddr_in); + ASSERT_EQ(1, inet_pton(AF_INET, "192.0.2.254", + &(reinterpret_cast<sockaddr_in*>(&receiver_addr.addr) + ->sin_addr.s_addr))); + ASSERT_THAT( + bind(rcv_sock->get(), reinterpret_cast<sockaddr*>(&receiver_addr.addr), + receiver_addr.addr_len), + SyscallSucceeds()); + socklen_t receiver_addr_len = receiver_addr.addr_len; + ASSERT_THAT(getsockname(rcv_sock->get(), + reinterpret_cast<sockaddr*>(&receiver_addr.addr), + &receiver_addr_len), + SyscallSucceeds()); + ASSERT_EQ(receiver_addr_len, receiver_addr.addr_len); + char send_buf[kSendBufSize]; + RandomizeBuffer(send_buf, kSendBufSize); + ASSERT_THAT( + RetryEINTR(sendto)(snd_sock->get(), send_buf, kSendBufSize, 0, + reinterpret_cast<sockaddr*>(&receiver_addr.addr), + receiver_addr.addr_len), + SyscallSucceedsWithValue(kSendBufSize)); + + // Check that we received the packet. + char recv_buf[kSendBufSize] = {}; + ASSERT_THAT(RetryEINTR(recv)(rcv_sock->get(), recv_buf, kSendBufSize, 0), + SyscallSucceedsWithValue(kSendBufSize)); + ASSERT_EQ(0, memcmp(send_buf, recv_buf, kSendBufSize)); +} + +// Tests that broadcast packets are delivered to all interested sockets +// (wildcard and broadcast address specified sockets). +// +// Note, we cannot test the IPv4 Broadcast (255.255.255.255) because we do +// not have a route to it. +TEST_P(IPv4UDPUnboundSocketNetlinkTest, ReuseAddrSubnetDirectedBroadcast) { + constexpr uint16_t kPort = 9876; + // Wait up to 20 seconds for the data. + constexpr int kPollTimeoutMs = 20000; + // Number of sockets per socket type. + constexpr int kNumSocketsPerType = 2; + + SKIP_IF(!ASSERT_NO_ERRNO_AND_VALUE(HaveCapability(CAP_NET_ADMIN))); + + // Add an IP address to the loopback interface. + Link loopback_link = ASSERT_NO_ERRNO_AND_VALUE(LoopbackLink()); + struct in_addr addr; + ASSERT_EQ(1, inet_pton(AF_INET, "192.0.2.1", &addr)); + ASSERT_NO_ERRNO(LinkAddLocalAddr(loopback_link.index, AF_INET, + 24 /* prefixlen */, &addr, sizeof(addr))); + Cleanup defer_addr_removal = Cleanup( + [loopback_link = std::move(loopback_link), addr = std::move(addr)] { + if (IsRunningOnGvisor()) { + // TODO(gvisor.dev/issue/3921): Remove this once deleting addresses + // via netlink is supported. + EXPECT_THAT(LinkDelLocalAddr(loopback_link.index, AF_INET, + /*prefixlen=*/24, &addr, sizeof(addr)), + PosixErrorIs(EOPNOTSUPP, ::testing::_)); + } else { + EXPECT_NO_ERRNO(LinkDelLocalAddr(loopback_link.index, AF_INET, + /*prefixlen=*/24, &addr, + sizeof(addr))); + } + }); + + TestAddress broadcast_address("SubnetBroadcastAddress"); + broadcast_address.addr.ss_family = AF_INET; + broadcast_address.addr_len = sizeof(sockaddr_in); + auto broadcast_address_in = + reinterpret_cast<sockaddr_in*>(&broadcast_address.addr); + ASSERT_EQ(1, inet_pton(AF_INET, "192.0.2.255", + &broadcast_address_in->sin_addr.s_addr)); + broadcast_address_in->sin_port = htons(kPort); + + TestAddress any_address = V4Any(); + reinterpret_cast<sockaddr_in*>(&any_address.addr)->sin_port = htons(kPort); + + // We create sockets bound to both the wildcard address and the broadcast + // address to make sure both of these types of "broadcast interested" sockets + // receive broadcast packets. + std::vector<std::unique_ptr<FileDescriptor>> socks; + for (bool bind_wildcard : {false, true}) { + // Create multiple sockets for each type of "broadcast interested" + // socket so we can test that all sockets receive the broadcast packet. + for (int i = 0; i < kNumSocketsPerType; i++) { + auto sock = ASSERT_NO_ERRNO_AND_VALUE(NewSocket()); + auto idx = socks.size(); + + ASSERT_THAT(setsockopt(sock->get(), SOL_SOCKET, SO_REUSEADDR, &kSockOptOn, + sizeof(kSockOptOn)), + SyscallSucceedsWithValue(0)) + << "socks[" << idx << "]"; + + ASSERT_THAT(setsockopt(sock->get(), SOL_SOCKET, SO_BROADCAST, &kSockOptOn, + sizeof(kSockOptOn)), + SyscallSucceedsWithValue(0)) + << "socks[" << idx << "]"; + + if (bind_wildcard) { + ASSERT_THAT( + bind(sock->get(), reinterpret_cast<sockaddr*>(&any_address.addr), + any_address.addr_len), + SyscallSucceeds()) + << "socks[" << idx << "]"; + } else { + ASSERT_THAT(bind(sock->get(), + reinterpret_cast<sockaddr*>(&broadcast_address.addr), + broadcast_address.addr_len), + SyscallSucceeds()) + << "socks[" << idx << "]"; + } + + socks.push_back(std::move(sock)); + } + } + + char send_buf[kSendBufSize]; + RandomizeBuffer(send_buf, kSendBufSize); + + // Broadcasts from each socket should be received by every socket (including + // the sending socket). + for (int w = 0; w < socks.size(); w++) { + auto& w_sock = socks[w]; + ASSERT_THAT( + RetryEINTR(sendto)(w_sock->get(), send_buf, kSendBufSize, 0, + reinterpret_cast<sockaddr*>(&broadcast_address.addr), + broadcast_address.addr_len), + SyscallSucceedsWithValue(kSendBufSize)) + << "write socks[" << w << "]"; + + // Check that we received the packet on all sockets. + for (int r = 0; r < socks.size(); r++) { + auto& r_sock = socks[r]; + + struct pollfd poll_fd = {r_sock->get(), POLLIN, 0}; + EXPECT_THAT(RetryEINTR(poll)(&poll_fd, 1, kPollTimeoutMs), + SyscallSucceedsWithValue(1)) + << "write socks[" << w << "] & read socks[" << r << "]"; + + char recv_buf[kSendBufSize] = {}; + EXPECT_THAT(RetryEINTR(recv)(r_sock->get(), recv_buf, kSendBufSize, 0), + SyscallSucceedsWithValue(kSendBufSize)) + << "write socks[" << w << "] & read socks[" << r << "]"; + EXPECT_EQ(0, memcmp(send_buf, recv_buf, kSendBufSize)) + << "write socks[" << w << "] & read socks[" << r << "]"; + } + } +} + +} // namespace testing +} // namespace gvisor diff --git a/test/syscalls/linux/socket_ipv4_udp_unbound_netlink.h b/test/syscalls/linux/socket_ipv4_udp_unbound_netlink.h new file mode 100644 index 000000000..73e7836d5 --- /dev/null +++ b/test/syscalls/linux/socket_ipv4_udp_unbound_netlink.h @@ -0,0 +1,29 @@ +// Copyright 2020 The gVisor Authors. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#ifndef GVISOR_TEST_SYSCALLS_LINUX_SOCKET_IPV4_UDP_UNBOUND_NETLINK_UTIL_H_ +#define GVISOR_TEST_SYSCALLS_LINUX_SOCKET_IPV4_UDP_UNBOUND_NETLINK_UTIL_H_ + +#include "test/syscalls/linux/socket_test_util.h" + +namespace gvisor { +namespace testing { + +// Test fixture for tests that apply to IPv4 UDP sockets. +using IPv4UDPUnboundSocketNetlinkTest = SimpleSocketTest; + +} // namespace testing +} // namespace gvisor + +#endif // GVISOR_TEST_SYSCALLS_LINUX_SOCKET_IPV4_UDP_UNBOUND_NETLINK_UTIL_H_ diff --git a/test/syscalls/linux/socket_ipv6_udp_unbound_loopback_netlink.cc b/test/syscalls/linux/socket_ipv6_udp_unbound_loopback_netlink.cc new file mode 100644 index 000000000..17021ff82 --- /dev/null +++ b/test/syscalls/linux/socket_ipv6_udp_unbound_loopback_netlink.cc @@ -0,0 +1,32 @@ +// Copyright 2020 The gVisor Authors. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include <vector> + +#include "test/syscalls/linux/ip_socket_test_util.h" +#include "test/syscalls/linux/socket_ipv6_udp_unbound_netlink.h" +#include "test/syscalls/linux/socket_test_util.h" +#include "test/util/test_util.h" + +namespace gvisor { +namespace testing { + +INSTANTIATE_TEST_SUITE_P( + IPv6UDPSockets, IPv6UDPUnboundSocketNetlinkTest, + ::testing::ValuesIn(ApplyVec<SocketKind>(IPv6UDPUnboundSocket, + AllBitwiseCombinations(List<int>{ + 0, SOCK_NONBLOCK})))); + +} // namespace testing +} // namespace gvisor diff --git a/test/syscalls/linux/socket_ipv6_udp_unbound_netlink.cc b/test/syscalls/linux/socket_ipv6_udp_unbound_netlink.cc new file mode 100644 index 000000000..2ee218231 --- /dev/null +++ b/test/syscalls/linux/socket_ipv6_udp_unbound_netlink.cc @@ -0,0 +1,53 @@ +// Copyright 2020 The gVisor Authors. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "test/syscalls/linux/socket_ipv6_udp_unbound_netlink.h" + +#include <arpa/inet.h> + +#include "gtest/gtest.h" +#include "test/syscalls/linux/socket_netlink_route_util.h" +#include "test/util/capability_util.h" + +namespace gvisor { +namespace testing { + +// Checks that the loopback interface does not consider itself bound to all IPs +// in an associated subnet. +TEST_P(IPv6UDPUnboundSocketNetlinkTest, JoinSubnet) { + SKIP_IF(!ASSERT_NO_ERRNO_AND_VALUE(HaveCapability(CAP_NET_ADMIN))); + + // Add an IP address to the loopback interface. + Link loopback_link = ASSERT_NO_ERRNO_AND_VALUE(LoopbackLink()); + struct in6_addr addr; + EXPECT_EQ(1, inet_pton(AF_INET6, "2001:db8::1", &addr)); + EXPECT_NO_ERRNO(LinkAddLocalAddr(loopback_link.index, AF_INET6, + /*prefixlen=*/64, &addr, sizeof(addr))); + + // Binding to an unassigned address but an address that is in the subnet + // associated with the loopback interface should fail. + TestAddress sender_addr("V6NotAssignd1"); + sender_addr.addr.ss_family = AF_INET6; + sender_addr.addr_len = sizeof(sockaddr_in6); + EXPECT_EQ(1, inet_pton(AF_INET6, "2001:db8::2", + reinterpret_cast<sockaddr_in6*>(&sender_addr.addr) + ->sin6_addr.s6_addr)); + auto sock = ASSERT_NO_ERRNO_AND_VALUE(NewSocket()); + EXPECT_THAT(bind(sock->get(), reinterpret_cast<sockaddr*>(&sender_addr.addr), + sender_addr.addr_len), + SyscallFailsWithErrno(EADDRNOTAVAIL)); +} + +} // namespace testing +} // namespace gvisor diff --git a/test/syscalls/linux/socket_ipv6_udp_unbound_netlink.h b/test/syscalls/linux/socket_ipv6_udp_unbound_netlink.h new file mode 100644 index 000000000..88098be82 --- /dev/null +++ b/test/syscalls/linux/socket_ipv6_udp_unbound_netlink.h @@ -0,0 +1,29 @@ +// Copyright 2020 The gVisor Authors. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#ifndef GVISOR_TEST_SYSCALLS_LINUX_SOCKET_IPV6_UDP_UNBOUND_NETLINK_UTIL_H_ +#define GVISOR_TEST_SYSCALLS_LINUX_SOCKET_IPV6_UDP_UNBOUND_NETLINK_UTIL_H_ + +#include "test/syscalls/linux/socket_test_util.h" + +namespace gvisor { +namespace testing { + +// Test fixture for tests that apply to IPv6 UDP sockets. +using IPv6UDPUnboundSocketNetlinkTest = SimpleSocketTest; + +} // namespace testing +} // namespace gvisor + +#endif // GVISOR_TEST_SYSCALLS_LINUX_SOCKET_IPV6_UDP_UNBOUND_NETLINK_UTIL_H_ diff --git a/test/syscalls/linux/socket_netlink_route_util.cc b/test/syscalls/linux/socket_netlink_route_util.cc index bde1dbb4d..a354f3f80 100644 --- a/test/syscalls/linux/socket_netlink_route_util.cc +++ b/test/syscalls/linux/socket_netlink_route_util.cc @@ -26,6 +26,62 @@ namespace { constexpr uint32_t kSeq = 12345; +// Types of address modifications that may be performed on an interface. +enum class LinkAddrModification { + kAdd, + kDelete, +}; + +// Populates |hdr| with appripriate values for the modification type. +PosixError PopulateNlmsghdr(LinkAddrModification modification, + struct nlmsghdr* hdr) { + switch (modification) { + case LinkAddrModification::kAdd: + hdr->nlmsg_type = RTM_NEWADDR; + hdr->nlmsg_flags = NLM_F_REQUEST | NLM_F_ACK; + return NoError(); + case LinkAddrModification::kDelete: + hdr->nlmsg_type = RTM_DELADDR; + hdr->nlmsg_flags = NLM_F_REQUEST; + return NoError(); + } + + return PosixError(EINVAL); +} + +// Adds or removes the specified address from the specified interface. +PosixError LinkModifyLocalAddr(int index, int family, int prefixlen, + const void* addr, int addrlen, + LinkAddrModification modification) { + ASSIGN_OR_RETURN_ERRNO(FileDescriptor fd, NetlinkBoundSocket(NETLINK_ROUTE)); + + struct request { + struct nlmsghdr hdr; + struct ifaddrmsg ifaddr; + char attrbuf[512]; + }; + + struct request req = {}; + PosixError err = PopulateNlmsghdr(modification, &req.hdr); + if (!err.ok()) { + return err; + } + req.hdr.nlmsg_len = NLMSG_LENGTH(sizeof(req.ifaddr)); + req.hdr.nlmsg_seq = kSeq; + req.ifaddr.ifa_index = index; + req.ifaddr.ifa_family = family; + req.ifaddr.ifa_prefixlen = prefixlen; + + struct rtattr* rta = reinterpret_cast<struct rtattr*>( + reinterpret_cast<int8_t*>(&req) + NLMSG_ALIGN(req.hdr.nlmsg_len)); + rta->rta_type = IFA_LOCAL; + rta->rta_len = RTA_LENGTH(addrlen); + req.hdr.nlmsg_len = NLMSG_ALIGN(req.hdr.nlmsg_len) + RTA_LENGTH(addrlen); + memcpy(RTA_DATA(rta), addr, addrlen); + + return NetlinkRequestAckOrError(fd, kSeq, &req, req.hdr.nlmsg_len); +} + } // namespace PosixError DumpLinks( @@ -84,31 +140,14 @@ PosixErrorOr<Link> LoopbackLink() { PosixError LinkAddLocalAddr(int index, int family, int prefixlen, const void* addr, int addrlen) { - ASSIGN_OR_RETURN_ERRNO(FileDescriptor fd, NetlinkBoundSocket(NETLINK_ROUTE)); - - struct request { - struct nlmsghdr hdr; - struct ifaddrmsg ifaddr; - char attrbuf[512]; - }; - - struct request req = {}; - req.hdr.nlmsg_len = NLMSG_LENGTH(sizeof(req.ifaddr)); - req.hdr.nlmsg_type = RTM_NEWADDR; - req.hdr.nlmsg_flags = NLM_F_REQUEST | NLM_F_ACK; - req.hdr.nlmsg_seq = kSeq; - req.ifaddr.ifa_index = index; - req.ifaddr.ifa_family = family; - req.ifaddr.ifa_prefixlen = prefixlen; - - struct rtattr* rta = reinterpret_cast<struct rtattr*>( - reinterpret_cast<int8_t*>(&req) + NLMSG_ALIGN(req.hdr.nlmsg_len)); - rta->rta_type = IFA_LOCAL; - rta->rta_len = RTA_LENGTH(addrlen); - req.hdr.nlmsg_len = NLMSG_ALIGN(req.hdr.nlmsg_len) + RTA_LENGTH(addrlen); - memcpy(RTA_DATA(rta), addr, addrlen); + return LinkModifyLocalAddr(index, family, prefixlen, addr, addrlen, + LinkAddrModification::kAdd); +} - return NetlinkRequestAckOrError(fd, kSeq, &req, req.hdr.nlmsg_len); +PosixError LinkDelLocalAddr(int index, int family, int prefixlen, + const void* addr, int addrlen) { + return LinkModifyLocalAddr(index, family, prefixlen, addr, addrlen, + LinkAddrModification::kDelete); } PosixError LinkChangeFlags(int index, unsigned int flags, unsigned int change) { diff --git a/test/syscalls/linux/socket_netlink_route_util.h b/test/syscalls/linux/socket_netlink_route_util.h index 149c4a7f6..e5badca70 100644 --- a/test/syscalls/linux/socket_netlink_route_util.h +++ b/test/syscalls/linux/socket_netlink_route_util.h @@ -43,6 +43,10 @@ PosixErrorOr<Link> LoopbackLink(); PosixError LinkAddLocalAddr(int index, int family, int prefixlen, const void* addr, int addrlen); +// LinkDelLocalAddr removes IFA_LOCAL attribute on the interface. +PosixError LinkDelLocalAddr(int index, int family, int prefixlen, + const void* addr, int addrlen); + // LinkChangeFlags changes interface flags. E.g. IFF_UP. PosixError LinkChangeFlags(int index, unsigned int flags, unsigned int change); diff --git a/test/syscalls/linux/socket_unix_stream.cc b/test/syscalls/linux/socket_unix_stream.cc index 99e77b89e..1edcb15a7 100644 --- a/test/syscalls/linux/socket_unix_stream.cc +++ b/test/syscalls/linux/socket_unix_stream.cc @@ -103,6 +103,24 @@ TEST_P(StreamUnixSocketPairTest, Sendto) { SyscallFailsWithErrno(EISCONN)); } +TEST_P(StreamUnixSocketPairTest, SetAndGetSocketLinger) { + auto sockets = ASSERT_NO_ERRNO_AND_VALUE(NewSocketPair()); + + struct linger sl = {1, 5}; + EXPECT_THAT( + setsockopt(sockets->first_fd(), SOL_SOCKET, SO_LINGER, &sl, sizeof(sl)), + SyscallSucceedsWithValue(0)); + + struct linger got_linger = {}; + socklen_t length = sizeof(sl); + EXPECT_THAT(getsockopt(sockets->first_fd(), SOL_SOCKET, SO_LINGER, + &got_linger, &length), + SyscallSucceedsWithValue(0)); + + ASSERT_EQ(length, sizeof(got_linger)); + EXPECT_EQ(0, memcmp(&got_linger, &sl, length)); +} + INSTANTIATE_TEST_SUITE_P( AllUnixDomainSockets, StreamUnixSocketPairTest, ::testing::ValuesIn(IncludeReversals(VecCat<SocketPairKind>( diff --git a/test/syscalls/linux/splice.cc b/test/syscalls/linux/splice.cc index 08fc4b1b7..a1d2b9b11 100644 --- a/test/syscalls/linux/splice.cc +++ b/test/syscalls/linux/splice.cc @@ -298,6 +298,23 @@ TEST(SpliceTest, ToPipe) { EXPECT_EQ(memcmp(rbuf.data(), buf.data(), buf.size()), 0); } +TEST(SpliceTest, ToPipeEOF) { + // Create and open an empty input file. + const TempPath in_file = ASSERT_NO_ERRNO_AND_VALUE(TempPath::CreateFile()); + const FileDescriptor in_fd = + ASSERT_NO_ERRNO_AND_VALUE(Open(in_file.path(), O_RDONLY)); + + // Create a new pipe. + int fds[2]; + ASSERT_THAT(pipe(fds), SyscallSucceeds()); + const FileDescriptor rfd(fds[0]); + const FileDescriptor wfd(fds[1]); + + // Splice from the empty file to the pipe. + EXPECT_THAT(splice(in_fd.get(), nullptr, wfd.get(), nullptr, 123, 0), + SyscallSucceedsWithValue(0)); +} + TEST(SpliceTest, ToPipeOffset) { // Open the input file. const TempPath in_file = ASSERT_NO_ERRNO_AND_VALUE(TempPath::CreateFile()); @@ -342,7 +359,7 @@ TEST(SpliceTest, FromPipe) { ASSERT_THAT(write(wfd.get(), buf.data(), buf.size()), SyscallSucceedsWithValue(kPageSize)); - // Open the input file. + // Open the output file. const TempPath out_file = ASSERT_NO_ERRNO_AND_VALUE(TempPath::CreateFile()); const FileDescriptor out_fd = ASSERT_NO_ERRNO_AND_VALUE(Open(out_file.path(), O_RDWR)); @@ -364,6 +381,40 @@ TEST(SpliceTest, FromPipe) { EXPECT_EQ(memcmp(rbuf.data(), buf.data(), buf.size()), 0); } +TEST(SpliceTest, FromPipeMultiple) { + // Create a new pipe. + int fds[2]; + ASSERT_THAT(pipe(fds), SyscallSucceeds()); + const FileDescriptor rfd(fds[0]); + const FileDescriptor wfd(fds[1]); + + std::string buf = "abcABC123"; + ASSERT_THAT(write(wfd.get(), buf.c_str(), buf.size()), + SyscallSucceedsWithValue(buf.size())); + + // Open the output file. + const TempPath out_file = ASSERT_NO_ERRNO_AND_VALUE(TempPath::CreateFile()); + const FileDescriptor out_fd = + ASSERT_NO_ERRNO_AND_VALUE(Open(out_file.path(), O_RDWR)); + + // Splice from the pipe to the output file over several calls. + EXPECT_THAT(splice(rfd.get(), nullptr, out_fd.get(), nullptr, 3, 0), + SyscallSucceedsWithValue(3)); + EXPECT_THAT(splice(rfd.get(), nullptr, out_fd.get(), nullptr, 3, 0), + SyscallSucceedsWithValue(3)); + EXPECT_THAT(splice(rfd.get(), nullptr, out_fd.get(), nullptr, 3, 0), + SyscallSucceedsWithValue(3)); + + // Reset cursor to zero so that we can check the contents. + ASSERT_THAT(lseek(out_fd.get(), 0, SEEK_SET), SyscallSucceedsWithValue(0)); + + // Contents should be equal. + std::vector<char> rbuf(buf.size()); + ASSERT_THAT(read(out_fd.get(), rbuf.data(), rbuf.size()), + SyscallSucceedsWithValue(rbuf.size())); + EXPECT_EQ(memcmp(rbuf.data(), buf.c_str(), buf.size()), 0); +} + TEST(SpliceTest, FromPipeOffset) { // Create a new pipe. int fds[2]; @@ -693,6 +744,34 @@ TEST(SpliceTest, FromPipeMaxFileSize) { EXPECT_EQ(memcmp(rbuf.data(), buf.data(), buf.size()), 0); } +TEST(SpliceTest, FromPipeToDevZero) { + // Create a new pipe. + int fds[2]; + ASSERT_THAT(pipe(fds), SyscallSucceeds()); + const FileDescriptor rfd(fds[0]); + FileDescriptor wfd(fds[1]); + + // Fill with some random data. + std::vector<char> buf(kPageSize); + RandomizeBuffer(buf.data(), buf.size()); + ASSERT_THAT(write(wfd.get(), buf.data(), buf.size()), + SyscallSucceedsWithValue(kPageSize)); + + const FileDescriptor zero = + ASSERT_NO_ERRNO_AND_VALUE(Open("/dev/zero", O_WRONLY)); + + // Close the write end to prevent blocking below. + wfd.reset(); + + // Splice to /dev/zero. The first call should empty the pipe, and the return + // value should not exceed the number of bytes available for reading. + EXPECT_THAT( + splice(rfd.get(), nullptr, zero.get(), nullptr, kPageSize + 123, 0), + SyscallSucceedsWithValue(kPageSize)); + EXPECT_THAT(splice(rfd.get(), nullptr, zero.get(), nullptr, 1, 0), + SyscallSucceedsWithValue(0)); +} + } // namespace } // namespace testing diff --git a/test/syscalls/linux/stat.cc b/test/syscalls/linux/stat.cc index 2503960f3..92260b1e1 100644 --- a/test/syscalls/linux/stat.cc +++ b/test/syscalls/linux/stat.cc @@ -97,6 +97,11 @@ TEST_F(StatTest, FstatatSymlink) { } TEST_F(StatTest, Nlinks) { + // Skip this test if we are testing overlayfs because overlayfs does not + // (intentionally) return the correct nlink value for directories. + // See fs/overlayfs/inode.c:ovl_getattr(). + SKIP_IF(ASSERT_NO_ERRNO_AND_VALUE(IsOverlayfs(GetAbsoluteTestTmpdir()))); + TempPath basedir = ASSERT_NO_ERRNO_AND_VALUE(TempPath::CreateDir()); // Directory is initially empty, it should contain 2 links (one from itself, @@ -328,20 +333,23 @@ TEST_F(StatTest, LeadingDoubleSlash) { // Test that a rename doesn't change the underlying file. TEST_F(StatTest, StatDoesntChangeAfterRename) { - const TempPath old_dir = ASSERT_NO_ERRNO_AND_VALUE(TempPath::CreateDir()); + const TempPath old_file = ASSERT_NO_ERRNO_AND_VALUE(TempPath::CreateFile()); const TempPath new_path(NewTempAbsPath()); struct stat st_old = {}; struct stat st_new = {}; - ASSERT_THAT(stat(old_dir.path().c_str(), &st_old), SyscallSucceeds()); - ASSERT_THAT(rename(old_dir.path().c_str(), new_path.path().c_str()), + ASSERT_THAT(stat(old_file.path().c_str(), &st_old), SyscallSucceeds()); + ASSERT_THAT(rename(old_file.path().c_str(), new_path.path().c_str()), SyscallSucceeds()); ASSERT_THAT(stat(new_path.path().c_str(), &st_new), SyscallSucceeds()); EXPECT_EQ(st_old.st_nlink, st_new.st_nlink); EXPECT_EQ(st_old.st_dev, st_new.st_dev); - EXPECT_EQ(st_old.st_ino, st_new.st_ino); + // Overlay filesystems may synthesize directory inode numbers on the fly. + if (!ASSERT_NO_ERRNO_AND_VALUE(IsOverlayfs(GetAbsoluteTestTmpdir()))) { + EXPECT_EQ(st_old.st_ino, st_new.st_ino); + } EXPECT_EQ(st_old.st_mode, st_new.st_mode); EXPECT_EQ(st_old.st_uid, st_new.st_uid); EXPECT_EQ(st_old.st_gid, st_new.st_gid); @@ -378,7 +386,9 @@ TEST_F(StatTest, LinkCountsWithRegularFileChild) { // This test verifies that inodes remain around when there is an open fd // after link count hits 0. -TEST_F(StatTest, ZeroLinksOpenFdRegularFileChild_NoRandomSave) { +// +// It is marked NoSave because we don't support saving unlinked files. +TEST_F(StatTest, ZeroLinksOpenFdRegularFileChild_NoSave) { // Setting the enviornment variable GVISOR_GOFER_UNCACHED to any value // will prevent this test from running, see the tmpfs lifecycle. // @@ -387,9 +397,6 @@ TEST_F(StatTest, ZeroLinksOpenFdRegularFileChild_NoRandomSave) { const char* uncached_gofer = getenv("GVISOR_GOFER_UNCACHED"); SKIP_IF(uncached_gofer != nullptr); - // We don't support saving unlinked files. - const DisableSave ds; - const TempPath dir = ASSERT_NO_ERRNO_AND_VALUE(TempPath::CreateDir()); const TempPath child = ASSERT_NO_ERRNO_AND_VALUE(TempPath::CreateFileWith( dir.path(), "hello", TempPath::kDefaultFileMode)); @@ -432,6 +439,11 @@ TEST_F(StatTest, ZeroLinksOpenFdRegularFileChild_NoRandomSave) { // Test link counts with a directory as the child. TEST_F(StatTest, LinkCountsWithDirChild) { + // Skip this test if we are testing overlayfs because overlayfs does not + // (intentionally) return the correct nlink value for directories. + // See fs/overlayfs/inode.c:ovl_getattr(). + SKIP_IF(ASSERT_NO_ERRNO_AND_VALUE(IsOverlayfs(GetAbsoluteTestTmpdir()))); + const TempPath dir = ASSERT_NO_ERRNO_AND_VALUE(TempPath::CreateDir()); // Before a child is added the two links are "." and the link from the parent. diff --git a/test/syscalls/linux/statfs.cc b/test/syscalls/linux/statfs.cc index aca51d30f..f0fb166bd 100644 --- a/test/syscalls/linux/statfs.cc +++ b/test/syscalls/linux/statfs.cc @@ -13,6 +13,7 @@ // limitations under the License. #include <fcntl.h> +#include <linux/magic.h> #include <sys/statfs.h> #include <unistd.h> @@ -43,14 +44,10 @@ TEST(StatfsTest, InternalTmpfs) { TEST(StatfsTest, InternalDevShm) { struct statfs st; EXPECT_THAT(statfs("/dev/shm", &st), SyscallSucceeds()); -} - -TEST(StatfsTest, NameLen) { - struct statfs st; - EXPECT_THAT(statfs("/dev/shm", &st), SyscallSucceeds()); // This assumes that /dev/shm is tmpfs. - EXPECT_EQ(st.f_namelen, NAME_MAX); + // Note: We could be an overlay on some configurations. + EXPECT_TRUE(st.f_type == TMPFS_MAGIC || st.f_type == OVERLAYFS_SUPER_MAGIC); } TEST(FstatfsTest, CannotStatBadFd) { diff --git a/test/syscalls/linux/symlink.cc b/test/syscalls/linux/symlink.cc index a17ff62e9..4d9eba7f0 100644 --- a/test/syscalls/linux/symlink.cc +++ b/test/syscalls/linux/symlink.cc @@ -218,6 +218,36 @@ TEST(SymlinkTest, PreadFromSymlink) { EXPECT_THAT(unlink(linkname.c_str()), SyscallSucceeds()); } +TEST(SymlinkTest, PwriteToSymlink) { + std::string name = NewTempAbsPath(); + int fd; + ASSERT_THAT(fd = open(name.c_str(), O_CREAT, 0644), SyscallSucceeds()); + ASSERT_THAT(close(fd), SyscallSucceeds()); + + std::string linkname = NewTempAbsPath(); + ASSERT_THAT(symlink(name.c_str(), linkname.c_str()), SyscallSucceeds()); + + ASSERT_THAT(fd = open(linkname.c_str(), O_WRONLY), SyscallSucceeds()); + + const int data_size = 10; + const std::string data = std::string(data_size, 'a'); + EXPECT_THAT(pwrite64(fd, data.c_str(), data.size(), 0), + SyscallSucceedsWithValue(data.size())); + + ASSERT_THAT(close(fd), SyscallSucceeds()); + ASSERT_THAT(fd = open(name.c_str(), O_RDONLY), SyscallSucceeds()); + + char buf[data_size + 1]; + EXPECT_THAT(pread64(fd, buf, data.size(), 0), SyscallSucceeds()); + buf[data.size()] = '\0'; + EXPECT_STREQ(buf, data.c_str()); + + ASSERT_THAT(close(fd), SyscallSucceeds()); + + EXPECT_THAT(unlink(name.c_str()), SyscallSucceeds()); + EXPECT_THAT(unlink(linkname.c_str()), SyscallSucceeds()); +} + TEST(SymlinkTest, SymlinkAtDegradedPermissions_NoRandomSave) { // Drop capabilities that allow us to override file and directory permissions. ASSERT_NO_ERRNO(SetCapability(CAP_DAC_OVERRIDE, false)); @@ -297,6 +327,16 @@ TEST(SymlinkTest, FollowUpdatesATime) { EXPECT_LT(st_before_follow.st_atime, st_after_follow.st_atime); } +TEST(SymlinkTest, SymlinkAtEmptyPath) { + auto file = ASSERT_NO_ERRNO_AND_VALUE(TempPath::CreateFile()); + auto dir = ASSERT_NO_ERRNO_AND_VALUE(TempPath::CreateDir()); + + auto fd = + ASSERT_NO_ERRNO_AND_VALUE(Open(dir.path(), O_RDONLY | O_DIRECTORY, 0666)); + EXPECT_THAT(symlinkat(file.path().c_str(), fd.get(), ""), + SyscallFailsWithErrno(ENOENT)); +} + class ParamSymlinkTest : public ::testing::TestWithParam<std::string> {}; // Test that creating an existing symlink with creat will create the target. diff --git a/test/syscalls/linux/tcp_socket.cc b/test/syscalls/linux/tcp_socket.cc index a6325a761..ab731db1d 100644 --- a/test/syscalls/linux/tcp_socket.cc +++ b/test/syscalls/linux/tcp_socket.cc @@ -13,9 +13,9 @@ // limitations under the License. #include <fcntl.h> -#ifndef __fuchsia__ +#ifdef __linux__ #include <linux/filter.h> -#endif // __fuchsia__ +#endif // __linux__ #include <netinet/in.h> #include <netinet/tcp.h> #include <poll.h> @@ -1586,7 +1586,7 @@ TEST_P(SimpleTcpSocketTest, SetTCPWindowClampAboveHalfMinRcvBuf) { } } -#ifndef __fuchsia__ +#ifdef __linux__ // TODO(gvisor.dev/2746): Support SO_ATTACH_FILTER/SO_DETACH_FILTER. // gVisor currently silently ignores attaching a filter. @@ -1620,6 +1620,8 @@ TEST_P(SimpleTcpSocketTest, SetSocketAttachDetachFilter) { SyscallSucceeds()); } +#endif // __linux__ + TEST_P(SimpleTcpSocketTest, SetSocketDetachFilterNoInstalledFilter) { // TODO(gvisor.dev/2746): Support SO_ATTACH_FILTER/SO_DETACH_FILTER. SKIP_IF(IsRunningOnGvisor()); @@ -1641,8 +1643,6 @@ TEST_P(SimpleTcpSocketTest, GetSocketDetachFilter) { SyscallFailsWithErrno(ENOPROTOOPT)); } -#endif // __fuchsia__ - INSTANTIATE_TEST_SUITE_P(AllInetTests, SimpleTcpSocketTest, ::testing::Values(AF_INET, AF_INET6)); diff --git a/test/syscalls/linux/udp_socket.cc b/test/syscalls/linux/udp_socket.cc index 7a8ac30a4..97db2b321 100644 --- a/test/syscalls/linux/udp_socket.cc +++ b/test/syscalls/linux/udp_socket.cc @@ -12,13 +12,1839 @@ // See the License for the specific language governing permissions and // limitations under the License. -#include "test/syscalls/linux/udp_socket_test_cases.h" +#include <arpa/inet.h> +#include <fcntl.h> +#ifdef __linux__ +#include <linux/errqueue.h> +#include <linux/filter.h> +#endif // __linux__ +#include <netinet/in.h> +#include <poll.h> +#include <sys/ioctl.h> +#include <sys/socket.h> +#include <sys/types.h> + +#include "absl/strings/str_format.h" +#ifndef SIOCGSTAMP +#include <linux/sockios.h> +#endif + +#include "gtest/gtest.h" +#include "absl/base/macros.h" +#include "absl/time/clock.h" +#include "absl/time/time.h" +#include "test/syscalls/linux/ip_socket_test_util.h" +#include "test/syscalls/linux/socket_test_util.h" +#include "test/syscalls/linux/unix_domain_socket_test_util.h" +#include "test/util/file_descriptor.h" +#include "test/util/posix_error.h" +#include "test/util/test_util.h" +#include "test/util/thread_util.h" namespace gvisor { namespace testing { namespace { +// Fixture for tests parameterized by the address family to use (AF_INET and +// AF_INET6) when creating sockets. +class UdpSocketTest + : public ::testing::TestWithParam<gvisor::testing::AddressFamily> { + protected: + // Creates two sockets that will be used by test cases. + void SetUp() override; + + // Binds the socket bind_ to the loopback and updates bind_addr_. + PosixError BindLoopback(); + + // Binds the socket bind_ to Any and updates bind_addr_. + PosixError BindAny(); + + // Binds given socket to address addr and updates. + PosixError BindSocket(int socket, struct sockaddr* addr); + + // Return initialized Any address to port 0. + struct sockaddr_storage InetAnyAddr(); + + // Return initialized Loopback address to port 0. + struct sockaddr_storage InetLoopbackAddr(); + + // Disconnects socket sockfd. + void Disconnect(int sockfd); + + // Get family for the test. + int GetFamily(); + + // Socket used by Bind methods + FileDescriptor bind_; + + // Second socket used for tests. + FileDescriptor sock_; + + // Address for bind_ socket. + struct sockaddr* bind_addr_; + + // Initialized to the length based on GetFamily(). + socklen_t addrlen_; + + // Storage for bind_addr_. + struct sockaddr_storage bind_addr_storage_; + + private: + // Helper to initialize addrlen_ for the test case. + socklen_t GetAddrLength(); +}; + +// Gets a pointer to the port component of the given address. +uint16_t* Port(struct sockaddr_storage* addr) { + switch (addr->ss_family) { + case AF_INET: { + auto sin = reinterpret_cast<struct sockaddr_in*>(addr); + return &sin->sin_port; + } + case AF_INET6: { + auto sin6 = reinterpret_cast<struct sockaddr_in6*>(addr); + return &sin6->sin6_port; + } + } + + return nullptr; +} + +// Sets addr port to "port". +void SetPort(struct sockaddr_storage* addr, uint16_t port) { + switch (addr->ss_family) { + case AF_INET: { + auto sin = reinterpret_cast<struct sockaddr_in*>(addr); + sin->sin_port = port; + break; + } + case AF_INET6: { + auto sin6 = reinterpret_cast<struct sockaddr_in6*>(addr); + sin6->sin6_port = port; + break; + } + } +} + +void UdpSocketTest::SetUp() { + addrlen_ = GetAddrLength(); + + bind_ = + ASSERT_NO_ERRNO_AND_VALUE(Socket(GetFamily(), SOCK_DGRAM, IPPROTO_UDP)); + memset(&bind_addr_storage_, 0, sizeof(bind_addr_storage_)); + bind_addr_ = reinterpret_cast<struct sockaddr*>(&bind_addr_storage_); + + sock_ = + ASSERT_NO_ERRNO_AND_VALUE(Socket(GetFamily(), SOCK_DGRAM, IPPROTO_UDP)); +} + +int UdpSocketTest::GetFamily() { + if (GetParam() == AddressFamily::kIpv4) { + return AF_INET; + } + return AF_INET6; +} + +PosixError UdpSocketTest::BindLoopback() { + bind_addr_storage_ = InetLoopbackAddr(); + struct sockaddr* bind_addr_ = + reinterpret_cast<struct sockaddr*>(&bind_addr_storage_); + return BindSocket(bind_.get(), bind_addr_); +} + +PosixError UdpSocketTest::BindAny() { + bind_addr_storage_ = InetAnyAddr(); + struct sockaddr* bind_addr_ = + reinterpret_cast<struct sockaddr*>(&bind_addr_storage_); + return BindSocket(bind_.get(), bind_addr_); +} + +PosixError UdpSocketTest::BindSocket(int socket, struct sockaddr* addr) { + socklen_t len = sizeof(bind_addr_storage_); + + // Bind, then check that we get the right address. + RETURN_ERROR_IF_SYSCALL_FAIL(bind(socket, addr, addrlen_)); + + RETURN_ERROR_IF_SYSCALL_FAIL(getsockname(socket, addr, &len)); + + if (addrlen_ != len) { + return PosixError( + EINVAL, + absl::StrFormat("getsockname len: %u expected: %u", len, addrlen_)); + } + return PosixError(0); +} + +socklen_t UdpSocketTest::GetAddrLength() { + struct sockaddr_storage addr; + if (GetFamily() == AF_INET) { + auto sin = reinterpret_cast<struct sockaddr_in*>(&addr); + return sizeof(*sin); + } + + auto sin6 = reinterpret_cast<struct sockaddr_in6*>(&addr); + return sizeof(*sin6); +} + +sockaddr_storage UdpSocketTest::InetAnyAddr() { + struct sockaddr_storage addr; + memset(&addr, 0, sizeof(addr)); + reinterpret_cast<struct sockaddr*>(&addr)->sa_family = GetFamily(); + + if (GetFamily() == AF_INET) { + auto sin = reinterpret_cast<struct sockaddr_in*>(&addr); + sin->sin_addr.s_addr = htonl(INADDR_ANY); + sin->sin_port = htons(0); + return addr; + } + + auto sin6 = reinterpret_cast<struct sockaddr_in6*>(&addr); + sin6->sin6_addr = IN6ADDR_ANY_INIT; + sin6->sin6_port = htons(0); + return addr; +} + +sockaddr_storage UdpSocketTest::InetLoopbackAddr() { + struct sockaddr_storage addr; + memset(&addr, 0, sizeof(addr)); + reinterpret_cast<struct sockaddr*>(&addr)->sa_family = GetFamily(); + + if (GetFamily() == AF_INET) { + auto sin = reinterpret_cast<struct sockaddr_in*>(&addr); + sin->sin_addr.s_addr = htonl(INADDR_LOOPBACK); + sin->sin_port = htons(0); + return addr; + } + auto sin6 = reinterpret_cast<struct sockaddr_in6*>(&addr); + sin6->sin6_addr = in6addr_loopback; + sin6->sin6_port = htons(0); + return addr; +} + +void UdpSocketTest::Disconnect(int sockfd) { + sockaddr_storage addr_storage = InetAnyAddr(); + sockaddr* addr = reinterpret_cast<struct sockaddr*>(&addr_storage); + socklen_t addrlen = sizeof(addr_storage); + + addr->sa_family = AF_UNSPEC; + ASSERT_THAT(connect(sockfd, addr, addrlen), SyscallSucceeds()); + + // Check that after disconnect the socket is bound to the ANY address. + EXPECT_THAT(getsockname(sockfd, addr, &addrlen), SyscallSucceeds()); + if (GetParam() == AddressFamily::kIpv4) { + auto addr_out = reinterpret_cast<struct sockaddr_in*>(addr); + EXPECT_EQ(addrlen, sizeof(*addr_out)); + EXPECT_EQ(addr_out->sin_addr.s_addr, htonl(INADDR_ANY)); + } else { + auto addr_out = reinterpret_cast<struct sockaddr_in6*>(addr); + EXPECT_EQ(addrlen, sizeof(*addr_out)); + struct in6_addr loopback = IN6ADDR_ANY_INIT; + + EXPECT_EQ(memcmp(&addr_out->sin6_addr, &loopback, sizeof(in6_addr)), 0); + } +} + +TEST_P(UdpSocketTest, Creation) { + FileDescriptor sock = + ASSERT_NO_ERRNO_AND_VALUE(Socket(GetFamily(), SOCK_DGRAM, IPPROTO_UDP)); + EXPECT_THAT(close(sock.release()), SyscallSucceeds()); + + sock = ASSERT_NO_ERRNO_AND_VALUE(Socket(GetFamily(), SOCK_DGRAM, 0)); + EXPECT_THAT(close(sock.release()), SyscallSucceeds()); + + ASSERT_THAT(socket(GetFamily(), SOCK_STREAM, IPPROTO_UDP), SyscallFails()); +} + +TEST_P(UdpSocketTest, Getsockname) { + // Check that we're not bound. + struct sockaddr_storage addr; + socklen_t addrlen = sizeof(addr); + EXPECT_THAT( + getsockname(bind_.get(), reinterpret_cast<sockaddr*>(&addr), &addrlen), + SyscallSucceeds()); + EXPECT_EQ(addrlen, addrlen_); + struct sockaddr_storage any = InetAnyAddr(); + EXPECT_EQ(memcmp(&addr, reinterpret_cast<struct sockaddr*>(&any), addrlen_), + 0); + + ASSERT_NO_ERRNO(BindLoopback()); + + EXPECT_THAT( + getsockname(bind_.get(), reinterpret_cast<sockaddr*>(&addr), &addrlen), + SyscallSucceeds()); + + EXPECT_EQ(addrlen, addrlen_); + EXPECT_EQ(memcmp(&addr, bind_addr_, addrlen_), 0); +} + +TEST_P(UdpSocketTest, Getpeername) { + ASSERT_NO_ERRNO(BindLoopback()); + + // Check that we're not connected. + struct sockaddr_storage addr; + socklen_t addrlen = sizeof(addr); + EXPECT_THAT( + getpeername(sock_.get(), reinterpret_cast<sockaddr*>(&addr), &addrlen), + SyscallFailsWithErrno(ENOTCONN)); + + // Connect, then check that we get the right address. + ASSERT_THAT(connect(sock_.get(), bind_addr_, addrlen_), SyscallSucceeds()); + + addrlen = sizeof(addr); + EXPECT_THAT( + getpeername(sock_.get(), reinterpret_cast<sockaddr*>(&addr), &addrlen), + SyscallSucceeds()); + EXPECT_EQ(addrlen, addrlen_); + EXPECT_EQ(memcmp(&addr, bind_addr_, addrlen_), 0); +} + +TEST_P(UdpSocketTest, SendNotConnected) { + ASSERT_NO_ERRNO(BindLoopback()); + + // Do send & write, they must fail. + char buf[512]; + EXPECT_THAT(send(sock_.get(), buf, sizeof(buf), 0), + SyscallFailsWithErrno(EDESTADDRREQ)); + + EXPECT_THAT(write(sock_.get(), buf, sizeof(buf)), + SyscallFailsWithErrno(EDESTADDRREQ)); + + // Use sendto. + ASSERT_THAT(sendto(sock_.get(), buf, sizeof(buf), 0, bind_addr_, addrlen_), + SyscallSucceedsWithValue(sizeof(buf))); + + // Check that we're bound now. + struct sockaddr_storage addr; + socklen_t addrlen = sizeof(addr); + EXPECT_THAT( + getsockname(sock_.get(), reinterpret_cast<sockaddr*>(&addr), &addrlen), + SyscallSucceeds()); + EXPECT_EQ(addrlen, addrlen_); + EXPECT_NE(*Port(&addr), 0); +} + +TEST_P(UdpSocketTest, ConnectBinds) { + ASSERT_NO_ERRNO(BindLoopback()); + + // Connect the socket. + ASSERT_THAT(connect(sock_.get(), bind_addr_, addrlen_), SyscallSucceeds()); + + // Check that we're bound now. + struct sockaddr_storage addr; + socklen_t addrlen = sizeof(addr); + EXPECT_THAT( + getsockname(sock_.get(), reinterpret_cast<sockaddr*>(&addr), &addrlen), + SyscallSucceeds()); + EXPECT_EQ(addrlen, addrlen_); + EXPECT_NE(*Port(&addr), 0); +} + +TEST_P(UdpSocketTest, ReceiveNotBound) { + char buf[512]; + EXPECT_THAT(recv(sock_.get(), buf, sizeof(buf), MSG_DONTWAIT), + SyscallFailsWithErrno(EWOULDBLOCK)); +} + +TEST_P(UdpSocketTest, Bind) { + ASSERT_NO_ERRNO(BindLoopback()); + + // Try to bind again. + EXPECT_THAT(bind(bind_.get(), bind_addr_, addrlen_), + SyscallFailsWithErrno(EINVAL)); + + // Check that we're still bound to the original address. + struct sockaddr_storage addr; + socklen_t addrlen = sizeof(addr); + EXPECT_THAT( + getsockname(bind_.get(), reinterpret_cast<sockaddr*>(&addr), &addrlen), + SyscallSucceeds()); + EXPECT_EQ(addrlen, addrlen_); + EXPECT_EQ(memcmp(&addr, bind_addr_, addrlen_), 0); +} + +TEST_P(UdpSocketTest, BindInUse) { + ASSERT_NO_ERRNO(BindLoopback()); + + // Try to bind again. + EXPECT_THAT(bind(sock_.get(), bind_addr_, addrlen_), + SyscallFailsWithErrno(EADDRINUSE)); +} + +TEST_P(UdpSocketTest, ReceiveAfterConnect) { + ASSERT_NO_ERRNO(BindLoopback()); + ASSERT_THAT(connect(sock_.get(), bind_addr_, addrlen_), SyscallSucceeds()); + + // Send from sock_ to bind_ + char buf[512]; + RandomizeBuffer(buf, sizeof(buf)); + ASSERT_THAT(sendto(sock_.get(), buf, sizeof(buf), 0, bind_addr_, addrlen_), + SyscallSucceedsWithValue(sizeof(buf))); + + // Receive the data. + char received[sizeof(buf)]; + EXPECT_THAT(recv(bind_.get(), received, sizeof(received), 0), + SyscallSucceedsWithValue(sizeof(received))); + EXPECT_EQ(memcmp(buf, received, sizeof(buf)), 0); +} + +TEST_P(UdpSocketTest, ReceiveAfterDisconnect) { + ASSERT_NO_ERRNO(BindLoopback()); + + for (int i = 0; i < 2; i++) { + // Connet sock_ to bound address. + ASSERT_THAT(connect(sock_.get(), bind_addr_, addrlen_), SyscallSucceeds()); + + struct sockaddr_storage addr; + socklen_t addrlen = sizeof(addr); + EXPECT_THAT( + getsockname(sock_.get(), reinterpret_cast<sockaddr*>(&addr), &addrlen), + SyscallSucceeds()); + EXPECT_EQ(addrlen, addrlen_); + + // Send from sock to bind_. + char buf[512]; + RandomizeBuffer(buf, sizeof(buf)); + + ASSERT_THAT(sendto(bind_.get(), buf, sizeof(buf), 0, + reinterpret_cast<sockaddr*>(&addr), addrlen), + SyscallSucceedsWithValue(sizeof(buf))); + + // Receive the data. + char received[sizeof(buf)]; + EXPECT_THAT(recv(sock_.get(), received, sizeof(received), 0), + SyscallSucceedsWithValue(sizeof(received))); + EXPECT_EQ(memcmp(buf, received, sizeof(buf)), 0); + + // Disconnect sock_. + struct sockaddr unspec = {}; + unspec.sa_family = AF_UNSPEC; + ASSERT_THAT(connect(sock_.get(), &unspec, sizeof(unspec.sa_family)), + SyscallSucceeds()); + } +} + +TEST_P(UdpSocketTest, Connect) { + ASSERT_NO_ERRNO(BindLoopback()); + ASSERT_THAT(connect(sock_.get(), bind_addr_, addrlen_), SyscallSucceeds()); + + // Check that we're connected to the right peer. + struct sockaddr_storage peer; + socklen_t peerlen = sizeof(peer); + EXPECT_THAT( + getpeername(sock_.get(), reinterpret_cast<sockaddr*>(&peer), &peerlen), + SyscallSucceeds()); + EXPECT_EQ(peerlen, addrlen_); + EXPECT_EQ(memcmp(&peer, bind_addr_, addrlen_), 0); + + // Try to bind after connect. + struct sockaddr_storage any = InetAnyAddr(); + EXPECT_THAT( + bind(sock_.get(), reinterpret_cast<struct sockaddr*>(&any), addrlen_), + SyscallFailsWithErrno(EINVAL)); + + struct sockaddr_storage bind2_storage = InetLoopbackAddr(); + struct sockaddr* bind2_addr = + reinterpret_cast<struct sockaddr*>(&bind2_storage); + FileDescriptor bind2 = + ASSERT_NO_ERRNO_AND_VALUE(Socket(GetFamily(), SOCK_DGRAM, IPPROTO_UDP)); + ASSERT_NO_ERRNO(BindSocket(bind2.get(), bind2_addr)); + + // Try to connect again. + EXPECT_THAT(connect(sock_.get(), bind2_addr, addrlen_), SyscallSucceeds()); + + // Check that peer name changed. + peerlen = sizeof(peer); + EXPECT_THAT( + getpeername(sock_.get(), reinterpret_cast<sockaddr*>(&peer), &peerlen), + SyscallSucceeds()); + EXPECT_EQ(peerlen, addrlen_); + EXPECT_EQ(memcmp(&peer, bind2_addr, addrlen_), 0); +} + +TEST_P(UdpSocketTest, ConnectAnyZero) { + // TODO(138658473): Enable when we can connect to port 0 with gVisor. + SKIP_IF(IsRunningOnGvisor()); + + struct sockaddr_storage any = InetAnyAddr(); + EXPECT_THAT( + connect(sock_.get(), reinterpret_cast<struct sockaddr*>(&any), addrlen_), + SyscallSucceeds()); + + struct sockaddr_storage addr; + socklen_t addrlen = sizeof(addr); + EXPECT_THAT( + getpeername(sock_.get(), reinterpret_cast<sockaddr*>(&addr), &addrlen), + SyscallFailsWithErrno(ENOTCONN)); +} + +TEST_P(UdpSocketTest, ConnectAnyWithPort) { + ASSERT_NO_ERRNO(BindAny()); + ASSERT_THAT(connect(sock_.get(), bind_addr_, addrlen_), SyscallSucceeds()); + + struct sockaddr_storage addr; + socklen_t addrlen = sizeof(addr); + EXPECT_THAT( + getpeername(sock_.get(), reinterpret_cast<sockaddr*>(&addr), &addrlen), + SyscallSucceeds()); +} + +TEST_P(UdpSocketTest, DisconnectAfterConnectAny) { + // TODO(138658473): Enable when we can connect to port 0 with gVisor. + SKIP_IF(IsRunningOnGvisor()); + struct sockaddr_storage any = InetAnyAddr(); + EXPECT_THAT( + connect(sock_.get(), reinterpret_cast<struct sockaddr*>(&any), addrlen_), + SyscallSucceeds()); + + struct sockaddr_storage addr; + socklen_t addrlen = sizeof(addr); + EXPECT_THAT( + getpeername(sock_.get(), reinterpret_cast<sockaddr*>(&addr), &addrlen), + SyscallFailsWithErrno(ENOTCONN)); + + Disconnect(sock_.get()); +} + +TEST_P(UdpSocketTest, DisconnectAfterConnectAnyWithPort) { + ASSERT_NO_ERRNO(BindAny()); + EXPECT_THAT(connect(sock_.get(), bind_addr_, addrlen_), SyscallSucceeds()); + + struct sockaddr_storage addr; + socklen_t addrlen = sizeof(addr); + EXPECT_THAT( + getpeername(sock_.get(), reinterpret_cast<sockaddr*>(&addr), &addrlen), + SyscallSucceeds()); + + EXPECT_EQ(addrlen, addrlen_); + EXPECT_EQ(*Port(&bind_addr_storage_), *Port(&addr)); + + Disconnect(sock_.get()); +} + +TEST_P(UdpSocketTest, DisconnectAfterBind) { + ASSERT_NO_ERRNO(BindLoopback()); + + // Bind to the next port above bind_. + struct sockaddr_storage addr_storage = InetLoopbackAddr(); + struct sockaddr* addr = reinterpret_cast<struct sockaddr*>(&addr_storage); + SetPort(&addr_storage, *Port(&bind_addr_storage_) + 1); + ASSERT_NO_ERRNO(BindSocket(sock_.get(), addr)); + + // Connect the socket. + ASSERT_THAT(connect(sock_.get(), bind_addr_, addrlen_), SyscallSucceeds()); + + struct sockaddr_storage unspec = {}; + unspec.ss_family = AF_UNSPEC; + EXPECT_THAT(connect(sock_.get(), reinterpret_cast<sockaddr*>(&unspec), + sizeof(unspec.ss_family)), + SyscallSucceeds()); + + // Check that we're still bound. + socklen_t addrlen = sizeof(unspec); + EXPECT_THAT( + getsockname(sock_.get(), reinterpret_cast<sockaddr*>(&unspec), &addrlen), + SyscallSucceeds()); + + EXPECT_EQ(addrlen, addrlen_); + EXPECT_EQ(memcmp(addr, &unspec, addrlen_), 0); + + addrlen = sizeof(addr); + EXPECT_THAT(getpeername(sock_.get(), addr, &addrlen), + SyscallFailsWithErrno(ENOTCONN)); +} + +TEST_P(UdpSocketTest, BindToAnyConnnectToLocalhost) { + ASSERT_NO_ERRNO(BindAny()); + + struct sockaddr_storage addr_storage = InetLoopbackAddr(); + struct sockaddr* addr = reinterpret_cast<struct sockaddr*>(&addr_storage); + SetPort(&addr_storage, *Port(&bind_addr_storage_) + 1); + socklen_t addrlen = sizeof(addr); + + // Connect the socket. + ASSERT_THAT(connect(bind_.get(), addr, addrlen_), SyscallSucceeds()); + + EXPECT_THAT(getsockname(bind_.get(), addr, &addrlen), SyscallSucceeds()); + + // If the socket is bound to ANY and connected to a loopback address, + // getsockname() has to return the loopback address. + if (GetParam() == AddressFamily::kIpv4) { + auto addr_out = reinterpret_cast<struct sockaddr_in*>(addr); + EXPECT_EQ(addrlen, sizeof(*addr_out)); + EXPECT_EQ(addr_out->sin_addr.s_addr, htonl(INADDR_LOOPBACK)); + } else { + auto addr_out = reinterpret_cast<struct sockaddr_in6*>(addr); + struct in6_addr loopback = IN6ADDR_LOOPBACK_INIT; + EXPECT_EQ(addrlen, sizeof(*addr_out)); + EXPECT_EQ(memcmp(&addr_out->sin6_addr, &loopback, sizeof(in6_addr)), 0); + } +} + +TEST_P(UdpSocketTest, DisconnectAfterBindToAny) { + ASSERT_NO_ERRNO(BindLoopback()); + + struct sockaddr_storage any_storage = InetAnyAddr(); + struct sockaddr* any = reinterpret_cast<struct sockaddr*>(&any_storage); + SetPort(&any_storage, *Port(&bind_addr_storage_) + 1); + + ASSERT_NO_ERRNO(BindSocket(sock_.get(), any)); + + // Connect the socket. + ASSERT_THAT(connect(sock_.get(), bind_addr_, addrlen_), SyscallSucceeds()); + + Disconnect(sock_.get()); + + // Check that we're still bound. + struct sockaddr_storage addr; + socklen_t addrlen = sizeof(addr); + EXPECT_THAT( + getsockname(sock_.get(), reinterpret_cast<sockaddr*>(&addr), &addrlen), + SyscallSucceeds()); + + EXPECT_EQ(addrlen, addrlen_); + EXPECT_EQ(memcmp(&addr, any, addrlen), 0); + + addrlen = sizeof(addr); + EXPECT_THAT( + getpeername(sock_.get(), reinterpret_cast<sockaddr*>(&addr), &addrlen), + SyscallFailsWithErrno(ENOTCONN)); +} + +TEST_P(UdpSocketTest, Disconnect) { + ASSERT_NO_ERRNO(BindLoopback()); + + struct sockaddr_storage any_storage = InetAnyAddr(); + struct sockaddr* any = reinterpret_cast<struct sockaddr*>(&any_storage); + SetPort(&any_storage, *Port(&bind_addr_storage_) + 1); + ASSERT_NO_ERRNO(BindSocket(sock_.get(), any)); + + for (int i = 0; i < 2; i++) { + // Try to connect again. + EXPECT_THAT(connect(sock_.get(), bind_addr_, addrlen_), SyscallSucceeds()); + + // Check that we're connected to the right peer. + struct sockaddr_storage peer; + socklen_t peerlen = sizeof(peer); + EXPECT_THAT( + getpeername(sock_.get(), reinterpret_cast<sockaddr*>(&peer), &peerlen), + SyscallSucceeds()); + EXPECT_EQ(peerlen, addrlen_); + EXPECT_EQ(memcmp(&peer, bind_addr_, addrlen_), 0); + + // Try to disconnect. + struct sockaddr_storage addr = {}; + addr.ss_family = AF_UNSPEC; + EXPECT_THAT(connect(sock_.get(), reinterpret_cast<sockaddr*>(&addr), + sizeof(addr.ss_family)), + SyscallSucceeds()); + + peerlen = sizeof(peer); + EXPECT_THAT( + getpeername(sock_.get(), reinterpret_cast<sockaddr*>(&peer), &peerlen), + SyscallFailsWithErrno(ENOTCONN)); + + // Check that we're still bound. + socklen_t addrlen = sizeof(addr); + EXPECT_THAT( + getsockname(sock_.get(), reinterpret_cast<sockaddr*>(&addr), &addrlen), + SyscallSucceeds()); + EXPECT_EQ(addrlen, addrlen_); + EXPECT_EQ(*Port(&addr), *Port(&any_storage)); + } +} + +TEST_P(UdpSocketTest, ConnectBadAddress) { + struct sockaddr addr = {}; + addr.sa_family = GetFamily(); + ASSERT_THAT(connect(sock_.get(), &addr, sizeof(addr.sa_family)), + SyscallFailsWithErrno(EINVAL)); +} + +TEST_P(UdpSocketTest, SendToAddressOtherThanConnected) { + ASSERT_NO_ERRNO(BindLoopback()); + + struct sockaddr_storage addr_storage = InetAnyAddr(); + struct sockaddr* addr = reinterpret_cast<struct sockaddr*>(&addr_storage); + SetPort(&addr_storage, *Port(&bind_addr_storage_) + 1); + + ASSERT_THAT(connect(sock_.get(), bind_addr_, addrlen_), SyscallSucceeds()); + + // Send to a different destination than we're connected to. + char buf[512]; + EXPECT_THAT(sendto(sock_.get(), buf, sizeof(buf), 0, addr, addrlen_), + SyscallSucceedsWithValue(sizeof(buf))); +} + +TEST_P(UdpSocketTest, ZerolengthWriteAllowed) { + // TODO(gvisor.dev/issue/1202): Hostinet does not support zero length writes. + SKIP_IF(IsRunningWithHostinet()); + + ASSERT_NO_ERRNO(BindLoopback()); + // Connect to loopback:bind_addr_+1. + struct sockaddr_storage addr_storage = InetLoopbackAddr(); + struct sockaddr* addr = reinterpret_cast<struct sockaddr*>(&addr_storage); + SetPort(&addr_storage, *Port(&bind_addr_storage_) + 1); + ASSERT_THAT(connect(bind_.get(), addr, addrlen_), SyscallSucceeds()); + + // Bind sock to loopback:bind_addr_+1. + ASSERT_THAT(bind(sock_.get(), addr, addrlen_), SyscallSucceeds()); + + char buf[3]; + // Send zero length packet from bind_ to sock_. + ASSERT_THAT(write(bind_.get(), buf, 0), SyscallSucceedsWithValue(0)); + + struct pollfd pfd = {sock_.get(), POLLIN, 0}; + ASSERT_THAT(RetryEINTR(poll)(&pfd, 1, /*timeout*/ 1000), + SyscallSucceedsWithValue(1)); + + // Receive the packet. + char received[3]; + EXPECT_THAT(read(sock_.get(), received, sizeof(received)), + SyscallSucceedsWithValue(0)); +} + +TEST_P(UdpSocketTest, ZerolengthWriteAllowedNonBlockRead) { + // TODO(gvisor.dev/issue/1202): Hostinet does not support zero length writes. + SKIP_IF(IsRunningWithHostinet()); + + ASSERT_NO_ERRNO(BindLoopback()); + + // Connect to loopback:bind_addr_port+1. + struct sockaddr_storage addr_storage = InetLoopbackAddr(); + struct sockaddr* addr = reinterpret_cast<struct sockaddr*>(&addr_storage); + SetPort(&addr_storage, *Port(&bind_addr_storage_) + 1); + ASSERT_THAT(connect(bind_.get(), addr, addrlen_), SyscallSucceeds()); + + // Bind sock to loopback:bind_addr_port+1. + ASSERT_THAT(bind(sock_.get(), addr, addrlen_), SyscallSucceeds()); + + // Set sock to non-blocking. + int opts = 0; + ASSERT_THAT(opts = fcntl(sock_.get(), F_GETFL), SyscallSucceeds()); + ASSERT_THAT(fcntl(sock_.get(), F_SETFL, opts | O_NONBLOCK), + SyscallSucceeds()); + + char buf[3]; + // Send zero length packet from bind_ to sock_. + ASSERT_THAT(write(bind_.get(), buf, 0), SyscallSucceedsWithValue(0)); + + struct pollfd pfd = {sock_.get(), POLLIN, 0}; + ASSERT_THAT(RetryEINTR(poll)(&pfd, 1, /*timeout=*/1000), + SyscallSucceedsWithValue(1)); + + // Receive the packet. + char received[3]; + EXPECT_THAT(read(sock_.get(), received, sizeof(received)), + SyscallSucceedsWithValue(0)); + EXPECT_THAT(read(sock_.get(), received, sizeof(received)), + SyscallFailsWithErrno(EAGAIN)); +} + +TEST_P(UdpSocketTest, SendAndReceiveNotConnected) { + ASSERT_NO_ERRNO(BindLoopback()); + + // Send some data to bind_. + char buf[512]; + RandomizeBuffer(buf, sizeof(buf)); + + ASSERT_THAT(sendto(sock_.get(), buf, sizeof(buf), 0, bind_addr_, addrlen_), + SyscallSucceedsWithValue(sizeof(buf))); + + // Receive the data. + char received[sizeof(buf)]; + EXPECT_THAT(recv(bind_.get(), received, sizeof(received), 0), + SyscallSucceedsWithValue(sizeof(received))); + EXPECT_EQ(memcmp(buf, received, sizeof(buf)), 0); +} + +TEST_P(UdpSocketTest, SendAndReceiveConnected) { + ASSERT_NO_ERRNO(BindLoopback()); + + // Connect to loopback:bind_addr_port+1. + struct sockaddr_storage addr_storage = InetLoopbackAddr(); + struct sockaddr* addr = reinterpret_cast<struct sockaddr*>(&addr_storage); + SetPort(&addr_storage, *Port(&bind_addr_storage_) + 1); + ASSERT_THAT(connect(bind_.get(), addr, addrlen_), SyscallSucceeds()); + + // Bind sock to loopback:bind_addr_port+1. + ASSERT_THAT(bind(sock_.get(), addr, addrlen_), SyscallSucceeds()); + + // Send some data from sock to bind_. + char buf[512]; + RandomizeBuffer(buf, sizeof(buf)); + + ASSERT_THAT(sendto(sock_.get(), buf, sizeof(buf), 0, bind_addr_, addrlen_), + SyscallSucceedsWithValue(sizeof(buf))); + + // Receive the data. + char received[sizeof(buf)]; + EXPECT_THAT(recv(bind_.get(), received, sizeof(received), 0), + SyscallSucceedsWithValue(sizeof(received))); + EXPECT_EQ(memcmp(buf, received, sizeof(buf)), 0); +} + +TEST_P(UdpSocketTest, ReceiveFromNotConnected) { + ASSERT_NO_ERRNO(BindLoopback()); + + // Connect to loopback:bind_addr_port+1. + struct sockaddr_storage addr_storage = InetLoopbackAddr(); + struct sockaddr* addr = reinterpret_cast<struct sockaddr*>(&addr_storage); + SetPort(&addr_storage, *Port(&bind_addr_storage_) + 1); + ASSERT_THAT(connect(bind_.get(), addr, addrlen_), SyscallSucceeds()); + + // Bind sock to loopback:bind_addr_port+2. + struct sockaddr_storage addr2_storage = InetLoopbackAddr(); + struct sockaddr* addr2 = reinterpret_cast<struct sockaddr*>(&addr2_storage); + SetPort(&addr2_storage, *Port(&bind_addr_storage_) + 2); + ASSERT_THAT(bind(sock_.get(), addr2, addrlen_), SyscallSucceeds()); + + // Send some data from sock to bind_. + char buf[512]; + ASSERT_THAT(sendto(sock_.get(), buf, sizeof(buf), 0, bind_addr_, addrlen_), + SyscallSucceedsWithValue(sizeof(buf))); + + // Check that the data isn't received because it was sent from a different + // address than we're connected. + EXPECT_THAT(recv(sock_.get(), buf, sizeof(buf), MSG_DONTWAIT), + SyscallFailsWithErrno(EWOULDBLOCK)); +} + +TEST_P(UdpSocketTest, ReceiveBeforeConnect) { + ASSERT_NO_ERRNO(BindLoopback()); + + // Bind sock to loopback:bind_addr_port+2. + struct sockaddr_storage addr2_storage = InetLoopbackAddr(); + struct sockaddr* addr2 = reinterpret_cast<struct sockaddr*>(&addr2_storage); + SetPort(&addr2_storage, *Port(&bind_addr_storage_) + 2); + ASSERT_THAT(bind(sock_.get(), addr2, addrlen_), SyscallSucceeds()); + + // Send some data from sock to bind_. + char buf[512]; + RandomizeBuffer(buf, sizeof(buf)); + + ASSERT_THAT(sendto(sock_.get(), buf, sizeof(buf), 0, bind_addr_, addrlen_), + SyscallSucceedsWithValue(sizeof(buf))); + + // Connect to loopback:bind_addr_port+1. + struct sockaddr_storage addr_storage = InetLoopbackAddr(); + struct sockaddr* addr = reinterpret_cast<struct sockaddr*>(&addr_storage); + SetPort(&addr_storage, *Port(&bind_addr_storage_) + 1); + ASSERT_THAT(connect(bind_.get(), addr, addrlen_), SyscallSucceeds()); + + // Receive the data. It works because it was sent before the connect. + char received[sizeof(buf)]; + EXPECT_THAT(recv(bind_.get(), received, sizeof(received), 0), + SyscallSucceedsWithValue(sizeof(received))); + EXPECT_EQ(memcmp(buf, received, sizeof(buf)), 0); + + // Send again. This time it should not be received. + ASSERT_THAT(sendto(sock_.get(), buf, sizeof(buf), 0, bind_addr_, addrlen_), + SyscallSucceedsWithValue(sizeof(buf))); + + EXPECT_THAT(recv(bind_.get(), buf, sizeof(buf), MSG_DONTWAIT), + SyscallFailsWithErrno(EWOULDBLOCK)); +} + +TEST_P(UdpSocketTest, ReceiveFrom) { + ASSERT_NO_ERRNO(BindLoopback()); + + // Connect to loopback:bind_addr_port+1. + struct sockaddr_storage addr_storage = InetLoopbackAddr(); + struct sockaddr* addr = reinterpret_cast<struct sockaddr*>(&addr_storage); + SetPort(&addr_storage, *Port(&bind_addr_storage_) + 1); + ASSERT_THAT(connect(bind_.get(), addr, addrlen_), SyscallSucceeds()); + + // Bind sock to loopback:bind_addr_port+1. + ASSERT_THAT(bind(sock_.get(), addr, addrlen_), SyscallSucceeds()); + + // Send some data from sock to bind_. + char buf[512]; + RandomizeBuffer(buf, sizeof(buf)); + + ASSERT_THAT(sendto(sock_.get(), buf, sizeof(buf), 0, bind_addr_, addrlen_), + SyscallSucceedsWithValue(sizeof(buf))); + + // Receive the data and sender address. + char received[sizeof(buf)]; + struct sockaddr_storage addr2; + socklen_t addr2len = sizeof(addr2); + EXPECT_THAT(recvfrom(bind_.get(), received, sizeof(received), 0, + reinterpret_cast<sockaddr*>(&addr2), &addr2len), + SyscallSucceedsWithValue(sizeof(received))); + EXPECT_EQ(memcmp(buf, received, sizeof(buf)), 0); + EXPECT_EQ(addr2len, addrlen_); + EXPECT_EQ(memcmp(addr, &addr2, addrlen_), 0); +} + +TEST_P(UdpSocketTest, Listen) { + ASSERT_THAT(listen(sock_.get(), SOMAXCONN), + SyscallFailsWithErrno(EOPNOTSUPP)); +} + +TEST_P(UdpSocketTest, Accept) { + ASSERT_THAT(accept(sock_.get(), nullptr, nullptr), + SyscallFailsWithErrno(EOPNOTSUPP)); +} + +// This test validates that a read shutdown with pending data allows the read +// to proceed with the data before returning EAGAIN. +TEST_P(UdpSocketTest, ReadShutdownNonblockPendingData) { + ASSERT_NO_ERRNO(BindLoopback()); + + // Connect to loopback:bind_addr_port+1. + struct sockaddr_storage addr_storage = InetLoopbackAddr(); + struct sockaddr* addr = reinterpret_cast<struct sockaddr*>(&addr_storage); + SetPort(&addr_storage, *Port(&bind_addr_storage_) + 1); + ASSERT_THAT(connect(bind_.get(), addr, addrlen_), SyscallSucceeds()); + + // Bind to loopback:bind_addr_port+1 and connect to bind_addr_. + ASSERT_THAT(bind(sock_.get(), addr, addrlen_), SyscallSucceeds()); + ASSERT_THAT(connect(sock_.get(), bind_addr_, addrlen_), SyscallSucceeds()); + + // Verify that we get EWOULDBLOCK when there is nothing to read. + char received[512]; + EXPECT_THAT(recv(bind_.get(), received, sizeof(received), MSG_DONTWAIT), + SyscallFailsWithErrno(EWOULDBLOCK)); + + const char* buf = "abc"; + EXPECT_THAT(write(sock_.get(), buf, 3), SyscallSucceedsWithValue(3)); + + int opts = 0; + ASSERT_THAT(opts = fcntl(bind_.get(), F_GETFL), SyscallSucceeds()); + ASSERT_THAT(fcntl(bind_.get(), F_SETFL, opts | O_NONBLOCK), + SyscallSucceeds()); + ASSERT_THAT(opts = fcntl(bind_.get(), F_GETFL), SyscallSucceeds()); + ASSERT_NE(opts & O_NONBLOCK, 0); + + EXPECT_THAT(shutdown(bind_.get(), SHUT_RD), SyscallSucceeds()); + + struct pollfd pfd = {bind_.get(), POLLIN, 0}; + ASSERT_THAT(RetryEINTR(poll)(&pfd, 1, /*timeout=*/1000), + SyscallSucceedsWithValue(1)); + + // We should get the data even though read has been shutdown. + EXPECT_THAT(recv(bind_.get(), received, 2, 0), SyscallSucceedsWithValue(2)); + + // Because we read less than the entire packet length, since it's a packet + // based socket any subsequent reads should return EWOULDBLOCK. + EXPECT_THAT(recv(bind_.get(), received, 1, 0), + SyscallFailsWithErrno(EWOULDBLOCK)); +} + +// This test is validating that even after a socket is shutdown if it's +// reconnected it will reset the shutdown state. +TEST_P(UdpSocketTest, ReadShutdownSameSocketResetsShutdownState) { + char received[512]; + EXPECT_THAT(recv(bind_.get(), received, sizeof(received), MSG_DONTWAIT), + SyscallFailsWithErrno(EWOULDBLOCK)); + + EXPECT_THAT(shutdown(bind_.get(), SHUT_RD), SyscallFailsWithErrno(ENOTCONN)); + + EXPECT_THAT(recv(bind_.get(), received, sizeof(received), MSG_DONTWAIT), + SyscallFailsWithErrno(EWOULDBLOCK)); + + // Connect the socket, then try to shutdown again. + ASSERT_NO_ERRNO(BindLoopback()); + + // Connect to loopback:bind_addr_port+1. + struct sockaddr_storage addr_storage = InetLoopbackAddr(); + struct sockaddr* addr = reinterpret_cast<struct sockaddr*>(&addr_storage); + SetPort(&addr_storage, *Port(&bind_addr_storage_) + 1); + ASSERT_THAT(connect(bind_.get(), addr, addrlen_), SyscallSucceeds()); + + EXPECT_THAT(recv(bind_.get(), received, sizeof(received), MSG_DONTWAIT), + SyscallFailsWithErrno(EWOULDBLOCK)); +} + +TEST_P(UdpSocketTest, ReadShutdown) { + // TODO(gvisor.dev/issue/1202): Calling recv() after shutdown without + // MSG_DONTWAIT blocks indefinitely. + SKIP_IF(IsRunningWithHostinet()); + + ASSERT_NO_ERRNO(BindLoopback()); + + char received[512]; + EXPECT_THAT(recv(sock_.get(), received, sizeof(received), MSG_DONTWAIT), + SyscallFailsWithErrno(EWOULDBLOCK)); + + EXPECT_THAT(shutdown(sock_.get(), SHUT_RD), SyscallFailsWithErrno(ENOTCONN)); + + EXPECT_THAT(recv(sock_.get(), received, sizeof(received), MSG_DONTWAIT), + SyscallFailsWithErrno(EWOULDBLOCK)); + + // Connect the socket, then try to shutdown again. + ASSERT_THAT(connect(sock_.get(), bind_addr_, addrlen_), SyscallSucceeds()); + + EXPECT_THAT(recv(sock_.get(), received, sizeof(received), MSG_DONTWAIT), + SyscallFailsWithErrno(EWOULDBLOCK)); + + EXPECT_THAT(shutdown(sock_.get(), SHUT_RD), SyscallSucceeds()); + + EXPECT_THAT(recv(sock_.get(), received, sizeof(received), 0), + SyscallSucceedsWithValue(0)); +} + +TEST_P(UdpSocketTest, ReadShutdownDifferentThread) { + // TODO(gvisor.dev/issue/1202): Calling recv() after shutdown without + // MSG_DONTWAIT blocks indefinitely. + SKIP_IF(IsRunningWithHostinet()); + ASSERT_NO_ERRNO(BindLoopback()); + + char received[512]; + EXPECT_THAT(recv(sock_.get(), received, sizeof(received), MSG_DONTWAIT), + SyscallFailsWithErrno(EWOULDBLOCK)); + + // Connect the socket, then shutdown from another thread. + ASSERT_THAT(connect(sock_.get(), bind_addr_, addrlen_), SyscallSucceeds()); + + EXPECT_THAT(recv(sock_.get(), received, sizeof(received), MSG_DONTWAIT), + SyscallFailsWithErrno(EWOULDBLOCK)); + + ScopedThread t([&] { + absl::SleepFor(absl::Milliseconds(200)); + EXPECT_THAT(shutdown(sock_.get(), SHUT_RD), SyscallSucceeds()); + }); + EXPECT_THAT(RetryEINTR(recv)(sock_.get(), received, sizeof(received), 0), + SyscallSucceedsWithValue(0)); + t.Join(); + + EXPECT_THAT(RetryEINTR(recv)(sock_.get(), received, sizeof(received), 0), + SyscallSucceedsWithValue(0)); +} + +TEST_P(UdpSocketTest, WriteShutdown) { + ASSERT_NO_ERRNO(BindLoopback()); + EXPECT_THAT(shutdown(sock_.get(), SHUT_WR), SyscallFailsWithErrno(ENOTCONN)); + ASSERT_THAT(connect(sock_.get(), bind_addr_, addrlen_), SyscallSucceeds()); + EXPECT_THAT(shutdown(sock_.get(), SHUT_WR), SyscallSucceeds()); +} + +TEST_P(UdpSocketTest, SynchronousReceive) { + ASSERT_NO_ERRNO(BindLoopback()); + + // Send some data to bind_ from another thread. + char buf[512]; + RandomizeBuffer(buf, sizeof(buf)); + + // Receive the data prior to actually starting the other thread. + char received[512]; + EXPECT_THAT( + RetryEINTR(recv)(bind_.get(), received, sizeof(received), MSG_DONTWAIT), + SyscallFailsWithErrno(EWOULDBLOCK)); + + // Start the thread. + ScopedThread t([&] { + absl::SleepFor(absl::Milliseconds(200)); + ASSERT_THAT(sendto(sock_.get(), buf, sizeof(buf), 0, this->bind_addr_, + this->addrlen_), + SyscallSucceedsWithValue(sizeof(buf))); + }); + + EXPECT_THAT(RetryEINTR(recv)(bind_.get(), received, sizeof(received), 0), + SyscallSucceedsWithValue(512)); + EXPECT_EQ(memcmp(buf, received, sizeof(buf)), 0); +} + +TEST_P(UdpSocketTest, BoundaryPreserved_SendRecv) { + ASSERT_NO_ERRNO(BindLoopback()); + + // Send 3 packets from sock to bind_. + constexpr int psize = 100; + char buf[3 * psize]; + RandomizeBuffer(buf, sizeof(buf)); + + for (int i = 0; i < 3; ++i) { + ASSERT_THAT( + sendto(sock_.get(), buf + i * psize, psize, 0, bind_addr_, addrlen_), + SyscallSucceedsWithValue(psize)); + } + + // Receive the data as 3 separate packets. + char received[6 * psize]; + for (int i = 0; i < 3; ++i) { + EXPECT_THAT(recv(bind_.get(), received + i * psize, 3 * psize, 0), + SyscallSucceedsWithValue(psize)); + } + EXPECT_EQ(memcmp(buf, received, 3 * psize), 0); +} + +TEST_P(UdpSocketTest, BoundaryPreserved_WritevReadv) { + ASSERT_NO_ERRNO(BindLoopback()); + + // Direct writes from sock to bind_. + ASSERT_THAT(connect(sock_.get(), bind_addr_, addrlen_), SyscallSucceeds()); + + // Send 2 packets from sock to bind_, where each packet's data consists of + // 2 discontiguous iovecs. + constexpr size_t kPieceSize = 100; + char buf[4 * kPieceSize]; + RandomizeBuffer(buf, sizeof(buf)); + + for (int i = 0; i < 2; i++) { + struct iovec iov[2]; + for (int j = 0; j < 2; j++) { + iov[j].iov_base = reinterpret_cast<void*>( + reinterpret_cast<uintptr_t>(buf) + (i + 2 * j) * kPieceSize); + iov[j].iov_len = kPieceSize; + } + ASSERT_THAT(writev(sock_.get(), iov, 2), + SyscallSucceedsWithValue(2 * kPieceSize)); + } + + // Receive the data as 2 separate packets. + char received[6 * kPieceSize]; + for (int i = 0; i < 2; i++) { + struct iovec iov[3]; + for (int j = 0; j < 3; j++) { + iov[j].iov_base = reinterpret_cast<void*>( + reinterpret_cast<uintptr_t>(received) + (i + 2 * j) * kPieceSize); + iov[j].iov_len = kPieceSize; + } + ASSERT_THAT(readv(bind_.get(), iov, 3), + SyscallSucceedsWithValue(2 * kPieceSize)); + } + EXPECT_EQ(memcmp(buf, received, 4 * kPieceSize), 0); +} + +TEST_P(UdpSocketTest, BoundaryPreserved_SendMsgRecvMsg) { + ASSERT_NO_ERRNO(BindLoopback()); + + // Send 2 packets from sock to bind_, where each packet's data consists of + // 2 discontiguous iovecs. + constexpr size_t kPieceSize = 100; + char buf[4 * kPieceSize]; + RandomizeBuffer(buf, sizeof(buf)); + + for (int i = 0; i < 2; i++) { + struct iovec iov[2]; + for (int j = 0; j < 2; j++) { + iov[j].iov_base = reinterpret_cast<void*>( + reinterpret_cast<uintptr_t>(buf) + (i + 2 * j) * kPieceSize); + iov[j].iov_len = kPieceSize; + } + struct msghdr msg = {}; + msg.msg_name = bind_addr_; + msg.msg_namelen = addrlen_; + msg.msg_iov = iov; + msg.msg_iovlen = 2; + ASSERT_THAT(sendmsg(sock_.get(), &msg, 0), + SyscallSucceedsWithValue(2 * kPieceSize)); + } + + // Receive the data as 2 separate packets. + char received[6 * kPieceSize]; + for (int i = 0; i < 2; i++) { + struct iovec iov[3]; + for (int j = 0; j < 3; j++) { + iov[j].iov_base = reinterpret_cast<void*>( + reinterpret_cast<uintptr_t>(received) + (i + 2 * j) * kPieceSize); + iov[j].iov_len = kPieceSize; + } + struct msghdr msg = {}; + msg.msg_iov = iov; + msg.msg_iovlen = 3; + ASSERT_THAT(recvmsg(bind_.get(), &msg, 0), + SyscallSucceedsWithValue(2 * kPieceSize)); + } + EXPECT_EQ(memcmp(buf, received, 4 * kPieceSize), 0); +} + +TEST_P(UdpSocketTest, FIONREADShutdown) { + ASSERT_NO_ERRNO(BindLoopback()); + + int n = -1; + EXPECT_THAT(ioctl(sock_.get(), FIONREAD, &n), SyscallSucceedsWithValue(0)); + EXPECT_EQ(n, 0); + + // A UDP socket must be connected before it can be shutdown. + ASSERT_THAT(connect(sock_.get(), bind_addr_, addrlen_), SyscallSucceeds()); + + n = -1; + EXPECT_THAT(ioctl(sock_.get(), FIONREAD, &n), SyscallSucceedsWithValue(0)); + EXPECT_EQ(n, 0); + + EXPECT_THAT(shutdown(sock_.get(), SHUT_RD), SyscallSucceeds()); + + n = -1; + EXPECT_THAT(ioctl(sock_.get(), FIONREAD, &n), SyscallSucceedsWithValue(0)); + EXPECT_EQ(n, 0); +} + +TEST_P(UdpSocketTest, FIONREADWriteShutdown) { + int n = -1; + EXPECT_THAT(ioctl(bind_.get(), FIONREAD, &n), SyscallSucceedsWithValue(0)); + EXPECT_EQ(n, 0); + + ASSERT_NO_ERRNO(BindLoopback()); + + // A UDP socket must be connected before it can be shutdown. + ASSERT_THAT(connect(bind_.get(), bind_addr_, addrlen_), SyscallSucceeds()); + + n = -1; + EXPECT_THAT(ioctl(bind_.get(), FIONREAD, &n), SyscallSucceedsWithValue(0)); + EXPECT_EQ(n, 0); + + const char str[] = "abc"; + ASSERT_THAT(send(bind_.get(), str, sizeof(str), 0), + SyscallSucceedsWithValue(sizeof(str))); + + struct pollfd pfd = {bind_.get(), POLLIN, 0}; + ASSERT_THAT(RetryEINTR(poll)(&pfd, 1, /*timeout=*/1000), + SyscallSucceedsWithValue(1)); + + n = -1; + EXPECT_THAT(ioctl(bind_.get(), FIONREAD, &n), SyscallSucceedsWithValue(0)); + EXPECT_EQ(n, sizeof(str)); + + EXPECT_THAT(shutdown(bind_.get(), SHUT_RD), SyscallSucceeds()); + + n = -1; + EXPECT_THAT(ioctl(bind_.get(), FIONREAD, &n), SyscallSucceedsWithValue(0)); + EXPECT_EQ(n, sizeof(str)); +} + +// NOTE: Do not use `FIONREAD` as test name because it will be replaced by the +// corresponding macro and become `0x541B`. +TEST_P(UdpSocketTest, Fionread) { + ASSERT_NO_ERRNO(BindLoopback()); + + // Check that the bound socket with an empty buffer reports an empty first + // packet. + int n = -1; + EXPECT_THAT(ioctl(bind_.get(), FIONREAD, &n), SyscallSucceedsWithValue(0)); + EXPECT_EQ(n, 0); + + // Send 3 packets from sock to bind_. + constexpr int psize = 100; + char buf[3 * psize]; + RandomizeBuffer(buf, sizeof(buf)); + + struct pollfd pfd = {bind_.get(), POLLIN, 0}; + for (int i = 0; i < 3; ++i) { + ASSERT_THAT( + sendto(sock_.get(), buf + i * psize, psize, 0, bind_addr_, addrlen_), + SyscallSucceedsWithValue(psize)); + + ASSERT_THAT(RetryEINTR(poll)(&pfd, 1, /*timeout=*/1000), + SyscallSucceedsWithValue(1)); + + // Check that regardless of how many packets are in the queue, the size + // reported is that of a single packet. + n = -1; + EXPECT_THAT(ioctl(bind_.get(), FIONREAD, &n), SyscallSucceedsWithValue(0)); + EXPECT_EQ(n, psize); + } +} + +TEST_P(UdpSocketTest, FIONREADZeroLengthPacket) { + ASSERT_NO_ERRNO(BindLoopback()); + + // Check that the bound socket with an empty buffer reports an empty first + // packet. + int n = -1; + EXPECT_THAT(ioctl(bind_.get(), FIONREAD, &n), SyscallSucceedsWithValue(0)); + EXPECT_EQ(n, 0); + + // Send 3 packets from sock to bind_. + constexpr int psize = 100; + char buf[3 * psize]; + RandomizeBuffer(buf, sizeof(buf)); + + struct pollfd pfd = {bind_.get(), POLLIN, 0}; + for (int i = 0; i < 3; ++i) { + ASSERT_THAT( + sendto(sock_.get(), buf + i * psize, 0, 0, bind_addr_, addrlen_), + SyscallSucceedsWithValue(0)); + + // TODO(gvisor.dev/issue/2726): sending a zero-length message to a hostinet + // socket does not cause a poll event to be triggered. + if (!IsRunningWithHostinet()) { + ASSERT_THAT(RetryEINTR(poll)(&pfd, 1, /*timeout=*/1000), + SyscallSucceedsWithValue(1)); + } + + // Check that regardless of how many packets are in the queue, the size + // reported is that of a single packet. + n = -1; + EXPECT_THAT(ioctl(bind_.get(), FIONREAD, &n), SyscallSucceedsWithValue(0)); + EXPECT_EQ(n, 0); + } +} + +TEST_P(UdpSocketTest, FIONREADZeroLengthWriteShutdown) { + int n = -1; + EXPECT_THAT(ioctl(bind_.get(), FIONREAD, &n), SyscallSucceedsWithValue(0)); + EXPECT_EQ(n, 0); + + ASSERT_NO_ERRNO(BindLoopback()); + + // A UDP socket must be connected before it can be shutdown. + ASSERT_THAT(connect(bind_.get(), bind_addr_, addrlen_), SyscallSucceeds()); + + n = -1; + EXPECT_THAT(ioctl(bind_.get(), FIONREAD, &n), SyscallSucceedsWithValue(0)); + EXPECT_EQ(n, 0); + + const char str[] = "abc"; + ASSERT_THAT(send(bind_.get(), str, 0, 0), SyscallSucceedsWithValue(0)); + + n = -1; + EXPECT_THAT(ioctl(bind_.get(), FIONREAD, &n), SyscallSucceedsWithValue(0)); + EXPECT_EQ(n, 0); + + EXPECT_THAT(shutdown(bind_.get(), SHUT_RD), SyscallSucceeds()); + + n = -1; + EXPECT_THAT(ioctl(bind_.get(), FIONREAD, &n), SyscallSucceedsWithValue(0)); + EXPECT_EQ(n, 0); +} + +TEST_P(UdpSocketTest, SoNoCheckOffByDefault) { + // TODO(gvisor.dev/issue/1202): SO_NO_CHECK socket option not supported by + // hostinet. + SKIP_IF(IsRunningWithHostinet()); + + int v = -1; + socklen_t optlen = sizeof(v); + ASSERT_THAT(getsockopt(bind_.get(), SOL_SOCKET, SO_NO_CHECK, &v, &optlen), + SyscallSucceeds()); + ASSERT_EQ(v, kSockOptOff); + ASSERT_EQ(optlen, sizeof(v)); +} + +TEST_P(UdpSocketTest, SoNoCheck) { + // TODO(gvisor.dev/issue/1202): SO_NO_CHECK socket option not supported by + // hostinet. + SKIP_IF(IsRunningWithHostinet()); + + int v = kSockOptOn; + socklen_t optlen = sizeof(v); + ASSERT_THAT(setsockopt(bind_.get(), SOL_SOCKET, SO_NO_CHECK, &v, optlen), + SyscallSucceeds()); + v = -1; + ASSERT_THAT(getsockopt(bind_.get(), SOL_SOCKET, SO_NO_CHECK, &v, &optlen), + SyscallSucceeds()); + ASSERT_EQ(v, kSockOptOn); + ASSERT_EQ(optlen, sizeof(v)); + + v = kSockOptOff; + ASSERT_THAT(setsockopt(bind_.get(), SOL_SOCKET, SO_NO_CHECK, &v, optlen), + SyscallSucceeds()); + v = -1; + ASSERT_THAT(getsockopt(bind_.get(), SOL_SOCKET, SO_NO_CHECK, &v, &optlen), + SyscallSucceeds()); + ASSERT_EQ(v, kSockOptOff); + ASSERT_EQ(optlen, sizeof(v)); +} + +#ifdef __linux__ +TEST_P(UdpSocketTest, ErrorQueue) { + char cmsgbuf[CMSG_SPACE(sizeof(sock_extended_err))]; + msghdr msg; + memset(&msg, 0, sizeof(msg)); + iovec iov; + memset(&iov, 0, sizeof(iov)); + msg.msg_iov = &iov; + msg.msg_iovlen = 1; + msg.msg_control = cmsgbuf; + msg.msg_controllen = sizeof(cmsgbuf); + + // recv*(MSG_ERRQUEUE) never blocks, even without MSG_DONTWAIT. + EXPECT_THAT(RetryEINTR(recvmsg)(bind_.get(), &msg, MSG_ERRQUEUE), + SyscallFailsWithErrno(EAGAIN)); +} +#endif // __linux__ + +TEST_P(UdpSocketTest, SoTimestampOffByDefault) { + // TODO(gvisor.dev/issue/1202): SO_TIMESTAMP socket option not supported by + // hostinet. + SKIP_IF(IsRunningWithHostinet()); + + int v = -1; + socklen_t optlen = sizeof(v); + ASSERT_THAT(getsockopt(bind_.get(), SOL_SOCKET, SO_TIMESTAMP, &v, &optlen), + SyscallSucceeds()); + ASSERT_EQ(v, kSockOptOff); + ASSERT_EQ(optlen, sizeof(v)); +} + +TEST_P(UdpSocketTest, SoTimestamp) { + // TODO(gvisor.dev/issue/1202): ioctl() and SO_TIMESTAMP socket option are not + // supported by hostinet. + SKIP_IF(IsRunningWithHostinet()); + + ASSERT_NO_ERRNO(BindLoopback()); + ASSERT_THAT(connect(sock_.get(), bind_addr_, addrlen_), SyscallSucceeds()); + + int v = 1; + ASSERT_THAT(setsockopt(bind_.get(), SOL_SOCKET, SO_TIMESTAMP, &v, sizeof(v)), + SyscallSucceeds()); + + char buf[3]; + // Send zero length packet from sock to bind_. + ASSERT_THAT(RetryEINTR(write)(sock_.get(), buf, 0), + SyscallSucceedsWithValue(0)); + + struct pollfd pfd = {bind_.get(), POLLIN, 0}; + ASSERT_THAT(RetryEINTR(poll)(&pfd, 1, /*timeout=*/1000), + SyscallSucceedsWithValue(1)); + + char cmsgbuf[CMSG_SPACE(sizeof(struct timeval))]; + msghdr msg; + memset(&msg, 0, sizeof(msg)); + iovec iov; + memset(&iov, 0, sizeof(iov)); + msg.msg_iov = &iov; + msg.msg_iovlen = 1; + msg.msg_control = cmsgbuf; + msg.msg_controllen = sizeof(cmsgbuf); + + ASSERT_THAT(RetryEINTR(recvmsg)(bind_.get(), &msg, 0), + SyscallSucceedsWithValue(0)); + + struct cmsghdr* cmsg = CMSG_FIRSTHDR(&msg); + ASSERT_NE(cmsg, nullptr); + ASSERT_EQ(cmsg->cmsg_level, SOL_SOCKET); + ASSERT_EQ(cmsg->cmsg_type, SO_TIMESTAMP); + ASSERT_EQ(cmsg->cmsg_len, CMSG_LEN(sizeof(struct timeval))); + + struct timeval tv = {}; + memcpy(&tv, CMSG_DATA(cmsg), sizeof(struct timeval)); + + ASSERT_TRUE(tv.tv_sec != 0 || tv.tv_usec != 0); + + // There should be nothing to get via ioctl. + ASSERT_THAT(ioctl(bind_.get(), SIOCGSTAMP, &tv), + SyscallFailsWithErrno(ENOENT)); +} + +TEST_P(UdpSocketTest, WriteShutdownNotConnected) { + EXPECT_THAT(shutdown(bind_.get(), SHUT_WR), SyscallFailsWithErrno(ENOTCONN)); +} + +TEST_P(UdpSocketTest, TimestampIoctl) { + // TODO(gvisor.dev/issue/1202): ioctl() is not supported by hostinet. + SKIP_IF(IsRunningWithHostinet()); + + ASSERT_NO_ERRNO(BindLoopback()); + ASSERT_THAT(connect(sock_.get(), bind_addr_, addrlen_), SyscallSucceeds()); + + char buf[3]; + // Send packet from sock to bind_. + ASSERT_THAT(RetryEINTR(write)(sock_.get(), buf, sizeof(buf)), + SyscallSucceedsWithValue(sizeof(buf))); + + struct pollfd pfd = {bind_.get(), POLLIN, 0}; + ASSERT_THAT(RetryEINTR(poll)(&pfd, 1, /*timeout=*/1000), + SyscallSucceedsWithValue(1)); + + // There should be no control messages. + char recv_buf[sizeof(buf)]; + ASSERT_NO_FATAL_FAILURE(RecvNoCmsg(bind_.get(), recv_buf, sizeof(recv_buf))); + + // A nonzero timeval should be available via ioctl. + struct timeval tv = {}; + ASSERT_THAT(ioctl(bind_.get(), SIOCGSTAMP, &tv), SyscallSucceeds()); + ASSERT_TRUE(tv.tv_sec != 0 || tv.tv_usec != 0); +} + +TEST_P(UdpSocketTest, TimestampIoctlNothingRead) { + // TODO(gvisor.dev/issue/1202): ioctl() is not supported by hostinet. + SKIP_IF(IsRunningWithHostinet()); + + ASSERT_NO_ERRNO(BindLoopback()); + ASSERT_THAT(connect(sock_.get(), bind_addr_, addrlen_), SyscallSucceeds()); + + struct timeval tv = {}; + ASSERT_THAT(ioctl(sock_.get(), SIOCGSTAMP, &tv), + SyscallFailsWithErrno(ENOENT)); +} + +// Test that the timestamp accessed via SIOCGSTAMP is still accessible after +// SO_TIMESTAMP is enabled and used to retrieve a timestamp. +TEST_P(UdpSocketTest, TimestampIoctlPersistence) { + // TODO(gvisor.dev/issue/1202): ioctl() and SO_TIMESTAMP socket option are not + // supported by hostinet. + SKIP_IF(IsRunningWithHostinet()); + + ASSERT_NO_ERRNO(BindLoopback()); + ASSERT_THAT(connect(sock_.get(), bind_addr_, addrlen_), SyscallSucceeds()); + + char buf[3]; + // Send packet from sock to bind_. + ASSERT_THAT(RetryEINTR(write)(sock_.get(), buf, sizeof(buf)), + SyscallSucceedsWithValue(sizeof(buf))); + ASSERT_THAT(RetryEINTR(write)(sock_.get(), buf, 0), + SyscallSucceedsWithValue(0)); + + struct pollfd pfd = {bind_.get(), POLLIN, 0}; + ASSERT_THAT(RetryEINTR(poll)(&pfd, 1, /*timeout=*/1000), + SyscallSucceedsWithValue(1)); + + // There should be no control messages. + char recv_buf[sizeof(buf)]; + ASSERT_NO_FATAL_FAILURE(RecvNoCmsg(bind_.get(), recv_buf, sizeof(recv_buf))); + + // A nonzero timeval should be available via ioctl. + struct timeval tv = {}; + ASSERT_THAT(ioctl(bind_.get(), SIOCGSTAMP, &tv), SyscallSucceeds()); + ASSERT_TRUE(tv.tv_sec != 0 || tv.tv_usec != 0); + + // Enable SO_TIMESTAMP and send a message. + int v = 1; + EXPECT_THAT(setsockopt(bind_.get(), SOL_SOCKET, SO_TIMESTAMP, &v, sizeof(v)), + SyscallSucceeds()); + ASSERT_THAT(RetryEINTR(write)(sock_.get(), buf, 0), + SyscallSucceedsWithValue(0)); + + ASSERT_THAT(RetryEINTR(poll)(&pfd, 1, /*timeout=*/1000), + SyscallSucceedsWithValue(1)); + + // There should be a message for SO_TIMESTAMP. + char cmsgbuf[CMSG_SPACE(sizeof(struct timeval))]; + msghdr msg = {}; + iovec iov = {}; + msg.msg_iov = &iov; + msg.msg_iovlen = 1; + msg.msg_control = cmsgbuf; + msg.msg_controllen = sizeof(cmsgbuf); + ASSERT_THAT(RetryEINTR(recvmsg)(bind_.get(), &msg, 0), + SyscallSucceedsWithValue(0)); + struct cmsghdr* cmsg = CMSG_FIRSTHDR(&msg); + ASSERT_NE(cmsg, nullptr); + + // The ioctl should return the exact same values as before. + struct timeval tv2 = {}; + ASSERT_THAT(ioctl(bind_.get(), SIOCGSTAMP, &tv2), SyscallSucceeds()); + ASSERT_EQ(tv.tv_sec, tv2.tv_sec); + ASSERT_EQ(tv.tv_usec, tv2.tv_usec); +} + +// Test that a socket with IP_TOS or IPV6_TCLASS set will set the TOS byte on +// outgoing packets, and that a receiving socket with IP_RECVTOS or +// IPV6_RECVTCLASS will create the corresponding control message. +TEST_P(UdpSocketTest, SetAndReceiveTOS) { + ASSERT_NO_ERRNO(BindLoopback()); + ASSERT_THAT(connect(sock_.get(), bind_addr_, addrlen_), SyscallSucceeds()); + + // Allow socket to receive control message. + int recv_level = SOL_IP; + int recv_type = IP_RECVTOS; + if (GetParam() != AddressFamily::kIpv4) { + recv_level = SOL_IPV6; + recv_type = IPV6_RECVTCLASS; + } + ASSERT_THAT(setsockopt(bind_.get(), recv_level, recv_type, &kSockOptOn, + sizeof(kSockOptOn)), + SyscallSucceeds()); + + // Set socket TOS. + int sent_level = recv_level; + int sent_type = IP_TOS; + if (sent_level == SOL_IPV6) { + sent_type = IPV6_TCLASS; + } + int sent_tos = IPTOS_LOWDELAY; // Choose some TOS value. + ASSERT_THAT(setsockopt(sock_.get(), sent_level, sent_type, &sent_tos, + sizeof(sent_tos)), + SyscallSucceeds()); + + // Prepare message to send. + constexpr size_t kDataLength = 1024; + struct msghdr sent_msg = {}; + struct iovec sent_iov = {}; + char sent_data[kDataLength]; + sent_iov.iov_base = &sent_data[0]; + sent_iov.iov_len = kDataLength; + sent_msg.msg_iov = &sent_iov; + sent_msg.msg_iovlen = 1; + + ASSERT_THAT(RetryEINTR(sendmsg)(sock_.get(), &sent_msg, 0), + SyscallSucceedsWithValue(kDataLength)); + + // Receive message. + struct msghdr received_msg = {}; + struct iovec received_iov = {}; + char received_data[kDataLength]; + received_iov.iov_base = &received_data[0]; + received_iov.iov_len = kDataLength; + received_msg.msg_iov = &received_iov; + received_msg.msg_iovlen = 1; + size_t cmsg_data_len = sizeof(int8_t); + if (sent_type == IPV6_TCLASS) { + cmsg_data_len = sizeof(int); + } + std::vector<char> received_cmsgbuf(CMSG_SPACE(cmsg_data_len)); + received_msg.msg_control = &received_cmsgbuf[0]; + received_msg.msg_controllen = received_cmsgbuf.size(); + ASSERT_THAT(RetryEINTR(recvmsg)(bind_.get(), &received_msg, 0), + SyscallSucceedsWithValue(kDataLength)); + + struct cmsghdr* cmsg = CMSG_FIRSTHDR(&received_msg); + ASSERT_NE(cmsg, nullptr); + EXPECT_EQ(cmsg->cmsg_len, CMSG_LEN(cmsg_data_len)); + EXPECT_EQ(cmsg->cmsg_level, sent_level); + EXPECT_EQ(cmsg->cmsg_type, sent_type); + int8_t received_tos = 0; + memcpy(&received_tos, CMSG_DATA(cmsg), sizeof(received_tos)); + EXPECT_EQ(received_tos, sent_tos); +} + +// Test that sendmsg with IP_TOS and IPV6_TCLASS control messages will set the +// TOS byte on outgoing packets, and that a receiving socket with IP_RECVTOS or +// IPV6_RECVTCLASS will create the corresponding control message. +TEST_P(UdpSocketTest, SendAndReceiveTOS) { + // TODO(b/146661005): Setting TOS via cmsg not supported for netstack. + SKIP_IF(IsRunningOnGvisor() && !IsRunningWithHostinet()); + + ASSERT_NO_ERRNO(BindLoopback()); + ASSERT_THAT(connect(sock_.get(), bind_addr_, addrlen_), SyscallSucceeds()); + + // Allow socket to receive control message. + int recv_level = SOL_IP; + int recv_type = IP_RECVTOS; + if (GetParam() != AddressFamily::kIpv4) { + recv_level = SOL_IPV6; + recv_type = IPV6_RECVTCLASS; + } + int recv_opt = kSockOptOn; + ASSERT_THAT(setsockopt(bind_.get(), recv_level, recv_type, &recv_opt, + sizeof(recv_opt)), + SyscallSucceeds()); + + // Prepare message to send. + constexpr size_t kDataLength = 1024; + int sent_level = recv_level; + int sent_type = IP_TOS; + int sent_tos = IPTOS_LOWDELAY; // Choose some TOS value. + + struct msghdr sent_msg = {}; + struct iovec sent_iov = {}; + char sent_data[kDataLength]; + sent_iov.iov_base = &sent_data[0]; + sent_iov.iov_len = kDataLength; + sent_msg.msg_iov = &sent_iov; + sent_msg.msg_iovlen = 1; + size_t cmsg_data_len = sizeof(int8_t); + if (sent_level == SOL_IPV6) { + sent_type = IPV6_TCLASS; + cmsg_data_len = sizeof(int); + } + std::vector<char> sent_cmsgbuf(CMSG_SPACE(cmsg_data_len)); + sent_msg.msg_control = &sent_cmsgbuf[0]; + sent_msg.msg_controllen = CMSG_LEN(cmsg_data_len); + + // Manually add control message. + struct cmsghdr* sent_cmsg = CMSG_FIRSTHDR(&sent_msg); + sent_cmsg->cmsg_len = CMSG_LEN(cmsg_data_len); + sent_cmsg->cmsg_level = sent_level; + sent_cmsg->cmsg_type = sent_type; + *(int8_t*)CMSG_DATA(sent_cmsg) = sent_tos; + + ASSERT_THAT(RetryEINTR(sendmsg)(sock_.get(), &sent_msg, 0), + SyscallSucceedsWithValue(kDataLength)); + + // Receive message. + struct msghdr received_msg = {}; + struct iovec received_iov = {}; + char received_data[kDataLength]; + received_iov.iov_base = &received_data[0]; + received_iov.iov_len = kDataLength; + received_msg.msg_iov = &received_iov; + received_msg.msg_iovlen = 1; + std::vector<char> received_cmsgbuf(CMSG_SPACE(cmsg_data_len)); + received_msg.msg_control = &received_cmsgbuf[0]; + received_msg.msg_controllen = CMSG_LEN(cmsg_data_len); + ASSERT_THAT(RetryEINTR(recvmsg)(bind_.get(), &received_msg, 0), + SyscallSucceedsWithValue(kDataLength)); + + struct cmsghdr* cmsg = CMSG_FIRSTHDR(&received_msg); + ASSERT_NE(cmsg, nullptr); + EXPECT_EQ(cmsg->cmsg_len, CMSG_LEN(cmsg_data_len)); + EXPECT_EQ(cmsg->cmsg_level, sent_level); + EXPECT_EQ(cmsg->cmsg_type, sent_type); + int8_t received_tos = 0; + memcpy(&received_tos, CMSG_DATA(cmsg), sizeof(received_tos)); + EXPECT_EQ(received_tos, sent_tos); +} + +TEST_P(UdpSocketTest, RecvBufLimitsEmptyRcvBuf) { + // Discover minimum buffer size by setting it to zero. + constexpr int kRcvBufSz = 0; + ASSERT_THAT(setsockopt(bind_.get(), SOL_SOCKET, SO_RCVBUF, &kRcvBufSz, + sizeof(kRcvBufSz)), + SyscallSucceeds()); + + int min = 0; + socklen_t min_len = sizeof(min); + ASSERT_THAT(getsockopt(bind_.get(), SOL_SOCKET, SO_RCVBUF, &min, &min_len), + SyscallSucceeds()); + + // Bind bind_ to loopback. + ASSERT_NO_ERRNO(BindLoopback()); + + { + // Send data of size min and verify that it's received. + std::vector<char> buf(min); + RandomizeBuffer(buf.data(), buf.size()); + ASSERT_THAT( + sendto(sock_.get(), buf.data(), buf.size(), 0, bind_addr_, addrlen_), + SyscallSucceedsWithValue(buf.size())); + std::vector<char> received(buf.size()); + EXPECT_THAT( + recv(bind_.get(), received.data(), received.size(), MSG_DONTWAIT), + SyscallSucceedsWithValue(received.size())); + } + + { + // Send data of size min + 1 and verify that its received. Both linux and + // Netstack accept a dgram that exceeds rcvBuf limits if the receive buffer + // is currently empty. + std::vector<char> buf(min + 1); + RandomizeBuffer(buf.data(), buf.size()); + ASSERT_THAT( + sendto(sock_.get(), buf.data(), buf.size(), 0, bind_addr_, addrlen_), + SyscallSucceedsWithValue(buf.size())); + + std::vector<char> received(buf.size()); + EXPECT_THAT( + recv(bind_.get(), received.data(), received.size(), MSG_DONTWAIT), + SyscallSucceedsWithValue(received.size())); + } +} + +// Test that receive buffer limits are enforced. +TEST_P(UdpSocketTest, RecvBufLimits) { + // Bind s_ to loopback. + ASSERT_NO_ERRNO(BindLoopback()); + + int min = 0; + { + // Discover minimum buffer size by trying to set it to zero. + constexpr int kRcvBufSz = 0; + ASSERT_THAT(setsockopt(bind_.get(), SOL_SOCKET, SO_RCVBUF, &kRcvBufSz, + sizeof(kRcvBufSz)), + SyscallSucceeds()); + + socklen_t min_len = sizeof(min); + ASSERT_THAT(getsockopt(bind_.get(), SOL_SOCKET, SO_RCVBUF, &min, &min_len), + SyscallSucceeds()); + } + + // Now set the limit to min * 4. + int new_rcv_buf_sz = min * 4; + if (!IsRunningOnGvisor() || IsRunningWithHostinet()) { + // Linux doubles the value specified so just set to min * 2. + new_rcv_buf_sz = min * 2; + } + + ASSERT_THAT(setsockopt(bind_.get(), SOL_SOCKET, SO_RCVBUF, &new_rcv_buf_sz, + sizeof(new_rcv_buf_sz)), + SyscallSucceeds()); + int rcv_buf_sz = 0; + { + socklen_t rcv_buf_len = sizeof(rcv_buf_sz); + ASSERT_THAT(getsockopt(bind_.get(), SOL_SOCKET, SO_RCVBUF, &rcv_buf_sz, + &rcv_buf_len), + SyscallSucceeds()); + } + + { + std::vector<char> buf(min); + RandomizeBuffer(buf.data(), buf.size()); + + ASSERT_THAT( + sendto(sock_.get(), buf.data(), buf.size(), 0, bind_addr_, addrlen_), + SyscallSucceedsWithValue(buf.size())); + ASSERT_THAT( + sendto(sock_.get(), buf.data(), buf.size(), 0, bind_addr_, addrlen_), + SyscallSucceedsWithValue(buf.size())); + ASSERT_THAT( + sendto(sock_.get(), buf.data(), buf.size(), 0, bind_addr_, addrlen_), + SyscallSucceedsWithValue(buf.size())); + ASSERT_THAT( + sendto(sock_.get(), buf.data(), buf.size(), 0, bind_addr_, addrlen_), + SyscallSucceedsWithValue(buf.size())); + int sent = 4; + if (IsRunningOnGvisor() && !IsRunningWithHostinet()) { + // Linux seems to drop the 4th packet even though technically it should + // fit in the receive buffer. + ASSERT_THAT( + sendto(sock_.get(), buf.data(), buf.size(), 0, bind_addr_, addrlen_), + SyscallSucceedsWithValue(buf.size())); + sent++; + } + + for (int i = 0; i < sent - 1; i++) { + // Receive the data. + std::vector<char> received(buf.size()); + EXPECT_THAT( + recv(bind_.get(), received.data(), received.size(), MSG_DONTWAIT), + SyscallSucceedsWithValue(received.size())); + EXPECT_EQ(memcmp(buf.data(), received.data(), buf.size()), 0); + } + + // The last receive should fail with EAGAIN as the last packet should have + // been dropped due to lack of space in the receive buffer. + std::vector<char> received(buf.size()); + EXPECT_THAT( + recv(bind_.get(), received.data(), received.size(), MSG_DONTWAIT), + SyscallFailsWithErrno(EAGAIN)); + } +} + +#ifdef __linux__ + +// TODO(gvisor.dev/2746): Support SO_ATTACH_FILTER/SO_DETACH_FILTER. +// gVisor currently silently ignores attaching a filter. +TEST_P(UdpSocketTest, SetSocketDetachFilter) { + // Program generated using sudo tcpdump -i lo udp and port 1234 -dd + struct sock_filter code[] = { + {0x28, 0, 0, 0x0000000c}, {0x15, 0, 6, 0x000086dd}, + {0x30, 0, 0, 0x00000014}, {0x15, 0, 15, 0x00000011}, + {0x28, 0, 0, 0x00000036}, {0x15, 12, 0, 0x000004d2}, + {0x28, 0, 0, 0x00000038}, {0x15, 10, 11, 0x000004d2}, + {0x15, 0, 10, 0x00000800}, {0x30, 0, 0, 0x00000017}, + {0x15, 0, 8, 0x00000011}, {0x28, 0, 0, 0x00000014}, + {0x45, 6, 0, 0x00001fff}, {0xb1, 0, 0, 0x0000000e}, + {0x48, 0, 0, 0x0000000e}, {0x15, 2, 0, 0x000004d2}, + {0x48, 0, 0, 0x00000010}, {0x15, 0, 1, 0x000004d2}, + {0x6, 0, 0, 0x00040000}, {0x6, 0, 0, 0x00000000}, + }; + struct sock_fprog bpf = { + .len = ABSL_ARRAYSIZE(code), + .filter = code, + }; + ASSERT_THAT( + setsockopt(sock_.get(), SOL_SOCKET, SO_ATTACH_FILTER, &bpf, sizeof(bpf)), + SyscallSucceeds()); + + constexpr int val = 0; + ASSERT_THAT( + setsockopt(sock_.get(), SOL_SOCKET, SO_DETACH_FILTER, &val, sizeof(val)), + SyscallSucceeds()); +} + +#endif // __linux__ + +TEST_P(UdpSocketTest, SetSocketDetachFilterNoInstalledFilter) { + // TODO(gvisor.dev/2746): Support SO_ATTACH_FILTER/SO_DETACH_FILTER. + SKIP_IF(IsRunningOnGvisor()); + constexpr int val = 0; + ASSERT_THAT( + setsockopt(sock_.get(), SOL_SOCKET, SO_DETACH_FILTER, &val, sizeof(val)), + SyscallFailsWithErrno(ENOENT)); +} + +TEST_P(UdpSocketTest, GetSocketDetachFilter) { + int val = 0; + socklen_t val_len = sizeof(val); + ASSERT_THAT( + getsockopt(sock_.get(), SOL_SOCKET, SO_DETACH_FILTER, &val, &val_len), + SyscallFailsWithErrno(ENOPROTOOPT)); +} + INSTANTIATE_TEST_SUITE_P(AllInetTests, UdpSocketTest, ::testing::Values(AddressFamily::kIpv4, AddressFamily::kIpv6, diff --git a/test/syscalls/linux/udp_socket_errqueue_test_case.cc b/test/syscalls/linux/udp_socket_errqueue_test_case.cc deleted file mode 100644 index 54a0594f7..000000000 --- a/test/syscalls/linux/udp_socket_errqueue_test_case.cc +++ /dev/null @@ -1,57 +0,0 @@ -// Copyright 2018 The gVisor Authors. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -#ifndef __fuchsia__ - -#include <arpa/inet.h> -#include <fcntl.h> -#include <linux/errqueue.h> -#include <netinet/in.h> -#include <sys/ioctl.h> -#include <sys/socket.h> -#include <sys/types.h> - -#include "gtest/gtest.h" -#include "absl/base/macros.h" -#include "absl/time/clock.h" -#include "absl/time/time.h" -#include "test/syscalls/linux/socket_test_util.h" -#include "test/syscalls/linux/udp_socket_test_cases.h" -#include "test/syscalls/linux/unix_domain_socket_test_util.h" -#include "test/util/test_util.h" -#include "test/util/thread_util.h" - -namespace gvisor { -namespace testing { - -TEST_P(UdpSocketTest, ErrorQueue) { - char cmsgbuf[CMSG_SPACE(sizeof(sock_extended_err))]; - msghdr msg; - memset(&msg, 0, sizeof(msg)); - iovec iov; - memset(&iov, 0, sizeof(iov)); - msg.msg_iov = &iov; - msg.msg_iovlen = 1; - msg.msg_control = cmsgbuf; - msg.msg_controllen = sizeof(cmsgbuf); - - // recv*(MSG_ERRQUEUE) never blocks, even without MSG_DONTWAIT. - EXPECT_THAT(RetryEINTR(recvmsg)(bind_.get(), &msg, MSG_ERRQUEUE), - SyscallFailsWithErrno(EAGAIN)); -} - -} // namespace testing -} // namespace gvisor - -#endif // __fuchsia__ diff --git a/test/syscalls/linux/udp_socket_test_cases.cc b/test/syscalls/linux/udp_socket_test_cases.cc deleted file mode 100644 index 60c48ed6e..000000000 --- a/test/syscalls/linux/udp_socket_test_cases.cc +++ /dev/null @@ -1,1781 +0,0 @@ -// Copyright 2018 The gVisor Authors. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -#include "test/syscalls/linux/udp_socket_test_cases.h" - -#include <arpa/inet.h> -#include <fcntl.h> -#ifndef __fuchsia__ -#include <linux/filter.h> -#endif // __fuchsia__ -#include <netinet/in.h> -#include <poll.h> -#include <sys/ioctl.h> -#include <sys/socket.h> -#include <sys/types.h> - -#include "absl/strings/str_format.h" -#ifndef SIOCGSTAMP -#include <linux/sockios.h> -#endif - -#include "gtest/gtest.h" -#include "absl/base/macros.h" -#include "absl/time/clock.h" -#include "absl/time/time.h" -#include "test/syscalls/linux/ip_socket_test_util.h" -#include "test/syscalls/linux/socket_test_util.h" -#include "test/syscalls/linux/unix_domain_socket_test_util.h" -#include "test/util/file_descriptor.h" -#include "test/util/posix_error.h" -#include "test/util/test_util.h" -#include "test/util/thread_util.h" - -namespace gvisor { -namespace testing { - -// Gets a pointer to the port component of the given address. -uint16_t* Port(struct sockaddr_storage* addr) { - switch (addr->ss_family) { - case AF_INET: { - auto sin = reinterpret_cast<struct sockaddr_in*>(addr); - return &sin->sin_port; - } - case AF_INET6: { - auto sin6 = reinterpret_cast<struct sockaddr_in6*>(addr); - return &sin6->sin6_port; - } - } - - return nullptr; -} - -// Sets addr port to "port". -void SetPort(struct sockaddr_storage* addr, uint16_t port) { - switch (addr->ss_family) { - case AF_INET: { - auto sin = reinterpret_cast<struct sockaddr_in*>(addr); - sin->sin_port = port; - break; - } - case AF_INET6: { - auto sin6 = reinterpret_cast<struct sockaddr_in6*>(addr); - sin6->sin6_port = port; - break; - } - } -} - -void UdpSocketTest::SetUp() { - addrlen_ = GetAddrLength(); - - bind_ = - ASSERT_NO_ERRNO_AND_VALUE(Socket(GetFamily(), SOCK_DGRAM, IPPROTO_UDP)); - memset(&bind_addr_storage_, 0, sizeof(bind_addr_storage_)); - bind_addr_ = reinterpret_cast<struct sockaddr*>(&bind_addr_storage_); - - sock_ = - ASSERT_NO_ERRNO_AND_VALUE(Socket(GetFamily(), SOCK_DGRAM, IPPROTO_UDP)); -} - -int UdpSocketTest::GetFamily() { - if (GetParam() == AddressFamily::kIpv4) { - return AF_INET; - } - return AF_INET6; -} - -PosixError UdpSocketTest::BindLoopback() { - bind_addr_storage_ = InetLoopbackAddr(); - struct sockaddr* bind_addr_ = - reinterpret_cast<struct sockaddr*>(&bind_addr_storage_); - return BindSocket(bind_.get(), bind_addr_); -} - -PosixError UdpSocketTest::BindAny() { - bind_addr_storage_ = InetAnyAddr(); - struct sockaddr* bind_addr_ = - reinterpret_cast<struct sockaddr*>(&bind_addr_storage_); - return BindSocket(bind_.get(), bind_addr_); -} - -PosixError UdpSocketTest::BindSocket(int socket, struct sockaddr* addr) { - socklen_t len = sizeof(bind_addr_storage_); - - // Bind, then check that we get the right address. - RETURN_ERROR_IF_SYSCALL_FAIL(bind(socket, addr, addrlen_)); - - RETURN_ERROR_IF_SYSCALL_FAIL(getsockname(socket, addr, &len)); - - if (addrlen_ != len) { - return PosixError( - EINVAL, - absl::StrFormat("getsockname len: %u expected: %u", len, addrlen_)); - } - return PosixError(0); -} - -socklen_t UdpSocketTest::GetAddrLength() { - struct sockaddr_storage addr; - if (GetFamily() == AF_INET) { - auto sin = reinterpret_cast<struct sockaddr_in*>(&addr); - return sizeof(*sin); - } - - auto sin6 = reinterpret_cast<struct sockaddr_in6*>(&addr); - return sizeof(*sin6); -} - -sockaddr_storage UdpSocketTest::InetAnyAddr() { - struct sockaddr_storage addr; - memset(&addr, 0, sizeof(addr)); - reinterpret_cast<struct sockaddr*>(&addr)->sa_family = GetFamily(); - - if (GetFamily() == AF_INET) { - auto sin = reinterpret_cast<struct sockaddr_in*>(&addr); - sin->sin_addr.s_addr = htonl(INADDR_ANY); - sin->sin_port = htons(0); - return addr; - } - - auto sin6 = reinterpret_cast<struct sockaddr_in6*>(&addr); - sin6->sin6_addr = IN6ADDR_ANY_INIT; - sin6->sin6_port = htons(0); - return addr; -} - -sockaddr_storage UdpSocketTest::InetLoopbackAddr() { - struct sockaddr_storage addr; - memset(&addr, 0, sizeof(addr)); - reinterpret_cast<struct sockaddr*>(&addr)->sa_family = GetFamily(); - - if (GetFamily() == AF_INET) { - auto sin = reinterpret_cast<struct sockaddr_in*>(&addr); - sin->sin_addr.s_addr = htonl(INADDR_LOOPBACK); - sin->sin_port = htons(0); - return addr; - } - auto sin6 = reinterpret_cast<struct sockaddr_in6*>(&addr); - sin6->sin6_addr = in6addr_loopback; - sin6->sin6_port = htons(0); - return addr; -} - -void UdpSocketTest::Disconnect(int sockfd) { - sockaddr_storage addr_storage = InetAnyAddr(); - sockaddr* addr = reinterpret_cast<struct sockaddr*>(&addr_storage); - socklen_t addrlen = sizeof(addr_storage); - - addr->sa_family = AF_UNSPEC; - ASSERT_THAT(connect(sockfd, addr, addrlen), SyscallSucceeds()); - - // Check that after disconnect the socket is bound to the ANY address. - EXPECT_THAT(getsockname(sockfd, addr, &addrlen), SyscallSucceeds()); - if (GetParam() == AddressFamily::kIpv4) { - auto addr_out = reinterpret_cast<struct sockaddr_in*>(addr); - EXPECT_EQ(addrlen, sizeof(*addr_out)); - EXPECT_EQ(addr_out->sin_addr.s_addr, htonl(INADDR_ANY)); - } else { - auto addr_out = reinterpret_cast<struct sockaddr_in6*>(addr); - EXPECT_EQ(addrlen, sizeof(*addr_out)); - struct in6_addr loopback = IN6ADDR_ANY_INIT; - - EXPECT_EQ(memcmp(&addr_out->sin6_addr, &loopback, sizeof(in6_addr)), 0); - } -} - -TEST_P(UdpSocketTest, Creation) { - FileDescriptor sock = - ASSERT_NO_ERRNO_AND_VALUE(Socket(GetFamily(), SOCK_DGRAM, IPPROTO_UDP)); - EXPECT_THAT(close(sock.release()), SyscallSucceeds()); - - sock = ASSERT_NO_ERRNO_AND_VALUE(Socket(GetFamily(), SOCK_DGRAM, 0)); - EXPECT_THAT(close(sock.release()), SyscallSucceeds()); - - ASSERT_THAT(socket(GetFamily(), SOCK_STREAM, IPPROTO_UDP), SyscallFails()); -} - -TEST_P(UdpSocketTest, Getsockname) { - // Check that we're not bound. - struct sockaddr_storage addr; - socklen_t addrlen = sizeof(addr); - EXPECT_THAT( - getsockname(bind_.get(), reinterpret_cast<sockaddr*>(&addr), &addrlen), - SyscallSucceeds()); - EXPECT_EQ(addrlen, addrlen_); - struct sockaddr_storage any = InetAnyAddr(); - EXPECT_EQ(memcmp(&addr, reinterpret_cast<struct sockaddr*>(&any), addrlen_), - 0); - - ASSERT_NO_ERRNO(BindLoopback()); - - EXPECT_THAT( - getsockname(bind_.get(), reinterpret_cast<sockaddr*>(&addr), &addrlen), - SyscallSucceeds()); - - EXPECT_EQ(addrlen, addrlen_); - EXPECT_EQ(memcmp(&addr, bind_addr_, addrlen_), 0); -} - -TEST_P(UdpSocketTest, Getpeername) { - ASSERT_NO_ERRNO(BindLoopback()); - - // Check that we're not connected. - struct sockaddr_storage addr; - socklen_t addrlen = sizeof(addr); - EXPECT_THAT( - getpeername(sock_.get(), reinterpret_cast<sockaddr*>(&addr), &addrlen), - SyscallFailsWithErrno(ENOTCONN)); - - // Connect, then check that we get the right address. - ASSERT_THAT(connect(sock_.get(), bind_addr_, addrlen_), SyscallSucceeds()); - - addrlen = sizeof(addr); - EXPECT_THAT( - getpeername(sock_.get(), reinterpret_cast<sockaddr*>(&addr), &addrlen), - SyscallSucceeds()); - EXPECT_EQ(addrlen, addrlen_); - EXPECT_EQ(memcmp(&addr, bind_addr_, addrlen_), 0); -} - -TEST_P(UdpSocketTest, SendNotConnected) { - ASSERT_NO_ERRNO(BindLoopback()); - - // Do send & write, they must fail. - char buf[512]; - EXPECT_THAT(send(sock_.get(), buf, sizeof(buf), 0), - SyscallFailsWithErrno(EDESTADDRREQ)); - - EXPECT_THAT(write(sock_.get(), buf, sizeof(buf)), - SyscallFailsWithErrno(EDESTADDRREQ)); - - // Use sendto. - ASSERT_THAT(sendto(sock_.get(), buf, sizeof(buf), 0, bind_addr_, addrlen_), - SyscallSucceedsWithValue(sizeof(buf))); - - // Check that we're bound now. - struct sockaddr_storage addr; - socklen_t addrlen = sizeof(addr); - EXPECT_THAT( - getsockname(sock_.get(), reinterpret_cast<sockaddr*>(&addr), &addrlen), - SyscallSucceeds()); - EXPECT_EQ(addrlen, addrlen_); - EXPECT_NE(*Port(&addr), 0); -} - -TEST_P(UdpSocketTest, ConnectBinds) { - ASSERT_NO_ERRNO(BindLoopback()); - - // Connect the socket. - ASSERT_THAT(connect(sock_.get(), bind_addr_, addrlen_), SyscallSucceeds()); - - // Check that we're bound now. - struct sockaddr_storage addr; - socklen_t addrlen = sizeof(addr); - EXPECT_THAT( - getsockname(sock_.get(), reinterpret_cast<sockaddr*>(&addr), &addrlen), - SyscallSucceeds()); - EXPECT_EQ(addrlen, addrlen_); - EXPECT_NE(*Port(&addr), 0); -} - -TEST_P(UdpSocketTest, ReceiveNotBound) { - char buf[512]; - EXPECT_THAT(recv(sock_.get(), buf, sizeof(buf), MSG_DONTWAIT), - SyscallFailsWithErrno(EWOULDBLOCK)); -} - -TEST_P(UdpSocketTest, Bind) { - ASSERT_NO_ERRNO(BindLoopback()); - - // Try to bind again. - EXPECT_THAT(bind(bind_.get(), bind_addr_, addrlen_), - SyscallFailsWithErrno(EINVAL)); - - // Check that we're still bound to the original address. - struct sockaddr_storage addr; - socklen_t addrlen = sizeof(addr); - EXPECT_THAT( - getsockname(bind_.get(), reinterpret_cast<sockaddr*>(&addr), &addrlen), - SyscallSucceeds()); - EXPECT_EQ(addrlen, addrlen_); - EXPECT_EQ(memcmp(&addr, bind_addr_, addrlen_), 0); -} - -TEST_P(UdpSocketTest, BindInUse) { - ASSERT_NO_ERRNO(BindLoopback()); - - // Try to bind again. - EXPECT_THAT(bind(sock_.get(), bind_addr_, addrlen_), - SyscallFailsWithErrno(EADDRINUSE)); -} - -TEST_P(UdpSocketTest, ReceiveAfterConnect) { - ASSERT_NO_ERRNO(BindLoopback()); - ASSERT_THAT(connect(sock_.get(), bind_addr_, addrlen_), SyscallSucceeds()); - - // Send from sock_ to bind_ - char buf[512]; - RandomizeBuffer(buf, sizeof(buf)); - ASSERT_THAT(sendto(sock_.get(), buf, sizeof(buf), 0, bind_addr_, addrlen_), - SyscallSucceedsWithValue(sizeof(buf))); - - // Receive the data. - char received[sizeof(buf)]; - EXPECT_THAT(recv(bind_.get(), received, sizeof(received), 0), - SyscallSucceedsWithValue(sizeof(received))); - EXPECT_EQ(memcmp(buf, received, sizeof(buf)), 0); -} - -TEST_P(UdpSocketTest, ReceiveAfterDisconnect) { - ASSERT_NO_ERRNO(BindLoopback()); - - for (int i = 0; i < 2; i++) { - // Connet sock_ to bound address. - ASSERT_THAT(connect(sock_.get(), bind_addr_, addrlen_), SyscallSucceeds()); - - struct sockaddr_storage addr; - socklen_t addrlen = sizeof(addr); - EXPECT_THAT( - getsockname(sock_.get(), reinterpret_cast<sockaddr*>(&addr), &addrlen), - SyscallSucceeds()); - EXPECT_EQ(addrlen, addrlen_); - - // Send from sock to bind_. - char buf[512]; - RandomizeBuffer(buf, sizeof(buf)); - - ASSERT_THAT(sendto(bind_.get(), buf, sizeof(buf), 0, - reinterpret_cast<sockaddr*>(&addr), addrlen), - SyscallSucceedsWithValue(sizeof(buf))); - - // Receive the data. - char received[sizeof(buf)]; - EXPECT_THAT(recv(sock_.get(), received, sizeof(received), 0), - SyscallSucceedsWithValue(sizeof(received))); - EXPECT_EQ(memcmp(buf, received, sizeof(buf)), 0); - - // Disconnect sock_. - struct sockaddr unspec = {}; - unspec.sa_family = AF_UNSPEC; - ASSERT_THAT(connect(sock_.get(), &unspec, sizeof(unspec.sa_family)), - SyscallSucceeds()); - } -} - -TEST_P(UdpSocketTest, Connect) { - ASSERT_NO_ERRNO(BindLoopback()); - ASSERT_THAT(connect(sock_.get(), bind_addr_, addrlen_), SyscallSucceeds()); - - // Check that we're connected to the right peer. - struct sockaddr_storage peer; - socklen_t peerlen = sizeof(peer); - EXPECT_THAT( - getpeername(sock_.get(), reinterpret_cast<sockaddr*>(&peer), &peerlen), - SyscallSucceeds()); - EXPECT_EQ(peerlen, addrlen_); - EXPECT_EQ(memcmp(&peer, bind_addr_, addrlen_), 0); - - // Try to bind after connect. - struct sockaddr_storage any = InetAnyAddr(); - EXPECT_THAT( - bind(sock_.get(), reinterpret_cast<struct sockaddr*>(&any), addrlen_), - SyscallFailsWithErrno(EINVAL)); - - struct sockaddr_storage bind2_storage = InetLoopbackAddr(); - struct sockaddr* bind2_addr = - reinterpret_cast<struct sockaddr*>(&bind2_storage); - FileDescriptor bind2 = - ASSERT_NO_ERRNO_AND_VALUE(Socket(GetFamily(), SOCK_DGRAM, IPPROTO_UDP)); - ASSERT_NO_ERRNO(BindSocket(bind2.get(), bind2_addr)); - - // Try to connect again. - EXPECT_THAT(connect(sock_.get(), bind2_addr, addrlen_), SyscallSucceeds()); - - // Check that peer name changed. - peerlen = sizeof(peer); - EXPECT_THAT( - getpeername(sock_.get(), reinterpret_cast<sockaddr*>(&peer), &peerlen), - SyscallSucceeds()); - EXPECT_EQ(peerlen, addrlen_); - EXPECT_EQ(memcmp(&peer, bind2_addr, addrlen_), 0); -} - -TEST_P(UdpSocketTest, ConnectAnyZero) { - // TODO(138658473): Enable when we can connect to port 0 with gVisor. - SKIP_IF(IsRunningOnGvisor()); - - struct sockaddr_storage any = InetAnyAddr(); - EXPECT_THAT( - connect(sock_.get(), reinterpret_cast<struct sockaddr*>(&any), addrlen_), - SyscallSucceeds()); - - struct sockaddr_storage addr; - socklen_t addrlen = sizeof(addr); - EXPECT_THAT( - getpeername(sock_.get(), reinterpret_cast<sockaddr*>(&addr), &addrlen), - SyscallFailsWithErrno(ENOTCONN)); -} - -TEST_P(UdpSocketTest, ConnectAnyWithPort) { - ASSERT_NO_ERRNO(BindAny()); - ASSERT_THAT(connect(sock_.get(), bind_addr_, addrlen_), SyscallSucceeds()); - - struct sockaddr_storage addr; - socklen_t addrlen = sizeof(addr); - EXPECT_THAT( - getpeername(sock_.get(), reinterpret_cast<sockaddr*>(&addr), &addrlen), - SyscallSucceeds()); -} - -TEST_P(UdpSocketTest, DisconnectAfterConnectAny) { - // TODO(138658473): Enable when we can connect to port 0 with gVisor. - SKIP_IF(IsRunningOnGvisor()); - struct sockaddr_storage any = InetAnyAddr(); - EXPECT_THAT( - connect(sock_.get(), reinterpret_cast<struct sockaddr*>(&any), addrlen_), - SyscallSucceeds()); - - struct sockaddr_storage addr; - socklen_t addrlen = sizeof(addr); - EXPECT_THAT( - getpeername(sock_.get(), reinterpret_cast<sockaddr*>(&addr), &addrlen), - SyscallFailsWithErrno(ENOTCONN)); - - Disconnect(sock_.get()); -} - -TEST_P(UdpSocketTest, DisconnectAfterConnectAnyWithPort) { - ASSERT_NO_ERRNO(BindAny()); - EXPECT_THAT(connect(sock_.get(), bind_addr_, addrlen_), SyscallSucceeds()); - - struct sockaddr_storage addr; - socklen_t addrlen = sizeof(addr); - EXPECT_THAT( - getpeername(sock_.get(), reinterpret_cast<sockaddr*>(&addr), &addrlen), - SyscallSucceeds()); - - EXPECT_EQ(addrlen, addrlen_); - EXPECT_EQ(*Port(&bind_addr_storage_), *Port(&addr)); - - Disconnect(sock_.get()); -} - -TEST_P(UdpSocketTest, DisconnectAfterBind) { - ASSERT_NO_ERRNO(BindLoopback()); - - // Bind to the next port above bind_. - struct sockaddr_storage addr_storage = InetLoopbackAddr(); - struct sockaddr* addr = reinterpret_cast<struct sockaddr*>(&addr_storage); - SetPort(&addr_storage, *Port(&bind_addr_storage_) + 1); - ASSERT_NO_ERRNO(BindSocket(sock_.get(), addr)); - - // Connect the socket. - ASSERT_THAT(connect(sock_.get(), bind_addr_, addrlen_), SyscallSucceeds()); - - struct sockaddr_storage unspec = {}; - unspec.ss_family = AF_UNSPEC; - EXPECT_THAT(connect(sock_.get(), reinterpret_cast<sockaddr*>(&unspec), - sizeof(unspec.ss_family)), - SyscallSucceeds()); - - // Check that we're still bound. - socklen_t addrlen = sizeof(unspec); - EXPECT_THAT( - getsockname(sock_.get(), reinterpret_cast<sockaddr*>(&unspec), &addrlen), - SyscallSucceeds()); - - EXPECT_EQ(addrlen, addrlen_); - EXPECT_EQ(memcmp(addr, &unspec, addrlen_), 0); - - addrlen = sizeof(addr); - EXPECT_THAT(getpeername(sock_.get(), addr, &addrlen), - SyscallFailsWithErrno(ENOTCONN)); -} - -TEST_P(UdpSocketTest, BindToAnyConnnectToLocalhost) { - ASSERT_NO_ERRNO(BindAny()); - - struct sockaddr_storage addr_storage = InetLoopbackAddr(); - struct sockaddr* addr = reinterpret_cast<struct sockaddr*>(&addr_storage); - SetPort(&addr_storage, *Port(&bind_addr_storage_) + 1); - socklen_t addrlen = sizeof(addr); - - // Connect the socket. - ASSERT_THAT(connect(bind_.get(), addr, addrlen_), SyscallSucceeds()); - - EXPECT_THAT(getsockname(bind_.get(), addr, &addrlen), SyscallSucceeds()); - - // If the socket is bound to ANY and connected to a loopback address, - // getsockname() has to return the loopback address. - if (GetParam() == AddressFamily::kIpv4) { - auto addr_out = reinterpret_cast<struct sockaddr_in*>(addr); - EXPECT_EQ(addrlen, sizeof(*addr_out)); - EXPECT_EQ(addr_out->sin_addr.s_addr, htonl(INADDR_LOOPBACK)); - } else { - auto addr_out = reinterpret_cast<struct sockaddr_in6*>(addr); - struct in6_addr loopback = IN6ADDR_LOOPBACK_INIT; - EXPECT_EQ(addrlen, sizeof(*addr_out)); - EXPECT_EQ(memcmp(&addr_out->sin6_addr, &loopback, sizeof(in6_addr)), 0); - } -} - -TEST_P(UdpSocketTest, DisconnectAfterBindToAny) { - ASSERT_NO_ERRNO(BindLoopback()); - - struct sockaddr_storage any_storage = InetAnyAddr(); - struct sockaddr* any = reinterpret_cast<struct sockaddr*>(&any_storage); - SetPort(&any_storage, *Port(&bind_addr_storage_) + 1); - - ASSERT_NO_ERRNO(BindSocket(sock_.get(), any)); - - // Connect the socket. - ASSERT_THAT(connect(sock_.get(), bind_addr_, addrlen_), SyscallSucceeds()); - - Disconnect(sock_.get()); - - // Check that we're still bound. - struct sockaddr_storage addr; - socklen_t addrlen = sizeof(addr); - EXPECT_THAT( - getsockname(sock_.get(), reinterpret_cast<sockaddr*>(&addr), &addrlen), - SyscallSucceeds()); - - EXPECT_EQ(addrlen, addrlen_); - EXPECT_EQ(memcmp(&addr, any, addrlen), 0); - - addrlen = sizeof(addr); - EXPECT_THAT( - getpeername(sock_.get(), reinterpret_cast<sockaddr*>(&addr), &addrlen), - SyscallFailsWithErrno(ENOTCONN)); -} - -TEST_P(UdpSocketTest, Disconnect) { - ASSERT_NO_ERRNO(BindLoopback()); - - struct sockaddr_storage any_storage = InetAnyAddr(); - struct sockaddr* any = reinterpret_cast<struct sockaddr*>(&any_storage); - SetPort(&any_storage, *Port(&bind_addr_storage_) + 1); - ASSERT_NO_ERRNO(BindSocket(sock_.get(), any)); - - for (int i = 0; i < 2; i++) { - // Try to connect again. - EXPECT_THAT(connect(sock_.get(), bind_addr_, addrlen_), SyscallSucceeds()); - - // Check that we're connected to the right peer. - struct sockaddr_storage peer; - socklen_t peerlen = sizeof(peer); - EXPECT_THAT( - getpeername(sock_.get(), reinterpret_cast<sockaddr*>(&peer), &peerlen), - SyscallSucceeds()); - EXPECT_EQ(peerlen, addrlen_); - EXPECT_EQ(memcmp(&peer, bind_addr_, addrlen_), 0); - - // Try to disconnect. - struct sockaddr_storage addr = {}; - addr.ss_family = AF_UNSPEC; - EXPECT_THAT(connect(sock_.get(), reinterpret_cast<sockaddr*>(&addr), - sizeof(addr.ss_family)), - SyscallSucceeds()); - - peerlen = sizeof(peer); - EXPECT_THAT( - getpeername(sock_.get(), reinterpret_cast<sockaddr*>(&peer), &peerlen), - SyscallFailsWithErrno(ENOTCONN)); - - // Check that we're still bound. - socklen_t addrlen = sizeof(addr); - EXPECT_THAT( - getsockname(sock_.get(), reinterpret_cast<sockaddr*>(&addr), &addrlen), - SyscallSucceeds()); - EXPECT_EQ(addrlen, addrlen_); - EXPECT_EQ(*Port(&addr), *Port(&any_storage)); - } -} - -TEST_P(UdpSocketTest, ConnectBadAddress) { - struct sockaddr addr = {}; - addr.sa_family = GetFamily(); - ASSERT_THAT(connect(sock_.get(), &addr, sizeof(addr.sa_family)), - SyscallFailsWithErrno(EINVAL)); -} - -TEST_P(UdpSocketTest, SendToAddressOtherThanConnected) { - ASSERT_NO_ERRNO(BindLoopback()); - - struct sockaddr_storage addr_storage = InetAnyAddr(); - struct sockaddr* addr = reinterpret_cast<struct sockaddr*>(&addr_storage); - SetPort(&addr_storage, *Port(&bind_addr_storage_) + 1); - - ASSERT_THAT(connect(sock_.get(), bind_addr_, addrlen_), SyscallSucceeds()); - - // Send to a different destination than we're connected to. - char buf[512]; - EXPECT_THAT(sendto(sock_.get(), buf, sizeof(buf), 0, addr, addrlen_), - SyscallSucceedsWithValue(sizeof(buf))); -} - -TEST_P(UdpSocketTest, ZerolengthWriteAllowed) { - // TODO(gvisor.dev/issue/1202): Hostinet does not support zero length writes. - SKIP_IF(IsRunningWithHostinet()); - - ASSERT_NO_ERRNO(BindLoopback()); - // Connect to loopback:bind_addr_+1. - struct sockaddr_storage addr_storage = InetLoopbackAddr(); - struct sockaddr* addr = reinterpret_cast<struct sockaddr*>(&addr_storage); - SetPort(&addr_storage, *Port(&bind_addr_storage_) + 1); - ASSERT_THAT(connect(bind_.get(), addr, addrlen_), SyscallSucceeds()); - - // Bind sock to loopback:bind_addr_+1. - ASSERT_THAT(bind(sock_.get(), addr, addrlen_), SyscallSucceeds()); - - char buf[3]; - // Send zero length packet from bind_ to sock_. - ASSERT_THAT(write(bind_.get(), buf, 0), SyscallSucceedsWithValue(0)); - - struct pollfd pfd = {sock_.get(), POLLIN, 0}; - ASSERT_THAT(RetryEINTR(poll)(&pfd, 1, /*timeout*/ 1000), - SyscallSucceedsWithValue(1)); - - // Receive the packet. - char received[3]; - EXPECT_THAT(read(sock_.get(), received, sizeof(received)), - SyscallSucceedsWithValue(0)); -} - -TEST_P(UdpSocketTest, ZerolengthWriteAllowedNonBlockRead) { - // TODO(gvisor.dev/issue/1202): Hostinet does not support zero length writes. - SKIP_IF(IsRunningWithHostinet()); - - ASSERT_NO_ERRNO(BindLoopback()); - - // Connect to loopback:bind_addr_port+1. - struct sockaddr_storage addr_storage = InetLoopbackAddr(); - struct sockaddr* addr = reinterpret_cast<struct sockaddr*>(&addr_storage); - SetPort(&addr_storage, *Port(&bind_addr_storage_) + 1); - ASSERT_THAT(connect(bind_.get(), addr, addrlen_), SyscallSucceeds()); - - // Bind sock to loopback:bind_addr_port+1. - ASSERT_THAT(bind(sock_.get(), addr, addrlen_), SyscallSucceeds()); - - // Set sock to non-blocking. - int opts = 0; - ASSERT_THAT(opts = fcntl(sock_.get(), F_GETFL), SyscallSucceeds()); - ASSERT_THAT(fcntl(sock_.get(), F_SETFL, opts | O_NONBLOCK), - SyscallSucceeds()); - - char buf[3]; - // Send zero length packet from bind_ to sock_. - ASSERT_THAT(write(bind_.get(), buf, 0), SyscallSucceedsWithValue(0)); - - struct pollfd pfd = {sock_.get(), POLLIN, 0}; - ASSERT_THAT(RetryEINTR(poll)(&pfd, 1, /*timeout=*/1000), - SyscallSucceedsWithValue(1)); - - // Receive the packet. - char received[3]; - EXPECT_THAT(read(sock_.get(), received, sizeof(received)), - SyscallSucceedsWithValue(0)); - EXPECT_THAT(read(sock_.get(), received, sizeof(received)), - SyscallFailsWithErrno(EAGAIN)); -} - -TEST_P(UdpSocketTest, SendAndReceiveNotConnected) { - ASSERT_NO_ERRNO(BindLoopback()); - - // Send some data to bind_. - char buf[512]; - RandomizeBuffer(buf, sizeof(buf)); - - ASSERT_THAT(sendto(sock_.get(), buf, sizeof(buf), 0, bind_addr_, addrlen_), - SyscallSucceedsWithValue(sizeof(buf))); - - // Receive the data. - char received[sizeof(buf)]; - EXPECT_THAT(recv(bind_.get(), received, sizeof(received), 0), - SyscallSucceedsWithValue(sizeof(received))); - EXPECT_EQ(memcmp(buf, received, sizeof(buf)), 0); -} - -TEST_P(UdpSocketTest, SendAndReceiveConnected) { - ASSERT_NO_ERRNO(BindLoopback()); - - // Connect to loopback:bind_addr_port+1. - struct sockaddr_storage addr_storage = InetLoopbackAddr(); - struct sockaddr* addr = reinterpret_cast<struct sockaddr*>(&addr_storage); - SetPort(&addr_storage, *Port(&bind_addr_storage_) + 1); - ASSERT_THAT(connect(bind_.get(), addr, addrlen_), SyscallSucceeds()); - - // Bind sock to loopback:TestPort+1. - ASSERT_THAT(bind(sock_.get(), addr, addrlen_), SyscallSucceeds()); - - // Send some data from sock to bind_. - char buf[512]; - RandomizeBuffer(buf, sizeof(buf)); - - ASSERT_THAT(sendto(sock_.get(), buf, sizeof(buf), 0, bind_addr_, addrlen_), - SyscallSucceedsWithValue(sizeof(buf))); - - // Receive the data. - char received[sizeof(buf)]; - EXPECT_THAT(recv(bind_.get(), received, sizeof(received), 0), - SyscallSucceedsWithValue(sizeof(received))); - EXPECT_EQ(memcmp(buf, received, sizeof(buf)), 0); -} - -TEST_P(UdpSocketTest, ReceiveFromNotConnected) { - ASSERT_NO_ERRNO(BindLoopback()); - - // Connect to loopback:bind_addr_port+1. - struct sockaddr_storage addr_storage = InetLoopbackAddr(); - struct sockaddr* addr = reinterpret_cast<struct sockaddr*>(&addr_storage); - SetPort(&addr_storage, *Port(&bind_addr_storage_) + 1); - ASSERT_THAT(connect(bind_.get(), addr, addrlen_), SyscallSucceeds()); - - // Bind sock to loopback:bind_addr_port+2. - struct sockaddr_storage addr2_storage = InetLoopbackAddr(); - struct sockaddr* addr2 = reinterpret_cast<struct sockaddr*>(&addr2_storage); - SetPort(&addr2_storage, *Port(&bind_addr_storage_) + 2); - ASSERT_THAT(bind(sock_.get(), addr2, addrlen_), SyscallSucceeds()); - - // Send some data from sock to bind_. - char buf[512]; - ASSERT_THAT(sendto(sock_.get(), buf, sizeof(buf), 0, bind_addr_, addrlen_), - SyscallSucceedsWithValue(sizeof(buf))); - - // Check that the data isn't received because it was sent from a different - // address than we're connected. - EXPECT_THAT(recv(sock_.get(), buf, sizeof(buf), MSG_DONTWAIT), - SyscallFailsWithErrno(EWOULDBLOCK)); -} - -TEST_P(UdpSocketTest, ReceiveBeforeConnect) { - ASSERT_NO_ERRNO(BindLoopback()); - - // Bind sock to loopback:bind_addr_port+2. - struct sockaddr_storage addr2_storage = InetLoopbackAddr(); - struct sockaddr* addr2 = reinterpret_cast<struct sockaddr*>(&addr2_storage); - SetPort(&addr2_storage, *Port(&bind_addr_storage_) + 2); - ASSERT_THAT(bind(sock_.get(), addr2, addrlen_), SyscallSucceeds()); - - // Send some data from sock to bind_. - char buf[512]; - RandomizeBuffer(buf, sizeof(buf)); - - ASSERT_THAT(sendto(sock_.get(), buf, sizeof(buf), 0, bind_addr_, addrlen_), - SyscallSucceedsWithValue(sizeof(buf))); - - // Connect to loopback:TestPort+1. - struct sockaddr_storage addr_storage = InetLoopbackAddr(); - struct sockaddr* addr = reinterpret_cast<struct sockaddr*>(&addr_storage); - SetPort(&addr_storage, *Port(&bind_addr_storage_) + 1); - ASSERT_THAT(connect(bind_.get(), addr, addrlen_), SyscallSucceeds()); - - // Receive the data. It works because it was sent before the connect. - char received[sizeof(buf)]; - EXPECT_THAT(recv(bind_.get(), received, sizeof(received), 0), - SyscallSucceedsWithValue(sizeof(received))); - EXPECT_EQ(memcmp(buf, received, sizeof(buf)), 0); - - // Send again. This time it should not be received. - ASSERT_THAT(sendto(sock_.get(), buf, sizeof(buf), 0, bind_addr_, addrlen_), - SyscallSucceedsWithValue(sizeof(buf))); - - EXPECT_THAT(recv(bind_.get(), buf, sizeof(buf), MSG_DONTWAIT), - SyscallFailsWithErrno(EWOULDBLOCK)); -} - -TEST_P(UdpSocketTest, ReceiveFrom) { - ASSERT_NO_ERRNO(BindLoopback()); - - // Connect to loopback:bind_addr_port+1. - struct sockaddr_storage addr_storage = InetLoopbackAddr(); - struct sockaddr* addr = reinterpret_cast<struct sockaddr*>(&addr_storage); - SetPort(&addr_storage, *Port(&bind_addr_storage_) + 1); - ASSERT_THAT(connect(bind_.get(), addr, addrlen_), SyscallSucceeds()); - - // Bind sock to loopback:TestPort+1. - ASSERT_THAT(bind(sock_.get(), addr, addrlen_), SyscallSucceeds()); - - // Send some data from sock to bind_. - char buf[512]; - RandomizeBuffer(buf, sizeof(buf)); - - ASSERT_THAT(sendto(sock_.get(), buf, sizeof(buf), 0, bind_addr_, addrlen_), - SyscallSucceedsWithValue(sizeof(buf))); - - // Receive the data and sender address. - char received[sizeof(buf)]; - struct sockaddr_storage addr2; - socklen_t addr2len = sizeof(addr2); - EXPECT_THAT(recvfrom(bind_.get(), received, sizeof(received), 0, - reinterpret_cast<sockaddr*>(&addr2), &addr2len), - SyscallSucceedsWithValue(sizeof(received))); - EXPECT_EQ(memcmp(buf, received, sizeof(buf)), 0); - EXPECT_EQ(addr2len, addrlen_); - EXPECT_EQ(memcmp(addr, &addr2, addrlen_), 0); -} - -TEST_P(UdpSocketTest, Listen) { - ASSERT_THAT(listen(sock_.get(), SOMAXCONN), - SyscallFailsWithErrno(EOPNOTSUPP)); -} - -TEST_P(UdpSocketTest, Accept) { - ASSERT_THAT(accept(sock_.get(), nullptr, nullptr), - SyscallFailsWithErrno(EOPNOTSUPP)); -} - -// This test validates that a read shutdown with pending data allows the read -// to proceed with the data before returning EAGAIN. -TEST_P(UdpSocketTest, ReadShutdownNonblockPendingData) { - ASSERT_NO_ERRNO(BindLoopback()); - - // Connect to loopback:bind_addr_port+1. - struct sockaddr_storage addr_storage = InetLoopbackAddr(); - struct sockaddr* addr = reinterpret_cast<struct sockaddr*>(&addr_storage); - SetPort(&addr_storage, *Port(&bind_addr_storage_) + 1); - ASSERT_THAT(connect(bind_.get(), addr, addrlen_), SyscallSucceeds()); - - // Bind to loopback:bind_addr_port+1 and connect to bind_addr_. - ASSERT_THAT(bind(sock_.get(), addr, addrlen_), SyscallSucceeds()); - ASSERT_THAT(connect(sock_.get(), bind_addr_, addrlen_), SyscallSucceeds()); - - // Verify that we get EWOULDBLOCK when there is nothing to read. - char received[512]; - EXPECT_THAT(recv(bind_.get(), received, sizeof(received), MSG_DONTWAIT), - SyscallFailsWithErrno(EWOULDBLOCK)); - - const char* buf = "abc"; - EXPECT_THAT(write(sock_.get(), buf, 3), SyscallSucceedsWithValue(3)); - - int opts = 0; - ASSERT_THAT(opts = fcntl(bind_.get(), F_GETFL), SyscallSucceeds()); - ASSERT_THAT(fcntl(bind_.get(), F_SETFL, opts | O_NONBLOCK), - SyscallSucceeds()); - ASSERT_THAT(opts = fcntl(bind_.get(), F_GETFL), SyscallSucceeds()); - ASSERT_NE(opts & O_NONBLOCK, 0); - - EXPECT_THAT(shutdown(bind_.get(), SHUT_RD), SyscallSucceeds()); - - struct pollfd pfd = {bind_.get(), POLLIN, 0}; - ASSERT_THAT(RetryEINTR(poll)(&pfd, 1, /*timeout=*/1000), - SyscallSucceedsWithValue(1)); - - // We should get the data even though read has been shutdown. - EXPECT_THAT(recv(bind_.get(), received, 2, 0), SyscallSucceedsWithValue(2)); - - // Because we read less than the entire packet length, since it's a packet - // based socket any subsequent reads should return EWOULDBLOCK. - EXPECT_THAT(recv(bind_.get(), received, 1, 0), - SyscallFailsWithErrno(EWOULDBLOCK)); -} - -// This test is validating that even after a socket is shutdown if it's -// reconnected it will reset the shutdown state. -TEST_P(UdpSocketTest, ReadShutdownSameSocketResetsShutdownState) { - char received[512]; - EXPECT_THAT(recv(bind_.get(), received, sizeof(received), MSG_DONTWAIT), - SyscallFailsWithErrno(EWOULDBLOCK)); - - EXPECT_THAT(shutdown(bind_.get(), SHUT_RD), SyscallFailsWithErrno(ENOTCONN)); - - EXPECT_THAT(recv(bind_.get(), received, sizeof(received), MSG_DONTWAIT), - SyscallFailsWithErrno(EWOULDBLOCK)); - - // Connect the socket, then try to shutdown again. - ASSERT_NO_ERRNO(BindLoopback()); - - // Connect to loopback:bind_addr_port+1. - struct sockaddr_storage addr_storage = InetLoopbackAddr(); - struct sockaddr* addr = reinterpret_cast<struct sockaddr*>(&addr_storage); - SetPort(&addr_storage, *Port(&bind_addr_storage_) + 1); - ASSERT_THAT(connect(bind_.get(), addr, addrlen_), SyscallSucceeds()); - - EXPECT_THAT(recv(bind_.get(), received, sizeof(received), MSG_DONTWAIT), - SyscallFailsWithErrno(EWOULDBLOCK)); -} - -TEST_P(UdpSocketTest, ReadShutdown) { - // TODO(gvisor.dev/issue/1202): Calling recv() after shutdown without - // MSG_DONTWAIT blocks indefinitely. - SKIP_IF(IsRunningWithHostinet()); - - ASSERT_NO_ERRNO(BindLoopback()); - - char received[512]; - EXPECT_THAT(recv(sock_.get(), received, sizeof(received), MSG_DONTWAIT), - SyscallFailsWithErrno(EWOULDBLOCK)); - - EXPECT_THAT(shutdown(sock_.get(), SHUT_RD), SyscallFailsWithErrno(ENOTCONN)); - - EXPECT_THAT(recv(sock_.get(), received, sizeof(received), MSG_DONTWAIT), - SyscallFailsWithErrno(EWOULDBLOCK)); - - // Connect the socket, then try to shutdown again. - ASSERT_THAT(connect(sock_.get(), bind_addr_, addrlen_), SyscallSucceeds()); - - EXPECT_THAT(recv(sock_.get(), received, sizeof(received), MSG_DONTWAIT), - SyscallFailsWithErrno(EWOULDBLOCK)); - - EXPECT_THAT(shutdown(sock_.get(), SHUT_RD), SyscallSucceeds()); - - EXPECT_THAT(recv(sock_.get(), received, sizeof(received), 0), - SyscallSucceedsWithValue(0)); -} - -TEST_P(UdpSocketTest, ReadShutdownDifferentThread) { - // TODO(gvisor.dev/issue/1202): Calling recv() after shutdown without - // MSG_DONTWAIT blocks indefinitely. - SKIP_IF(IsRunningWithHostinet()); - ASSERT_NO_ERRNO(BindLoopback()); - - char received[512]; - EXPECT_THAT(recv(sock_.get(), received, sizeof(received), MSG_DONTWAIT), - SyscallFailsWithErrno(EWOULDBLOCK)); - - // Connect the socket, then shutdown from another thread. - ASSERT_THAT(connect(sock_.get(), bind_addr_, addrlen_), SyscallSucceeds()); - - EXPECT_THAT(recv(sock_.get(), received, sizeof(received), MSG_DONTWAIT), - SyscallFailsWithErrno(EWOULDBLOCK)); - - ScopedThread t([&] { - absl::SleepFor(absl::Milliseconds(200)); - EXPECT_THAT(shutdown(sock_.get(), SHUT_RD), SyscallSucceeds()); - }); - EXPECT_THAT(RetryEINTR(recv)(sock_.get(), received, sizeof(received), 0), - SyscallSucceedsWithValue(0)); - t.Join(); - - EXPECT_THAT(RetryEINTR(recv)(sock_.get(), received, sizeof(received), 0), - SyscallSucceedsWithValue(0)); -} - -TEST_P(UdpSocketTest, WriteShutdown) { - ASSERT_NO_ERRNO(BindLoopback()); - EXPECT_THAT(shutdown(sock_.get(), SHUT_WR), SyscallFailsWithErrno(ENOTCONN)); - ASSERT_THAT(connect(sock_.get(), bind_addr_, addrlen_), SyscallSucceeds()); - EXPECT_THAT(shutdown(sock_.get(), SHUT_WR), SyscallSucceeds()); -} - -TEST_P(UdpSocketTest, SynchronousReceive) { - ASSERT_NO_ERRNO(BindLoopback()); - - // Send some data to bind_ from another thread. - char buf[512]; - RandomizeBuffer(buf, sizeof(buf)); - - // Receive the data prior to actually starting the other thread. - char received[512]; - EXPECT_THAT( - RetryEINTR(recv)(bind_.get(), received, sizeof(received), MSG_DONTWAIT), - SyscallFailsWithErrno(EWOULDBLOCK)); - - // Start the thread. - ScopedThread t([&] { - absl::SleepFor(absl::Milliseconds(200)); - ASSERT_THAT(sendto(sock_.get(), buf, sizeof(buf), 0, this->bind_addr_, - this->addrlen_), - SyscallSucceedsWithValue(sizeof(buf))); - }); - - EXPECT_THAT(RetryEINTR(recv)(bind_.get(), received, sizeof(received), 0), - SyscallSucceedsWithValue(512)); - EXPECT_EQ(memcmp(buf, received, sizeof(buf)), 0); -} - -TEST_P(UdpSocketTest, BoundaryPreserved_SendRecv) { - ASSERT_NO_ERRNO(BindLoopback()); - - // Send 3 packets from sock to bind_. - constexpr int psize = 100; - char buf[3 * psize]; - RandomizeBuffer(buf, sizeof(buf)); - - for (int i = 0; i < 3; ++i) { - ASSERT_THAT( - sendto(sock_.get(), buf + i * psize, psize, 0, bind_addr_, addrlen_), - SyscallSucceedsWithValue(psize)); - } - - // Receive the data as 3 separate packets. - char received[6 * psize]; - for (int i = 0; i < 3; ++i) { - EXPECT_THAT(recv(bind_.get(), received + i * psize, 3 * psize, 0), - SyscallSucceedsWithValue(psize)); - } - EXPECT_EQ(memcmp(buf, received, 3 * psize), 0); -} - -TEST_P(UdpSocketTest, BoundaryPreserved_WritevReadv) { - ASSERT_NO_ERRNO(BindLoopback()); - - // Direct writes from sock to bind_. - ASSERT_THAT(connect(sock_.get(), bind_addr_, addrlen_), SyscallSucceeds()); - - // Send 2 packets from sock to bind_, where each packet's data consists of - // 2 discontiguous iovecs. - constexpr size_t kPieceSize = 100; - char buf[4 * kPieceSize]; - RandomizeBuffer(buf, sizeof(buf)); - - for (int i = 0; i < 2; i++) { - struct iovec iov[2]; - for (int j = 0; j < 2; j++) { - iov[j].iov_base = reinterpret_cast<void*>( - reinterpret_cast<uintptr_t>(buf) + (i + 2 * j) * kPieceSize); - iov[j].iov_len = kPieceSize; - } - ASSERT_THAT(writev(sock_.get(), iov, 2), - SyscallSucceedsWithValue(2 * kPieceSize)); - } - - // Receive the data as 2 separate packets. - char received[6 * kPieceSize]; - for (int i = 0; i < 2; i++) { - struct iovec iov[3]; - for (int j = 0; j < 3; j++) { - iov[j].iov_base = reinterpret_cast<void*>( - reinterpret_cast<uintptr_t>(received) + (i + 2 * j) * kPieceSize); - iov[j].iov_len = kPieceSize; - } - ASSERT_THAT(readv(bind_.get(), iov, 3), - SyscallSucceedsWithValue(2 * kPieceSize)); - } - EXPECT_EQ(memcmp(buf, received, 4 * kPieceSize), 0); -} - -TEST_P(UdpSocketTest, BoundaryPreserved_SendMsgRecvMsg) { - ASSERT_NO_ERRNO(BindLoopback()); - - // Send 2 packets from sock to bind_, where each packet's data consists of - // 2 discontiguous iovecs. - constexpr size_t kPieceSize = 100; - char buf[4 * kPieceSize]; - RandomizeBuffer(buf, sizeof(buf)); - - for (int i = 0; i < 2; i++) { - struct iovec iov[2]; - for (int j = 0; j < 2; j++) { - iov[j].iov_base = reinterpret_cast<void*>( - reinterpret_cast<uintptr_t>(buf) + (i + 2 * j) * kPieceSize); - iov[j].iov_len = kPieceSize; - } - struct msghdr msg = {}; - msg.msg_name = bind_addr_; - msg.msg_namelen = addrlen_; - msg.msg_iov = iov; - msg.msg_iovlen = 2; - ASSERT_THAT(sendmsg(sock_.get(), &msg, 0), - SyscallSucceedsWithValue(2 * kPieceSize)); - } - - // Receive the data as 2 separate packets. - char received[6 * kPieceSize]; - for (int i = 0; i < 2; i++) { - struct iovec iov[3]; - for (int j = 0; j < 3; j++) { - iov[j].iov_base = reinterpret_cast<void*>( - reinterpret_cast<uintptr_t>(received) + (i + 2 * j) * kPieceSize); - iov[j].iov_len = kPieceSize; - } - struct msghdr msg = {}; - msg.msg_iov = iov; - msg.msg_iovlen = 3; - ASSERT_THAT(recvmsg(bind_.get(), &msg, 0), - SyscallSucceedsWithValue(2 * kPieceSize)); - } - EXPECT_EQ(memcmp(buf, received, 4 * kPieceSize), 0); -} - -TEST_P(UdpSocketTest, FIONREADShutdown) { - ASSERT_NO_ERRNO(BindLoopback()); - - int n = -1; - EXPECT_THAT(ioctl(sock_.get(), FIONREAD, &n), SyscallSucceedsWithValue(0)); - EXPECT_EQ(n, 0); - - // A UDP socket must be connected before it can be shutdown. - ASSERT_THAT(connect(sock_.get(), bind_addr_, addrlen_), SyscallSucceeds()); - - n = -1; - EXPECT_THAT(ioctl(sock_.get(), FIONREAD, &n), SyscallSucceedsWithValue(0)); - EXPECT_EQ(n, 0); - - EXPECT_THAT(shutdown(sock_.get(), SHUT_RD), SyscallSucceeds()); - - n = -1; - EXPECT_THAT(ioctl(sock_.get(), FIONREAD, &n), SyscallSucceedsWithValue(0)); - EXPECT_EQ(n, 0); -} - -TEST_P(UdpSocketTest, FIONREADWriteShutdown) { - int n = -1; - EXPECT_THAT(ioctl(bind_.get(), FIONREAD, &n), SyscallSucceedsWithValue(0)); - EXPECT_EQ(n, 0); - - ASSERT_NO_ERRNO(BindLoopback()); - - // A UDP socket must be connected before it can be shutdown. - ASSERT_THAT(connect(bind_.get(), bind_addr_, addrlen_), SyscallSucceeds()); - - n = -1; - EXPECT_THAT(ioctl(bind_.get(), FIONREAD, &n), SyscallSucceedsWithValue(0)); - EXPECT_EQ(n, 0); - - const char str[] = "abc"; - ASSERT_THAT(send(bind_.get(), str, sizeof(str), 0), - SyscallSucceedsWithValue(sizeof(str))); - - struct pollfd pfd = {bind_.get(), POLLIN, 0}; - ASSERT_THAT(RetryEINTR(poll)(&pfd, 1, /*timeout=*/1000), - SyscallSucceedsWithValue(1)); - - n = -1; - EXPECT_THAT(ioctl(bind_.get(), FIONREAD, &n), SyscallSucceedsWithValue(0)); - EXPECT_EQ(n, sizeof(str)); - - EXPECT_THAT(shutdown(bind_.get(), SHUT_RD), SyscallSucceeds()); - - n = -1; - EXPECT_THAT(ioctl(bind_.get(), FIONREAD, &n), SyscallSucceedsWithValue(0)); - EXPECT_EQ(n, sizeof(str)); -} - -// NOTE: Do not use `FIONREAD` as test name because it will be replaced by the -// corresponding macro and become `0x541B`. -TEST_P(UdpSocketTest, Fionread) { - ASSERT_NO_ERRNO(BindLoopback()); - - // Check that the bound socket with an empty buffer reports an empty first - // packet. - int n = -1; - EXPECT_THAT(ioctl(bind_.get(), FIONREAD, &n), SyscallSucceedsWithValue(0)); - EXPECT_EQ(n, 0); - - // Send 3 packets from sock to bind_. - constexpr int psize = 100; - char buf[3 * psize]; - RandomizeBuffer(buf, sizeof(buf)); - - struct pollfd pfd = {bind_.get(), POLLIN, 0}; - for (int i = 0; i < 3; ++i) { - ASSERT_THAT( - sendto(sock_.get(), buf + i * psize, psize, 0, bind_addr_, addrlen_), - SyscallSucceedsWithValue(psize)); - - ASSERT_THAT(RetryEINTR(poll)(&pfd, 1, /*timeout=*/1000), - SyscallSucceedsWithValue(1)); - - // Check that regardless of how many packets are in the queue, the size - // reported is that of a single packet. - n = -1; - EXPECT_THAT(ioctl(bind_.get(), FIONREAD, &n), SyscallSucceedsWithValue(0)); - EXPECT_EQ(n, psize); - } -} - -TEST_P(UdpSocketTest, FIONREADZeroLengthPacket) { - ASSERT_NO_ERRNO(BindLoopback()); - - // Check that the bound socket with an empty buffer reports an empty first - // packet. - int n = -1; - EXPECT_THAT(ioctl(bind_.get(), FIONREAD, &n), SyscallSucceedsWithValue(0)); - EXPECT_EQ(n, 0); - - // Send 3 packets from sock to bind_. - constexpr int psize = 100; - char buf[3 * psize]; - RandomizeBuffer(buf, sizeof(buf)); - - struct pollfd pfd = {bind_.get(), POLLIN, 0}; - for (int i = 0; i < 3; ++i) { - ASSERT_THAT( - sendto(sock_.get(), buf + i * psize, 0, 0, bind_addr_, addrlen_), - SyscallSucceedsWithValue(0)); - - // TODO(gvisor.dev/issue/2726): sending a zero-length message to a hostinet - // socket does not cause a poll event to be triggered. - if (!IsRunningWithHostinet()) { - ASSERT_THAT(RetryEINTR(poll)(&pfd, 1, /*timeout=*/1000), - SyscallSucceedsWithValue(1)); - } - - // Check that regardless of how many packets are in the queue, the size - // reported is that of a single packet. - n = -1; - EXPECT_THAT(ioctl(bind_.get(), FIONREAD, &n), SyscallSucceedsWithValue(0)); - EXPECT_EQ(n, 0); - } -} - -TEST_P(UdpSocketTest, FIONREADZeroLengthWriteShutdown) { - int n = -1; - EXPECT_THAT(ioctl(bind_.get(), FIONREAD, &n), SyscallSucceedsWithValue(0)); - EXPECT_EQ(n, 0); - - ASSERT_NO_ERRNO(BindLoopback()); - - // A UDP socket must be connected before it can be shutdown. - ASSERT_THAT(connect(bind_.get(), bind_addr_, addrlen_), SyscallSucceeds()); - - n = -1; - EXPECT_THAT(ioctl(bind_.get(), FIONREAD, &n), SyscallSucceedsWithValue(0)); - EXPECT_EQ(n, 0); - - const char str[] = "abc"; - ASSERT_THAT(send(bind_.get(), str, 0, 0), SyscallSucceedsWithValue(0)); - - n = -1; - EXPECT_THAT(ioctl(bind_.get(), FIONREAD, &n), SyscallSucceedsWithValue(0)); - EXPECT_EQ(n, 0); - - EXPECT_THAT(shutdown(bind_.get(), SHUT_RD), SyscallSucceeds()); - - n = -1; - EXPECT_THAT(ioctl(bind_.get(), FIONREAD, &n), SyscallSucceedsWithValue(0)); - EXPECT_EQ(n, 0); -} - -TEST_P(UdpSocketTest, SoNoCheckOffByDefault) { - // TODO(gvisor.dev/issue/1202): SO_NO_CHECK socket option not supported by - // hostinet. - SKIP_IF(IsRunningWithHostinet()); - - int v = -1; - socklen_t optlen = sizeof(v); - ASSERT_THAT(getsockopt(bind_.get(), SOL_SOCKET, SO_NO_CHECK, &v, &optlen), - SyscallSucceeds()); - ASSERT_EQ(v, kSockOptOff); - ASSERT_EQ(optlen, sizeof(v)); -} - -TEST_P(UdpSocketTest, SoNoCheck) { - // TODO(gvisor.dev/issue/1202): SO_NO_CHECK socket option not supported by - // hostinet. - SKIP_IF(IsRunningWithHostinet()); - - int v = kSockOptOn; - socklen_t optlen = sizeof(v); - ASSERT_THAT(setsockopt(bind_.get(), SOL_SOCKET, SO_NO_CHECK, &v, optlen), - SyscallSucceeds()); - v = -1; - ASSERT_THAT(getsockopt(bind_.get(), SOL_SOCKET, SO_NO_CHECK, &v, &optlen), - SyscallSucceeds()); - ASSERT_EQ(v, kSockOptOn); - ASSERT_EQ(optlen, sizeof(v)); - - v = kSockOptOff; - ASSERT_THAT(setsockopt(bind_.get(), SOL_SOCKET, SO_NO_CHECK, &v, optlen), - SyscallSucceeds()); - v = -1; - ASSERT_THAT(getsockopt(bind_.get(), SOL_SOCKET, SO_NO_CHECK, &v, &optlen), - SyscallSucceeds()); - ASSERT_EQ(v, kSockOptOff); - ASSERT_EQ(optlen, sizeof(v)); -} - -TEST_P(UdpSocketTest, SoTimestampOffByDefault) { - // TODO(gvisor.dev/issue/1202): SO_TIMESTAMP socket option not supported by - // hostinet. - SKIP_IF(IsRunningWithHostinet()); - - int v = -1; - socklen_t optlen = sizeof(v); - ASSERT_THAT(getsockopt(bind_.get(), SOL_SOCKET, SO_TIMESTAMP, &v, &optlen), - SyscallSucceeds()); - ASSERT_EQ(v, kSockOptOff); - ASSERT_EQ(optlen, sizeof(v)); -} - -TEST_P(UdpSocketTest, SoTimestamp) { - // TODO(gvisor.dev/issue/1202): ioctl() and SO_TIMESTAMP socket option are not - // supported by hostinet. - SKIP_IF(IsRunningWithHostinet()); - - ASSERT_NO_ERRNO(BindLoopback()); - ASSERT_THAT(connect(sock_.get(), bind_addr_, addrlen_), SyscallSucceeds()); - - int v = 1; - ASSERT_THAT(setsockopt(bind_.get(), SOL_SOCKET, SO_TIMESTAMP, &v, sizeof(v)), - SyscallSucceeds()); - - char buf[3]; - // Send zero length packet from sock to bind_. - ASSERT_THAT(RetryEINTR(write)(sock_.get(), buf, 0), - SyscallSucceedsWithValue(0)); - - struct pollfd pfd = {bind_.get(), POLLIN, 0}; - ASSERT_THAT(RetryEINTR(poll)(&pfd, 1, /*timeout=*/1000), - SyscallSucceedsWithValue(1)); - - char cmsgbuf[CMSG_SPACE(sizeof(struct timeval))]; - msghdr msg; - memset(&msg, 0, sizeof(msg)); - iovec iov; - memset(&iov, 0, sizeof(iov)); - msg.msg_iov = &iov; - msg.msg_iovlen = 1; - msg.msg_control = cmsgbuf; - msg.msg_controllen = sizeof(cmsgbuf); - - ASSERT_THAT(RetryEINTR(recvmsg)(bind_.get(), &msg, 0), - SyscallSucceedsWithValue(0)); - - struct cmsghdr* cmsg = CMSG_FIRSTHDR(&msg); - ASSERT_NE(cmsg, nullptr); - ASSERT_EQ(cmsg->cmsg_level, SOL_SOCKET); - ASSERT_EQ(cmsg->cmsg_type, SO_TIMESTAMP); - ASSERT_EQ(cmsg->cmsg_len, CMSG_LEN(sizeof(struct timeval))); - - struct timeval tv = {}; - memcpy(&tv, CMSG_DATA(cmsg), sizeof(struct timeval)); - - ASSERT_TRUE(tv.tv_sec != 0 || tv.tv_usec != 0); - - // There should be nothing to get via ioctl. - ASSERT_THAT(ioctl(bind_.get(), SIOCGSTAMP, &tv), - SyscallFailsWithErrno(ENOENT)); -} - -TEST_P(UdpSocketTest, WriteShutdownNotConnected) { - EXPECT_THAT(shutdown(bind_.get(), SHUT_WR), SyscallFailsWithErrno(ENOTCONN)); -} - -TEST_P(UdpSocketTest, TimestampIoctl) { - // TODO(gvisor.dev/issue/1202): ioctl() is not supported by hostinet. - SKIP_IF(IsRunningWithHostinet()); - - ASSERT_NO_ERRNO(BindLoopback()); - ASSERT_THAT(connect(sock_.get(), bind_addr_, addrlen_), SyscallSucceeds()); - - char buf[3]; - // Send packet from sock to bind_. - ASSERT_THAT(RetryEINTR(write)(sock_.get(), buf, sizeof(buf)), - SyscallSucceedsWithValue(sizeof(buf))); - - struct pollfd pfd = {bind_.get(), POLLIN, 0}; - ASSERT_THAT(RetryEINTR(poll)(&pfd, 1, /*timeout=*/1000), - SyscallSucceedsWithValue(1)); - - // There should be no control messages. - char recv_buf[sizeof(buf)]; - ASSERT_NO_FATAL_FAILURE(RecvNoCmsg(bind_.get(), recv_buf, sizeof(recv_buf))); - - // A nonzero timeval should be available via ioctl. - struct timeval tv = {}; - ASSERT_THAT(ioctl(bind_.get(), SIOCGSTAMP, &tv), SyscallSucceeds()); - ASSERT_TRUE(tv.tv_sec != 0 || tv.tv_usec != 0); -} - -TEST_P(UdpSocketTest, TimestampIoctlNothingRead) { - // TODO(gvisor.dev/issue/1202): ioctl() is not supported by hostinet. - SKIP_IF(IsRunningWithHostinet()); - - ASSERT_NO_ERRNO(BindLoopback()); - ASSERT_THAT(connect(sock_.get(), bind_addr_, addrlen_), SyscallSucceeds()); - - struct timeval tv = {}; - ASSERT_THAT(ioctl(sock_.get(), SIOCGSTAMP, &tv), - SyscallFailsWithErrno(ENOENT)); -} - -// Test that the timestamp accessed via SIOCGSTAMP is still accessible after -// SO_TIMESTAMP is enabled and used to retrieve a timestamp. -TEST_P(UdpSocketTest, TimestampIoctlPersistence) { - // TODO(gvisor.dev/issue/1202): ioctl() and SO_TIMESTAMP socket option are not - // supported by hostinet. - SKIP_IF(IsRunningWithHostinet()); - - ASSERT_NO_ERRNO(BindLoopback()); - ASSERT_THAT(connect(sock_.get(), bind_addr_, addrlen_), SyscallSucceeds()); - - char buf[3]; - // Send packet from sock to bind_. - ASSERT_THAT(RetryEINTR(write)(sock_.get(), buf, sizeof(buf)), - SyscallSucceedsWithValue(sizeof(buf))); - ASSERT_THAT(RetryEINTR(write)(sock_.get(), buf, 0), - SyscallSucceedsWithValue(0)); - - struct pollfd pfd = {bind_.get(), POLLIN, 0}; - ASSERT_THAT(RetryEINTR(poll)(&pfd, 1, /*timeout=*/1000), - SyscallSucceedsWithValue(1)); - - // There should be no control messages. - char recv_buf[sizeof(buf)]; - ASSERT_NO_FATAL_FAILURE(RecvNoCmsg(bind_.get(), recv_buf, sizeof(recv_buf))); - - // A nonzero timeval should be available via ioctl. - struct timeval tv = {}; - ASSERT_THAT(ioctl(bind_.get(), SIOCGSTAMP, &tv), SyscallSucceeds()); - ASSERT_TRUE(tv.tv_sec != 0 || tv.tv_usec != 0); - - // Enable SO_TIMESTAMP and send a message. - int v = 1; - EXPECT_THAT(setsockopt(bind_.get(), SOL_SOCKET, SO_TIMESTAMP, &v, sizeof(v)), - SyscallSucceeds()); - ASSERT_THAT(RetryEINTR(write)(sock_.get(), buf, 0), - SyscallSucceedsWithValue(0)); - - ASSERT_THAT(RetryEINTR(poll)(&pfd, 1, /*timeout=*/1000), - SyscallSucceedsWithValue(1)); - - // There should be a message for SO_TIMESTAMP. - char cmsgbuf[CMSG_SPACE(sizeof(struct timeval))]; - msghdr msg = {}; - iovec iov = {}; - msg.msg_iov = &iov; - msg.msg_iovlen = 1; - msg.msg_control = cmsgbuf; - msg.msg_controllen = sizeof(cmsgbuf); - ASSERT_THAT(RetryEINTR(recvmsg)(bind_.get(), &msg, 0), - SyscallSucceedsWithValue(0)); - struct cmsghdr* cmsg = CMSG_FIRSTHDR(&msg); - ASSERT_NE(cmsg, nullptr); - - // The ioctl should return the exact same values as before. - struct timeval tv2 = {}; - ASSERT_THAT(ioctl(bind_.get(), SIOCGSTAMP, &tv2), SyscallSucceeds()); - ASSERT_EQ(tv.tv_sec, tv2.tv_sec); - ASSERT_EQ(tv.tv_usec, tv2.tv_usec); -} - -// Test that a socket with IP_TOS or IPV6_TCLASS set will set the TOS byte on -// outgoing packets, and that a receiving socket with IP_RECVTOS or -// IPV6_RECVTCLASS will create the corresponding control message. -TEST_P(UdpSocketTest, SetAndReceiveTOS) { - ASSERT_NO_ERRNO(BindLoopback()); - ASSERT_THAT(connect(sock_.get(), bind_addr_, addrlen_), SyscallSucceeds()); - - // Allow socket to receive control message. - int recv_level = SOL_IP; - int recv_type = IP_RECVTOS; - if (GetParam() != AddressFamily::kIpv4) { - recv_level = SOL_IPV6; - recv_type = IPV6_RECVTCLASS; - } - ASSERT_THAT(setsockopt(bind_.get(), recv_level, recv_type, &kSockOptOn, - sizeof(kSockOptOn)), - SyscallSucceeds()); - - // Set socket TOS. - int sent_level = recv_level; - int sent_type = IP_TOS; - if (sent_level == SOL_IPV6) { - sent_type = IPV6_TCLASS; - } - int sent_tos = IPTOS_LOWDELAY; // Choose some TOS value. - ASSERT_THAT(setsockopt(sock_.get(), sent_level, sent_type, &sent_tos, - sizeof(sent_tos)), - SyscallSucceeds()); - - // Prepare message to send. - constexpr size_t kDataLength = 1024; - struct msghdr sent_msg = {}; - struct iovec sent_iov = {}; - char sent_data[kDataLength]; - sent_iov.iov_base = &sent_data[0]; - sent_iov.iov_len = kDataLength; - sent_msg.msg_iov = &sent_iov; - sent_msg.msg_iovlen = 1; - - ASSERT_THAT(RetryEINTR(sendmsg)(sock_.get(), &sent_msg, 0), - SyscallSucceedsWithValue(kDataLength)); - - // Receive message. - struct msghdr received_msg = {}; - struct iovec received_iov = {}; - char received_data[kDataLength]; - received_iov.iov_base = &received_data[0]; - received_iov.iov_len = kDataLength; - received_msg.msg_iov = &received_iov; - received_msg.msg_iovlen = 1; - size_t cmsg_data_len = sizeof(int8_t); - if (sent_type == IPV6_TCLASS) { - cmsg_data_len = sizeof(int); - } - std::vector<char> received_cmsgbuf(CMSG_SPACE(cmsg_data_len)); - received_msg.msg_control = &received_cmsgbuf[0]; - received_msg.msg_controllen = received_cmsgbuf.size(); - ASSERT_THAT(RetryEINTR(recvmsg)(bind_.get(), &received_msg, 0), - SyscallSucceedsWithValue(kDataLength)); - - struct cmsghdr* cmsg = CMSG_FIRSTHDR(&received_msg); - ASSERT_NE(cmsg, nullptr); - EXPECT_EQ(cmsg->cmsg_len, CMSG_LEN(cmsg_data_len)); - EXPECT_EQ(cmsg->cmsg_level, sent_level); - EXPECT_EQ(cmsg->cmsg_type, sent_type); - int8_t received_tos = 0; - memcpy(&received_tos, CMSG_DATA(cmsg), sizeof(received_tos)); - EXPECT_EQ(received_tos, sent_tos); -} - -// Test that sendmsg with IP_TOS and IPV6_TCLASS control messages will set the -// TOS byte on outgoing packets, and that a receiving socket with IP_RECVTOS or -// IPV6_RECVTCLASS will create the corresponding control message. -TEST_P(UdpSocketTest, SendAndReceiveTOS) { - // TODO(b/146661005): Setting TOS via cmsg not supported for netstack. - SKIP_IF(IsRunningOnGvisor() && !IsRunningWithHostinet()); - - ASSERT_NO_ERRNO(BindLoopback()); - ASSERT_THAT(connect(sock_.get(), bind_addr_, addrlen_), SyscallSucceeds()); - - // Allow socket to receive control message. - int recv_level = SOL_IP; - int recv_type = IP_RECVTOS; - if (GetParam() != AddressFamily::kIpv4) { - recv_level = SOL_IPV6; - recv_type = IPV6_RECVTCLASS; - } - int recv_opt = kSockOptOn; - ASSERT_THAT(setsockopt(bind_.get(), recv_level, recv_type, &recv_opt, - sizeof(recv_opt)), - SyscallSucceeds()); - - // Prepare message to send. - constexpr size_t kDataLength = 1024; - int sent_level = recv_level; - int sent_type = IP_TOS; - int sent_tos = IPTOS_LOWDELAY; // Choose some TOS value. - - struct msghdr sent_msg = {}; - struct iovec sent_iov = {}; - char sent_data[kDataLength]; - sent_iov.iov_base = &sent_data[0]; - sent_iov.iov_len = kDataLength; - sent_msg.msg_iov = &sent_iov; - sent_msg.msg_iovlen = 1; - size_t cmsg_data_len = sizeof(int8_t); - if (sent_level == SOL_IPV6) { - sent_type = IPV6_TCLASS; - cmsg_data_len = sizeof(int); - } - std::vector<char> sent_cmsgbuf(CMSG_SPACE(cmsg_data_len)); - sent_msg.msg_control = &sent_cmsgbuf[0]; - sent_msg.msg_controllen = CMSG_LEN(cmsg_data_len); - - // Manually add control message. - struct cmsghdr* sent_cmsg = CMSG_FIRSTHDR(&sent_msg); - sent_cmsg->cmsg_len = CMSG_LEN(cmsg_data_len); - sent_cmsg->cmsg_level = sent_level; - sent_cmsg->cmsg_type = sent_type; - *(int8_t*)CMSG_DATA(sent_cmsg) = sent_tos; - - ASSERT_THAT(RetryEINTR(sendmsg)(sock_.get(), &sent_msg, 0), - SyscallSucceedsWithValue(kDataLength)); - - // Receive message. - struct msghdr received_msg = {}; - struct iovec received_iov = {}; - char received_data[kDataLength]; - received_iov.iov_base = &received_data[0]; - received_iov.iov_len = kDataLength; - received_msg.msg_iov = &received_iov; - received_msg.msg_iovlen = 1; - std::vector<char> received_cmsgbuf(CMSG_SPACE(cmsg_data_len)); - received_msg.msg_control = &received_cmsgbuf[0]; - received_msg.msg_controllen = CMSG_LEN(cmsg_data_len); - ASSERT_THAT(RetryEINTR(recvmsg)(bind_.get(), &received_msg, 0), - SyscallSucceedsWithValue(kDataLength)); - - struct cmsghdr* cmsg = CMSG_FIRSTHDR(&received_msg); - ASSERT_NE(cmsg, nullptr); - EXPECT_EQ(cmsg->cmsg_len, CMSG_LEN(cmsg_data_len)); - EXPECT_EQ(cmsg->cmsg_level, sent_level); - EXPECT_EQ(cmsg->cmsg_type, sent_type); - int8_t received_tos = 0; - memcpy(&received_tos, CMSG_DATA(cmsg), sizeof(received_tos)); - EXPECT_EQ(received_tos, sent_tos); -} - -TEST_P(UdpSocketTest, RecvBufLimitsEmptyRcvBuf) { - // Discover minimum buffer size by setting it to zero. - constexpr int kRcvBufSz = 0; - ASSERT_THAT(setsockopt(bind_.get(), SOL_SOCKET, SO_RCVBUF, &kRcvBufSz, - sizeof(kRcvBufSz)), - SyscallSucceeds()); - - int min = 0; - socklen_t min_len = sizeof(min); - ASSERT_THAT(getsockopt(bind_.get(), SOL_SOCKET, SO_RCVBUF, &min, &min_len), - SyscallSucceeds()); - - // Bind bind_ to loopback. - ASSERT_NO_ERRNO(BindLoopback()); - - { - // Send data of size min and verify that it's received. - std::vector<char> buf(min); - RandomizeBuffer(buf.data(), buf.size()); - ASSERT_THAT( - sendto(sock_.get(), buf.data(), buf.size(), 0, bind_addr_, addrlen_), - SyscallSucceedsWithValue(buf.size())); - std::vector<char> received(buf.size()); - EXPECT_THAT( - recv(bind_.get(), received.data(), received.size(), MSG_DONTWAIT), - SyscallSucceedsWithValue(received.size())); - } - - { - // Send data of size min + 1 and verify that its received. Both linux and - // Netstack accept a dgram that exceeds rcvBuf limits if the receive buffer - // is currently empty. - std::vector<char> buf(min + 1); - RandomizeBuffer(buf.data(), buf.size()); - ASSERT_THAT( - sendto(sock_.get(), buf.data(), buf.size(), 0, bind_addr_, addrlen_), - SyscallSucceedsWithValue(buf.size())); - - std::vector<char> received(buf.size()); - EXPECT_THAT( - recv(bind_.get(), received.data(), received.size(), MSG_DONTWAIT), - SyscallSucceedsWithValue(received.size())); - } -} - -// Test that receive buffer limits are enforced. -TEST_P(UdpSocketTest, RecvBufLimits) { - // Bind s_ to loopback. - ASSERT_NO_ERRNO(BindLoopback()); - - int min = 0; - { - // Discover minimum buffer size by trying to set it to zero. - constexpr int kRcvBufSz = 0; - ASSERT_THAT(setsockopt(bind_.get(), SOL_SOCKET, SO_RCVBUF, &kRcvBufSz, - sizeof(kRcvBufSz)), - SyscallSucceeds()); - - socklen_t min_len = sizeof(min); - ASSERT_THAT(getsockopt(bind_.get(), SOL_SOCKET, SO_RCVBUF, &min, &min_len), - SyscallSucceeds()); - } - - // Now set the limit to min * 4. - int new_rcv_buf_sz = min * 4; - if (!IsRunningOnGvisor() || IsRunningWithHostinet()) { - // Linux doubles the value specified so just set to min * 2. - new_rcv_buf_sz = min * 2; - } - - ASSERT_THAT(setsockopt(bind_.get(), SOL_SOCKET, SO_RCVBUF, &new_rcv_buf_sz, - sizeof(new_rcv_buf_sz)), - SyscallSucceeds()); - int rcv_buf_sz = 0; - { - socklen_t rcv_buf_len = sizeof(rcv_buf_sz); - ASSERT_THAT(getsockopt(bind_.get(), SOL_SOCKET, SO_RCVBUF, &rcv_buf_sz, - &rcv_buf_len), - SyscallSucceeds()); - } - - { - std::vector<char> buf(min); - RandomizeBuffer(buf.data(), buf.size()); - - ASSERT_THAT( - sendto(sock_.get(), buf.data(), buf.size(), 0, bind_addr_, addrlen_), - SyscallSucceedsWithValue(buf.size())); - ASSERT_THAT( - sendto(sock_.get(), buf.data(), buf.size(), 0, bind_addr_, addrlen_), - SyscallSucceedsWithValue(buf.size())); - ASSERT_THAT( - sendto(sock_.get(), buf.data(), buf.size(), 0, bind_addr_, addrlen_), - SyscallSucceedsWithValue(buf.size())); - ASSERT_THAT( - sendto(sock_.get(), buf.data(), buf.size(), 0, bind_addr_, addrlen_), - SyscallSucceedsWithValue(buf.size())); - int sent = 4; - if (IsRunningOnGvisor() && !IsRunningWithHostinet()) { - // Linux seems to drop the 4th packet even though technically it should - // fit in the receive buffer. - ASSERT_THAT( - sendto(sock_.get(), buf.data(), buf.size(), 0, bind_addr_, addrlen_), - SyscallSucceedsWithValue(buf.size())); - sent++; - } - - for (int i = 0; i < sent - 1; i++) { - // Receive the data. - std::vector<char> received(buf.size()); - EXPECT_THAT( - recv(bind_.get(), received.data(), received.size(), MSG_DONTWAIT), - SyscallSucceedsWithValue(received.size())); - EXPECT_EQ(memcmp(buf.data(), received.data(), buf.size()), 0); - } - - // The last receive should fail with EAGAIN as the last packet should have - // been dropped due to lack of space in the receive buffer. - std::vector<char> received(buf.size()); - EXPECT_THAT( - recv(bind_.get(), received.data(), received.size(), MSG_DONTWAIT), - SyscallFailsWithErrno(EAGAIN)); - } -} - -#ifndef __fuchsia__ - -// TODO(gvisor.dev/2746): Support SO_ATTACH_FILTER/SO_DETACH_FILTER. -// gVisor currently silently ignores attaching a filter. -TEST_P(UdpSocketTest, SetSocketDetachFilter) { - // Program generated using sudo tcpdump -i lo udp and port 1234 -dd - struct sock_filter code[] = { - {0x28, 0, 0, 0x0000000c}, {0x15, 0, 6, 0x000086dd}, - {0x30, 0, 0, 0x00000014}, {0x15, 0, 15, 0x00000011}, - {0x28, 0, 0, 0x00000036}, {0x15, 12, 0, 0x000004d2}, - {0x28, 0, 0, 0x00000038}, {0x15, 10, 11, 0x000004d2}, - {0x15, 0, 10, 0x00000800}, {0x30, 0, 0, 0x00000017}, - {0x15, 0, 8, 0x00000011}, {0x28, 0, 0, 0x00000014}, - {0x45, 6, 0, 0x00001fff}, {0xb1, 0, 0, 0x0000000e}, - {0x48, 0, 0, 0x0000000e}, {0x15, 2, 0, 0x000004d2}, - {0x48, 0, 0, 0x00000010}, {0x15, 0, 1, 0x000004d2}, - {0x6, 0, 0, 0x00040000}, {0x6, 0, 0, 0x00000000}, - }; - struct sock_fprog bpf = { - .len = ABSL_ARRAYSIZE(code), - .filter = code, - }; - ASSERT_THAT( - setsockopt(sock_.get(), SOL_SOCKET, SO_ATTACH_FILTER, &bpf, sizeof(bpf)), - SyscallSucceeds()); - - constexpr int val = 0; - ASSERT_THAT( - setsockopt(sock_.get(), SOL_SOCKET, SO_DETACH_FILTER, &val, sizeof(val)), - SyscallSucceeds()); -} - -TEST_P(UdpSocketTest, SetSocketDetachFilterNoInstalledFilter) { - // TODO(gvisor.dev/2746): Support SO_ATTACH_FILTER/SO_DETACH_FILTER. - SKIP_IF(IsRunningOnGvisor()); - constexpr int val = 0; - ASSERT_THAT( - setsockopt(sock_.get(), SOL_SOCKET, SO_DETACH_FILTER, &val, sizeof(val)), - SyscallFailsWithErrno(ENOENT)); -} - -TEST_P(UdpSocketTest, GetSocketDetachFilter) { - int val = 0; - socklen_t val_len = sizeof(val); - ASSERT_THAT( - getsockopt(sock_.get(), SOL_SOCKET, SO_DETACH_FILTER, &val, &val_len), - SyscallFailsWithErrno(ENOPROTOOPT)); -} - -#endif // __fuchsia__ - -} // namespace testing -} // namespace gvisor diff --git a/test/syscalls/linux/udp_socket_test_cases.h b/test/syscalls/linux/udp_socket_test_cases.h deleted file mode 100644 index f7e25c805..000000000 --- a/test/syscalls/linux/udp_socket_test_cases.h +++ /dev/null @@ -1,82 +0,0 @@ -// Copyright 2019 The gVisor Authors. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -#ifndef THIRD_PARTY_GOLANG_GVISOR_TEST_SYSCALLS_LINUX_SOCKET_IPV4_UDP_UNBOUND_H_ -#define THIRD_PARTY_GOLANG_GVISOR_TEST_SYSCALLS_LINUX_SOCKET_IPV4_UDP_UNBOUND_H_ - -#include <sys/socket.h> - -#include "gtest/gtest.h" -#include "test/syscalls/linux/socket_test_util.h" -#include "test/util/file_descriptor.h" -#include "test/util/posix_error.h" - -namespace gvisor { -namespace testing { - -// The initial port to be be used on gvisor. -constexpr int TestPort = 40000; - -// Fixture for tests parameterized by the address family to use (AF_INET and -// AF_INET6) when creating sockets. -class UdpSocketTest - : public ::testing::TestWithParam<gvisor::testing::AddressFamily> { - protected: - // Creates two sockets that will be used by test cases. - void SetUp() override; - - // Binds the socket bind_ to the loopback and updates bind_addr_. - PosixError BindLoopback(); - - // Binds the socket bind_ to Any and updates bind_addr_. - PosixError BindAny(); - - // Binds given socket to address addr and updates. - PosixError BindSocket(int socket, struct sockaddr* addr); - - // Return initialized Any address to port 0. - struct sockaddr_storage InetAnyAddr(); - - // Return initialized Loopback address to port 0. - struct sockaddr_storage InetLoopbackAddr(); - - // Disconnects socket sockfd. - void Disconnect(int sockfd); - - // Get family for the test. - int GetFamily(); - - // Socket used by Bind methods - FileDescriptor bind_; - - // Second socket used for tests. - FileDescriptor sock_; - - // Address for bind_ socket. - struct sockaddr* bind_addr_; - - // Initialized to the length based on GetFamily(). - socklen_t addrlen_; - - // Storage for bind_addr_. - struct sockaddr_storage bind_addr_storage_; - - private: - // Helper to initialize addrlen_ for the test case. - socklen_t GetAddrLength(); -}; -} // namespace testing -} // namespace gvisor - -#endif // THIRD_PARTY_GOLANG_GVISOR_TEST_SYSCALLS_LINUX_SOCKET_IPV4_UDP_UNBOUND_H_ diff --git a/test/syscalls/linux/unlink.cc b/test/syscalls/linux/unlink.cc index 2040375c9..061e2e0f1 100644 --- a/test/syscalls/linux/unlink.cc +++ b/test/syscalls/linux/unlink.cc @@ -208,6 +208,20 @@ TEST(RmdirTest, CanRemoveWithTrailingSlashes) { ASSERT_THAT(rmdir(slashslash.c_str()), SyscallSucceeds()); } +TEST(UnlinkTest, UnlinkAtEmptyPath) { + auto dir = ASSERT_NO_ERRNO_AND_VALUE(TempPath::CreateDir()); + + auto file = ASSERT_NO_ERRNO_AND_VALUE(TempPath::CreateFileIn(dir.path())); + auto fd = ASSERT_NO_ERRNO_AND_VALUE(Open(file.path(), O_RDWR, 0666)); + EXPECT_THAT(unlinkat(fd.get(), "", 0), SyscallFailsWithErrno(ENOENT)); + + auto dirInDir = ASSERT_NO_ERRNO_AND_VALUE(TempPath::CreateDirIn(dir.path())); + auto dirFD = ASSERT_NO_ERRNO_AND_VALUE( + Open(dirInDir.path(), O_RDONLY | O_DIRECTORY, 0666)); + EXPECT_THAT(unlinkat(dirFD.get(), "", AT_REMOVEDIR), + SyscallFailsWithErrno(ENOENT)); +} + } // namespace } // namespace testing diff --git a/test/syscalls/linux/vdso_clock_gettime.cc b/test/syscalls/linux/vdso_clock_gettime.cc index ce1899f45..2a8699a7b 100644 --- a/test/syscalls/linux/vdso_clock_gettime.cc +++ b/test/syscalls/linux/vdso_clock_gettime.cc @@ -38,8 +38,6 @@ std::string PrintClockId(::testing::TestParamInfo<clockid_t> info) { switch (info.param) { case CLOCK_MONOTONIC: return "CLOCK_MONOTONIC"; - case CLOCK_REALTIME: - return "CLOCK_REALTIME"; case CLOCK_BOOTTIME: return "CLOCK_BOOTTIME"; default: @@ -47,59 +45,36 @@ std::string PrintClockId(::testing::TestParamInfo<clockid_t> info) { } } -class CorrectVDSOClockTest : public ::testing::TestWithParam<clockid_t> {}; +class MonotonicVDSOClockTest : public ::testing::TestWithParam<clockid_t> {}; -TEST_P(CorrectVDSOClockTest, IsCorrect) { +TEST_P(MonotonicVDSOClockTest, IsCorrect) { + // The VDSO implementation of clock_gettime() uses the TSC. On KVM, sentry and + // application TSCs can be very desynchronized; see + // sentry/platform/kvm/kvm.vCPU.setSystemTime(). + SKIP_IF(GvisorPlatform() == Platform::kKVM); + + // Check that when we alternate readings from the clock_gettime syscall and + // the VDSO's implementation, we observe the combined sequence as being + // monotonic. struct timespec tvdso, tsys; absl::Time vdso_time, sys_time; - uint64_t total_calls = 0; - - // It is expected that 82.5% of clock_gettime calls will be less than 100us - // skewed from the system time. - // Unfortunately this is not only influenced by the VDSO clock skew, but also - // by arbitrary scheduling delays and the like. The test is therefore - // regularly disabled. - std::map<absl::Duration, std::tuple<double, uint64_t, uint64_t>> confidence = - { - {absl::Microseconds(100), std::make_tuple(0.825, 0, 0)}, - {absl::Microseconds(250), std::make_tuple(0.94, 0, 0)}, - {absl::Milliseconds(1), std::make_tuple(0.999, 0, 0)}, - }; - - absl::Time start = absl::Now(); - while (absl::Now() < start + absl::Seconds(30)) { - EXPECT_THAT(clock_gettime(GetParam(), &tvdso), SyscallSucceeds()); - EXPECT_THAT(syscall(__NR_clock_gettime, GetParam(), &tsys), - SyscallSucceeds()); - + ASSERT_THAT(syscall(__NR_clock_gettime, GetParam(), &tsys), + SyscallSucceeds()); + sys_time = absl::TimeFromTimespec(tsys); + auto end = absl::Now() + absl::Seconds(10); + while (absl::Now() < end) { + ASSERT_THAT(clock_gettime(GetParam(), &tvdso), SyscallSucceeds()); vdso_time = absl::TimeFromTimespec(tvdso); - - for (auto const& conf : confidence) { - std::get<1>(confidence[conf.first]) += - (sys_time - vdso_time) < conf.first; - } - + EXPECT_LE(sys_time, vdso_time); + ASSERT_THAT(syscall(__NR_clock_gettime, GetParam(), &tsys), + SyscallSucceeds()); sys_time = absl::TimeFromTimespec(tsys); - - for (auto const& conf : confidence) { - std::get<2>(confidence[conf.first]) += - (vdso_time - sys_time) < conf.first; - } - - ++total_calls; - } - - for (auto const& conf : confidence) { - EXPECT_GE(std::get<1>(conf.second) / static_cast<double>(total_calls), - std::get<0>(conf.second)); - EXPECT_GE(std::get<2>(conf.second) / static_cast<double>(total_calls), - std::get<0>(conf.second)); + EXPECT_LE(vdso_time, sys_time); } } -INSTANTIATE_TEST_SUITE_P(ClockGettime, CorrectVDSOClockTest, - ::testing::Values(CLOCK_MONOTONIC, CLOCK_REALTIME, - CLOCK_BOOTTIME), +INSTANTIATE_TEST_SUITE_P(ClockGettime, MonotonicVDSOClockTest, + ::testing::Values(CLOCK_MONOTONIC, CLOCK_BOOTTIME), PrintClockId); } // namespace diff --git a/test/syscalls/linux/write.cc b/test/syscalls/linux/write.cc index 39b5b2f56..77bcfbb8a 100644 --- a/test/syscalls/linux/write.cc +++ b/test/syscalls/linux/write.cc @@ -133,6 +133,91 @@ TEST_F(WriteTest, WriteExceedsRLimit) { EXPECT_THAT(close(fd), SyscallSucceeds()); } +TEST_F(WriteTest, WriteIncrementOffset) { + TempPath tmpfile = ASSERT_NO_ERRNO_AND_VALUE(TempPath::CreateFile()); + FileDescriptor f = + ASSERT_NO_ERRNO_AND_VALUE(Open(tmpfile.path().c_str(), O_WRONLY)); + int fd = f.get(); + + EXPECT_THAT(WriteBytes(fd, 0), SyscallSucceedsWithValue(0)); + EXPECT_THAT(lseek(fd, 0, SEEK_CUR), SyscallSucceedsWithValue(0)); + + const int bytes_total = 1024; + + EXPECT_THAT(WriteBytes(fd, bytes_total), + SyscallSucceedsWithValue(bytes_total)); + EXPECT_THAT(lseek(fd, 0, SEEK_CUR), SyscallSucceedsWithValue(bytes_total)); +} + +TEST_F(WriteTest, WriteIncrementOffsetSeek) { + const std::string data = "hello world\n"; + TempPath tmpfile = ASSERT_NO_ERRNO_AND_VALUE(TempPath::CreateFileWith( + GetAbsoluteTestTmpdir(), data, TempPath::kDefaultFileMode)); + FileDescriptor f = + ASSERT_NO_ERRNO_AND_VALUE(Open(tmpfile.path().c_str(), O_WRONLY)); + int fd = f.get(); + + const int seek_offset = data.size() / 2; + ASSERT_THAT(lseek(fd, seek_offset, SEEK_SET), + SyscallSucceedsWithValue(seek_offset)); + + const int write_bytes = 512; + EXPECT_THAT(WriteBytes(fd, write_bytes), + SyscallSucceedsWithValue(write_bytes)); + EXPECT_THAT(lseek(fd, 0, SEEK_CUR), + SyscallSucceedsWithValue(seek_offset + write_bytes)); +} + +TEST_F(WriteTest, WriteIncrementOffsetAppend) { + const std::string data = "hello world\n"; + TempPath tmpfile = ASSERT_NO_ERRNO_AND_VALUE(TempPath::CreateFileWith( + GetAbsoluteTestTmpdir(), data, TempPath::kDefaultFileMode)); + FileDescriptor f = ASSERT_NO_ERRNO_AND_VALUE( + Open(tmpfile.path().c_str(), O_WRONLY | O_APPEND)); + int fd = f.get(); + + EXPECT_THAT(WriteBytes(fd, 1024), SyscallSucceedsWithValue(1024)); + EXPECT_THAT(lseek(fd, 0, SEEK_CUR), + SyscallSucceedsWithValue(data.size() + 1024)); +} + +TEST_F(WriteTest, WriteIncrementOffsetEOF) { + const std::string data = "hello world\n"; + const TempPath tmpfile = ASSERT_NO_ERRNO_AND_VALUE(TempPath::CreateFileWith( + GetAbsoluteTestTmpdir(), data, TempPath::kDefaultFileMode)); + FileDescriptor f = + ASSERT_NO_ERRNO_AND_VALUE(Open(tmpfile.path().c_str(), O_WRONLY)); + int fd = f.get(); + + EXPECT_THAT(lseek(fd, 0, SEEK_END), SyscallSucceedsWithValue(data.size())); + + EXPECT_THAT(WriteBytes(fd, 1024), SyscallSucceedsWithValue(1024)); + EXPECT_THAT(lseek(fd, 0, SEEK_END), + SyscallSucceedsWithValue(data.size() + 1024)); +} + +TEST_F(WriteTest, PwriteNoChangeOffset) { + TempPath tmpfile = ASSERT_NO_ERRNO_AND_VALUE(TempPath::CreateFile()); + FileDescriptor f = + ASSERT_NO_ERRNO_AND_VALUE(Open(tmpfile.path().c_str(), O_WRONLY)); + int fd = f.get(); + + const std::string data = "hello world\n"; + + EXPECT_THAT(pwrite(fd, data.data(), data.size(), 0), + SyscallSucceedsWithValue(data.size())); + EXPECT_THAT(lseek(fd, 0, SEEK_CUR), SyscallSucceedsWithValue(0)); + + const int bytes_total = 1024; + ASSERT_THAT(WriteBytes(fd, bytes_total), + SyscallSucceedsWithValue(bytes_total)); + ASSERT_THAT(lseek(fd, 0, SEEK_CUR), SyscallSucceedsWithValue(bytes_total)); + + EXPECT_THAT(pwrite(fd, data.data(), data.size(), bytes_total), + SyscallSucceedsWithValue(data.size())); + EXPECT_THAT(lseek(fd, 0, SEEK_CUR), SyscallSucceedsWithValue(bytes_total)); +} + } // namespace } // namespace testing diff --git a/test/syscalls/linux/xattr.cc b/test/syscalls/linux/xattr.cc index cbcf08451..bd3f829c4 100644 --- a/test/syscalls/linux/xattr.cc +++ b/test/syscalls/linux/xattr.cc @@ -28,6 +28,7 @@ #include "test/syscalls/linux/file_base.h" #include "test/util/capability_util.h" #include "test/util/file_descriptor.h" +#include "test/util/fs_util.h" #include "test/util/posix_error.h" #include "test/util/temp_path.h" #include "test/util/test_util.h" @@ -37,6 +38,8 @@ namespace testing { namespace { +using ::gvisor::testing::IsTmpfs; + class XattrTest : public FileTest {}; TEST_F(XattrTest, XattrNonexistentFile) { @@ -229,7 +232,7 @@ TEST_F(XattrTest, XattrOnInvalidFileTypes) { EXPECT_THAT(removexattr(path, name), SyscallFailsWithErrno(EPERM)); } -TEST_F(XattrTest, SetxattrSizeSmallerThanValue) { +TEST_F(XattrTest, SetXattrSizeSmallerThanValue) { const char* path = test_file_name_.c_str(); const char name[] = "user.test"; std::vector<char> val = {'a', 'a'}; @@ -244,7 +247,7 @@ TEST_F(XattrTest, SetxattrSizeSmallerThanValue) { EXPECT_EQ(buf, expected_buf); } -TEST_F(XattrTest, SetxattrZeroSize) { +TEST_F(XattrTest, SetXattrZeroSize) { const char* path = test_file_name_.c_str(); const char name[] = "user.test"; char val = 'a'; @@ -256,7 +259,7 @@ TEST_F(XattrTest, SetxattrZeroSize) { EXPECT_EQ(buf, '-'); } -TEST_F(XattrTest, SetxattrSizeTooLarge) { +TEST_F(XattrTest, SetXattrSizeTooLarge) { const char* path = test_file_name_.c_str(); const char name[] = "user.test"; @@ -271,7 +274,7 @@ TEST_F(XattrTest, SetxattrSizeTooLarge) { EXPECT_THAT(getxattr(path, name, nullptr, 0), SyscallFailsWithErrno(ENODATA)); } -TEST_F(XattrTest, SetxattrNullValueAndNonzeroSize) { +TEST_F(XattrTest, SetXattrNullValueAndNonzeroSize) { const char* path = test_file_name_.c_str(); const char name[] = "user.test"; EXPECT_THAT(setxattr(path, name, nullptr, 1, /*flags=*/0), @@ -280,7 +283,7 @@ TEST_F(XattrTest, SetxattrNullValueAndNonzeroSize) { EXPECT_THAT(getxattr(path, name, nullptr, 0), SyscallFailsWithErrno(ENODATA)); } -TEST_F(XattrTest, SetxattrNullValueAndZeroSize) { +TEST_F(XattrTest, SetXattrNullValueAndZeroSize) { const char* path = test_file_name_.c_str(); const char name[] = "user.test"; EXPECT_THAT(setxattr(path, name, nullptr, 0, /*flags=*/0), SyscallSucceeds()); @@ -288,7 +291,7 @@ TEST_F(XattrTest, SetxattrNullValueAndZeroSize) { EXPECT_THAT(getxattr(path, name, nullptr, 0), SyscallSucceedsWithValue(0)); } -TEST_F(XattrTest, SetxattrValueTooLargeButOKSize) { +TEST_F(XattrTest, SetXattrValueTooLargeButOKSize) { const char* path = test_file_name_.c_str(); const char name[] = "user.test"; std::vector<char> val(XATTR_SIZE_MAX + 1); @@ -304,7 +307,7 @@ TEST_F(XattrTest, SetxattrValueTooLargeButOKSize) { EXPECT_EQ(buf, expected_buf); } -TEST_F(XattrTest, SetxattrReplaceWithSmaller) { +TEST_F(XattrTest, SetXattrReplaceWithSmaller) { const char* path = test_file_name_.c_str(); const char name[] = "user.test"; std::vector<char> val = {'a', 'a'}; @@ -319,7 +322,7 @@ TEST_F(XattrTest, SetxattrReplaceWithSmaller) { EXPECT_EQ(buf, expected_buf); } -TEST_F(XattrTest, SetxattrReplaceWithLarger) { +TEST_F(XattrTest, SetXattrReplaceWithLarger) { const char* path = test_file_name_.c_str(); const char name[] = "user.test"; std::vector<char> val = {'a', 'a'}; @@ -333,7 +336,7 @@ TEST_F(XattrTest, SetxattrReplaceWithLarger) { EXPECT_EQ(buf, val); } -TEST_F(XattrTest, SetxattrCreateFlag) { +TEST_F(XattrTest, SetXattrCreateFlag) { const char* path = test_file_name_.c_str(); const char name[] = "user.test"; EXPECT_THAT(setxattr(path, name, nullptr, 0, XATTR_CREATE), @@ -344,7 +347,7 @@ TEST_F(XattrTest, SetxattrCreateFlag) { EXPECT_THAT(getxattr(path, name, nullptr, 0), SyscallSucceedsWithValue(0)); } -TEST_F(XattrTest, SetxattrReplaceFlag) { +TEST_F(XattrTest, SetXattrReplaceFlag) { const char* path = test_file_name_.c_str(); const char name[] = "user.test"; EXPECT_THAT(setxattr(path, name, nullptr, 0, XATTR_REPLACE), @@ -356,14 +359,14 @@ TEST_F(XattrTest, SetxattrReplaceFlag) { EXPECT_THAT(getxattr(path, name, nullptr, 0), SyscallSucceedsWithValue(0)); } -TEST_F(XattrTest, SetxattrInvalidFlags) { +TEST_F(XattrTest, SetXattrInvalidFlags) { const char* path = test_file_name_.c_str(); int invalid_flags = 0xff; EXPECT_THAT(setxattr(path, nullptr, nullptr, 0, invalid_flags), SyscallFailsWithErrno(EINVAL)); } -TEST_F(XattrTest, Getxattr) { +TEST_F(XattrTest, GetXattr) { const char* path = test_file_name_.c_str(); const char name[] = "user.test"; int val = 1234; @@ -375,7 +378,7 @@ TEST_F(XattrTest, Getxattr) { EXPECT_EQ(buf, val); } -TEST_F(XattrTest, GetxattrSizeSmallerThanValue) { +TEST_F(XattrTest, GetXattrSizeSmallerThanValue) { const char* path = test_file_name_.c_str(); const char name[] = "user.test"; std::vector<char> val = {'a', 'a'}; @@ -387,7 +390,7 @@ TEST_F(XattrTest, GetxattrSizeSmallerThanValue) { EXPECT_EQ(buf, '-'); } -TEST_F(XattrTest, GetxattrSizeLargerThanValue) { +TEST_F(XattrTest, GetXattrSizeLargerThanValue) { const char* path = test_file_name_.c_str(); const char name[] = "user.test"; char val = 'a'; @@ -402,7 +405,7 @@ TEST_F(XattrTest, GetxattrSizeLargerThanValue) { EXPECT_EQ(buf, expected_buf); } -TEST_F(XattrTest, GetxattrZeroSize) { +TEST_F(XattrTest, GetXattrZeroSize) { const char* path = test_file_name_.c_str(); const char name[] = "user.test"; char val = 'a'; @@ -415,7 +418,7 @@ TEST_F(XattrTest, GetxattrZeroSize) { EXPECT_EQ(buf, '-'); } -TEST_F(XattrTest, GetxattrSizeTooLarge) { +TEST_F(XattrTest, GetXattrSizeTooLarge) { const char* path = test_file_name_.c_str(); const char name[] = "user.test"; char val = 'a'; @@ -431,7 +434,7 @@ TEST_F(XattrTest, GetxattrSizeTooLarge) { EXPECT_EQ(buf, expected_buf); } -TEST_F(XattrTest, GetxattrNullValue) { +TEST_F(XattrTest, GetXattrNullValue) { const char* path = test_file_name_.c_str(); const char name[] = "user.test"; char val = 'a'; @@ -442,7 +445,7 @@ TEST_F(XattrTest, GetxattrNullValue) { SyscallFailsWithErrno(EFAULT)); } -TEST_F(XattrTest, GetxattrNullValueAndZeroSize) { +TEST_F(XattrTest, GetXattrNullValueAndZeroSize) { const char* path = test_file_name_.c_str(); const char name[] = "user.test"; char val = 'a'; @@ -458,13 +461,13 @@ TEST_F(XattrTest, GetxattrNullValueAndZeroSize) { EXPECT_THAT(getxattr(path, name, nullptr, 0), SyscallSucceedsWithValue(size)); } -TEST_F(XattrTest, GetxattrNonexistentName) { +TEST_F(XattrTest, GetXattrNonexistentName) { const char* path = test_file_name_.c_str(); const char name[] = "user.test"; EXPECT_THAT(getxattr(path, name, nullptr, 0), SyscallFailsWithErrno(ENODATA)); } -TEST_F(XattrTest, Listxattr) { +TEST_F(XattrTest, ListXattr) { const char* path = test_file_name_.c_str(); const std::string name = "user.test"; const std::string name2 = "user.test2"; @@ -490,7 +493,7 @@ TEST_F(XattrTest, Listxattr) { EXPECT_EQ(got, expected); } -TEST_F(XattrTest, ListxattrNoXattrs) { +TEST_F(XattrTest, ListXattrNoXattrs) { const char* path = test_file_name_.c_str(); std::vector<char> list, expected; @@ -498,13 +501,13 @@ TEST_F(XattrTest, ListxattrNoXattrs) { SyscallSucceedsWithValue(0)); EXPECT_EQ(list, expected); - // Listxattr should succeed if there are no attributes, even if the buffer + // ListXattr should succeed if there are no attributes, even if the buffer // passed in is a nullptr. EXPECT_THAT(listxattr(path, nullptr, sizeof(list)), SyscallSucceedsWithValue(0)); } -TEST_F(XattrTest, ListxattrNullBuffer) { +TEST_F(XattrTest, ListXattrNullBuffer) { const char* path = test_file_name_.c_str(); const char name[] = "user.test"; EXPECT_THAT(setxattr(path, name, nullptr, 0, /*flags=*/0), SyscallSucceeds()); @@ -513,7 +516,7 @@ TEST_F(XattrTest, ListxattrNullBuffer) { SyscallFailsWithErrno(EFAULT)); } -TEST_F(XattrTest, ListxattrSizeTooSmall) { +TEST_F(XattrTest, ListXattrSizeTooSmall) { const char* path = test_file_name_.c_str(); const char name[] = "user.test"; EXPECT_THAT(setxattr(path, name, nullptr, 0, /*flags=*/0), SyscallSucceeds()); @@ -523,7 +526,7 @@ TEST_F(XattrTest, ListxattrSizeTooSmall) { SyscallFailsWithErrno(ERANGE)); } -TEST_F(XattrTest, ListxattrZeroSize) { +TEST_F(XattrTest, ListXattrZeroSize) { const char* path = test_file_name_.c_str(); const char name[] = "user.test"; EXPECT_THAT(setxattr(path, name, nullptr, 0, /*flags=*/0), SyscallSucceeds()); @@ -604,6 +607,83 @@ TEST_F(XattrTest, XattrWithFD) { EXPECT_THAT(fremovexattr(fd.get(), name), SyscallSucceeds()); } +TEST_F(XattrTest, TrustedNamespaceWithCapSysAdmin) { + // Trusted namespace not supported in VFS1. + SKIP_IF(IsRunningWithVFS1()); + + // TODO(b/66162845): Only gVisor tmpfs currently supports trusted namespace. + SKIP_IF(IsRunningOnGvisor() && + !ASSERT_NO_ERRNO_AND_VALUE(IsTmpfs(test_file_name_))); + + const char* path = test_file_name_.c_str(); + const char name[] = "trusted.test"; + + // Writing to the trusted.* xattr namespace requires CAP_SYS_ADMIN in the root + // user namespace. There's no easy way to check that, other than trying the + // operation and seeing what happens. We'll call removexattr because it's + // simplest. + if (removexattr(path, name) < 0) { + SKIP_IF(errno == EPERM); + FAIL() << "unexpected errno from removexattr: " << errno; + } + + // Set. + char val = 'a'; + size_t size = sizeof(val); + EXPECT_THAT(setxattr(path, name, &val, size, /*flags=*/0), SyscallSucceeds()); + + // Get. + char got = '\0'; + EXPECT_THAT(getxattr(path, name, &got, size), SyscallSucceedsWithValue(size)); + EXPECT_EQ(val, got); + + // List. + char list[sizeof(name)]; + EXPECT_THAT(listxattr(path, list, sizeof(list)), + SyscallSucceedsWithValue(sizeof(name))); + EXPECT_STREQ(list, name); + + // Remove. + EXPECT_THAT(removexattr(path, name), SyscallSucceeds()); + + // Get should now return ENODATA. + EXPECT_THAT(getxattr(path, name, &got, size), SyscallFailsWithErrno(ENODATA)); +} + +TEST_F(XattrTest, TrustedNamespaceWithoutCapSysAdmin) { + // Trusted namespace not supported in VFS1. + SKIP_IF(IsRunningWithVFS1()); + + // TODO(b/66162845): Only gVisor tmpfs currently supports trusted namespace. + SKIP_IF(IsRunningOnGvisor() && + !ASSERT_NO_ERRNO_AND_VALUE(IsTmpfs(test_file_name_))); + + // Drop CAP_SYS_ADMIN if we have it. + if (ASSERT_NO_ERRNO_AND_VALUE(HaveCapability(CAP_SYS_ADMIN))) { + EXPECT_NO_ERRNO(SetCapability(CAP_SYS_ADMIN, false)); + } + + const char* path = test_file_name_.c_str(); + const char name[] = "trusted.test"; + + // Set fails. + char val = 'a'; + size_t size = sizeof(val); + EXPECT_THAT(setxattr(path, name, &val, size, /*flags=*/0), + SyscallFailsWithErrno(EPERM)); + + // Get fails. + char got = '\0'; + EXPECT_THAT(getxattr(path, name, &got, size), SyscallFailsWithErrno(ENODATA)); + + // List still works, but returns no items. + char list[sizeof(name)]; + EXPECT_THAT(listxattr(path, list, sizeof(list)), SyscallSucceedsWithValue(0)); + + // Remove fails. + EXPECT_THAT(removexattr(path, name), SyscallFailsWithErrno(EPERM)); +} + } // namespace } // namespace testing diff --git a/test/util/BUILD b/test/util/BUILD index 2a17c33ee..fc5fb3a8d 100644 --- a/test/util/BUILD +++ b/test/util/BUILD @@ -46,6 +46,13 @@ cc_library( ) cc_library( + name = "fuse_util", + testonly = 1, + srcs = ["fuse_util.cc"], + hdrs = ["fuse_util.h"], +) + +cc_library( name = "proc_util", testonly = 1, srcs = ["proc_util.cc"], diff --git a/test/util/fs_util.cc b/test/util/fs_util.cc index 5418948fe..b16055dd8 100644 --- a/test/util/fs_util.cc +++ b/test/util/fs_util.cc @@ -15,7 +15,11 @@ #include "test/util/fs_util.h" #include <dirent.h> +#ifdef __linux__ +#include <linux/magic.h> +#endif // __linux__ #include <sys/stat.h> +#include <sys/statfs.h> #include <sys/types.h> #include <unistd.h> @@ -629,5 +633,35 @@ PosixErrorOr<std::string> ProcessExePath(int pid) { return ReadLink(absl::StrCat("/proc/", pid, "/exe")); } +#ifdef __linux__ +PosixErrorOr<bool> IsTmpfs(const std::string& path) { + struct statfs stat; + if (statfs(path.c_str(), &stat)) { + if (errno == ENOENT) { + // Nothing at path, don't raise this as an error. Instead, just report no + // tmpfs at path. + return false; + } + return PosixError(errno, + absl::StrFormat("statfs(\"%s\", %#p)", path, &stat)); + } + return stat.f_type == TMPFS_MAGIC; +} +#endif // __linux__ + +PosixErrorOr<bool> IsOverlayfs(const std::string& path) { + struct statfs stat; + if (statfs(path.c_str(), &stat)) { + if (errno == ENOENT) { + // Nothing at path, don't raise this as an error. Instead, just report no + // overlayfs at path. + return false; + } + return PosixError(errno, + absl::StrFormat("statfs(\"%s\", %#p)", path, &stat)); + } + return stat.f_type == OVERLAYFS_SUPER_MAGIC; +} + } // namespace testing } // namespace gvisor diff --git a/test/util/fs_util.h b/test/util/fs_util.h index 8cdac23a1..c99cf5eb7 100644 --- a/test/util/fs_util.h +++ b/test/util/fs_util.h @@ -17,6 +17,7 @@ #include <dirent.h> #include <sys/stat.h> +#include <sys/statfs.h> #include <sys/types.h> #include <unistd.h> @@ -37,6 +38,10 @@ constexpr int kOLargeFile = 00400000; #error "Unknown architecture" #endif +// From linux/magic.h. For some reason, not defined in the headers for some +// build environments. +#define OVERLAYFS_SUPER_MAGIC 0x794c7630 + // Returns a status or the current working directory. PosixErrorOr<std::string> GetCWD(); @@ -178,6 +183,14 @@ std::string CleanPath(absl::string_view path); // Returns the full path to the executable of the given pid or a PosixError. PosixErrorOr<std::string> ProcessExePath(int pid); +#ifdef __linux__ +// IsTmpfs returns true if the file at path is backed by tmpfs. +PosixErrorOr<bool> IsTmpfs(const std::string& path); +#endif // __linux__ + +// IsOverlayfs returns true if the file at path is backed by overlayfs. +PosixErrorOr<bool> IsOverlayfs(const std::string& path); + namespace internal { // Not part of the public API. std::string JoinPathImpl(std::initializer_list<absl::string_view> paths); diff --git a/test/util/fuse_util.cc b/test/util/fuse_util.cc new file mode 100644 index 000000000..027f8386c --- /dev/null +++ b/test/util/fuse_util.cc @@ -0,0 +1,63 @@ +// Copyright 2020 The gVisor Authors. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "test/util/fuse_util.h" + +#include <sys/stat.h> +#include <sys/types.h> + +#include <string> + +namespace gvisor { +namespace testing { + +// Create a default FuseAttr struct with specified mode, inode, and size. +fuse_attr DefaultFuseAttr(mode_t mode, uint64_t inode, uint64_t size) { + const int time_sec = 1595436289; + const int time_nsec = 134150844; + return (struct fuse_attr){ + .ino = inode, + .size = size, + .blocks = 4, + .atime = time_sec, + .mtime = time_sec, + .ctime = time_sec, + .atimensec = time_nsec, + .mtimensec = time_nsec, + .ctimensec = time_nsec, + .mode = mode, + .nlink = 2, + .uid = 1234, + .gid = 4321, + .rdev = 12, + .blksize = 4096, + }; +} + +// Create response body with specified mode, nodeID, and size. +fuse_entry_out DefaultEntryOut(mode_t mode, uint64_t node_id, uint64_t size) { + struct fuse_entry_out default_entry_out = { + .nodeid = node_id, + .generation = 0, + .entry_valid = 0, + .attr_valid = 0, + .entry_valid_nsec = 0, + .attr_valid_nsec = 0, + .attr = DefaultFuseAttr(mode, node_id, size), + }; + return default_entry_out; +} + +} // namespace testing +} // namespace gvisor diff --git a/test/util/fuse_util.h b/test/util/fuse_util.h new file mode 100644 index 000000000..544fe1b38 --- /dev/null +++ b/test/util/fuse_util.h @@ -0,0 +1,75 @@ +// Copyright 2020 The gVisor Authors. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#ifndef GVISOR_TEST_UTIL_FUSE_UTIL_H_ +#define GVISOR_TEST_UTIL_FUSE_UTIL_H_ + +#include <linux/fuse.h> +#include <sys/uio.h> + +#include <string> +#include <vector> + +namespace gvisor { +namespace testing { + +// The fundamental generation function with a single argument. If passed by +// std::string or std::vector<char>, it will call specialized versions as +// implemented below. +template <typename T> +std::vector<struct iovec> FuseGenerateIovecs(T &first) { + return {(struct iovec){.iov_base = &first, .iov_len = sizeof(first)}}; +} + +// If an argument is of type std::string, it must be used in read-only scenario. +// Because we are setting up iovec, which contains the original address of a +// data structure, we have to drop const qualification. Usually used with +// variable-length payload data. +template <typename T = std::string> +std::vector<struct iovec> FuseGenerateIovecs(std::string &first) { + // Pad one byte for null-terminate c-string. + return {(struct iovec){.iov_base = const_cast<char *>(first.c_str()), + .iov_len = first.size() + 1}}; +} + +// If an argument is of type std::vector<char>, it must be used in write-only +// scenario and the size of the variable must be greater than or equal to the +// size of the expected data. Usually used with variable-length payload data. +template <typename T = std::vector<char>> +std::vector<struct iovec> FuseGenerateIovecs(std::vector<char> &first) { + return {(struct iovec){.iov_base = first.data(), .iov_len = first.size()}}; +} + +// A helper function to set up an array of iovec struct for testing purpose. +// Use variadic class template to generalize different numbers and different +// types of FUSE structs. +template <typename T, typename... Types> +std::vector<struct iovec> FuseGenerateIovecs(T &first, Types &...args) { + auto first_iovec = FuseGenerateIovecs(first); + auto iovecs = FuseGenerateIovecs(args...); + first_iovec.insert(std::end(first_iovec), std::begin(iovecs), + std::end(iovecs)); + return first_iovec; +} + +// Create a fuse_attr filled with the specified mode and inode. +fuse_attr DefaultFuseAttr(mode_t mode, uint64_t inode, uint64_t size = 512); + +// Return a fuse_entry_out FUSE server response body. +fuse_entry_out DefaultEntryOut(mode_t mode, uint64_t node_id, + uint64_t size = 512); + +} // namespace testing +} // namespace gvisor +#endif // GVISOR_TEST_UTIL_FUSE_UTIL_H_ diff --git a/test/util/pty_util.cc b/test/util/pty_util.cc index c01f916aa..2cf0bea74 100644 --- a/test/util/pty_util.cc +++ b/test/util/pty_util.cc @@ -23,15 +23,15 @@ namespace gvisor { namespace testing { -PosixErrorOr<FileDescriptor> OpenSlave(const FileDescriptor& master) { - PosixErrorOr<int> n = SlaveID(master); +PosixErrorOr<FileDescriptor> OpenReplica(const FileDescriptor& master) { + PosixErrorOr<int> n = ReplicaID(master); if (!n.ok()) { return PosixErrorOr<FileDescriptor>(n.error()); } return Open(absl::StrCat("/dev/pts/", n.ValueOrDie()), O_RDWR | O_NONBLOCK); } -PosixErrorOr<int> SlaveID(const FileDescriptor& master) { +PosixErrorOr<int> ReplicaID(const FileDescriptor& master) { // Get pty index. int n; int ret = ioctl(master.get(), TIOCGPTN, &n); diff --git a/test/util/pty_util.h b/test/util/pty_util.h index 0722da379..ed7658868 100644 --- a/test/util/pty_util.h +++ b/test/util/pty_util.h @@ -21,11 +21,11 @@ namespace gvisor { namespace testing { -// Opens the slave end of the passed master as R/W and nonblocking. -PosixErrorOr<FileDescriptor> OpenSlave(const FileDescriptor& master); +// Opens the replica end of the passed master as R/W and nonblocking. +PosixErrorOr<FileDescriptor> OpenReplica(const FileDescriptor& master); -// Get the number of the slave end of the master. -PosixErrorOr<int> SlaveID(const FileDescriptor& master); +// Get the number of the replica end of the master. +PosixErrorOr<int> ReplicaID(const FileDescriptor& master); } // namespace testing } // namespace gvisor diff --git a/test/util/test_util_runfiles.cc b/test/util/test_util_runfiles.cc index 694d21692..7210094eb 100644 --- a/test/util/test_util_runfiles.cc +++ b/test/util/test_util_runfiles.cc @@ -12,8 +12,6 @@ // See the License for the specific language governing permissions and // limitations under the License. -#ifndef __fuchsia__ - #include <iostream> #include <string> @@ -46,5 +44,3 @@ std::string RunfilePath(std::string path) { } // namespace testing } // namespace gvisor - -#endif // __fuchsia__ diff --git a/tools/bazel.mk b/tools/bazel.mk index 3e27af7d1..5cc1cdea2 100644 --- a/tools/bazel.mk +++ b/tools/bazel.mk @@ -31,6 +31,7 @@ DOCKER_PRIVILEGED ?= --privileged BAZEL_CACHE := $(shell readlink -m ~/.cache/bazel/) GCLOUD_CONFIG := $(shell readlink -m ~/.config/gcloud/) DOCKER_SOCKET := /var/run/docker.sock +DOCKER_CONFIG := /etc/docker/daemon.json # Bazel flags. BAZEL := bazel $(STARTUP_OPTIONS) @@ -56,6 +57,9 @@ endif # Add docker passthrough options. ifneq ($(DOCKER_PRIVILEGED),) FULL_DOCKER_RUN_OPTIONS += -v "$(DOCKER_SOCKET):$(DOCKER_SOCKET)" +# TODO(gvisor.dev/issue/1624): Remove docker config volume. This is required +# temporarily for checking VFS1 vs VFS2 by some tests. +FULL_DOCKER_RUN_OPTIONS += -v "$(DOCKER_CONFIG):$(DOCKER_CONFIG)" FULL_DOCKER_RUN_OPTIONS += $(DOCKER_PRIVILEGED) FULL_DOCKER_EXEC_OPTIONS += $(DOCKER_PRIVILEGED) DOCKER_GROUP := $(shell stat -c '%g' $(DOCKER_SOCKET)) diff --git a/tools/bazeldefs/defs.bzl b/tools/bazeldefs/defs.bzl index db7f379b8..dad5fc3b2 100644 --- a/tools/bazeldefs/defs.bzl +++ b/tools/bazeldefs/defs.bzl @@ -87,13 +87,14 @@ def cc_binary(name, static = False, **kwargs): **kwargs ) -def go_binary(name, static = False, pure = False, **kwargs): +def go_binary(name, static = False, pure = False, x_defs = None, **kwargs): """Build a go binary. Args: name: name of the target. static: build a static binary. pure: build without cgo. + x_defs: additional definitions. **kwargs: rest of the arguments are passed to _go_binary. """ if static: @@ -102,6 +103,7 @@ def go_binary(name, static = False, pure = False, **kwargs): kwargs["pure"] = "on" _go_binary( name = name, + x_defs = x_defs, **kwargs ) @@ -145,18 +147,28 @@ def go_rule(rule, implementation, **kwargs): Returns: The result of invoking the rule. """ - attrs = kwargs.pop("attrs", []) + attrs = kwargs.pop("attrs", dict()) attrs["_go_context_data"] = attr.label(default = "@io_bazel_rules_go//:go_context_data") attrs["_stdlib"] = attr.label(default = "@io_bazel_rules_go//:stdlib") toolchains = kwargs.get("toolchains", []) + ["@io_bazel_rules_go//go:toolchain"] return rule(implementation, attrs = attrs, toolchains = toolchains, **kwargs) -def go_context(ctx): +def go_test_library(target): + if hasattr(target.attr, "embed") and len(target.attr.embed) > 0: + return target.attr.embed[0] + return None + +def go_context(ctx, std = False): + # We don't change anything for the standard library analysis. All Go files + # are available in all instances. Note that this includes the standard + # library sources, which are analyzed by nogo. go_ctx = _go_context(ctx) return struct( go = go_ctx.go, env = go_ctx.env, - runfiles = depset([go_ctx.go] + go_ctx.sdk.tools + go_ctx.stdlib.libs), + nogo_args = [], + stdlib_srcs = go_ctx.sdk.srcs, + runfiles = depset([go_ctx.go] + go_ctx.sdk.srcs + go_ctx.sdk.tools + go_ctx.stdlib.libs), goos = go_ctx.sdk.goos, goarch = go_ctx.sdk.goarch, tags = go_ctx.tags, diff --git a/tools/checkescape/BUILD b/tools/checkescape/BUILD index b8c3ddf44..8956be621 100644 --- a/tools/checkescape/BUILD +++ b/tools/checkescape/BUILD @@ -8,7 +8,6 @@ go_library( nogo = False, visibility = ["//tools/nogo:__subpackages__"], deps = [ - "//tools/nogo/data", "@org_golang_x_tools//go/analysis:go_tool_library", "@org_golang_x_tools//go/analysis/passes/buildssa:go_tool_library", "@org_golang_x_tools//go/ssa:go_tool_library", diff --git a/tools/checkescape/checkescape.go b/tools/checkescape/checkescape.go index f8def4823..d98f5c3a1 100644 --- a/tools/checkescape/checkescape.go +++ b/tools/checkescape/checkescape.go @@ -61,20 +61,20 @@ package checkescape import ( "bufio" "bytes" + "flag" "fmt" "go/ast" "go/token" "go/types" "io" "os" + "os/exec" "path/filepath" - "strconv" "strings" "golang.org/x/tools/go/analysis" "golang.org/x/tools/go/analysis/passes/buildssa" "golang.org/x/tools/go/ssa" - "gvisor.dev/gvisor/tools/nogo/data" ) const ( @@ -91,81 +91,20 @@ const ( exempt = "// escapes" ) -// escapingBuiltins are builtins known to escape. -// -// These are lowered at an earlier stage of compilation to explicit function -// calls, but are not available for recursive analysis. -var escapingBuiltins = []string{ - "append", - "makemap", - "newobject", - "mallocgc", -} - -// Analyzer defines the entrypoint. -var Analyzer = &analysis.Analyzer{ - Name: "checkescape", - Doc: "surfaces recursive escape analysis results", - Run: run, - Requires: []*analysis.Analyzer{buildssa.Analyzer}, - FactTypes: []analysis.Fact{(*packageEscapeFacts)(nil)}, -} - -// packageEscapeFacts is the set of all functions in a package, and whether or -// not they recursively pass escape analysis. -// -// All the type names for receivers are encoded in the full key. The key -// represents the fully qualified package and type name used at link time. -type packageEscapeFacts struct { - Funcs map[string][]Escape -} - -// AFact implements analysis.Fact.AFact. -func (*packageEscapeFacts) AFact() {} - -// CallSite is a single call site. -// -// These can be chained. -type CallSite struct { - LocalPos token.Pos - Resolved LinePosition -} - -// Escape is a single escape instance. -type Escape struct { - Reason EscapeReason - Detail string - Chain []CallSite -} - -// LinePosition is a low-resolution token.Position. -// -// This is used to match against possible exemptions placed in the source. -type LinePosition struct { - Filename string - Line int -} +var ( + // Binary is the binary under analysis. + // + // See Reader, below. + binary = flag.String("binary", "", "binary under analysis") -// String implements fmt.Stringer.String. -func (e *LinePosition) String() string { - return fmt.Sprintf("%s:%d", e.Filename, e.Line) -} + // Reader is the input stream. + // + // This may be set instead of Binary. + Reader io.Reader -// String implements fmt.Stringer.String. -// -// Note that this string will contain new lines. -func (e *Escape) String() string { - var b bytes.Buffer - fmt.Fprintf(&b, "%s", e.Reason.String()) - for i, cs := range e.Chain { - if i == len(e.Chain)-1 { - fmt.Fprintf(&b, "\n @ %s → %s", cs.Resolved.String(), e.Detail) - } else { - fmt.Fprintf(&b, "\n + %s", cs.Resolved.String()) - } - } - return b.String() -} + // Tool is the tool used to dump a binary. + tool = flag.String("dump_tool", "", "tool used to dump a binary") +) // EscapeReason is an escape reason. // @@ -173,12 +112,12 @@ func (e *Escape) String() string { type EscapeReason int const ( - interfaceInvoke EscapeReason = iota - unknownPackage - allocation + allocation EscapeReason = iota builtin + interfaceInvoke dynamicCall stackSplit + unknownPackage reasonCount // Count for below. ) @@ -189,17 +128,17 @@ const ( func (e EscapeReason) String() string { switch e { case interfaceInvoke: - return "interface: function invocation via interface" + return "interface: call to potentially allocating function" case unknownPackage: return "unknown: no package information available" case allocation: - return "heap: call to runtime heap allocation" + return "heap: explicit allocation" case builtin: - return "builtin: call to runtime builtin" + return "builtin: call to potentially allocating builtin" case dynamicCall: - return "dynamic: call via dynamic function" + return "dynamic: call to potentially allocating function" case stackSplit: - return "stack: stack split on function entry" + return "stack: possible split on function entry" default: panic(fmt.Sprintf("unknown reason: %d", e)) } @@ -228,47 +167,241 @@ var escapeTypes = func() map[string]EscapeReason { return result }() -// EscapeCount counts escapes. +// escapingBuiltins are builtins known to escape. // -// It is used to avoid accumulating too many escapes for the same reason, for -// the same function. We limit each class to 3 instances (arbitrarily). -type EscapeCount struct { - byReason [reasonCount]uint32 +// These are lowered at an earlier stage of compilation to explicit function +// calls, but are not available for recursive analysis. +var escapingBuiltins = []string{ + "append", + "makemap", + "newobject", + "mallocgc", } -// maxRecordsPerReason is the number of explicit records. +// packageEscapeFacts is the set of all functions in a package, and whether or +// not they recursively pass escape analysis. // -// See EscapeCount (and usage), and Record implementation. -const maxRecordsPerReason = 5 - -// Record records the reason or returns false if it should not be added. -func (ec *EscapeCount) Record(reason EscapeReason) bool { - ec.byReason[reason]++ - if ec.byReason[reason] > maxRecordsPerReason { - return false +// All the type names for receivers are encoded in the full key. The key +// represents the fully qualified package and type name used at link time. +// +// Note that each Escapes object is a summary. Local findings may be reported +// using more detailed information. +type packageEscapeFacts struct { + Funcs map[string]Escapes +} + +// AFact implements analysis.Fact.AFact. +func (*packageEscapeFacts) AFact() {} + +// Analyzer includes specific results. +var Analyzer = &analysis.Analyzer{ + Name: "checkescape", + Doc: "escape analysis checks based on +checkescape annotations", + Run: runSelectEscapes, + Requires: []*analysis.Analyzer{buildssa.Analyzer}, + FactTypes: []analysis.Fact{(*packageEscapeFacts)(nil)}, +} + +// EscapeAnalyzer includes all local escape results. +var EscapeAnalyzer = &analysis.Analyzer{ + Name: "checkescape", + Doc: "complete local escape analysis results (requires Analyzer facts)", + Run: runAllEscapes, + Requires: []*analysis.Analyzer{buildssa.Analyzer}, +} + +// LinePosition is a low-resolution token.Position. +// +// This is used to match against possible exemptions placed in the source. +type LinePosition struct { + Filename string + Line int +} + +// String implements fmt.Stringer.String. +func (e LinePosition) String() string { + return fmt.Sprintf("%s:%d", e.Filename, e.Line) +} + +// Simplified returns the simplified name. +func (e LinePosition) Simplified() string { + return fmt.Sprintf("%s:%d", filepath.Base(e.Filename), e.Line) +} + +// CallSite is a single call site. +// +// These can be chained. +type CallSite struct { + LocalPos token.Pos + Resolved LinePosition +} + +// IsValid indicates whether the CallSite is valid or not. +func (cs *CallSite) IsValid() bool { + return cs.LocalPos.IsValid() +} + +// Escapes is a collection of escapes. +// +// We record at most one escape for each reason, but record the number of +// escapes that were omitted. +// +// This object should be used to summarize all escapes for a single line (local +// analysis) or a single function (package facts). +// +// All fields are exported for gob. +type Escapes struct { + CallSites [reasonCount][]CallSite + Details [reasonCount]string + Omitted [reasonCount]int +} + +// add is called by Add and Merge. +func (es *Escapes) add(r EscapeReason, detail string, omitted int, callSites ...CallSite) { + if es.CallSites[r] != nil { + // We will either be replacing the current escape or dropping + // the added one. Either way, we increment omitted by the + // appropriate amount. + es.Omitted[r]++ + // If the callSites in the other is only a single element, then + // we will universally favor this. This provides the cleanest + // set of escapes to summarize, and more importantly: if there + if len(es.CallSites) == 1 || len(callSites) != 1 { + return + } + } + es.Details[r] = detail + es.CallSites[r] = callSites + es.Omitted[r] += omitted +} + +// Add adds a single escape. +func (es *Escapes) Add(r EscapeReason, detail string, callSites ...CallSite) { + es.add(r, detail, 0, callSites...) +} + +// IsEmpty returns true iff this Escapes is empty. +func (es *Escapes) IsEmpty() bool { + for _, cs := range es.CallSites { + if cs != nil { + return false + } } return true } +// Filter filters out all escapes except those matches the given reasons. +// +// If local is set, then non-local escapes will also be filtered. +func (es *Escapes) Filter(reasons []EscapeReason, local bool) { +FilterReasons: + for r := EscapeReason(0); r < reasonCount; r++ { + for i := 0; i < len(reasons); i++ { + if r == reasons[i] { + continue FilterReasons + } + } + // Zap this reason. + es.CallSites[r] = nil + es.Details[r] = "" + es.Omitted[r] = 0 + } + if !local { + return + } + for r := EscapeReason(0); r < reasonCount; r++ { + // Is does meet our local requirement? + if len(es.CallSites[r]) > 1 { + es.CallSites[r] = nil + es.Details[r] = "" + es.Omitted[r] = 0 + } + } +} + +// MergeWithCall merges these escapes with another. +// +// If callSite is nil, no call is added. +func (es *Escapes) MergeWithCall(other Escapes, callSite CallSite) { + for r := EscapeReason(0); r < reasonCount; r++ { + if other.CallSites[r] != nil { + // Construct our new call chain. + newCallSites := other.CallSites[r] + if callSite.IsValid() { + newCallSites = append([]CallSite{callSite}, newCallSites...) + } + // Add (potentially replacing) the underlying escape. + es.add(r, other.Details[r], other.Omitted[r], newCallSites...) + } + } +} + +// Reportf will call Reportf for each class of escapes. +func (es *Escapes) Reportf(pass *analysis.Pass) { + var b bytes.Buffer // Reused for all escapes. + for r := EscapeReason(0); r < reasonCount; r++ { + if es.CallSites[r] == nil { + continue + } + b.Reset() + fmt.Fprintf(&b, "%s ", r.String()) + if es.Omitted[r] > 0 { + fmt.Fprintf(&b, "(%d omitted) ", es.Omitted[r]) + } + for _, cs := range es.CallSites[r][1:] { + fmt.Fprintf(&b, "→ %s ", cs.Resolved.String()) + } + fmt.Fprintf(&b, "→ %s", es.Details[r]) + pass.Reportf(es.CallSites[r][0].LocalPos, b.String()) + } +} + +// MergeAll merges a sequence of escapes. +func MergeAll(others []Escapes) (es Escapes) { + for _, other := range others { + es.MergeWithCall(other, CallSite{}) + } + return +} + // loadObjdump reads the objdump output. // // This records if there is a call any function for every source line. It is // used only to remove false positives for escape analysis. The call will be // elided if escape analysis is able to put the object on the heap exclusively. -func loadObjdump() (map[LinePosition]string, error) { - f, err := os.Open(data.Objdump) +// +// Note that the map uses <basename.go>:<line> because that is all that is +// provided in the objdump format. Since this is all local, it is sufficient. +func loadObjdump() (map[string][]string, error) { + var ( + args []string + stdin io.Reader + ) + if *binary != "" { + args = append(args, *binary) + } else if Reader != nil { + stdin = Reader + } else { + // We have no input stream or binary. + return nil, fmt.Errorf("no binary or reader provided") + } + + // Construct our command. + cmd := exec.Command(*tool, args...) + cmd.Stdin = stdin + cmd.Stderr = os.Stderr + out, err := cmd.StdoutPipe() if err != nil { return nil, err } - defer f.Close() + if err := cmd.Start(); err != nil { + return nil, err + } // Build the map. - m := make(map[LinePosition]string) - r := bufio.NewReader(f) - var ( - lastField string - lastPos LinePosition - ) + m := make(map[string][]string) + r := bufio.NewReader(out) +NextLine: for { line, err := r.ReadString('\n') if err != nil && err != io.EOF { @@ -288,47 +421,73 @@ func loadObjdump() (map[LinePosition]string, error) { if !strings.Contains(fields[3], "CALL") { continue } + site := strings.TrimSpace(fields[0]) + var callStr string // Friendly string. + if len(fields) > 5 { + callStr = strings.Join(fields[5:], " ") + } + if len(callStr) == 0 { + // Just a raw call? is this asm? + callStr = strings.Join(fields[3:], " ") + } // Ignore strings containing duffzero, which is just // used by stack allocations for types that are large // enough to warrant Duff's device. - if strings.Contains(line, "runtime.duffzero") { + if strings.Contains(callStr, "runtime.duffzero") || + strings.Contains(callStr, "runtime.duffcopy") { continue } // Ignore the racefuncenter call, which is used for // race builds. This does not escape. - if strings.Contains(line, "runtime.racefuncenter") { + if strings.Contains(callStr, "runtime.racefuncenter") { continue } - // Calculate the filename and line. Note that per the - // example above, the filename is not a fully qualified - // base, just the basename (what we require). - if fields[0] != lastField { - parts := strings.SplitN(fields[0], ":", 2) - lineNum, err := strconv.ParseInt(parts[1], 10, 64) - if err != nil { - return nil, err - } - lastPos = LinePosition{ - Filename: parts[0], - Line: int(lineNum), - } - lastField = fields[0] + // Ignore the write barriers. + if strings.Contains(callStr, "runtime.gcWriteBarrier") { + continue } - if _, ok := m[lastPos]; ok { - continue // Already marked. + + // Ignore retpolines. + if strings.Contains(callStr, "runtime.retpoline") { + continue + } + + // Ignore stack sanity check (does not split). + if strings.Contains(callStr, "runtime.stackcheck") { + continue } - // Save the actual call for the detail. - m[lastPos] = strings.Join(fields[3:], " ") + // Ignore tls functions. + if strings.Contains(callStr, "runtime.settls") { + continue + } + + // Does this exist already? + existing, ok := m[site] + if !ok { + existing = make([]string, 0, 1) + } + for _, other := range existing { + if callStr == other { + continue NextLine + } + } + existing = append(existing, callStr) + m[site] = existing // Update. } if err == io.EOF { break } } + // Wait for the dump to finish. + if err := cmd.Wait(); err != nil { + return nil, err + } + return m, nil } @@ -337,65 +496,148 @@ type poser interface { Pos() token.Pos } +// runSelectEscapes runs with only select escapes. +func runSelectEscapes(pass *analysis.Pass) (interface{}, error) { + return run(pass, false) +} + +// runAllEscapes runs with all escapes included. +func runAllEscapes(pass *analysis.Pass) (interface{}, error) { + return run(pass, true) +} + +// findReasons extracts reasons from the function. +func findReasons(pass *analysis.Pass, fdecl *ast.FuncDecl) ([]EscapeReason, bool, map[EscapeReason]bool) { + // Is there a comment? + if fdecl.Doc == nil { + return nil, false, nil + } + var ( + reasons []EscapeReason + local bool + testReasons = make(map[EscapeReason]bool) // reason -> local? + ) + // Scan all lines. + found := false + for _, c := range fdecl.Doc.List { + // Does the comment contain a +checkescape line? + if !strings.HasPrefix(c.Text, magic) && !strings.HasPrefix(c.Text, testMagic) { + continue + } + if c.Text == magic { + // Default: hard reasons, local only. + reasons = hardReasons + local = true + } else if strings.HasPrefix(c.Text, magicParams) { + // Extract specific reasons. + types := strings.Split(c.Text[len(magicParams):], ",") + found = true // For below. + for i := 0; i < len(types); i++ { + if types[i] == "local" { + // Limit search to local escapes. + local = true + } else if types[i] == "all" { + // Append all reasons. + reasons = append(reasons, allReasons...) + } else if types[i] == "hard" { + // Append all hard reasons. + reasons = append(reasons, hardReasons...) + } else { + r, ok := escapeTypes[types[i]] + if !ok { + // This is not a valid escape reason. + pass.Reportf(fdecl.Pos(), "unknown reason: %v", types[i]) + continue + } + reasons = append(reasons, r) + } + } + } else if strings.HasPrefix(c.Text, testMagic) { + types := strings.Split(c.Text[len(testMagic):], ",") + local := false + for i := 0; i < len(types); i++ { + if types[i] == "local" { + local = true + } else { + r, ok := escapeTypes[types[i]] + if !ok { + // This is not a valid escape reason. + pass.Reportf(fdecl.Pos(), "unknown reason: %v", types[i]) + continue + } + if v, ok := testReasons[r]; ok && v { + // Already registered as local. + continue + } + testReasons[r] = local + } + } + } + } + if len(reasons) == 0 && found { + // A magic annotation was provided, but no reasons. + pass.Reportf(fdecl.Pos(), "no reasons provided") + } + return reasons, local, testReasons +} + // run performs the analysis. -func run(pass *analysis.Pass) (interface{}, error) { +func run(pass *analysis.Pass, localEscapes bool) (interface{}, error) { calls, err := loadObjdump() if err != nil { return nil, err } - pef := packageEscapeFacts{ - Funcs: make(map[string][]Escape), - } + allEscapes := make(map[string][]Escapes) + mergedEscapes := make(map[string]Escapes) linePosition := func(inst, parent poser) LinePosition { p := pass.Fset.Position(inst.Pos()) if (p.Filename == "" || p.Line == 0) && parent != nil { p = pass.Fset.Position(parent.Pos()) } return LinePosition{ - Filename: filepath.Base(p.Filename), + Filename: p.Filename, Line: p.Line, } } - hasCall := func(inst poser) (string, bool) { - p := linePosition(inst, nil) - s, ok := calls[p] - return s, ok - } callSite := func(inst ssa.Instruction) CallSite { return CallSite{ LocalPos: inst.Pos(), Resolved: linePosition(inst, inst.Parent()), } } - escapes := func(reason EscapeReason, detail string, inst ssa.Instruction, ec *EscapeCount) []Escape { - if !ec.Record(reason) { - return nil // Skip. - } - es := Escape{ - Reason: reason, - Detail: detail, - Chain: []CallSite{callSite(inst)}, + hasCall := func(inst poser) (string, bool) { + p := linePosition(inst, nil) + s, ok := calls[p.Simplified()] + if !ok { + return "", false } - return []Escape{es} + // Join all calls together. + return strings.Join(s, " or "), true } - resolve := func(sub []Escape, inst ssa.Instruction, ec *EscapeCount) (es []Escape) { - for _, e := range sub { - if !ec.Record(e.Reason) { - continue // Skip. + state := pass.ResultOf[buildssa.Analyzer].(*buildssa.SSA) + + // Build the exception list. + exemptions := make(map[LinePosition]string) + for _, f := range pass.Files { + for _, cg := range f.Comments { + for _, c := range cg.List { + p := pass.Fset.Position(c.Slash) + if strings.HasPrefix(strings.ToLower(c.Text), exempt) { + exemptions[LinePosition{ + Filename: p.Filename, + Line: p.Line, + }] = c.Text[len(exempt):] + } } - es = append(es, Escape{ - Reason: e.Reason, - Detail: e.Detail, - Chain: append([]CallSite{callSite(inst)}, e.Chain...), - }) } - return es } - state := pass.ResultOf[buildssa.Analyzer].(*buildssa.SSA) - var loadFunc func(*ssa.Function) []Escape // Used below. - - analyzeInstruction := func(inst ssa.Instruction, ec *EscapeCount) []Escape { + var loadFunc func(*ssa.Function) Escapes // Used below. + analyzeInstruction := func(inst ssa.Instruction) (es Escapes) { + cs := callSite(inst) + if _, ok := exemptions[cs.Resolved]; ok { + return // No escape. + } switch x := inst.(type) { case *ssa.Call: if x.Call.IsInvoke() { @@ -404,19 +646,15 @@ func run(pass *analysis.Pass) (interface{}, error) { // not, since we don't know the underlying // type. call, _ := hasCall(inst) - return escapes(interfaceInvoke, call, inst, ec) + es.Add(interfaceInvoke, call, cs) + return } switch x := x.Call.Value.(type) { case *ssa.Function: if x.Pkg == nil { // Can't resolve the package. - return escapes(unknownPackage, "no package", inst, ec) - } - - // Atomic functions are instrinics. We can - // assume that they don't escape. - if x.Pkg.Pkg.Name() == "atomic" { - return nil + es.Add(unknownPackage, "no package", cs) + return } // Is this a local function? If yes, call the @@ -424,7 +662,8 @@ func run(pass *analysis.Pass) (interface{}, error) { // local escapes are the escapes found in the // local function. if x.Pkg.Pkg == pass.Pkg { - return resolve(loadFunc(x), inst, ec) + es.MergeWithCall(loadFunc(x), cs) + return } // Recursively collect information from @@ -433,22 +672,26 @@ func run(pass *analysis.Pass) (interface{}, error) { if !pass.ImportPackageFact(x.Pkg.Pkg, &imp) { // Unable to import the dependency; we must // declare these as escaping. - return escapes(unknownPackage, "no analysis", inst, ec) + es.Add(unknownPackage, "no analysis", cs) + return } // The escapes of this instruction are the // escapes of the called function directly. - return resolve(imp.Funcs[x.RelString(x.Pkg.Pkg)], inst, ec) + // Note that this may record many escapes. + es.MergeWithCall(imp.Funcs[x.RelString(x.Pkg.Pkg)], cs) + return case *ssa.Builtin: // Ignore elided escapes. if _, has := hasCall(inst); !has { - return nil + return } // Check if the builtin is escaping. for _, name := range escapingBuiltins { if x.Name() == name { - return escapes(builtin, name, inst, ec) + es.Add(builtin, name, cs) + return } } default: @@ -457,82 +700,87 @@ func run(pass *analysis.Pass) (interface{}, error) { // dispatches. We cannot actually look up what // this refers to using static analysis alone. call, _ := hasCall(inst) - return escapes(dynamicCall, call, inst, ec) + es.Add(dynamicCall, call, cs) } case *ssa.Alloc: // Ignore non-heap allocations. if !x.Heap { - return nil + return } // Ignore elided escapes. call, has := hasCall(inst) if !has { - return nil + return } // This is a real heap allocation. - return escapes(allocation, call, inst, ec) + es.Add(allocation, call, cs) case *ssa.MakeMap: - return escapes(builtin, "makemap", inst, ec) + es.Add(builtin, "makemap", cs) case *ssa.MakeSlice: - return escapes(builtin, "makeslice", inst, ec) + es.Add(builtin, "makeslice", cs) case *ssa.MakeClosure: - return escapes(builtin, "makeclosure", inst, ec) + es.Add(builtin, "makeclosure", cs) case *ssa.MakeChan: - return escapes(builtin, "makechan", inst, ec) + es.Add(builtin, "makechan", cs) } - return nil // No escapes. + return } - var analyzeBasicBlock func(*ssa.BasicBlock, *EscapeCount) []Escape // Recursive. - analyzeBasicBlock = func(block *ssa.BasicBlock, ec *EscapeCount) (rval []Escape) { + var analyzeBasicBlock func(*ssa.BasicBlock) []Escapes // Recursive. + analyzeBasicBlock = func(block *ssa.BasicBlock) (rval []Escapes) { for _, inst := range block.Instrs { - rval = append(rval, analyzeInstruction(inst, ec)...) + if es := analyzeInstruction(inst); !es.IsEmpty() { + rval = append(rval, es) + } } - return rval // N.B. may be empty. + return } - loadFunc = func(fn *ssa.Function) []Escape { + loadFunc = func(fn *ssa.Function) Escapes { // Is this already available? name := fn.RelString(pass.Pkg) - if es, ok := pef.Funcs[name]; ok { + if es, ok := mergedEscapes[name]; ok { return es } // In the case of a true cycle, we assume that the current - // function itself has no escapes until the rest of the - // analysis is complete. This will trip the above in the case - // of a cycle of any kind. - pef.Funcs[name] = nil + // function itself has no escapes. + // + // When evaluating the function again, the proper escapes will + // be filled in here. + allEscapes[name] = nil + mergedEscapes[name] = Escapes{} // Perform the basic analysis. - var ( - es []Escape - ec EscapeCount - ) + var es []Escapes if fn.Recover != nil { - es = append(es, analyzeBasicBlock(fn.Recover, &ec)...) + es = append(es, analyzeBasicBlock(fn.Recover)...) } for _, block := range fn.Blocks { - es = append(es, analyzeBasicBlock(block, &ec)...) + es = append(es, analyzeBasicBlock(block)...) } // Check for a stack split. if call, has := hasCall(fn); has { - es = append(es, Escape{ - Reason: stackSplit, - Detail: call, - Chain: []CallSite{CallSite{ - LocalPos: fn.Pos(), - Resolved: linePosition(fn, fn.Parent()), - }}, + var ss Escapes + ss.Add(stackSplit, call, CallSite{ + LocalPos: fn.Pos(), + Resolved: linePosition(fn, fn.Parent()), }) + es = append(es, ss) } // Save the result and return. - pef.Funcs[name] = es - return es + // + // Note that we merge the result when saving to the facts. It + // doesn't really matter the specific escapes, as long as we + // have recorded all the appropriate classes of escapes. + summary := MergeAll(es) + allEscapes[name] = es + mergedEscapes[name] = summary + return summary } // Complete all local functions. @@ -540,173 +788,76 @@ func run(pass *analysis.Pass) (interface{}, error) { loadFunc(fn) } - // Build the exception list. - exemptions := make(map[LinePosition]string) - for _, f := range pass.Files { - for _, cg := range f.Comments { - for _, c := range cg.List { - p := pass.Fset.Position(c.Slash) - if strings.HasPrefix(strings.ToLower(c.Text), exempt) { - exemptions[LinePosition{ - Filename: filepath.Base(p.Filename), - Line: p.Line, - }] = c.Text[len(exempt):] - } - } - } - } - - // Delete everything matching the excemtions. - // - // This has the implication that exceptions are applied recursively, - // since this now modified set is what will be saved. - for name, escapes := range pef.Funcs { - var newEscapes []Escape - for _, escape := range escapes { - isExempt := false - for line, _ := range exemptions { - // Note that an exemption applies if it is - // marked as an exemption anywhere in the call - // chain. It need not be marked as escapes in - // the function itself, nor in the top-level - // caller. - for _, callSite := range escape.Chain { - if callSite.Resolved == line { - isExempt = true - break - } - } - if isExempt { - break - } - } - if !isExempt { - // Record this escape; not an exception. - newEscapes = append(newEscapes, escape) - } - } - pef.Funcs[name] = newEscapes // Update. + if !localEscapes { + // Export all findings for future packages. We only do this in + // non-local escapes mode, and expect to run this analysis + // after the SelectAnalysis. + pass.ExportPackageFact(&packageEscapeFacts{ + Funcs: mergedEscapes, + }) } - // Export all findings for future packages. - pass.ExportPackageFact(&pef) - // Scan all functions for violations. for _, f := range pass.Files { // Scan all declarations. for _, decl := range f.Decls { - fdecl, ok := decl.(*ast.FuncDecl) // Function declaration? + fdecl, ok := decl.(*ast.FuncDecl) if !ok { continue } - // Is there a comment? - if fdecl.Doc == nil { - continue - } var ( reasons []EscapeReason - found bool local bool - testReasons = make(map[EscapeReason]bool) // reason -> local? + testReasons map[EscapeReason]bool ) - // Does the comment contain a +checkescape line? - for _, c := range fdecl.Doc.List { - if !strings.HasPrefix(c.Text, magic) && !strings.HasPrefix(c.Text, testMagic) { - continue - } - if c.Text == magic { - // Default: hard reasons, local only. - reasons = hardReasons - local = true - } else if strings.HasPrefix(c.Text, magicParams) { - // Extract specific reasons. - types := strings.Split(c.Text[len(magicParams):], ",") - found = true // For below. - for i := 0; i < len(types); i++ { - if types[i] == "local" { - // Limit search to local escapes. - local = true - } else if types[i] == "all" { - // Append all reasons. - reasons = append(reasons, allReasons...) - } else if types[i] == "hard" { - // Append all hard reasons. - reasons = append(reasons, hardReasons...) - } else { - r, ok := escapeTypes[types[i]] - if !ok { - // This is not a valid escape reason. - pass.Reportf(fdecl.Pos(), "unknown reason: %v", types[i]) - continue - } - reasons = append(reasons, r) - } - } - } else if strings.HasPrefix(c.Text, testMagic) { - types := strings.Split(c.Text[len(testMagic):], ",") - local := false - for i := 0; i < len(types); i++ { - if types[i] == "local" { - local = true - } else { - r, ok := escapeTypes[types[i]] - if !ok { - // This is not a valid escape reason. - pass.Reportf(fdecl.Pos(), "unknown reason: %v", types[i]) - continue - } - if v, ok := testReasons[r]; ok && v { - // Already registered as local. - continue - } - testReasons[r] = local - } - } - } - } - if len(reasons) == 0 && found { - // A magic annotation was provided, but no reasons. - pass.Reportf(fdecl.Pos(), "no reasons provided") - continue + if localEscapes { + // Find all hard escapes. + reasons = hardReasons + } else { + // Find all declared reasons. + reasons, local, testReasons = findReasons(pass, fdecl) } // Scan for matches. fn := pass.TypesInfo.Defs[fdecl.Name].(*types.Func) - name := state.Pkg.Prog.FuncValue(fn).RelString(pass.Pkg) - es, ok := pef.Funcs[name] - if !ok { + fv := state.Pkg.Prog.FuncValue(fn) + if fv == nil { + continue + } + name := fv.RelString(pass.Pkg) + all, allOk := allEscapes[name] + merged, mergedOk := mergedEscapes[name] + if !allOk || !mergedOk { pass.Reportf(fdecl.Pos(), "internal error: function %s not found.", name) continue } - for _, e := range es { - for _, r := range reasons { - // Is does meet our local requirement? - if local && len(e.Chain) > 1 { - continue - } - // Does this match the reason? Emit - // with a full stack trace that - // explains why this violates our - // constraints. - if e.Reason == r { - pass.Reportf(e.Chain[0].LocalPos, "%s", e.String()) - } - } + + // Filter reasons and report. + // + // For the findings, we use all escapes. + for _, es := range all { + es.Filter(reasons, local) + es.Reportf(pass) } // Scan for test (required) matches. + // + // For tests we need only the merged escapes. testReasonsFound := make(map[EscapeReason]bool) - for _, e := range es { + for r := EscapeReason(0); r < reasonCount; r++ { + if merged.CallSites[r] == nil { + continue + } // Is this local? - local, ok := testReasons[e.Reason] - wantLocal := len(e.Chain) == 1 - testReasonsFound[e.Reason] = wantLocal + wantLocal, ok := testReasons[r] + isLocal := len(merged.CallSites[r]) == 1 + testReasonsFound[r] = isLocal if !ok { continue } - if local == wantLocal { - delete(testReasons, e.Reason) + if isLocal == wantLocal { + delete(testReasons, r) } } for reason, local := range testReasons { @@ -714,10 +865,8 @@ func run(pass *analysis.Pass) (interface{}, error) { pass.Reportf(fdecl.Pos(), fmt.Sprintf("testescapes not found: reason=%s, local=%t", reason, local)) } if len(testReasons) > 0 { - // Dump all reasons found to help in debugging. - for _, e := range es { - pass.Reportf(e.Chain[0].LocalPos, "escape found: %s", e.String()) - } + // Report for debugging. + merged.Reportf(pass) } } } diff --git a/tools/checkescape/test1/test1.go b/tools/checkescape/test1/test1.go index 68d3f72cc..27991649f 100644 --- a/tools/checkescape/test1/test1.go +++ b/tools/checkescape/test1/test1.go @@ -17,7 +17,6 @@ package test1 import ( "fmt" - "reflect" ) // Interface is a generic interface. @@ -163,20 +162,6 @@ func dynamicRec(f func()) { Dynamic(f) } -// +mustescape:local,unknown -//go:noinline -//go:nosplit -func Unknown() { - _ = reflect.TypeOf((*Type)(nil)) // Does not actually escape. -} - -// +mustescape:unknown -//go:noinline -//go:nosplit -func unknownRec() { - Unknown() -} - //go:noinline //go:nosplit func internalFunc() { @@ -190,6 +175,7 @@ func Split() { // +mustescape:stack //go:noinline +//go:nosplit func splitRec() { Split() } diff --git a/tools/checkescape/test2/test2.go b/tools/checkescape/test2/test2.go index 7fce3e3be..067d5a1f4 100644 --- a/tools/checkescape/test2/test2.go +++ b/tools/checkescape/test2/test2.go @@ -81,14 +81,9 @@ func dynamicCrossPkg(f func()) { test1.Dynamic(f) } -// +mustescape:unknown -//go:noinline -func unknownCrossPkg() { - test1.Unknown() -} - // +mustescape:stack //go:noinline +//go:nosplit func splitCrosssPkt() { test1.Split() } diff --git a/tools/defs.bzl b/tools/defs.bzl index e71a26cf4..290d564f2 100644 --- a/tools/defs.bzl +++ b/tools/defs.bzl @@ -27,7 +27,6 @@ gbenchmark = _gbenchmark gazelle = _gazelle go_embed_data = _go_embed_data go_path = _go_path -go_test = _go_test gtest = _gtest grpcpp = _grpcpp loopback = _loopback @@ -45,17 +44,35 @@ vdso_linker_option = _vdso_linker_option default_platform = _default_platform platforms = _platforms -def go_binary(name, **kwargs): +def go_binary(name, nogo = True, pure = False, static = False, x_defs = None, **kwargs): """Wraps the standard go_binary. Args: name: the rule name. + nogo: enable nogo analysis. + pure: build a pure Go (no CGo) binary. + static: build a static binary. + x_defs: additional linker definitions. **kwargs: standard go_binary arguments. """ _go_binary( name = name, + pure = pure, + static = static, + x_defs = x_defs, **kwargs ) + if nogo: + # Note that the nogo rule applies only for go_library and go_test + # targets, therefore we construct a library from the binary sources. + _go_library( + name = name + "_nogo_library", + **kwargs + ) + nogo_test( + name = name + "_nogo", + deps = [":" + name + "_nogo_library"], + ) def calculate_sets(srcs): """Calculates special Go sets for templates. @@ -119,6 +136,7 @@ def go_library(name, srcs, deps = [], imports = [], stateify = True, marshal = F stateify: whether statify is enabled (default: true). marshal: whether marshal is enabled (default: false). marshal_debug: whether the gomarshal tools emits debugging output (default: false). + nogo: enable nogo analysis. **kwargs: standard go_library arguments. """ all_srcs = srcs @@ -202,6 +220,24 @@ def go_library(name, srcs, deps = [], imports = [], stateify = True, marshal = F **kwargs ) +def go_test(name, nogo = True, **kwargs): + """Wraps the standard go_test. + + Args: + name: the rule name. + nogo: enable nogo analysis. + **kwargs: standard go_test arguments. + """ + _go_test( + name = name, + **kwargs + ) + if nogo: + nogo_test( + name = name + "_nogo", + deps = [":" + name], + ) + def proto_library(name, srcs, deps = None, has_services = 0, **kwargs): """Wraps the standard proto_library. diff --git a/tools/go_marshal/README.md b/tools/go_marshal/README.md index 68d759083..75e5c7888 100644 --- a/tools/go_marshal/README.md +++ b/tools/go_marshal/README.md @@ -3,18 +3,19 @@ This package implements the go_marshal utility. # Overview `go_marshal` is a code generation utility similar to `go_stateify` for -automatically generating code to marshal go data structures to memory. +marshalling go data structures to and from memory. `go_marshal` attempts to improve on `binary.Write` and the sentry's -`binary.Marshal` by moving the go runtime reflection necessary to marshal a -struct to compile-time. +`binary.Marshal` by moving the expensive use of reflection from runtime to +compile-time. `go_marshal` automatically generates implementations for `marshal.Marshallable` -and `safemem.{Reader,Writer}`. Data structures that require custom serialization -will have manual implementations for these interfaces. +interface. Data structures that require custom serialization can be accomodated +through a manual implementation this interface. Data structures can be flagged for code generation by adding a struct-level -comment `// +marshal`. +comment `// +marshal`. For additional details and options, see the documentation +for the `marshal.Marshallable` interface. # Usage @@ -74,7 +75,7 @@ intended for ABI structs, which have these additional restrictions: dependent native pointer size. - Fields must either be a primitive integer type (`byte`, - `[u]int{8,16,32,64}`), or of a type that implements abi.Marshallable. + `[u]int{8,16,32,64}`), or of a type that implements `marshal.Marshallable`. - `int` and `uint` fields are not allowed. Use an explicitly-sized numeric type. diff --git a/tools/go_marshal/defs.bzl b/tools/go_marshal/defs.bzl index 323e33882..ba98f3599 100644 --- a/tools/go_marshal/defs.bzl +++ b/tools/go_marshal/defs.bzl @@ -56,7 +56,7 @@ marshal_deps = [ "//pkg/gohacks", "//pkg/safecopy", "//pkg/usermem", - "//tools/go_marshal/marshal", + "//pkg/marshal", ] # marshal_test_deps are required by test targets. diff --git a/tools/go_marshal/gomarshal/generator.go b/tools/go_marshal/gomarshal/generator.go index 19bcd4e6a..72ed6d109 100644 --- a/tools/go_marshal/gomarshal/generator.go +++ b/tools/go_marshal/gomarshal/generator.go @@ -38,8 +38,8 @@ import ( // All recievers are single letters, so we don't allow import aliases to be a // single letter. var badIdents = []string{ - "addr", "blk", "buf", "dst", "dsts", "count", "err", "hdr", "idx", "inner", - "length", "limit", "ptr", "size", "src", "srcs", "task", "val", + "addr", "blk", "buf", "cc", "dst", "dsts", "count", "err", "hdr", "idx", + "inner", "length", "limit", "ptr", "size", "src", "srcs", "val", // All single-letter identifiers. } @@ -107,7 +107,7 @@ func NewGenerator(srcs []string, out, outTest, pkg string, imports []string) (*G g.imports.add("gvisor.dev/gvisor/pkg/gohacks") g.imports.add("gvisor.dev/gvisor/pkg/safecopy") g.imports.add("gvisor.dev/gvisor/pkg/usermem") - g.imports.add("gvisor.dev/gvisor/tools/go_marshal/marshal") + g.imports.add("gvisor.dev/gvisor/pkg/marshal") return &g, nil } diff --git a/tools/go_marshal/gomarshal/generator_interfaces.go b/tools/go_marshal/gomarshal/generator_interfaces.go index e3c3dac63..cf76b5241 100644 --- a/tools/go_marshal/gomarshal/generator_interfaces.go +++ b/tools/go_marshal/gomarshal/generator_interfaces.go @@ -224,7 +224,7 @@ func (g *interfaceGenerator) emitNoEscapeSliceDataPointer(srcPtr, dstVar string) func (g *interfaceGenerator) emitKeepAlive(ptrVar string) { g.emit("// Since we bypassed the compiler's escape analysis, indicate that %s\n", ptrVar) g.emit("// must live until the use above.\n") - g.emit("runtime.KeepAlive(%s)\n", ptrVar) + g.emit("runtime.KeepAlive(%s) // escapes: replaced by intrinsic.\n", ptrVar) } func (g *interfaceGenerator) expandBinaryExpr(b *strings.Builder, e *ast.BinaryExpr) { diff --git a/tools/go_marshal/gomarshal/generator_interfaces_array_newtype.go b/tools/go_marshal/gomarshal/generator_interfaces_array_newtype.go index 72ef03a22..7525b52da 100644 --- a/tools/go_marshal/gomarshal/generator_interfaces_array_newtype.go +++ b/tools/go_marshal/gomarshal/generator_interfaces_array_newtype.go @@ -102,11 +102,11 @@ func (g *interfaceGenerator) emitMarshallableForArrayNewtype(n *ast.Ident, a *as g.emit("// CopyOutN implements marshal.Marshallable.CopyOutN.\n") g.emit("//go:nosplit\n") - g.emit("func (%s *%s) CopyOutN(task marshal.Task, addr usermem.Addr, limit int) (int, error) {\n", g.r, g.typeName()) + g.emit("func (%s *%s) CopyOutN(cc marshal.CopyContext, addr usermem.Addr, limit int) (int, error) {\n", g.r, g.typeName()) g.inIndent(func() { g.emitCastToByteSlice(g.r, "buf", fmt.Sprintf("%s.SizeBytes()", g.r)) - g.emit("length, err := task.CopyOutBytes(addr, buf[:limit]) // escapes: okay.\n") + g.emit("length, err := cc.CopyOutBytes(addr, buf[:limit]) // escapes: okay.\n") g.emitKeepAlive(g.r) g.emit("return length, err\n") }) @@ -114,19 +114,19 @@ func (g *interfaceGenerator) emitMarshallableForArrayNewtype(n *ast.Ident, a *as g.emit("// CopyOut implements marshal.Marshallable.CopyOut.\n") g.emit("//go:nosplit\n") - g.emit("func (%s *%s) CopyOut(task marshal.Task, addr usermem.Addr) (int, error) {\n", g.r, g.typeName()) + g.emit("func (%s *%s) CopyOut(cc marshal.CopyContext, addr usermem.Addr) (int, error) {\n", g.r, g.typeName()) g.inIndent(func() { - g.emit("return %s.CopyOutN(task, addr, %s.SizeBytes())\n", g.r, g.r) + g.emit("return %s.CopyOutN(cc, addr, %s.SizeBytes())\n", g.r, g.r) }) g.emit("}\n\n") g.emit("// CopyIn implements marshal.Marshallable.CopyIn.\n") g.emit("//go:nosplit\n") - g.emit("func (%s *%s) CopyIn(task marshal.Task, addr usermem.Addr) (int, error) {\n", g.r, g.typeName()) + g.emit("func (%s *%s) CopyIn(cc marshal.CopyContext, addr usermem.Addr) (int, error) {\n", g.r, g.typeName()) g.inIndent(func() { g.emitCastToByteSlice(g.r, "buf", fmt.Sprintf("%s.SizeBytes()", g.r)) - g.emit("length, err := task.CopyInBytes(addr, buf) // escapes: okay.\n") + g.emit("length, err := cc.CopyInBytes(addr, buf) // escapes: okay.\n") g.emitKeepAlive(g.r) g.emit("return length, err\n") }) diff --git a/tools/go_marshal/gomarshal/generator_interfaces_primitive_newtype.go b/tools/go_marshal/gomarshal/generator_interfaces_primitive_newtype.go index 39f654ea8..7edaf666c 100644 --- a/tools/go_marshal/gomarshal/generator_interfaces_primitive_newtype.go +++ b/tools/go_marshal/gomarshal/generator_interfaces_primitive_newtype.go @@ -154,11 +154,11 @@ func (g *interfaceGenerator) emitMarshallableForPrimitiveNewtype(nt *ast.Ident) g.emit("// CopyOutN implements marshal.Marshallable.CopyOutN.\n") g.emit("//go:nosplit\n") - g.emit("func (%s *%s) CopyOutN(task marshal.Task, addr usermem.Addr, limit int) (int, error) {\n", g.r, g.typeName()) + g.emit("func (%s *%s) CopyOutN(cc marshal.CopyContext, addr usermem.Addr, limit int) (int, error) {\n", g.r, g.typeName()) g.inIndent(func() { g.emitCastToByteSlice(g.r, "buf", fmt.Sprintf("%s.SizeBytes()", g.r)) - g.emit("length, err := task.CopyOutBytes(addr, buf[:limit]) // escapes: okay.\n") + g.emit("length, err := cc.CopyOutBytes(addr, buf[:limit]) // escapes: okay.\n") g.emitKeepAlive(g.r) g.emit("return length, err\n") }) @@ -166,19 +166,19 @@ func (g *interfaceGenerator) emitMarshallableForPrimitiveNewtype(nt *ast.Ident) g.emit("// CopyOut implements marshal.Marshallable.CopyOut.\n") g.emit("//go:nosplit\n") - g.emit("func (%s *%s) CopyOut(task marshal.Task, addr usermem.Addr) (int, error) {\n", g.r, g.typeName()) + g.emit("func (%s *%s) CopyOut(cc marshal.CopyContext, addr usermem.Addr) (int, error) {\n", g.r, g.typeName()) g.inIndent(func() { - g.emit("return %s.CopyOutN(task, addr, %s.SizeBytes())\n", g.r, g.r) + g.emit("return %s.CopyOutN(cc, addr, %s.SizeBytes())\n", g.r, g.r) }) g.emit("}\n\n") g.emit("// CopyIn implements marshal.Marshallable.CopyIn.\n") g.emit("//go:nosplit\n") - g.emit("func (%s *%s) CopyIn(task marshal.Task, addr usermem.Addr) (int, error) {\n", g.r, g.typeName()) + g.emit("func (%s *%s) CopyIn(cc marshal.CopyContext, addr usermem.Addr) (int, error) {\n", g.r, g.typeName()) g.inIndent(func() { g.emitCastToByteSlice(g.r, "buf", fmt.Sprintf("%s.SizeBytes()", g.r)) - g.emit("length, err := task.CopyInBytes(addr, buf) // escapes: okay.\n") + g.emit("length, err := cc.CopyInBytes(addr, buf) // escapes: okay.\n") g.emitKeepAlive(g.r) g.emit("return length, err\n") }) @@ -211,7 +211,7 @@ func (g *interfaceGenerator) emitMarshallableSliceForPrimitiveNewtype(nt *ast.Id g.emit("// Copy%sIn copies in a slice of %s objects from the task's memory.\n", slice.ident, eltType) g.emit("//go:nosplit\n") - g.emit("func Copy%sIn(task marshal.Task, addr usermem.Addr, dst []%s) (int, error) {\n", slice.ident, eltType) + g.emit("func Copy%sIn(cc marshal.CopyContext, addr usermem.Addr, dst []%s) (int, error) {\n", slice.ident, eltType) g.inIndent(func() { g.emit("count := len(dst)\n") g.emit("if count == 0 {\n") @@ -223,7 +223,7 @@ func (g *interfaceGenerator) emitMarshallableSliceForPrimitiveNewtype(nt *ast.Id g.emitCastSliceToByteSlice("&dst", "buf", "size * count") - g.emit("length, err := task.CopyInBytes(addr, buf) // escapes: okay.\n") + g.emit("length, err := cc.CopyInBytes(addr, buf) // escapes: okay.\n") g.emitKeepAlive("dst") g.emit("return length, err\n") }) @@ -231,7 +231,7 @@ func (g *interfaceGenerator) emitMarshallableSliceForPrimitiveNewtype(nt *ast.Id g.emit("// Copy%sOut copies a slice of %s objects to the task's memory.\n", slice.ident, eltType) g.emit("//go:nosplit\n") - g.emit("func Copy%sOut(task marshal.Task, addr usermem.Addr, src []%s) (int, error) {\n", slice.ident, eltType) + g.emit("func Copy%sOut(cc marshal.CopyContext, addr usermem.Addr, src []%s) (int, error) {\n", slice.ident, eltType) g.inIndent(func() { g.emit("count := len(src)\n") g.emit("if count == 0 {\n") @@ -243,7 +243,7 @@ func (g *interfaceGenerator) emitMarshallableSliceForPrimitiveNewtype(nt *ast.Id g.emitCastSliceToByteSlice("&src", "buf", "size * count") - g.emit("length, err := task.CopyOutBytes(addr, buf) // escapes: okay.\n") + g.emit("length, err := cc.CopyOutBytes(addr, buf) // escapes: okay.\n") g.emitKeepAlive("src") g.emit("return length, err\n") }) diff --git a/tools/go_marshal/gomarshal/generator_interfaces_struct.go b/tools/go_marshal/gomarshal/generator_interfaces_struct.go index 4b9cea08a..d3fc1c1c6 100644 --- a/tools/go_marshal/gomarshal/generator_interfaces_struct.go +++ b/tools/go_marshal/gomarshal/generator_interfaces_struct.go @@ -323,13 +323,13 @@ func (g *interfaceGenerator) emitMarshallableForStruct(st *ast.StructType) { g.emit("//go:nosplit\n") g.recordUsedImport("marshal") g.recordUsedImport("usermem") - g.emit("func (%s *%s) CopyOutN(task marshal.Task, addr usermem.Addr, limit int) (int, error) {\n", g.r, g.typeName()) + g.emit("func (%s *%s) CopyOutN(cc marshal.CopyContext, addr usermem.Addr, limit int) (int, error) {\n", g.r, g.typeName()) g.inIndent(func() { fallback := func() { g.emit("// Type %s doesn't have a packed layout in memory, fall back to MarshalBytes.\n", g.typeName()) - g.emit("buf := task.CopyScratchBuffer(%s.SizeBytes()) // escapes: okay.\n", g.r) + g.emit("buf := cc.CopyScratchBuffer(%s.SizeBytes()) // escapes: okay.\n", g.r) g.emit("%s.MarshalBytes(buf) // escapes: fallback.\n", g.r) - g.emit("return task.CopyOutBytes(addr, buf[:limit]) // escapes: okay.\n") + g.emit("return cc.CopyOutBytes(addr, buf[:limit]) // escapes: okay.\n") } if thisPacked { g.recordUsedImport("reflect") @@ -343,7 +343,7 @@ func (g *interfaceGenerator) emitMarshallableForStruct(st *ast.StructType) { // Fast serialization. g.emitCastToByteSlice(g.r, "buf", fmt.Sprintf("%s.SizeBytes()", g.r)) - g.emit("length, err := task.CopyOutBytes(addr, buf[:limit]) // escapes: okay.\n") + g.emit("length, err := cc.CopyOutBytes(addr, buf[:limit]) // escapes: okay.\n") g.emitKeepAlive(g.r) g.emit("return length, err\n") } else { @@ -356,9 +356,9 @@ func (g *interfaceGenerator) emitMarshallableForStruct(st *ast.StructType) { g.emit("//go:nosplit\n") g.recordUsedImport("marshal") g.recordUsedImport("usermem") - g.emit("func (%s *%s) CopyOut(task marshal.Task, addr usermem.Addr) (int, error) {\n", g.r, g.typeName()) + g.emit("func (%s *%s) CopyOut(cc marshal.CopyContext, addr usermem.Addr) (int, error) {\n", g.r, g.typeName()) g.inIndent(func() { - g.emit("return %s.CopyOutN(task, addr, %s.SizeBytes())\n", g.r, g.r) + g.emit("return %s.CopyOutN(cc, addr, %s.SizeBytes())\n", g.r, g.r) }) g.emit("}\n\n") @@ -366,12 +366,12 @@ func (g *interfaceGenerator) emitMarshallableForStruct(st *ast.StructType) { g.emit("//go:nosplit\n") g.recordUsedImport("marshal") g.recordUsedImport("usermem") - g.emit("func (%s *%s) CopyIn(task marshal.Task, addr usermem.Addr) (int, error) {\n", g.r, g.typeName()) + g.emit("func (%s *%s) CopyIn(cc marshal.CopyContext, addr usermem.Addr) (int, error) {\n", g.r, g.typeName()) g.inIndent(func() { fallback := func() { g.emit("// Type %s doesn't have a packed layout in memory, fall back to UnmarshalBytes.\n", g.typeName()) - g.emit("buf := task.CopyScratchBuffer(%s.SizeBytes()) // escapes: okay.\n", g.r) - g.emit("length, err := task.CopyInBytes(addr, buf) // escapes: okay.\n") + g.emit("buf := cc.CopyScratchBuffer(%s.SizeBytes()) // escapes: okay.\n", g.r) + g.emit("length, err := cc.CopyInBytes(addr, buf) // escapes: okay.\n") g.emit("// Unmarshal unconditionally. If we had a short copy-in, this results in a\n") g.emit("// partially unmarshalled struct.\n") g.emit("%s.UnmarshalBytes(buf) // escapes: fallback.\n", g.r) @@ -389,7 +389,7 @@ func (g *interfaceGenerator) emitMarshallableForStruct(st *ast.StructType) { // Fast deserialization. g.emitCastToByteSlice(g.r, "buf", fmt.Sprintf("%s.SizeBytes()", g.r)) - g.emit("length, err := task.CopyInBytes(addr, buf) // escapes: okay.\n") + g.emit("length, err := cc.CopyInBytes(addr, buf) // escapes: okay.\n") g.emitKeepAlive(g.r) g.emit("return length, err\n") } else { @@ -400,13 +400,13 @@ func (g *interfaceGenerator) emitMarshallableForStruct(st *ast.StructType) { g.emit("// WriteTo implements io.WriterTo.WriteTo.\n") g.recordUsedImport("io") - g.emit("func (%s *%s) WriteTo(w io.Writer) (int64, error) {\n", g.r, g.typeName()) + g.emit("func (%s *%s) WriteTo(writer io.Writer) (int64, error) {\n", g.r, g.typeName()) g.inIndent(func() { fallback := func() { g.emit("// Type %s doesn't have a packed layout in memory, fall back to MarshalBytes.\n", g.typeName()) g.emit("buf := make([]byte, %s.SizeBytes())\n", g.r) g.emit("%s.MarshalBytes(buf)\n", g.r) - g.emit("length, err := w.Write(buf)\n") + g.emit("length, err := writer.Write(buf)\n") g.emit("return int64(length), err\n") } if thisPacked { @@ -421,7 +421,7 @@ func (g *interfaceGenerator) emitMarshallableForStruct(st *ast.StructType) { // Fast serialization. g.emitCastToByteSlice(g.r, "buf", fmt.Sprintf("%s.SizeBytes()", g.r)) - g.emit("length, err := w.Write(buf)\n") + g.emit("length, err := writer.Write(buf)\n") g.emitKeepAlive(g.r) g.emit("return int64(length), err\n") } else { @@ -442,7 +442,7 @@ func (g *interfaceGenerator) emitMarshallableSliceForStruct(st *ast.StructType, g.recordUsedImport("usermem") g.emit("// Copy%sIn copies in a slice of %s objects from the task's memory.\n", slice.ident, g.typeName()) - g.emit("func Copy%sIn(task marshal.Task, addr usermem.Addr, dst []%s) (int, error) {\n", slice.ident, g.typeName()) + g.emit("func Copy%sIn(cc marshal.CopyContext, addr usermem.Addr, dst []%s) (int, error) {\n", slice.ident, g.typeName()) g.inIndent(func() { g.emit("count := len(dst)\n") g.emit("if count == 0 {\n") @@ -454,8 +454,8 @@ func (g *interfaceGenerator) emitMarshallableSliceForStruct(st *ast.StructType, fallback := func() { g.emit("// Type %s doesn't have a packed layout in memory, fall back to UnmarshalBytes.\n", g.typeName()) - g.emit("buf := task.CopyScratchBuffer(size * count)\n") - g.emit("length, err := task.CopyInBytes(addr, buf)\n\n") + g.emit("buf := cc.CopyScratchBuffer(size * count)\n") + g.emit("length, err := cc.CopyInBytes(addr, buf)\n\n") g.emit("// Unmarshal as much as possible, even on error. First handle full objects.\n") g.emit("limit := length/size\n") @@ -489,7 +489,7 @@ func (g *interfaceGenerator) emitMarshallableSliceForStruct(st *ast.StructType, // Fast deserialization. g.emitCastSliceToByteSlice("&dst", "buf", "size * count") - g.emit("length, err := task.CopyInBytes(addr, buf)\n") + g.emit("length, err := cc.CopyInBytes(addr, buf)\n") g.emitKeepAlive("dst") g.emit("return length, err\n") } else { @@ -499,7 +499,7 @@ func (g *interfaceGenerator) emitMarshallableSliceForStruct(st *ast.StructType, g.emit("}\n\n") g.emit("// Copy%sOut copies a slice of %s objects to the task's memory.\n", slice.ident, g.typeName()) - g.emit("func Copy%sOut(task marshal.Task, addr usermem.Addr, src []%s) (int, error) {\n", slice.ident, g.typeName()) + g.emit("func Copy%sOut(cc marshal.CopyContext, addr usermem.Addr, src []%s) (int, error) {\n", slice.ident, g.typeName()) g.inIndent(func() { g.emit("count := len(src)\n") g.emit("if count == 0 {\n") @@ -511,13 +511,13 @@ func (g *interfaceGenerator) emitMarshallableSliceForStruct(st *ast.StructType, fallback := func() { g.emit("// Type %s doesn't have a packed layout in memory, fall back to MarshalBytes.\n", g.typeName()) - g.emit("buf := task.CopyScratchBuffer(size * count)\n") + g.emit("buf := cc.CopyScratchBuffer(size * count)\n") g.emit("for idx := 0; idx < count; idx++ {\n") g.inIndent(func() { g.emit("src[idx].MarshalBytes(buf[size*idx:size*(idx+1)])\n") }) g.emit("}\n") - g.emit("return task.CopyOutBytes(addr, buf)\n") + g.emit("return cc.CopyOutBytes(addr, buf)\n") } if thisPacked { g.recordUsedImport("reflect") @@ -531,7 +531,7 @@ func (g *interfaceGenerator) emitMarshallableSliceForStruct(st *ast.StructType, // Fast serialization. g.emitCastSliceToByteSlice("&src", "buf", "size * count") - g.emit("length, err := task.CopyOutBytes(addr, buf)\n") + g.emit("length, err := cc.CopyOutBytes(addr, buf)\n") g.emitKeepAlive("src") g.emit("return length, err\n") } else { diff --git a/tools/go_marshal/test/BUILD b/tools/go_marshal/test/BUILD index 3d989823a..4b27773c2 100644 --- a/tools/go_marshal/test/BUILD +++ b/tools/go_marshal/test/BUILD @@ -35,10 +35,10 @@ go_test( srcs = ["marshal_test.go"], deps = [ ":test", + "//pkg/marshal", "//pkg/syserror", "//pkg/usermem", "//tools/go_marshal/analysis", - "//tools/go_marshal/marshal", "@com_github_google_go_cmp//cmp:go_default_library", ], ) diff --git a/tools/go_marshal/test/escape/BUILD b/tools/go_marshal/test/escape/BUILD index f74e6ffae..2981ef196 100644 --- a/tools/go_marshal/test/escape/BUILD +++ b/tools/go_marshal/test/escape/BUILD @@ -7,8 +7,8 @@ go_library( testonly = 1, srcs = ["escape.go"], deps = [ + "//pkg/marshal", "//pkg/usermem", - "//tools/go_marshal/marshal", "//tools/go_marshal/test", ], ) diff --git a/tools/go_marshal/test/escape/escape.go b/tools/go_marshal/test/escape/escape.go index 6a46ddbf8..7f62b0a2b 100644 --- a/tools/go_marshal/test/escape/escape.go +++ b/tools/go_marshal/test/escape/escape.go @@ -15,34 +15,34 @@ package escape import ( + "gvisor.dev/gvisor/pkg/marshal" "gvisor.dev/gvisor/pkg/usermem" - "gvisor.dev/gvisor/tools/go_marshal/marshal" "gvisor.dev/gvisor/tools/go_marshal/test" ) -// dummyTask implements marshal.Task. -type dummyTask struct { +// dummyCopyContext implements marshal.CopyContext. +type dummyCopyContext struct { } -func (*dummyTask) CopyScratchBuffer(size int) []byte { +func (*dummyCopyContext) CopyScratchBuffer(size int) []byte { return make([]byte, size) } -func (*dummyTask) CopyOutBytes(addr usermem.Addr, b []byte) (int, error) { +func (*dummyCopyContext) CopyOutBytes(addr usermem.Addr, b []byte) (int, error) { return len(b), nil } -func (*dummyTask) CopyInBytes(addr usermem.Addr, b []byte) (int, error) { +func (*dummyCopyContext) CopyInBytes(addr usermem.Addr, b []byte) (int, error) { return len(b), nil } -func (t *dummyTask) MarshalBytes(addr usermem.Addr, marshallable marshal.Marshallable) { +func (t *dummyCopyContext) MarshalBytes(addr usermem.Addr, marshallable marshal.Marshallable) { buf := t.CopyScratchBuffer(marshallable.SizeBytes()) marshallable.MarshalBytes(buf) t.CopyOutBytes(addr, buf) } -func (t *dummyTask) MarshalUnsafe(addr usermem.Addr, marshallable marshal.Marshallable) { +func (t *dummyCopyContext) MarshalUnsafe(addr usermem.Addr, marshallable marshal.Marshallable) { buf := t.CopyScratchBuffer(marshallable.SizeBytes()) marshallable.MarshalUnsafe(buf) t.CopyOutBytes(addr, buf) @@ -50,21 +50,22 @@ func (t *dummyTask) MarshalUnsafe(addr usermem.Addr, marshallable marshal.Marsha // +checkescape:all //go:nosplit -func doCopyIn(t *dummyTask) { +func doCopyIn(t *dummyCopyContext) { var stat test.Stat stat.CopyIn(t, usermem.Addr(0xf000ba12)) } // +checkescape:all //go:nosplit -func doCopyOut(t *dummyTask) { +func doCopyOut(t *dummyCopyContext) { var stat test.Stat stat.CopyOut(t, usermem.Addr(0xf000ba12)) } // +mustescape:builtin // +mustescape:stack -func doMarshalBytesDirect(t *dummyTask) { +//go:nosplit +func doMarshalBytesDirect(t *dummyCopyContext) { var stat test.Stat buf := t.CopyScratchBuffer(stat.SizeBytes()) stat.MarshalBytes(buf) @@ -73,7 +74,8 @@ func doMarshalBytesDirect(t *dummyTask) { // +mustescape:builtin // +mustescape:stack -func doMarshalUnsafeDirect(t *dummyTask) { +//go:nosplit +func doMarshalUnsafeDirect(t *dummyCopyContext) { var stat test.Stat buf := t.CopyScratchBuffer(stat.SizeBytes()) stat.MarshalUnsafe(buf) @@ -82,14 +84,16 @@ func doMarshalUnsafeDirect(t *dummyTask) { // +mustescape:local,heap // +mustescape:stack -func doMarshalBytesViaMarshallable(t *dummyTask) { +//go:nosplit +func doMarshalBytesViaMarshallable(t *dummyCopyContext) { var stat test.Stat t.MarshalBytes(usermem.Addr(0xf000ba12), &stat) } // +mustescape:local,heap // +mustescape:stack -func doMarshalUnsafeViaMarshallable(t *dummyTask) { +//go:nosplit +func doMarshalUnsafeViaMarshallable(t *dummyCopyContext) { var stat test.Stat t.MarshalUnsafe(usermem.Addr(0xf000ba12), &stat) } diff --git a/tools/go_marshal/test/marshal_test.go b/tools/go_marshal/test/marshal_test.go index 16829ee45..a00f9a684 100644 --- a/tools/go_marshal/test/marshal_test.go +++ b/tools/go_marshal/test/marshal_test.go @@ -27,22 +27,22 @@ import ( "unsafe" "github.com/google/go-cmp/cmp" + "gvisor.dev/gvisor/pkg/marshal" "gvisor.dev/gvisor/pkg/syserror" "gvisor.dev/gvisor/pkg/usermem" "gvisor.dev/gvisor/tools/go_marshal/analysis" - "gvisor.dev/gvisor/tools/go_marshal/marshal" "gvisor.dev/gvisor/tools/go_marshal/test" ) var simulatedErr error = syserror.EFAULT -// mockTask implements marshal.Task. -type mockTask struct { +// mockCopyContext implements marshal.CopyContext. +type mockCopyContext struct { taskMem usermem.BytesIO } // populate fills the task memory with the contents of val. -func (t *mockTask) populate(val interface{}) { +func (t *mockCopyContext) populate(val interface{}) { var buf bytes.Buffer // Use binary.Write so we aren't testing go-marshal against its own // potentially buggy implementation. @@ -52,7 +52,7 @@ func (t *mockTask) populate(val interface{}) { t.taskMem.Bytes = buf.Bytes() } -func (t *mockTask) setLimit(n int) { +func (t *mockCopyContext) setLimit(n int) { if len(t.taskMem.Bytes) < n { grown := make([]byte, n) copy(grown, t.taskMem.Bytes) @@ -62,22 +62,22 @@ func (t *mockTask) setLimit(n int) { t.taskMem.Bytes = t.taskMem.Bytes[:n] } -// CopyScratchBuffer implements marshal.Task.CopyScratchBuffer. -func (t *mockTask) CopyScratchBuffer(size int) []byte { +// CopyScratchBuffer implements marshal.CopyContext.CopyScratchBuffer. +func (t *mockCopyContext) CopyScratchBuffer(size int) []byte { return make([]byte, size) } -// CopyOutBytes implements marshal.Task.CopyOutBytes. The implementation +// CopyOutBytes implements marshal.CopyContext.CopyOutBytes. The implementation // completely ignores the target address and stores a copy of b in its // internally buffer, overriding any previous contents. -func (t *mockTask) CopyOutBytes(_ usermem.Addr, b []byte) (int, error) { +func (t *mockCopyContext) CopyOutBytes(_ usermem.Addr, b []byte) (int, error) { return t.taskMem.CopyOut(nil, 0, b, usermem.IOOpts{}) } -// CopyInBytes implements marshal.Task.CopyInBytes. The implementation +// CopyInBytes implements marshal.CopyContext.CopyInBytes. The implementation // completely ignores the source address and always fills b from the begining of // its internal buffer. -func (t *mockTask) CopyInBytes(_ usermem.Addr, b []byte) (int, error) { +func (t *mockCopyContext) CopyInBytes(_ usermem.Addr, b []byte) (int, error) { return t.taskMem.CopyIn(nil, 0, b, usermem.IOOpts{}) } @@ -171,11 +171,11 @@ func compareMemory(t *testing.T, expected, actual []byte, n int) { // dst. The task signals an error at limit bytes during copy-in, which should // result in a truncated unmarshalling. func limitedCopyIn(t *testing.T, src, dst marshal.Marshallable, limit int) { - var task mockTask - task.populate(src) - task.setLimit(limit) + var cc mockCopyContext + cc.populate(src) + cc.setLimit(limit) - n, err := dst.CopyIn(&task, usermem.Addr(0)) + n, err := dst.CopyIn(&cc, usermem.Addr(0)) if n != limit { t.Errorf("CopyIn copied unexpected number of bytes, expected %d, got %d", limit, n) } @@ -202,10 +202,10 @@ func limitedCopyIn(t *testing.T, src, dst marshal.Marshallable, limit int) { // limitedCopyOut marshals src to task memory. The task signals an error at // limit bytes during copy-out, which should result in a truncated marshalling. func limitedCopyOut(t *testing.T, src marshal.Marshallable, limit int) { - var task mockTask - task.setLimit(limit) + var cc mockCopyContext + cc.setLimit(limit) - n, err := src.CopyOut(&task, usermem.Addr(0)) + n, err := src.CopyOut(&cc, usermem.Addr(0)) if n != limit { t.Errorf("CopyOut copied unexpected number of bytes, expected %d, got %d", limit, n) } @@ -215,7 +215,7 @@ func limitedCopyOut(t *testing.T, src marshal.Marshallable, limit int) { expectedMem := unsafeMemory(src) defer runtime.KeepAlive(src) - actualMem := task.taskMem.Bytes + actualMem := cc.taskMem.Bytes compareMemory(t, expectedMem, actualMem, n) } @@ -223,10 +223,10 @@ func limitedCopyOut(t *testing.T, src marshal.Marshallable, limit int) { // copyOutN marshals src to task memory, requesting the marshalling to be // limited to limit bytes. func copyOutN(t *testing.T, src marshal.Marshallable, limit int) { - var task mockTask - task.setLimit(limit) + var cc mockCopyContext + cc.setLimit(limit) - n, err := src.CopyOutN(&task, usermem.Addr(0), limit) + n, err := src.CopyOutN(&cc, usermem.Addr(0), limit) if err != nil { t.Errorf("CopyOut returned unexpected error: %v", err) } @@ -236,7 +236,7 @@ func copyOutN(t *testing.T, src marshal.Marshallable, limit int) { expectedMem := unsafeMemory(src) defer runtime.KeepAlive(src) - actualMem := task.taskMem.Bytes + actualMem := cc.taskMem.Bytes t.Logf("Expected: %v + %v\n", expectedMem[:n], expectedMem[n:]) t.Logf("Actual : %v + %v\n", actualMem[:n], actualMem[n:]) @@ -303,20 +303,20 @@ func TestLimitedMarshalling(t *testing.T) { func TestLimitedSliceMarshalling(t *testing.T) { types := []struct { arrayPtrType reflect.Type - copySliceIn func(task marshal.Task, addr usermem.Addr, dstSlice interface{}) (int, error) - copySliceOut func(task marshal.Task, addr usermem.Addr, srcSlice interface{}) (int, error) + copySliceIn func(cc marshal.CopyContext, addr usermem.Addr, dstSlice interface{}) (int, error) + copySliceOut func(cc marshal.CopyContext, addr usermem.Addr, srcSlice interface{}) (int, error) unsafeMemory func(arrPtr interface{}) []byte }{ // Packed types. { reflect.TypeOf((*[20]test.Stat)(nil)), - func(task marshal.Task, addr usermem.Addr, dst interface{}) (int, error) { + func(cc marshal.CopyContext, addr usermem.Addr, dst interface{}) (int, error) { slice := dst.(*[20]test.Stat)[:] - return test.CopyStatSliceIn(task, addr, slice) + return test.CopyStatSliceIn(cc, addr, slice) }, - func(task marshal.Task, addr usermem.Addr, src interface{}) (int, error) { + func(cc marshal.CopyContext, addr usermem.Addr, src interface{}) (int, error) { slice := src.(*[20]test.Stat)[:] - return test.CopyStatSliceOut(task, addr, slice) + return test.CopyStatSliceOut(cc, addr, slice) }, func(a interface{}) []byte { slice := a.(*[20]test.Stat)[:] @@ -325,13 +325,13 @@ func TestLimitedSliceMarshalling(t *testing.T) { }, { reflect.TypeOf((*[1]test.Stat)(nil)), - func(task marshal.Task, addr usermem.Addr, dst interface{}) (int, error) { + func(cc marshal.CopyContext, addr usermem.Addr, dst interface{}) (int, error) { slice := dst.(*[1]test.Stat)[:] - return test.CopyStatSliceIn(task, addr, slice) + return test.CopyStatSliceIn(cc, addr, slice) }, - func(task marshal.Task, addr usermem.Addr, src interface{}) (int, error) { + func(cc marshal.CopyContext, addr usermem.Addr, src interface{}) (int, error) { slice := src.(*[1]test.Stat)[:] - return test.CopyStatSliceOut(task, addr, slice) + return test.CopyStatSliceOut(cc, addr, slice) }, func(a interface{}) []byte { slice := a.(*[1]test.Stat)[:] @@ -340,13 +340,13 @@ func TestLimitedSliceMarshalling(t *testing.T) { }, { reflect.TypeOf((*[5]test.SignalSetAlias)(nil)), - func(task marshal.Task, addr usermem.Addr, dst interface{}) (int, error) { + func(cc marshal.CopyContext, addr usermem.Addr, dst interface{}) (int, error) { slice := dst.(*[5]test.SignalSetAlias)[:] - return test.CopySignalSetAliasSliceIn(task, addr, slice) + return test.CopySignalSetAliasSliceIn(cc, addr, slice) }, - func(task marshal.Task, addr usermem.Addr, src interface{}) (int, error) { + func(cc marshal.CopyContext, addr usermem.Addr, src interface{}) (int, error) { slice := src.(*[5]test.SignalSetAlias)[:] - return test.CopySignalSetAliasSliceOut(task, addr, slice) + return test.CopySignalSetAliasSliceOut(cc, addr, slice) }, func(a interface{}) []byte { slice := a.(*[5]test.SignalSetAlias)[:] @@ -356,13 +356,13 @@ func TestLimitedSliceMarshalling(t *testing.T) { // Non-packed types. { reflect.TypeOf((*[20]test.Type1)(nil)), - func(task marshal.Task, addr usermem.Addr, dst interface{}) (int, error) { + func(cc marshal.CopyContext, addr usermem.Addr, dst interface{}) (int, error) { slice := dst.(*[20]test.Type1)[:] - return test.CopyType1SliceIn(task, addr, slice) + return test.CopyType1SliceIn(cc, addr, slice) }, - func(task marshal.Task, addr usermem.Addr, src interface{}) (int, error) { + func(cc marshal.CopyContext, addr usermem.Addr, src interface{}) (int, error) { slice := src.(*[20]test.Type1)[:] - return test.CopyType1SliceOut(task, addr, slice) + return test.CopyType1SliceOut(cc, addr, slice) }, func(a interface{}) []byte { slice := a.(*[20]test.Type1)[:] @@ -371,13 +371,13 @@ func TestLimitedSliceMarshalling(t *testing.T) { }, { reflect.TypeOf((*[1]test.Type1)(nil)), - func(task marshal.Task, addr usermem.Addr, dst interface{}) (int, error) { + func(cc marshal.CopyContext, addr usermem.Addr, dst interface{}) (int, error) { slice := dst.(*[1]test.Type1)[:] - return test.CopyType1SliceIn(task, addr, slice) + return test.CopyType1SliceIn(cc, addr, slice) }, - func(task marshal.Task, addr usermem.Addr, src interface{}) (int, error) { + func(cc marshal.CopyContext, addr usermem.Addr, src interface{}) (int, error) { slice := src.(*[1]test.Type1)[:] - return test.CopyType1SliceOut(task, addr, slice) + return test.CopyType1SliceOut(cc, addr, slice) }, func(a interface{}) []byte { slice := a.(*[1]test.Type1)[:] @@ -386,13 +386,13 @@ func TestLimitedSliceMarshalling(t *testing.T) { }, { reflect.TypeOf((*[7]test.Type8)(nil)), - func(task marshal.Task, addr usermem.Addr, dst interface{}) (int, error) { + func(cc marshal.CopyContext, addr usermem.Addr, dst interface{}) (int, error) { slice := dst.(*[7]test.Type8)[:] - return test.CopyType8SliceIn(task, addr, slice) + return test.CopyType8SliceIn(cc, addr, slice) }, - func(task marshal.Task, addr usermem.Addr, src interface{}) (int, error) { + func(cc marshal.CopyContext, addr usermem.Addr, src interface{}) (int, error) { slice := src.(*[7]test.Type8)[:] - return test.CopyType8SliceOut(task, addr, slice) + return test.CopyType8SliceOut(cc, addr, slice) }, func(a interface{}) []byte { slice := a.(*[7]test.Type8)[:] @@ -439,11 +439,11 @@ func TestLimitedSliceMarshalling(t *testing.T) { limit += elem.SizeBytes() / 2 analysis.RandomizeValue(expected) - var task mockTask - task.populate(expected) - task.setLimit(limit) + var cc mockCopyContext + cc.populate(expected) + cc.setLimit(limit) - n, err := tt.copySliceIn(&task, usermem.Addr(0), actual) + n, err := tt.copySliceIn(&cc, usermem.Addr(0), actual) if n != limit { t.Errorf("CopyIn copied unexpected number of bytes, expected %d, got %d", limit, n) } @@ -493,11 +493,11 @@ func TestLimitedSliceMarshalling(t *testing.T) { limit += elem.SizeBytes() / 2 analysis.RandomizeValue(expected) - var task mockTask - task.populate(expected) - task.setLimit(limit) + var cc mockCopyContext + cc.populate(expected) + cc.setLimit(limit) - n, err := tt.copySliceOut(&task, usermem.Addr(0), expected) + n, err := tt.copySliceOut(&cc, usermem.Addr(0), expected) if n != limit { t.Errorf("CopyIn copied unexpected number of bytes, expected %d, got %d", limit, n) } @@ -507,7 +507,7 @@ func TestLimitedSliceMarshalling(t *testing.T) { expectedMem := tt.unsafeMemory(expected) defer runtime.KeepAlive(expected) - actualMem := task.taskMem.Bytes + actualMem := cc.taskMem.Bytes compareMemory(t, expectedMem, actualMem, n) }) diff --git a/tools/issue_reviver/BUILD b/tools/issue_reviver/BUILD index 4ef1a3124..35b0111ca 100644 --- a/tools/issue_reviver/BUILD +++ b/tools/issue_reviver/BUILD @@ -5,6 +5,7 @@ package(licenses = ["notice"]) go_binary( name = "issue_reviver", srcs = ["main.go"], + nogo = False, deps = [ "//tools/issue_reviver/github", "//tools/issue_reviver/reviver", diff --git a/tools/issue_reviver/github/BUILD b/tools/issue_reviver/github/BUILD index 0eabc2835..555abd296 100644 --- a/tools/issue_reviver/github/BUILD +++ b/tools/issue_reviver/github/BUILD @@ -21,4 +21,5 @@ go_test( size = "small", srcs = ["github_test.go"], library = ":github", + nogo = False, ) diff --git a/tools/make_apt.sh b/tools/make_apt.sh index 3fb1066e5..13c5edd76 100755 --- a/tools/make_apt.sh +++ b/tools/make_apt.sh @@ -54,18 +54,22 @@ declare -r release="${root}/dists/${suite}" mkdir -p "${release}" # Create a temporary keyring, and ensure it is cleaned up. +# Using separate homedir allows us to install apt repositories multiple times +# using the same key. This is a limitation in GnuPG pre-2.1. declare -r keyring=$(mktemp /tmp/keyringXXXXXX.gpg) +declare -r homedir=$(mktemp -d /tmp/homedirXXXXXX) +declare -r gpg_opts=("--no-default-keyring" "--secret-keyring" "${keyring}" "--homedir" "${homedir}") cleanup() { - rm -f "${keyring}" + rm -rf "${keyring}" "${homedir}" } trap cleanup EXIT # We attempt the import twice because the first one will fail if the public key # is not found. This isn't actually a failure for us, because we don't require -# the public (this may be stored separately). The second import will succeed +# the public key (this may be stored separately). The second import will succeed # because, in reality, the first import succeeded and it's a no-op. -gpg --no-default-keyring --keyring "${keyring}" --import "${private_key}" || \ - gpg --no-default-keyring --keyring "${keyring}" --import "${private_key}" +gpg "${gpg_opts[@]}" --import "${private_key}" || \ + gpg "${gpg_opts[@]}" --import "${private_key}" # Copy the packages into the root. for pkg in "$@"; do @@ -100,7 +104,8 @@ for pkg in "$@"; do cp -a "${pkg}" "${target}" chmod 0644 "${target}" if [[ "${ext}" == "deb" ]]; then - dpkg-sig -g "--no-default-keyring --keyring ${keyring}" --sign builder "${target}" + # We use [*] here to expand the gpg_opts array into a single shell-word. + dpkg-sig -g "${gpg_opts[*]}" --sign builder "${target}" fi done @@ -135,5 +140,5 @@ rm "${release}"/apt.conf # Sign the release. declare -r digest_opts=("--digest-algo" "SHA512" "--cert-digest-algo" "SHA512") (cd "${release}" && rm -f Release.gpg InRelease) -(cd "${release}" && gpg --no-default-keyring --keyring "${keyring}" --clearsign "${digest_opts[@]}" -o InRelease Release) -(cd "${release}" && gpg --no-default-keyring --keyring "${keyring}" -abs "${digest_opts[@]}" -o Release.gpg Release) +(cd "${release}" && gpg "${gpg_opts[@]}" --clearsign "${digest_opts[@]}" -o InRelease Release) +(cd "${release}" && gpg "${gpg_opts[@]}" -abs "${digest_opts[@]}" -o Release.gpg Release) diff --git a/tools/nogo/BUILD b/tools/nogo/BUILD index e1bfb9a2c..9f1fcd9c7 100644 --- a/tools/nogo/BUILD +++ b/tools/nogo/BUILD @@ -1,7 +1,18 @@ load("//tools:defs.bzl", "bzl_library", "go_library") +load("//tools/nogo:defs.bzl", "nogo_dump_tool", "nogo_stdlib") package(licenses = ["notice"]) +nogo_dump_tool( + name = "dump_tool", + visibility = ["//visibility:public"], +) + +nogo_stdlib( + name = "stdlib", + visibility = ["//visibility:public"], +) + go_library( name = "nogo", srcs = [ @@ -16,7 +27,6 @@ go_library( deps = [ "//tools/checkescape", "//tools/checkunsafe", - "//tools/nogo/data", "@org_golang_x_tools//go/analysis:go_tool_library", "@org_golang_x_tools//go/analysis/internal/facts:go_tool_library", "@org_golang_x_tools//go/analysis/passes/asmdecl:go_tool_library", diff --git a/tools/nogo/build.go b/tools/nogo/build.go index 433d13738..37947b5c3 100644 --- a/tools/nogo/build.go +++ b/tools/nogo/build.go @@ -31,10 +31,10 @@ var ( ) // findStdPkg needs to find the bundled standard library packages. -func (i *importer) findStdPkg(path string) (io.ReadCloser, error) { +func findStdPkg(GOOS, GOARCH, path string) (io.ReadCloser, error) { if path == "C" { // Cgo builds cannot be analyzed. Skip. return nil, ErrSkip } - return os.Open(fmt.Sprintf("external/go_sdk/pkg/%s_%s/%s.a", i.GOOS, i.GOARCH, path)) + return os.Open(fmt.Sprintf("external/go_sdk/pkg/%s_%s/%s.a", GOOS, GOARCH, path)) } diff --git a/tools/nogo/check/BUILD b/tools/nogo/check/BUILD index e2d76cd5c..21ba2c306 100644 --- a/tools/nogo/check/BUILD +++ b/tools/nogo/check/BUILD @@ -7,6 +7,7 @@ package(licenses = ["notice"]) go_binary( name = "check", srcs = ["main.go"], + nogo = False, visibility = ["//visibility:public"], deps = ["//tools/nogo"], ) diff --git a/tools/nogo/config.go b/tools/nogo/config.go index 6958fca69..cfe7b4aa4 100644 --- a/tools/nogo/config.go +++ b/tools/nogo/config.go @@ -84,6 +84,14 @@ var analyzerConfig = map[*analysis.Analyzer]matcher{ externalExcluded( ".*protobuf/.*.go", // Bad conversions. ".*flate/huffman_bit_writer.go", // Bad conversion. + + // Runtime internal violations. + ".*reflect/value.go", + ".*encoding/xml/xml.go", + ".*runtime/pprof/internal/profile/proto.go", + ".*fmt/scan.go", + ".*go/types/conversions.go", + ".*golang.org/x/net/dns/dnsmessage/message.go", ), ), shadow.Analyzer: disableMatches(), // Disabled for now. @@ -114,3 +122,8 @@ var analyzerConfig = map[*analysis.Analyzer]matcher{ checkescape.Analyzer: internalMatches(), checkunsafe.Analyzer: internalMatches(), } + +var escapesConfig = map[*analysis.Analyzer]matcher{ + // Informational only: include all packages. + checkescape.EscapeAnalyzer: alwaysMatches(), +} diff --git a/tools/nogo/data/BUILD b/tools/nogo/data/BUILD deleted file mode 100644 index b7564cc44..000000000 --- a/tools/nogo/data/BUILD +++ /dev/null @@ -1,10 +0,0 @@ -load("//tools:defs.bzl", "go_library") - -package(licenses = ["notice"]) - -go_library( - name = "data", - srcs = ["data.go"], - nogo = False, - visibility = ["//tools:__subpackages__"], -) diff --git a/tools/nogo/defs.bzl b/tools/nogo/defs.bzl index d399079c5..480438047 100644 --- a/tools/nogo/defs.bzl +++ b/tools/nogo/defs.bzl @@ -1,6 +1,107 @@ """Nogo rules.""" -load("//tools/bazeldefs:defs.bzl", "go_context", "go_importpath", "go_rule") +load("//tools/bazeldefs:defs.bzl", "go_context", "go_importpath", "go_rule", "go_test_library") + +def _nogo_dump_tool_impl(ctx): + # Extract the Go context. + go_ctx = go_context(ctx) + + # Construct the magic dump command. + # + # Note that in some cases, the input is being fed into the tool via stdin. + # Unfortunately, the Go objdump tool expects to see a seekable file [1], so + # we need the tool to handle this case by creating a temporary file. + # + # [1] https://github.com/golang/go/issues/41051 + env_prefix = " ".join(["%s=%s" % (key, value) for (key, value) in go_ctx.env.items()]) + dumper = ctx.actions.declare_file(ctx.label.name) + ctx.actions.write(dumper, "\n".join([ + "#!/bin/bash", + "set -euo pipefail", + "if [[ $# -eq 0 ]]; then", + " T=$(mktemp -u -t libXXXXXX.a)", + " cat /dev/stdin > ${T}", + "else", + " T=$1;", + "fi", + "%s %s tool objdump ${T}" % ( + env_prefix, + go_ctx.go.path, + ), + "if [[ $# -eq 0 ]]; then", + " rm -rf ${T}", + "fi", + "", + ]), is_executable = True) + + # Include the full runfiles. + return [DefaultInfo( + runfiles = ctx.runfiles(files = go_ctx.runfiles.to_list()), + executable = dumper, + )] + +nogo_dump_tool = go_rule( + rule, + implementation = _nogo_dump_tool_impl, +) + +# NogoStdlibInfo is the set of standard library facts. +NogoStdlibInfo = provider( + "information for nogo analysis (standard library facts)", + fields = { + "facts": "serialized standard library facts", + "findings": "package findings (if relevant)", + }, +) + +def _nogo_stdlib_impl(ctx): + # Extract the Go context. + go_ctx = go_context(ctx) + + # Build the standard library facts. + facts = ctx.actions.declare_file(ctx.label.name + ".facts") + findings = ctx.actions.declare_file(ctx.label.name + ".findings") + config = struct( + Srcs = [f.path for f in go_ctx.stdlib_srcs], + GOOS = go_ctx.goos, + GOARCH = go_ctx.goarch, + Tags = go_ctx.tags, + ) + config_file = ctx.actions.declare_file(ctx.label.name + ".cfg") + ctx.actions.write(config_file, config.to_json()) + ctx.actions.run( + inputs = [config_file] + go_ctx.stdlib_srcs, + outputs = [facts, findings], + tools = depset(go_ctx.runfiles.to_list() + ctx.files._dump_tool), + executable = ctx.files._nogo[0], + mnemonic = "GoStandardLibraryAnalysis", + progress_message = "Analyzing Go Standard Library", + arguments = go_ctx.nogo_args + [ + "-dump_tool=%s" % ctx.files._dump_tool[0].path, + "-stdlib=%s" % config_file.path, + "-findings=%s" % findings.path, + "-facts=%s" % facts.path, + ], + ) + + # Return the stdlib facts as output. + return [NogoStdlibInfo( + facts = facts, + findings = findings, + )] + +nogo_stdlib = go_rule( + rule, + implementation = _nogo_stdlib_impl, + attrs = { + "_nogo": attr.label( + default = "//tools/nogo/check:check", + ), + "_dump_tool": attr.label( + default = "//tools/nogo:dump_tool", + ), + }, +) # NogoInfo is the serialized set of package facts for a nogo analysis. # @@ -8,10 +109,14 @@ load("//tools/bazeldefs:defs.bzl", "go_context", "go_importpath", "go_rule") # with the source files as input. Note however, that the individual nogo rules # are simply stubs that enter into the shadow dependency tree (the "aspect"). NogoInfo = provider( + "information for nogo analysis", fields = { "facts": "serialized package facts", + "findings": "package findings (if relevant)", "importpath": "package import path", "binaries": "package binary files", + "srcs": "original source files (for go_test support)", + "deps": "original deps (for go_test support)", }, ) @@ -21,17 +126,29 @@ def _nogo_aspect_impl(target, ctx): # All work is done in the shadow properties for go rules. For a proto # library, we simply skip the analysis portion but still need to return a # valid NogoInfo to reference the generated binary. - if ctx.rule.kind == "go_library": + if ctx.rule.kind in ("go_library", "go_binary", "go_test", "go_tool_library"): srcs = ctx.rule.files.srcs - elif ctx.rule.kind == "go_proto_library" or ctx.rule.kind == "go_wrap_cc": + deps = ctx.rule.attr.deps + elif ctx.rule.kind in ("go_proto_library", "go_wrap_cc"): srcs = [] + deps = ctx.rule.attr.deps else: return [NogoInfo()] + # Extract the Go context. go_ctx = go_context(ctx) - # Construct the Go environment from the go_ctx.env dictionary. - env_prefix = " ".join(["%s=%s" % (key, value) for (key, value) in go_ctx.env.items()]) + # If we're using the "library" attribute, then we need to aggregate the + # original library sources and dependencies into this target to perform + # proper type analysis. + if ctx.rule.kind == "go_test": + library = go_test_library(ctx.rule) + if library != None: + info = library[NogoInfo] + if hasattr(info, "srcs"): + srcs = srcs + info.srcs + if hasattr(info, "deps"): + deps = deps + info.deps # Start with all target files and srcs as input. inputs = target.files.to_list() + srcs @@ -41,50 +158,30 @@ def _nogo_aspect_impl(target, ctx): # to cleanly allow us redirect stdout to the actual output file. Perhaps # I'm missing something here, but the intermediate script does work. binaries = target.files.to_list() - disasm_file = ctx.actions.declare_file(target.label.name + ".out") - dumper = ctx.actions.declare_file("%s-dumper" % ctx.label.name) - ctx.actions.write(dumper, "\n".join([ - "#!/bin/bash", - "%s %s tool objdump %s > %s\n" % ( - env_prefix, - go_ctx.go.path, - [f.path for f in binaries if f.path.endswith(".a")][0], - disasm_file.path, - ), - ]), is_executable = True) - ctx.actions.run( - inputs = binaries, - outputs = [disasm_file], - tools = go_ctx.runfiles, - mnemonic = "GoObjdump", - progress_message = "Objdump %s" % target.label, - executable = dumper, - ) - inputs.append(disasm_file) + objfiles = [f for f in binaries if f.path.endswith(".a")] + if len(objfiles) > 0: + # Prefer the .a files for go_library targets. + target_objfile = objfiles[0] + else: + # Use the raw binary for go_binary and go_test targets. + target_objfile = binaries[0] + inputs.append(target_objfile) # Extract the importpath for this package. - importpath = go_importpath(target) - - # The nogo tool requires a configfile serialized in JSON format to do its - # work. This must line up with the nogo.Config fields. - facts = ctx.actions.declare_file(target.label.name + ".facts") - config = struct( - ImportPath = importpath, - GoFiles = [src.path for src in srcs if src.path.endswith(".go")], - NonGoFiles = [src.path for src in srcs if not src.path.endswith(".go")], - # Google's internal build system needs a bit more help to find std. - StdZip = go_ctx.std_zip.short_path if hasattr(go_ctx, "std_zip") else "", - GOOS = go_ctx.goos, - GOARCH = go_ctx.goarch, - Tags = go_ctx.tags, - FactMap = {}, # Constructed below. - ImportMap = {}, # Constructed below. - FactOutput = facts.path, - Objdump = disasm_file.path, - ) + if ctx.rule.kind == "go_test": + # If this is a test, then it will not be imported by anything else. + # We can safely set the importapth to just "test". Note that this + # is necessary if the library also imports the core library (in + # addition to including the sources directly), which happens in + # some complex cases (seccomp_victim). + importpath = "test" + else: + importpath = go_importpath(target) # Collect all info from shadow dependencies. - for dep in ctx.rule.attr.deps: + fact_map = dict() + import_map = dict() + for dep in deps: # There will be no file attribute set for all transitive dependencies # that are not go_library or go_binary rules, such as a proto rules. # This is handled by the ctx.rule.kind check above. @@ -98,45 +195,83 @@ def _nogo_aspect_impl(target, ctx): x_files = [f.path for f in info.binaries if f.path.endswith(".x")] if not len(x_files): x_files = [f.path for f in info.binaries if f.path.endswith(".a")] - config.ImportMap[info.importpath] = x_files[0] - config.FactMap[info.importpath] = info.facts.path + import_map[info.importpath] = x_files[0] + fact_map[info.importpath] = info.facts.path # Ensure the above are available as inputs. inputs.append(info.facts) inputs += info.binaries - # Write the configuration and run the tool. + # Add the standard library facts. + stdlib_facts = ctx.attr._nogo_stdlib[NogoStdlibInfo].facts + inputs.append(stdlib_facts) + + # The nogo tool operates on a configuration serialized in JSON format. + facts = ctx.actions.declare_file(target.label.name + ".facts") + findings = ctx.actions.declare_file(target.label.name + ".findings") + escapes = ctx.actions.declare_file(target.label.name + ".escapes") + config = struct( + ImportPath = importpath, + GoFiles = [src.path for src in srcs if src.path.endswith(".go")], + NonGoFiles = [src.path for src in srcs if not src.path.endswith(".go")], + GOOS = go_ctx.goos, + GOARCH = go_ctx.goarch, + Tags = go_ctx.tags, + FactMap = fact_map, + ImportMap = import_map, + StdlibFacts = stdlib_facts.path, + ) config_file = ctx.actions.declare_file(target.label.name + ".cfg") ctx.actions.write(config_file, config.to_json()) inputs.append(config_file) - - # Run the nogo tool itself. ctx.actions.run( inputs = inputs, - outputs = [facts], - tools = go_ctx.runfiles, + outputs = [facts, findings, escapes], + tools = depset(go_ctx.runfiles.to_list() + ctx.files._dump_tool), executable = ctx.files._nogo[0], mnemonic = "GoStaticAnalysis", progress_message = "Analyzing %s" % target.label, - arguments = ["-config=%s" % config_file.path], + arguments = go_ctx.nogo_args + [ + "-binary=%s" % target_objfile.path, + "-dump_tool=%s" % ctx.files._dump_tool[0].path, + "-package=%s" % config_file.path, + "-findings=%s" % findings.path, + "-facts=%s" % facts.path, + "-escapes=%s" % escapes.path, + ], ) # Return the package facts as output. - return [NogoInfo( - facts = facts, - importpath = importpath, - binaries = binaries, - )] + return [ + NogoInfo( + facts = facts, + findings = findings, + importpath = importpath, + binaries = binaries, + srcs = srcs, + deps = deps, + ), + OutputGroupInfo( + # Expose all findings (should just be a single file). This can be + # used for build analysis of the nogo findings. + nogo_findings = depset([findings]), + # Expose all escape analysis findings (see above). + nogo_escapes = depset([escapes]), + ), + ] nogo_aspect = go_rule( aspect, implementation = _nogo_aspect_impl, - attr_aspects = ["deps"], + attr_aspects = [ + "deps", + "library", + "embed", + ], attrs = { - "_nogo": attr.label( - default = "//tools/nogo/check:check", - allow_single_file = True, - ), + "_nogo": attr.label(default = "//tools/nogo/check:check"), + "_nogo_stdlib": attr.label(default = "//tools/nogo:stdlib"), + "_dump_tool": attr.label(default = "//tools/nogo:dump_tool"), }, ) @@ -148,13 +283,26 @@ def _nogo_test_impl(ctx): # this way so that any test applied is effectively pushed down to all # upstream dependencies through the aspect. inputs = [] + findings = [] runner = ctx.actions.declare_file("%s-executer" % ctx.label.name) runner_content = ["#!/bin/bash"] for dep in ctx.attr.deps: + # Extract the findings. info = dep[NogoInfo] - inputs.append(info.facts) + inputs.append(info.findings) + findings.append(info.findings) + + # Include all source files, transitively. This will make this target + # "directly affected" for the purpose of build analysis. + inputs += info.srcs - # Draw a sweet unicode checkmark with the package name (in green). + # If there are findings, dump them and fail. + runner_content.append("if [[ -s \"%s\" ]]; then cat \"%s\" && exit 1; fi" % ( + info.findings.short_path, + info.findings.short_path, + )) + + # Otherwise, draw a sweet unicode checkmark with the package name (in green). runner_content.append("echo -e \"\\033[0;32m\\xE2\\x9C\\x94\\033[0;31m\\033[0m %s\"" % info.importpath) runner_content.append("exit 0\n") ctx.actions.write(runner, "\n".join(runner_content), is_executable = True) @@ -171,6 +319,10 @@ _nogo_test = rule( test = True, ) -def nogo_test(**kwargs): +def nogo_test(name, **kwargs): tags = kwargs.pop("tags", []) + ["nogo"] - _nogo_test(tags = tags, **kwargs) + _nogo_test( + name = name, + tags = tags, + **kwargs + ) diff --git a/tools/nogo/nogo.go b/tools/nogo/nogo.go index ea1e97076..120fdcff5 100644 --- a/tools/nogo/nogo.go +++ b/tools/nogo/nogo.go @@ -32,51 +32,89 @@ import ( "io/ioutil" "log" "os" + "path" "path/filepath" "reflect" + "strings" "golang.org/x/tools/go/analysis" "golang.org/x/tools/go/analysis/internal/facts" "golang.org/x/tools/go/gcexportdata" - "gvisor.dev/gvisor/tools/nogo/data" + + // Special case: flags live here and change overall behavior. + "gvisor.dev/gvisor/tools/checkescape" ) -// pkgConfig is serialized as the configuration. +// stdlibConfig is serialized as the configuration. // -// This contains everything required for the analysis. -type pkgConfig struct { - ImportPath string - GoFiles []string - NonGoFiles []string - Tags []string - GOOS string - GOARCH string - ImportMap map[string]string - FactMap map[string]string - FactOutput string - Objdump string - StdZip string +// This contains everything required for stdlib analysis. +type stdlibConfig struct { + Srcs []string + GOOS string + GOARCH string + Tags []string } -// loadFacts finds and loads facts per FactMap. -func (c *pkgConfig) loadFacts(path string) ([]byte, error) { - realPath, ok := c.FactMap[path] - if !ok { - return nil, nil // No facts available. - } +// packageConfig is serialized as the configuration. +// +// This contains everything required for single package analysis. +type packageConfig struct { + ImportPath string + GoFiles []string + NonGoFiles []string + Tags []string + GOOS string + GOARCH string + ImportMap map[string]string + FactMap map[string]string + StdlibFacts string +} - // Read the files file. - data, err := ioutil.ReadFile(realPath) - if err != nil { - return nil, err +// loader is a fact-loader function. +type loader func(string) ([]byte, error) + +// saver is a fact-saver function. +type saver func([]byte) error + +// factLoader returns a function that loads facts. +// +// This resolves all standard library facts and imported package facts up +// front. The returned loader function will never return an error, only +// empty facts. +// +// This is done because all stdlib data is stored together, and we don't want +// to load this data many times over. +func (c *packageConfig) factLoader() (loader, error) { + allFacts := make(map[string][]byte) + if c.StdlibFacts != "" { + data, err := ioutil.ReadFile(c.StdlibFacts) + if err != nil { + return nil, fmt.Errorf("error loading stdlib facts from %q: %w", c.StdlibFacts, err) + } + var stdlibFacts map[string][]byte + if err := json.Unmarshal(data, &stdlibFacts); err != nil { + return nil, fmt.Errorf("error loading stdlib facts: %w", err) + } + for pkg, data := range stdlibFacts { + allFacts[pkg] = data + } + } + for pkg, file := range c.FactMap { + data, err := ioutil.ReadFile(file) + if err != nil { + return nil, fmt.Errorf("error loading %q: %w", file, err) + } + allFacts[pkg] = data } - return data, nil + return func(path string) ([]byte, error) { + return allFacts[path], nil + }, nil } // shouldInclude indicates whether the file should be included. // // NOTE: This does only basic parsing of tags. -func (c *pkgConfig) shouldInclude(path string) (bool, error) { +func (c *packageConfig) shouldInclude(path string) (bool, error) { ctx := build.Default ctx.GOOS = c.GOOS ctx.GOARCH = c.GOARCH @@ -90,10 +128,11 @@ func (c *pkgConfig) shouldInclude(path string) (bool, error) { // files, and the facts. Note that this importer implementation will always // pass when a given package is not available. type importer struct { - pkgConfig - fset *token.FileSet - cache map[string]*types.Package - lastErr error + *packageConfig + fset *token.FileSet + cache map[string]*types.Package + lastErr error + callback func(string) error } // Import implements types.Importer.Import. @@ -104,6 +143,17 @@ func (i *importer) Import(path string) (*types.Package, error) { // analyzers are specifically looking for this. return types.Unsafe, nil } + + // Call the internal callback. This is used to resolve loading order + // for the standard library. See checkStdlib. + if i.callback != nil { + if err := i.callback(path); err != nil { + i.lastErr = err + return nil, err + } + } + + // Actually load the data. realPath, ok := i.ImportMap[path] var ( rc io.ReadCloser @@ -112,7 +162,7 @@ func (i *importer) Import(path string) (*types.Package, error) { if !ok { // Not found in the import path. Attempt to find the package // via the standard library. - rc, err = i.findStdPkg(path) + rc, err = findStdPkg(i.GOOS, i.GOARCH, path) } else { // Open the file. rc, err = os.Open(realPath) @@ -135,6 +185,151 @@ func (i *importer) Import(path string) (*types.Package, error) { // ErrSkip indicates the package should be skipped. var ErrSkip = errors.New("skipped") +// checkStdlib checks the standard library. +// +// This constructs a synthetic package configuration for each library in the +// standard library sources, and call checkPackage repeatedly. +// +// Note that not all parts of the source are expected to build. We skip obvious +// test files, and cmd files, which should not be dependencies. +func checkStdlib(config *stdlibConfig, ac map[*analysis.Analyzer]matcher) ([]string, []byte, error) { + if len(config.Srcs) == 0 { + return nil, nil, nil + } + + // Ensure all paths are normalized. + for i := 0; i < len(config.Srcs); i++ { + config.Srcs[i] = path.Clean(config.Srcs[i]) + } + + // Calculate the root source directory. This is always a directory + // named 'src', of which we simply take the first we find. This is a + // bit fragile, but works for all currently known Go source + // configurations. + // + // Note that there may be extra files outside of the root source + // directory; we simply ignore those. + rootSrcPrefix := "" + for _, file := range config.Srcs { + const src = "/src/" + i := strings.Index(file, src) + if i == -1 { + // Superfluous file. + continue + } + + // Index of first character after /src/. + i += len(src) + rootSrcPrefix = file[:i] + break + } + + // Aggregate all files by directory. + packages := make(map[string]*packageConfig) + for _, file := range config.Srcs { + if !strings.HasPrefix(file, rootSrcPrefix) { + // Superflouous file. + continue + } + + d := path.Dir(file) + if len(rootSrcPrefix) >= len(d) { + continue // Not a file. + } + pkg := d[len(rootSrcPrefix):] + // Skip cmd packages and obvious test files: see above. + if strings.HasPrefix(pkg, "cmd/") || strings.HasSuffix(file, "_test.go") { + continue + } + c, ok := packages[pkg] + if !ok { + c = &packageConfig{ + ImportPath: pkg, + GOOS: config.GOOS, + GOARCH: config.GOARCH, + Tags: config.Tags, + } + packages[pkg] = c + } + // Add the files appropriately. Note that they will be further + // filtered by architecture and build tags below, so this need + // not be done immediately. + if strings.HasSuffix(file, ".go") { + c.GoFiles = append(c.GoFiles, file) + } else { + c.NonGoFiles = append(c.NonGoFiles, file) + } + } + + // Closure to check a single package. + allFindings := make([]string, 0) + stdlibFacts := make(map[string][]byte) + var checkOne func(pkg string) error // Recursive. + checkOne = func(pkg string) error { + // Is this already done? + if _, ok := stdlibFacts[pkg]; ok { + return nil + } + + // Lookup the configuration. + config, ok := packages[pkg] + if !ok { + return nil // Not known. + } + + // Find the binary package, and provide to objdump. + rc, err := findStdPkg(config.GOOS, config.GOARCH, pkg) + if err != nil { + // If there's no binary for this package, it is likely + // not built with the distribution. That's fine, we can + // just skip analysis. + return nil + } + + // Provide the input. + oldReader := checkescape.Reader + checkescape.Reader = rc // For analysis. + defer func() { + rc.Close() + checkescape.Reader = oldReader // Restore. + }() + + // Run the analysis. + findings, factData, err := checkPackage(config, ac, checkOne) + if err != nil { + // If we can't analyze a package from the standard library, + // then we skip it. It will simply not have any findings. + return nil + } + stdlibFacts[pkg] = factData + allFindings = append(allFindings, findings...) + return nil + } + + // Check all packages. + // + // Note that this may call checkOne recursively, so it's not guaranteed + // to evaluate in the order provided here. We do ensure however, that + // all packages are evaluated. + for pkg := range packages { + checkOne(pkg) + } + + // Sanity check. + if len(stdlibFacts) == 0 { + return nil, nil, fmt.Errorf("no stdlib facts found: misconfiguration?") + } + + // Write out all findings. + factData, err := json.Marshal(stdlibFacts) + if err != nil { + return nil, nil, fmt.Errorf("error saving stdlib facts: %w", err) + } + + // Return all findings. + return allFindings, factData, nil +} + // checkPackage runs all analyzers. // // The implementation was adapted from [1], which was in turn adpated from [2]. @@ -143,11 +338,12 @@ var ErrSkip = errors.New("skipped") // // [1] bazelbuid/rules_go/tools/builders/nogo_main.go // [2] golang.org/x/tools/go/checker/internal/checker -func checkPackage(config pkgConfig) ([]string, error) { +func checkPackage(config *packageConfig, ac map[*analysis.Analyzer]matcher, importCallback func(string) error) ([]string, []byte, error) { imp := &importer{ - pkgConfig: config, - fset: token.NewFileSet(), - cache: make(map[string]*types.Package), + packageConfig: config, + fset: token.NewFileSet(), + cache: make(map[string]*types.Package), + callback: importCallback, } // Load all source files. @@ -155,14 +351,14 @@ func checkPackage(config pkgConfig) ([]string, error) { for _, file := range config.GoFiles { include, err := config.shouldInclude(file) if err != nil { - return nil, fmt.Errorf("error evaluating file %q: %v", file, err) + return nil, nil, fmt.Errorf("error evaluating file %q: %v", file, err) } if !include { continue } s, err := parser.ParseFile(imp.fset, file, nil, parser.ParseComments) if err != nil { - return nil, fmt.Errorf("error parsing file %q: %v", file, err) + return nil, nil, fmt.Errorf("error parsing file %q: %v", file, err) } syntax = append(syntax, s) } @@ -180,17 +376,18 @@ func checkPackage(config pkgConfig) ([]string, error) { } types, err := typeConfig.Check(config.ImportPath, imp.fset, syntax, typesInfo) if err != nil && imp.lastErr != ErrSkip { - return nil, fmt.Errorf("error checking types: %w", err) + return nil, nil, fmt.Errorf("error checking types: %w", err) } // Load all package facts. - facts, err := facts.Decode(types, config.loadFacts) + loader, err := config.factLoader() if err != nil { - return nil, fmt.Errorf("error decoding facts: %w", err) + return nil, nil, fmt.Errorf("error loading facts: %w", err) + } + facts, err := facts.Decode(types, loader) + if err != nil { + return nil, nil, fmt.Errorf("error decoding facts: %w", err) } - - // Set the binary global for use. - data.Objdump = config.Objdump // Register fact types and establish dependencies between analyzers. // The visit closure will execute recursively, and populate results @@ -211,7 +408,7 @@ func checkPackage(config pkgConfig) ([]string, error) { } // Prepare the matcher. - m := analyzerConfig[a] + m := ac[a] report := func(d analysis.Diagnostic) { if m.ShouldReport(d, imp.fset) { diagnostics[a] = append(diagnostics[a], d) @@ -252,21 +449,13 @@ func checkPackage(config pkgConfig) ([]string, error) { return nil // Success. } - // Visit all analysis recursively. - for a, _ := range analyzerConfig { + // Visit all analyzers recursively. + for a, _ := range ac { if imp.lastErr == ErrSkip { continue // No local analysis. } if err := visit(a); err != nil { - return nil, err // Already has context. - } - } - - // Write the output file. - if config.FactOutput != "" { - factData := facts.Encode() - if err := ioutil.WriteFile(config.FactOutput, factData, 0644); err != nil { - return nil, fmt.Errorf("error: unable to open facts output %q: %v", config.FactOutput, err) + return nil, nil, err // Already has context. } } @@ -280,47 +469,104 @@ func checkPackage(config pkgConfig) ([]string, error) { } // Return all findings. - return findings, nil + factData := facts.Encode() + return findings, factData, nil } var ( - configFile = flag.String("config", "", "configuration file (in JSON format)") + packageFile = flag.String("package", "", "package configuration file (in JSON format)") + stdlibFile = flag.String("stdlib", "", "stdlib configuration file (in JSON format)") + findingsOutput = flag.String("findings", "", "output file (or stdout, if not specified)") + factsOutput = flag.String("facts", "", "output file for facts (optional)") + escapesOutput = flag.String("escapes", "", "output file for escapes (optional)") ) -// Main is the entrypoint; it should be called directly from main. -// -// N.B. This package registers it's own flags. -func Main() { - // Parse all flags. - flag.Parse() - +func loadConfig(file string, config interface{}) interface{} { // Load the configuration. - f, err := os.Open(*configFile) + f, err := os.Open(file) if err != nil { - log.Fatalf("unable to open configuration %q: %v", *configFile, err) + log.Fatalf("unable to open configuration %q: %v", file, err) } defer f.Close() - config := new(pkgConfig) dec := json.NewDecoder(f) dec.DisallowUnknownFields() if err := dec.Decode(config); err != nil { log.Fatalf("unable to decode configuration: %v", err) } + return config +} + +// Main is the entrypoint; it should be called directly from main. +// +// N.B. This package registers it's own flags. +func Main() { + // Parse all flags. + flag.Parse() - // Process the package. - findings, err := checkPackage(*config) + var ( + findings []string + factData []byte + err error + ) + + // Check the configuration. + if *packageFile != "" && *stdlibFile != "" { + log.Fatalf("unable to perform stdlib and package analysis; provide only one!") + } else if *stdlibFile != "" { + // Perform basic analysis. + c := loadConfig(*stdlibFile, new(stdlibConfig)).(*stdlibConfig) + findings, factData, err = checkStdlib(c, analyzerConfig) + } else if *packageFile != "" { + // Perform basic analysis. + c := loadConfig(*packageFile, new(packageConfig)).(*packageConfig) + findings, factData, err = checkPackage(c, analyzerConfig, nil) + // Do we need to do escape analysis? + if *escapesOutput != "" { + escapes, _, err := checkPackage(c, escapesConfig, nil) + if err != nil { + log.Fatalf("error performing escape analysis: %v", err) + } + f, err := os.OpenFile(*escapesOutput, os.O_WRONLY|os.O_CREATE|os.O_TRUNC, 0644) + if err != nil { + log.Fatalf("unable to open output %q: %v", *escapesOutput, err) + } + defer f.Close() + for _, escape := range escapes { + fmt.Fprintf(f, "%s\n", escape) + } + } + } else { + log.Fatalf("please provide at least one of package or stdlib!") + } + + // Save facts. + if *factsOutput != "" { + if err := ioutil.WriteFile(*factsOutput, factData, 0644); err != nil { + log.Fatalf("error saving findings to %q: %v", *factsOutput, err) + } + } + + // Open the output file. + var w io.Writer = os.Stdout + if *findingsOutput != "" { + f, err := os.OpenFile(*findingsOutput, os.O_WRONLY|os.O_CREATE|os.O_TRUNC, 0644) + if err != nil { + log.Fatalf("unable to open output %q: %v", *findingsOutput, err) + } + defer f.Close() + w = f + } + + // Handle findings & errors. if err != nil { log.Fatalf("error checking package: %v", err) } - - // No findings? if len(findings) == 0 { - os.Exit(0) + return } - // Print findings and exit with non-zero code. + // Print findings. for _, finding := range findings { - fmt.Fprintf(os.Stdout, "%s\n", finding) + fmt.Fprintf(w, "%s\n", finding) } - os.Exit(1) } diff --git a/tools/nogo/register.go b/tools/nogo/register.go index 62b499661..34b173937 100644 --- a/tools/nogo/register.go +++ b/tools/nogo/register.go @@ -26,6 +26,9 @@ func analyzers() (all []*analysis.Analyzer) { for a, _ := range analyzerConfig { all = append(all, a) } + for a, _ := range escapesConfig { + all = append(all, a) + } return all } diff --git a/website/BUILD b/website/BUILD index 7b61d13c8..6d92d9103 100644 --- a/website/BUILD +++ b/website/BUILD @@ -157,6 +157,7 @@ docs( "//g3doc/user_guide/quick_start:oci", "//g3doc/user_guide/tutorials:cni", "//g3doc/user_guide/tutorials:docker", + "//g3doc/user_guide/tutorials:docker_compose", "//g3doc/user_guide/tutorials:kubernetes", ], ) diff --git a/website/_config.yml b/website/_config.yml index b08602970..20fbb3d2d 100644 --- a/website/_config.yml +++ b/website/_config.yml @@ -34,3 +34,6 @@ authors: igudger: name: Ian Gudger email: igudger@google.com + fvoznika: + name: Fabricio Voznika + email: fvoznika@google.com diff --git a/website/_sass/front.scss b/website/_sass/front.scss index 0e4208f3c..f1b060560 100644 --- a/website/_sass/front.scss +++ b/website/_sass/front.scss @@ -1,5 +1,5 @@ .jumbotron { - background-image: url(/assets/images/background.jpg); + background-image: url(/assets/images/background_1080p.jpg); background-position: center; background-repeat: no-repeat; background-size: cover; diff --git a/website/assets/images/2020-09-18-containing-a-real-vulnerability-figure1.png b/website/assets/images/2020-09-18-containing-a-real-vulnerability-figure1.png Binary files differnew file mode 100644 index 000000000..c750f0851 --- /dev/null +++ b/website/assets/images/2020-09-18-containing-a-real-vulnerability-figure1.png diff --git a/website/assets/images/background_1080p.jpg b/website/assets/images/background_1080p.jpg Binary files differnew file mode 100644 index 000000000..d312595a6 --- /dev/null +++ b/website/assets/images/background_1080p.jpg diff --git a/website/blog/2019-11-18-security-basics.md b/website/blog/2019-11-18-security-basics.md index 76bbabc13..b6cf57a77 100644 --- a/website/blog/2019-11-18-security-basics.md +++ b/website/blog/2019-11-18-security-basics.md @@ -188,7 +188,7 @@ for direct access to some files. And most files will be remotely accessed through the Gofers, in which case no FDs are donated to the Sentry. The Sentry itself is only allowed access to specific -[whitelisted syscalls](https://github.com/google/gvisor/blob/master/runsc/boot/config.go). +[whitelisted syscalls](https://github.com/google/gvisor/blob/master/runsc/config/config.go). Without networking, the Sentry needs 53 host syscalls in order to function, and with networking, it uses an additional 15[^8]. By limiting the whitelist to only these needed syscalls, we radically reduce the amount of host OS attack surface. @@ -279,8 +279,10 @@ weaknesses of each gVisor component. We will also use it to introduce Google's Vulnerability Reward Program[^14], and other ways the community can contribute to help make gVisor safe, fast and stable. +<br> +<br> -## Notes +-------------------------------------------------------------------------------- [^1]: [https://en.wikipedia.org/wiki/Secure_by_design](https://en.wikipedia.org/wiki/Secure_by_design) [^2]: [https://gvisor.dev/docs/architecture_guide](https://gvisor.dev/docs/architecture_guide/) diff --git a/website/blog/2020-09-18-containing-a-real-vulnerability.md b/website/blog/2020-09-18-containing-a-real-vulnerability.md new file mode 100644 index 000000000..b71ef63d9 --- /dev/null +++ b/website/blog/2020-09-18-containing-a-real-vulnerability.md @@ -0,0 +1,224 @@ +# Containing a Real Vulnerability + +In the previous two posts we talked about gVisor's +[security design principles](https://gvisor.dev/blog/2019/11/18/gvisor-security-basics-part-1/) +as well as how those are applied in the +[context of networking](https://gvisor.dev/blog/2020/04/02/gvisor-networking-security/). +Recently, a new container escape vulnerability +([CVE-2020-14386](https://cve.mitre.org/cgi-bin/cvename.cgi?name=CVE-2020-14386)) +was announced that ties these topics well together. gVisor is +[not vulnerable](https://seclists.org/oss-sec/2020/q3/168) to this specific +issue, but it provides an interesting case study to continue our exploration of +gVisor's security. While gVisor is not immune to vulnerabilities, +[we take several steps](https://gvisor.dev/security/) to minimize the impact and +remediate if a vulnerability is found. + +## Escaping the Container + +First, let’s describe how the discovered vulnerability works. There are numerous +ways one can send and receive bytes over the network with Linux. One of the most +performant ways is to use a ring buffer, which is a memory region shared by the +application and the kernel. These rings are created by calling +[setsockopt(2)](https://man7.org/linux/man-pages/man2/setsockopt.2.html) with +[`PACKET_RX_RING`](https://man7.org/linux/man-pages/man7/packet.7.html) for +receiving and +[`PACKET_TX_RING`](https://man7.org/linux/man-pages/man7/packet.7.html) for +sending packets. + +The vulnerability is in the code that reads packets when `PACKET_RX_RING` is +enabled. There is another option +([`PACKET_RESERVE`](https://man7.org/linux/man-pages/man7/packet.7.html)) that +asks the kernel to leave some space in the ring buffer before each packet for +anything the application needs, e.g. control structures. When a packet is +received, the kernel calculates where to copy the packet to, taking the amount +reserved before each packet into consideration. If the amount reserved is large, +the kernel performed an incorrect calculation which could cause an overflow +leading to an out-of-bounds write of up to 10 bytes, controlled by the attacker. +The data in the write is easily controlled using the loopback to send a crafted +packet and receiving it using a `PACKET_RX_RING` with a carefully selected +`PACKET_RESERVE` size. + +```c +static int tpacket_rcv(struct sk_buff *skb, struct net_device *dev, + struct packet_type *pt, struct net_device *orig_dev) +{ +// ... + if (sk->sk_type == SOCK_DGRAM) { + macoff = netoff = TPACKET_ALIGN(po->tp_hdrlen) + 16 + + po->tp_reserve; + } else { + unsigned int maclen = skb_network_offset(skb); + // tp_reserve is unsigned int, netoff is unsigned short. Addition can overflow netoff + netoff = TPACKET_ALIGN(po->tp_hdrlen + + (maclen < 16 ? 16 : maclen)) + + po->tp_reserve; + if (po->has_vnet_hdr) { + netoff += sizeof(struct virtio_net_hdr); + do_vnet = true; + } + // Attacker controls netoff and can make macoff be smaller than sizeof(struct virtio_net_hdr) + macoff = netoff - maclen; + } +// ... + // "macoff - sizeof(struct virtio_net_hdr)" can be negative, resulting in a pointer before h.raw + if (do_vnet && + virtio_net_hdr_from_skb(skb, h.raw + macoff - + sizeof(struct virtio_net_hdr), + vio_le(), true, 0)) { +// ... +``` + +The [`CAP_NET_RAW`](https://man7.org/linux/man-pages/man7/capabilities.7.html) +capability is required to create the socket above. However, in order to support +common debugging tools like `ping` and `tcpdump`, Docker containers, including +those created for Kubernetes, are given `CAP_NET_RAW` by default and thus may be +able to trigger this vulnerability to elevate privileges and escape the +container. + +Next, we are going to explore why this vulnerability doesn’t work in gVisor, and +how gVisor could prevent the escape even if a similar vulnerability existed +inside gVisor’s kernel. + +## Default Protections + +gVisor does not implement `PACKET_RX_RING`, but **does** support raw sockets +which are required for `PACKET_RX_RING`. Raw sockets are a controversial feature +to support in a sandbox environment. While it allows great customizations for +essential tools like `ping`, it may allow packets to be written to the network +without any validation. In general, allowing an untrusted application to write +crafted packets to the network is a questionable idea and a historical source of +vulnerabilities. With that in mind, if `CAP_NET_RAW` is enabled by default, it +would not be _secure by default_ to run untrusted applications. + +After multiple discussions when raw sockets were first implemented, we decided +to disable raw sockets by default, **even if `CAP_NET_RAW` is given to the +application**. Instead, enabling raw sockets in gVisor requires the admin to set +`--net-raw` flag to runsc when configuring the runtime, in addition to requiring +the `CAP_NET_RAW` capability in the application. It comes at the expense that +some tools may not work out of the box, but as part of our +[secure-by-default](https://gvisor.dev/blog/2019/11/18/gvisor-security-basics-part-1/#secure-by-default) +principle, we felt that it was important for the “less secure” configuration to +be explicit. + +Since this bug was due to an overflow in the specific Linux implementation of +the packet ring, gVisor's raw socket implementation is not affected. However, if +there were a vulnerability in gVisor, containers would not be allowed to exploit +it by default. + +As an alternative way to implement this same constraint, Kubernetes allows +[admission controllers](https://kubernetes.io/docs/reference/access-authn-authz/admission-controllers/) +to be configured to customize requests. Cloud providers can use this to +implement more stringent policies. For example, GKE implements an admission +controller for gVisor that +[removes `CAP_NET_RAW` from gVisor pods](https://cloud.google.com/kubernetes-engine/docs/concepts/sandbox-pods#capabilities) +unless it has been explicitly set in the pod spec. + +## Isolated Kernel + +gVisor has its own application kernel, called the Sentry, that is distinct from +the host kernel. Just like what you would expect from a kernel, gVisor has a +memory management subsystem, virtual file system, and a full network stack. The +host network is only used as a transport to carry packets in and out the +sandbox[^1]. The loopback interface which is used in the exploit stays +completely inside the sandbox, never reaching the host. + +Therefore, even if the Sentry was vulnerable to the attack, there would be two +factors that would prevent a container escape from happening. First, the +vulnerability would be limited to the Sentry, and the attacker would compromise +only the application kernel, bound by a restricted set of +[seccomp](https://en.wikipedia.org/wiki/Seccomp) filters, discussed more in +depth below. Second, the Sentry is a distinct implementation of the API, written +in Go, which provides bounds checking that would have likely prevented access +past the bounds of the shared region (e.g. see +[aio](https://cs.opensource.google/gvisor/gvisor/+/master:pkg/sentry/syscalls/linux/vfs2/aio.go;l=210;drc=a11061d78a58ed75b10606d1a770b035ed944b66?q=file:aio&ss=gvisor%2Fgvisor) +or +[kcov](https://cs.opensource.google/gvisor/gvisor/+/master:pkg/sentry/kernel/kcov.go;l=272?q=file:kcov&ss=gvisor%2Fgvisor), +which have similar shared regions). + +Here, Kubernetes warrants slightly more explanation. gVisor makes pods the unit +of isolation and a pod can run multiple containers. In other words, each pod is +a gVisor instance, and each container is a set of processes running inside +gVisor, isolated via Sentry-internal namespaces like regular containers inside a +pod. If there were a vulnerability in gVisor, the privilege escalation would +allow a container inside the pod to break out to other **containers inside the +same pod**, but the container still **cannot break out of the pod**. + +## Defense in Depth + +gVisor follows a +[common security principle used at Google](https://cloud.google.com/security/infrastructure/design/resources/google_infrastructure_whitepaper_fa.pdf) +that the system should have two layers of protection, and those layers should +require different compromises to be broken. We apply this principle by assuming +that the Sentry (first layer of defense) +[will be compromised and should not be trusted](https://gvisor.dev/blog/2019/11/18/gvisor-security-basics-part-1/#defense-in-depth). +In order to protect the host kernel from a compromised Sentry, we wrap it around +many security and isolations features to ensure only the minimal set of +functionality from the host kernel is exposed. + +![Figure 1](/assets/images/2020-09-18-containing-a-real-vulnerability-figure1.png "Protection layers.") + +First, the sandbox runs inside a cgroup that can limit and throttle host +resources being used. Second, the sandbox joins empty namespaces, including user +and mount, to further isolate from the host. Next, it changes the process root +to a read-only directory that contains only `/proc` and nothing else. Then, it +executes with the unprivileged user/group +[`nobody`](https://en.wikipedia.org/wiki/Nobody_\(username\)) with all +capabilities stripped. Last and most importantly, a seccomp filter is added to +tightly restrict what parts of the Linux syscall surface that gVisor is allowed +to access. The allowed host surface is a far smaller set of syscalls than the +Sentry implements for applications to use. Not only restricting the syscall +being called, but also checking that arguments to these syscalls are within the +expected set. Dangerous syscalls like <code>execve(2)</code>, +<code>open(2)</code>, and <code>socket(2)</code> are prohibited, thus an +attacker isn’t able to execute binaries or acquire new resources on the host. + +if there were a vulnerability in gVisor that allowed an attacker to execute code +inside the Sentry, the attacker still has extremely limited privileges on the +host. In fact, a compromised Sentry is much more restricted than a +non-compromised regular container. For CVE-2020-14386 in particular, the attack +would be blocked by more than one security layer: non-privileged user, no +capability, and seccomp filters. + +Although the surface is drastically reduced, there is still a chance that there +is a vulnerability in one of the allowed syscalls. That’s why it’s important to +keep the surface small and carefully consider what syscalls are allowed. You can +find the full set of allowed syscalls +[here](https://cs.opensource.google/gvisor/gvisor/+/master:runsc/boot/filter/). + +Another possible attack vector is resources that are present in the Sentry, like +open file descriptors. The Sentry has file descriptors that an attacker could +potentially use, such as log files, platform files (e.g. `/dev/kvm`), an RPC +endpoint that allows external communication with the Sentry, and a Netstack +endpoint that connects the sandbox to the network. The Netstack endpoint in +particular is a concern because it gives direct access to the network. It’s an +`AF_PACKET` socket that allows arbitrary L2 packets to be written to the +network. In the normal case, Netstack assembles packets that go out the network, +giving the container control over only the payload. But if the Sentry is +compromised, an attacker can craft packets to the network. In many ways this is +similar to anyone sending random packets over the internet, but still this is a +place where the host kernel surface exposed is larger than we would like it to +be. + +## Conclusion + +Security comes with many tradeoffs that are often hard to make, such as the +decision to disable raw sockets by default. However, these tradeoffs have served +us well, and we've found them to have paid off over time. CVE-2020-14386 offers +great insight into how multiple layers of protection can be effective against +such an attack. + +We cannot guarantee that a container escape will never happen in gVisor, but we +do our best to make it as hard as we possibly can. + +If you have not tried gVisor yet, it’s easier than you think. Just follow the +steps in the +[Quick Start](https://gvisor.dev/docs/user_guide/quick_start/docker/) guide. +<br> +<br> + +-------------------------------------------------------------------------------- + +[^1]: Those packets are eventually handled by the host, as it needs to route + them to local containers or send them out the NIC. The packet will be + handled by many switches, routers, proxies, servers, etc. along the way, + which may be subject to their own vulnerabilities. diff --git a/website/blog/BUILD b/website/blog/BUILD index 01c1f5a6e..865e403da 100644 --- a/website/blog/BUILD +++ b/website/blog/BUILD @@ -28,6 +28,16 @@ doc( permalink = "/blog/2020/04/02/gvisor-networking-security/", ) +doc( + name = "containing_a_real_vulnerability", + src = "2020-09-18-containing-a-real-vulnerability.md", + authors = [ + "fvoznika", + ], + layout = "post", + permalink = "/blog/2020/09/18/containing-a-real-vulnerability/", +) + docs( name = "posts", deps = [ diff --git a/website/css/main.scss b/website/css/main.scss index 06106833f..4b3b7b500 100644 --- a/website/css/main.scss +++ b/website/css/main.scss @@ -1,5 +1,10 @@ -@import 'style.scss'; -@import 'front.scss'; -@import 'navbar.scss'; -@import 'sidebar.scss'; -@import 'footer.scss'; +// The main style sheet for gvisor.dev + +// NOTE: Do not include file extensions to import .sass and .css files seamlessly. +@import "style"; +@import "front"; +@import "navbar"; +@import "sidebar"; +@import "footer"; +// syntax is generated by rougify. +@import "syntax"; diff --git a/website/index.md b/website/index.md index 84f877d49..c6cd477c2 100644 --- a/website/index.md +++ b/website/index.md @@ -5,7 +5,7 @@ <div class="col-md-6"> <p>gVisor is an <b>application kernel</b> for <b>containers</b> that provides efficient defense-in-depth anywhere.</p> <p style="margin-top: 20px;"> - <a class="btn" href="/docs/user_guide/quick_start/docker/">Quick start <i class="fas fa-arrow-alt-circle-right ml-2"></i></a> + <a class="btn" href="/docs/user_guide/install/">Get started <i class="fas fa-arrow-alt-circle-right ml-2"></i></a> <a class="btn" href="/docs/">Learn More <i class="fas fa-arrow-alt-circle-right ml-2"></i></a> </p> </div> |